├── .gitignore ├── LICENSE ├── README.md └── mada.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2018 Arthur Douillard 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mada.pytorch 2 | Implementation of Multi-Adversarial Domain Adaptation (https://arxiv.org/abs/1809.02176) in Pytorch 3 | 4 | 5 | # THIS IMPLEMENTATION IS NOT FINISHED, and PROBABLY WILL NEVER BE. 6 | -------------------------------------------------------------------------------- /mada.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch import nn 4 | from torch.nn import functional as F 5 | 6 | 7 | class MADA(nn.Module): 8 | def __init__(self, n_classes, convnet=None, classifier=None): 9 | super().__init__() 10 | 11 | self._n_classes = n_classes 12 | 13 | self._convnet = convnet or ConvNet() 14 | self._classifier = classifier or Classifier(n_classes, 12544) 15 | self._grl = GRL(factor=-1) 16 | self._domain_classifiers = [ 17 | Classifier(1, 12544) 18 | for _ in range(n_classes) 19 | ] 20 | 21 | def forward(self, x): 22 | features = self._convnet(x) 23 | features = features.view(features.shape[0], -1) 24 | 25 | logits = self._classifier(features) 26 | predictions = F.softmax(logits, dim=1) 27 | 28 | features = self._grl(features) 29 | domain_logits = [] 30 | for class_idx in range(self._n_classes): 31 | weighted_features = predictions[:, class_idx].unsqueeze(1) * features 32 | domain_logits.append( 33 | self._domain_classifiers[class_idx](weighted_features) 34 | ) 35 | 36 | return logits, domain_logits 37 | 38 | 39 | class Classifier(nn.Module): 40 | def __init__(self, n_classes, input_dimension): 41 | super().__init__() 42 | 43 | self._n_classes = n_classes 44 | self._clf = nn.Linear(input_dimension, n_classes) 45 | 46 | def forward(self, x): 47 | return self._clf(x) 48 | 49 | 50 | class ConvNet(nn.Module): 51 | def __init__(self): 52 | super().__init__() 53 | 54 | self._convnet = nn.Sequential( 55 | nn.Conv2d(3, 32, kernel_size=3), 56 | nn.BatchNorm2d(32), 57 | nn.ReLU(), 58 | nn.Conv2d(32, 64, kernel_size=3), 59 | nn.BatchNorm2d(64), 60 | nn.ReLU(), 61 | nn.MaxPool2d(kernel_size=2) 62 | ) 63 | 64 | def forward(self, x): 65 | return self._convnet(x) 66 | 67 | 68 | class GRL(torch.autograd.Function): 69 | def __init__(self, factor=-1): 70 | super().__init__() 71 | self._factor = factor 72 | 73 | def forward(self, x): 74 | return x 75 | 76 | def backward(self, grad): 77 | return self._factor * grad 78 | --------------------------------------------------------------------------------