├── .github ├── pull_request_template.md └── workflows │ └── lint-and-format.yml ├── .gitignore ├── .vscode └── extensions.json ├── README.md ├── mimo ├── __init__.py ├── config.py ├── model.py └── trainer.py ├── pyproject.toml ├── pyrightconfig.json ├── requirements-dev.txt ├── requirements.txt ├── setup.cfg ├── setup.py ├── tox.ini └── train.py /.github/pull_request_template.md: -------------------------------------------------------------------------------- 1 | ## 무엇이 변경되었나요? 🎉 2 | 3 | - 화면 XX에서 YY 추가 4 | - `kk` 파일에서 `zz` 추가/제거/수정 5 | 6 | ## 관련된 이슈 혹은 PR은 무엇인가요? 🔍 7 | 8 | [GitHub 문서의 "Closing Issues Using Keywords"](https://help.github.com/en/articles/closing-issues-using-keywords)를 참고해주세요! 9 | 10 | ## 추가로 알아야 할 것을 알려주세요! 🥺 (선택사항) 11 | 12 | 다른 사람이 안다면 좋을 정보를 여기에 적어주세요! 13 | -------------------------------------------------------------------------------- /.github/workflows/lint-and-format.yml: -------------------------------------------------------------------------------- 1 | name: Lint and Format Python 2 | 3 | on: 4 | push: 5 | branches: [ master ] 6 | pull_request: 7 | branches: [ master ] 8 | 9 | jobs: 10 | build: 11 | runs-on: ubuntu-latest 12 | 13 | steps: 14 | - uses: actions/checkout@v2 15 | - name: Set up Python 3.7 16 | uses: actions/setup-python@v2 17 | with: 18 | python-version: 3.7 19 | - name: Install dependencies 20 | run: | 21 | python -m pip install --upgrade pip 22 | pip install -r requirements.txt -r requirements-dev.txt 23 | - name: Lint with flake8 24 | run: | 25 | flake8 mimo train.py 26 | - name: Check an order of import statements with isort 27 | run: | 28 | isort -c mimo train.py 29 | - name: Check the code formatting with black 30 | run: | 31 | black --check mimo train.py 32 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # editor specific 107 | .vscode/settings.json 108 | .vscode/sftp.json 109 | -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-pyright.pyright", 4 | "ms-python.python" 5 | ] 6 | } 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MIMO-pytorch (WIP) 2 | 3 | PyTorch implementation of MIMO proposed in [Training independent subnetworks for robust prediction](https://openreview.net/forum?id=OGg9XnKxFAH). 4 | 5 | # Model Training 6 | 7 | ``` sh 8 | python train.py 9 | ``` 10 | 11 | # References 12 | 13 | ``` plain 14 | @inproceedings{havasi2021training, 15 | author = {Marton Havasi and Rodolphe Jenatton and Stanislav Fort and Jeremiah Zhe Liu and Jasper Snoek and Balaji Lakshminarayanan and Andrew M. Dai and Dustin Tran}, 16 | title = {Training independent subnetworks for robust prediction}, 17 | booktitle = {International Conference on Learning Representations}, 18 | year = {2021}, 19 | } 20 | ``` 21 | 22 | * https://github.com/google/edward2/tree/master/experimental/mimo 23 | -------------------------------------------------------------------------------- /mimo/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.0.1" 2 | __author__ = "ScatterLab" 3 | -------------------------------------------------------------------------------- /mimo/config.py: -------------------------------------------------------------------------------- 1 | from typing import NamedTuple 2 | 3 | 4 | class Config(NamedTuple): 5 | """ 6 | Hyperparameters 7 | """ 8 | 9 | #: random seed 10 | seed: int = 42 11 | # training epochs 12 | num_epochs: int = 10 13 | # batch size 14 | batch_size: int = 64 15 | #: learning rate 16 | learning_rate: float = 1.0 17 | #: learning rate step gamma 18 | gamma: float = 0.7 19 | #: num workers 20 | num_workers: int = 10 21 | 22 | train_log_interval: int = 100 23 | valid_log_interval: int = 1000 24 | 25 | """ 26 | MIMO Hyperparameters 27 | """ 28 | ensemble_num: int = 5 29 | -------------------------------------------------------------------------------- /mimo/model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class MIMOModel(nn.Module): 7 | def __init__(self, hidden_dim: int = 784, ensemble_num: int = 3): 8 | super(MIMOModel, self).__init__() 9 | self.input_layer = nn.Linear(hidden_dim, hidden_dim * ensemble_num) 10 | self.backbone_model = BackboneModel(hidden_dim, ensemble_num) 11 | self.ensemble_num = ensemble_num 12 | self.output_layer = nn.Linear(128, 10 * ensemble_num) 13 | 14 | def forward(self, input_tensor: torch.Tensor) -> torch.Tensor: 15 | ensemble_num, batch_size, *_ = list(input_tensor.size()) 16 | input_tensor = input_tensor.transpose(1, 0).view( 17 | batch_size, ensemble_num, -1 18 | ) # (batch_size, ensemble_num, hidden_dim) 19 | input_tensor = self.input_layer(input_tensor) # (batch_size, ensemble_num, hidden_dim * ensemble_num) 20 | 21 | # usual model forward 22 | output = self.backbone_model(input_tensor) # (batch_size, ensemble_num, 128) 23 | output = self.output_layer(output) # (batch_size, ensemble_num, 10 * ensemble_num) 24 | output = output.reshape( 25 | batch_size, ensemble_num, -1, ensemble_num 26 | ) # (batch_size, ensemble_num, 10, ensemble_num) 27 | output = torch.diagonal(output, offset=0, dim1=1, dim2=3).transpose(2, 1) # (batch_size, ensemble_num, 10) 28 | output = F.log_softmax(output, dim=-1) # (batch_size, ensemble_num, 10) 29 | return output 30 | 31 | 32 | class BackboneModel(nn.Module): 33 | def __init__(self, hidden_dim: int, ensemble_num: int): 34 | super(BackboneModel, self).__init__() 35 | self.l1 = nn.Linear(hidden_dim * ensemble_num, 256) 36 | self.l2 = nn.Linear(256, 128) 37 | 38 | def forward(self, x: torch.Tensor) -> torch.Tensor: 39 | x = self.l1(x) 40 | x = F.relu(x) 41 | x = F.dropout(x, p=0.1) 42 | x = self.l2(x) 43 | x = F.relu(x) 44 | return x 45 | -------------------------------------------------------------------------------- /mimo/trainer.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | from torch.optim.lr_scheduler import StepLR 8 | from torch.utils.data import DataLoader 9 | 10 | from mimo.config import Config 11 | 12 | 13 | class MIMOTrainer: 14 | def __init__( 15 | self, 16 | config: Config, 17 | model: nn.Module, 18 | train_dataloaders: List[DataLoader], 19 | test_dataloader: DataLoader, 20 | device: torch.device, 21 | ): 22 | self.config = config 23 | self.model = model 24 | self.train_dataloaders: List[DataLoader] = train_dataloaders 25 | self.test_dataloader: DataLoader = test_dataloader 26 | 27 | self.optimizer = optim.Adadelta(self.model.parameters(), lr=config.learning_rate) 28 | self.scheduler = StepLR(self.optimizer, step_size=len(self.train_dataloaders[0]), gamma=config.gamma) 29 | 30 | self.device = device 31 | 32 | def train(self): 33 | self.model.to(self.device) 34 | self.model.train() 35 | global_step = 0 36 | for epoch in range(1, self.config.num_epochs + 1): 37 | for datum in zip(*self.train_dataloaders): 38 | model_inputs = torch.stack([data[0] for data in datum]).to(self.device) 39 | targets = torch.stack([data[1] for data in datum]).to(self.device) 40 | 41 | ensemble_num, batch_size = list(targets.size()) 42 | self.optimizer.zero_grad() 43 | outputs = self.model(model_inputs) 44 | loss = F.nll_loss( 45 | outputs.reshape(ensemble_num * batch_size, -1), targets.reshape(ensemble_num * batch_size) 46 | ) 47 | loss.backward() 48 | 49 | self.optimizer.step() 50 | self.scheduler.step() 51 | 52 | global_step += 1 53 | if global_step != 0 and global_step % self.config.train_log_interval == 0: 54 | print(f"[Train] epoch:{epoch} \t global step:{global_step} \t loss:{loss:.4f}") 55 | if global_step != 0 and global_step % self.config.valid_log_interval == 0: 56 | self.validate() 57 | 58 | def validate(self): 59 | self.model.eval() 60 | test_loss = 0 61 | correct = 0 62 | with torch.no_grad(): 63 | for data in self.test_dataloader: 64 | model_inputs = torch.stack([data[0]] * self.config.ensemble_num).to(self.device) 65 | target = data[1].to(self.device) 66 | 67 | outputs = self.model(model_inputs) 68 | output = torch.mean(outputs, axis=1) 69 | 70 | test_loss += F.nll_loss(output, target, reduction="sum").item() 71 | pred = output.argmax(dim=-1, keepdim=True) 72 | correct += pred.eq(target.view_as(pred)).sum().item() 73 | 74 | test_loss /= len(self.test_dataloader.dataset) 75 | acc = 100.0 * correct / len(self.test_dataloader.dataset) 76 | print(f"[Valid] Average loss: {test_loss:.4f} \t Accuracy:{acc:2.2f}%") 77 | self.model.train() 78 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 3 | target-version = ['py37'] 4 | include = '\.py$' 5 | -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "include": ["mimo"], 3 | "venvPath": "env", 4 | 5 | "reportUnknownVariableType": true, 6 | "reportUnknownMemberType": true, 7 | "reportUnusedImport": true, 8 | "reportUnusedVariable": true, 9 | "reportUnusedClass": true, 10 | "reportUnusedFunction": true, 11 | "reportImportCycles": true, 12 | "reportTypeshedErrors": true, 13 | "reportOptionalMemberAccess": true, 14 | "reportUntypedBaseClass": true, 15 | "reportPrivateUsage": true, 16 | "reportConstantRedefinition": true, 17 | "reportInvalidStringEscapeSequence": true, 18 | "reportUnnecessaryIsInstance": true, 19 | "reportUnnecessaryCast": true, 20 | "reportAssertAlwaysTrue": true, 21 | "reportSelfClsParameterName": true, 22 | 23 | "pythonVersion": "3.7", 24 | "pythonPlatform": "Linux" 25 | } 26 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | # dev dependency 2 | isort 3 | black 4 | flake8 -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch 2 | torchvision 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | max-line-length = 120 3 | ignore = E203, W503 4 | 5 | [tool:isort] 6 | line_length = 120 7 | multi_line_output = 3 8 | include_trailing_comma = True 9 | 10 | [tool:pytest] 11 | addopts = -ra -v -l 12 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import find_packages, setup 2 | 3 | setup( 4 | name="MIMO-pytorch", 5 | version="0.0.1", 6 | description="mimo in ICLR 2021", 7 | install_requires=[], 8 | url="https://github.com/noowad93/MIMO-pytorch.git", 9 | author="Dawoon Jung", 10 | author_email="dawoon@scatterlab.co.kr", 11 | packages=find_packages(), 12 | ) 13 | -------------------------------------------------------------------------------- /tox.ini: -------------------------------------------------------------------------------- 1 | [tox] 2 | envlist = py37 3 | 4 | [testenv] 5 | deps = 6 | -r requirements.txt 7 | -r requirements-dev.txt 8 | commands = 9 | black --check mimo 10 | flake8 mimo 11 | isort -rc -c mimo 12 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | from argparse import ArgumentParser 2 | 3 | import torch 4 | from torch.utils.data import DataLoader 5 | from torchvision import datasets, transforms 6 | 7 | from mimo.config import Config 8 | from mimo.model import MIMOModel 9 | from mimo.trainer import MIMOTrainer 10 | 11 | parser = ArgumentParser("MIMO Training") 12 | parser.add_argument("--ensemble-num", type=int, default=3) 13 | 14 | 15 | def main(args): 16 | config = Config(ensemble_num=args.ensemble_num) 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | 19 | transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) 20 | train_dataset = datasets.MNIST("../data", train=True, download=True, transform=transform) 21 | test_dataset = datasets.MNIST("../data", train=False, transform=transform) 22 | 23 | train_dataloaders = [ 24 | DataLoader(train_dataset, batch_size=config.batch_size, num_workers=config.num_workers, shuffle=True) 25 | for _ in range(config.ensemble_num) 26 | ] 27 | test_dataloader = DataLoader(test_dataset, batch_size=config.batch_size, num_workers=config.num_workers) 28 | 29 | model = MIMOModel(ensemble_num=config.ensemble_num).to(device) 30 | trainer = MIMOTrainer(config, model, train_dataloaders, test_dataloader, device) 31 | trainer.train() 32 | 33 | 34 | if __name__ == "__main__": 35 | args = parser.parse_args() 36 | main(args) 37 | --------------------------------------------------------------------------------