├── Dockerfile ├── mtdp ├── models │ ├── __init__.py │ ├── _util.py │ ├── resnet.py │ └── densenet.py ├── __init__.py ├── builder.py ├── components.py ├── networks.py ├── helpers.py └── loader.py ├── hubconf.py ├── setup.py ├── LICENSE ├── .gitignore ├── examples ├── feature_extract.py └── multi_task_train.py └── README.md /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia-docker/nvidia-docker -------------------------------------------------------------------------------- /mtdp/models/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from .densenet import build_densenet 3 | from .resnet import build_resnet 4 | -------------------------------------------------------------------------------- /mtdp/__init__.py: -------------------------------------------------------------------------------- 1 | from .helpers import module_freeze, module_unfreeze 2 | from mtdp.builder import build_model 3 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | def main(argv): 2 | pass 3 | 4 | 5 | if __name__ == "__main__": 6 | import sys 7 | 8 | main(sys.argv[1:]) 9 | -------------------------------------------------------------------------------- /mtdp/builder.py: -------------------------------------------------------------------------------- 1 | from mtdp.components import PooledFeatureExtractor 2 | from mtdp.models.densenet import build_densenet 3 | from mtdp.models.resnet import build_resnet 4 | 5 | 6 | def build_model(arch, pool=False, **kwargs): 7 | """Get a network by architecture. 8 | 9 | Parameters 10 | ---------- 11 | arch: str 12 | Architecture name. Supported architectures: 13 | `{'densenet121', 'densenet169', 'densenet201', 'densenet161', 14 | 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152'}` 15 | pool: bool 16 | True for adding a global pooling layer to the model. 17 | kwargs: dict 18 | """ 19 | if "densenet" in arch: 20 | model = build_densenet(arch=arch, **kwargs) 21 | elif "resnet" in arch: 22 | model = build_resnet(arch=arch, **kwargs) 23 | else: 24 | raise ValueError("Unknown architecture") 25 | 26 | if pool: 27 | return PooledFeatureExtractor(model) 28 | else: 29 | return model -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import os 2 | from setuptools import setup 3 | 4 | __version__ = "0.0.1-alpha" 5 | 6 | long_description = "" 7 | if os.path.isfile("README.md"): 8 | with open("README.md", "r") as fh: 9 | long_description = fh.read() 10 | 11 | setup( 12 | name='mtdp', 13 | version=__version__, 14 | description='Implementation of multi-task trained networks, including models pre-trained on digital pathology data', 15 | long_description=long_description, 16 | long_description_content_type="text/markdown", 17 | packages=['mtdp', 'mtdp.models'], 18 | classifiers=[ 19 | 'Intended Audience :: Science/Research', 20 | 'Intended Audience :: Developers', 21 | 'License :: OSI Approved', 22 | 'Programming Language :: Python', 23 | 'Topic :: Software Development', 24 | 'Topic :: Scientific/Engineering', 25 | 'Programming Language :: Python :: 3.6', 26 | 'Programming Language :: Python :: 3.7' 27 | ], 28 | install_requires=['torch', 'torchvision', 'numpy'], 29 | license='LICENSE' 30 | ) 31 | 32 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Romain Mormont 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /mtdp/models/_util.py: -------------------------------------------------------------------------------- 1 | import hashlib 2 | import os 3 | import re 4 | import shutil 5 | import sys 6 | import tempfile 7 | 8 | import torch 9 | from torch.hub import download_url_to_file 10 | 11 | try: 12 | from requests.utils import urlparse 13 | from requests import get as urlopen 14 | requests_available = True 15 | except ImportError: 16 | requests_available = False 17 | from urllib.request import urlopen 18 | from urllib.parse import urlparse 19 | try: 20 | from tqdm import tqdm 21 | except ImportError: 22 | tqdm = None # defined below 23 | 24 | 25 | def _remove_prefix(s, prefix): 26 | if s.startswith(prefix): 27 | s = s[len(prefix):] 28 | return s 29 | 30 | 31 | def clean_state_dict(state_dict, prefix, filter=None): 32 | if filter is None: 33 | filter = lambda *args: True 34 | return {_remove_prefix(k, prefix): v for k, v in state_dict.items() if filter(k)} 35 | 36 | 37 | def load_dox_url(url, filename, model_dir=None, map_location=None, progress=True): 38 | r"""Adapt to fit format file of mtdp pre-trained models 39 | """ 40 | if model_dir is None: 41 | torch_home = os.path.expanduser(os.getenv('TORCH_HOME', '~/.torch')) 42 | model_dir = os.getenv('TORCH_MODEL_ZOO', os.path.join(torch_home, 'models')) 43 | if not os.path.exists(model_dir): 44 | os.makedirs(model_dir) 45 | cached_file = os.path.join(model_dir, filename) 46 | if not os.path.exists(cached_file): 47 | sys.stderr.write('Downloading: "{}" to {}\n'.format(url, cached_file)) 48 | sys.stderr.flush() 49 | download_url_to_file(url, cached_file, None, progress=progress) 50 | return torch.load(cached_file, map_location=map_location) -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | -------------------------------------------------------------------------------- /mtdp/components.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch import nn 3 | 4 | 5 | class FeaturesInterface(object): 6 | @abstractmethod 7 | def n_features(self): 8 | pass 9 | 10 | 11 | class Head(nn.Module): 12 | """A head is a simple neural network that can be used as a task-specific predictor 13 | in a multi-task network. It features a global average pooling followed by a linear 14 | layer (no activation). 15 | """ 16 | 17 | def __init__(self, n_features, n_classes=2): 18 | """ 19 | Parameters 20 | ---------- 21 | n_features: int 22 | The number of input features after global average pooling 23 | n_classes: int 24 | The number of classes (i.e. output features) 25 | """ 26 | super().__init__() 27 | self.pool = nn.AdaptiveAvgPool2d(1) 28 | self.linear = nn.Conv2d( 29 | n_features, 30 | out_channels=n_classes, 31 | kernel_size=1 32 | ) 33 | 34 | def forward(self, x): 35 | x = self.pool(x) 36 | x = self.linear(x) 37 | return x.view(x.size(0), -1) 38 | 39 | 40 | class PooledFeatureExtractor(nn.Module, FeaturesInterface): 41 | """This module applies a global average pooling on features produced by a module. 42 | """ 43 | 44 | def __init__(self, features): 45 | """ 46 | Parameters 47 | ---------- 48 | features: nn.Module 49 | A network producing a set of feature maps. `features` should have a `n_features()` method 50 | returning how many features maps it produces. 51 | """ 52 | super().__init__() 53 | self.features = features 54 | self.pool = nn.AdaptiveAvgPool2d(1) 55 | 56 | def forward(self, x): 57 | return self.pool(self.features(x)) 58 | 59 | def n_features(self): 60 | return self.features.n_features() 61 | -------------------------------------------------------------------------------- /examples/feature_extract.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from sklearn.metrics import accuracy_score 4 | from sklearn.svm import LinearSVC 5 | from torch.utils.data import DataLoader 6 | from torchvision import transforms 7 | from torchvision.datasets import ImageFolder 8 | 9 | from mtdp import build_model 10 | 11 | 12 | if __name__ == "__main__": 13 | """Loading a model and using it as feature extractor 14 | """ 15 | device = torch.device("cpu") 16 | model = build_model(arch="densenet121", pretrained="mtdp", pool=True) 17 | model.to(device) 18 | 19 | input_size = 244 20 | transform = transforms.Compose([ 21 | transforms.Resize(input_size), 22 | transforms.CenterCrop(input_size), 23 | transforms.ToTensor(), 24 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats 25 | ]) 26 | train_dataset = ImageFolder("TRAIN_FOLDER", transform=transform) 27 | train_loader = DataLoader(train_dataset, batch_size=16) 28 | test_dataset = ImageFolder("TEST_FOLDER", transform=transform) 29 | test_loader = DataLoader(test_dataset, batch_size=16) 30 | 31 | with torch.no_grad(): 32 | model.eval() 33 | # train 34 | features = list() 35 | classes = list() 36 | for i, (x, y) in enumerate(train_loader): 37 | print("> train iter #{}".format(i + 1)) 38 | out = model.forward(x.to(device)) 39 | features.append(out.detach().cpu().numpy().squeeze()) 40 | classes.append(y.cpu().numpy()) 41 | 42 | features = np.vstack(features) 43 | classes = np.hstack(classes) 44 | 45 | print("Train svm.") 46 | svm = LinearSVC(C=0.01) 47 | svm.fit(features, classes) 48 | 49 | # predict 50 | preds = list() 51 | y_test = list() 52 | for i, (x_test, y) in enumerate(test_loader): 53 | print("> test iter #{}".format(i + 1)) 54 | out = model.forward(x_test.to(device)) 55 | preds.append(svm.predict(out.detach().cpu().numpy().squeeze())) 56 | y_test.append(y.cpu().numpy()) 57 | 58 | preds = np.hstack(preds) 59 | y_test = np.hstack(y_test) 60 | 61 | print("test accuracy:", accuracy_score(y_test, preds)) 62 | 63 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # mtdp 2 | 3 | Library containing implementation related to the [research paper](http://hdl.handle.net/2268/247134) "_Multi-task pre-training of deep neural networks for digital pathology_" (Mormont _et al._). 4 | 5 | It can be used to load our pre-trained models or to build a multi-task classification architecture. 6 | 7 | ## Loading our pre-trained weights. 8 | 9 | > For an example, check the file [`examples/feature_extract.py`](https://github.com/waliens/multitask-dipath/blob/master/examples/feature_extract.py). 10 | 11 | The library provides a `build_model` function to build a model and initialize it with our 12 | pre-trained weights. To load our weights, the parameter `pretrained` should be set to `mtdp`. 13 | 14 | ```python 15 | from mtdp import build_model 16 | 17 | model = build_model(arch="densenet121", pretrained="mtdp") 18 | ``` 19 | 20 | Alternatively, `pretrained` can be set to `imagenet` to load ImageNet pre-trained weights from PyTorch. 21 | 22 | We currently provide pre-trained weights for the following architectures: 23 | 24 | - `densenet121` 25 | - `resnet50` 26 | 27 | 28 | See an example script performing feature extraction using one of our model in the `examples` folder (file `feature_extract.py`). 29 | 30 | ### Raw model files 31 | 32 | If you want to bypass the library and download the raw PyTorch model files, you can access them at the following URLs: 33 | 34 | - `densenet121`: [https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download](https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download) 35 | - `resnet50`: [https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download](https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download) 36 | 37 | 38 | ## Building a multi-task architecture 39 | 40 | > For an example, see the [`examples/multi_task_train.py`](https://github.com/waliens/multitask-dipath/blob/master/examples/multi_task_train.py) file. 41 | 42 | Several steps for building the architecture: 43 | 44 | 1. define a `DatasetFolder`/`ImageFolder` for each of your individual dataset, 45 | 2. instantiate a `MultiImageFolders` object with all your dataset objects, 46 | 3. instantiate a `MultiHead` PyTorch module by passing it the `MultiImageFolders` from step 2. The 47 | module will use the information of the tasks in order to build the multi-task architecture. 48 | 49 | 50 | 51 | -------------------------------------------------------------------------------- /examples/multi_task_train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader 3 | from torchvision import transforms 4 | from torchvision.datasets import ImageFolder 5 | 6 | from mtdp import build_model 7 | from mtdp.helpers import compute_loss, rescale_head_grads 8 | from mtdp.loader import MultiImageFolders 9 | from mtdp.networks import MultiHead 10 | 11 | 12 | if __name__ == "__main__": 13 | LR = 1e-4 14 | BATCH_SIZE = 8 15 | DEVICE = "cpu" 16 | INPUT_SIZE = 224 17 | 18 | """ 19 | All your tasks should be provided as Dataset (e.g. ImageFolder, or custom implementation) to the 20 | `MultiImageFolders` class which will provide an unified interface to them (as if they were a single 21 | Dataset). Dataset root folder name should be unique. 22 | """ 23 | transform = transforms.Compose([ 24 | transforms.Resize(INPUT_SIZE), 25 | transforms.CenterCrop(INPUT_SIZE), 26 | transforms.ToTensor(), 27 | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet stats 28 | ]) 29 | paths = ["path_to_dataset1", "path_to_dataset2"] 30 | dataset = MultiImageFolders([ImageFolder(path, transform) for path in paths]) 31 | loader = DataLoader(dataset, batch_size=BATCH_SIZE) 32 | 33 | """ 34 | Build the backbone (shared) network. Pooling will be added automatically later, so it is disabled. 35 | """ 36 | backbone = build_model(arch="densenet121", pretrained="imagenet", pool=False) 37 | 38 | """ 39 | The `MultiHead` class will build a multi-task network based on the passed `MultiImageFolders` object and the 40 | backbone network. 41 | """ 42 | multihead = MultiHead(dataset, backbone) 43 | device = torch.device(DEVICE) 44 | multihead.to(device) 45 | 46 | # Training 47 | loss_fn = torch.nn.CrossEntropyLoss(reduce=False) 48 | optimizer = torch.optim.SGD(multihead.parameters(), lr=LR) 49 | multihead.train() 50 | for i, (x, y, sources) in enumerate(loader): 51 | x = x.to(device) 52 | loss = compute_loss(multihead, x, y, sources, loss_fn) 53 | optimizer.zero_grad() 54 | loss.backward() 55 | rescale_head_grads(multihead, sources) 56 | optimizer.step() 57 | print("> train iter #{}: {}".format(i, loss.detach().cpu())) -------------------------------------------------------------------------------- /mtdp/models/resnet.py: -------------------------------------------------------------------------------- 1 | from torch.utils import model_zoo 2 | from torchvision.models.resnet import ResNet, model_urls as resnet_urls, BasicBlock, Bottleneck 3 | from mtdp.components import FeaturesInterface 4 | from mtdp.models._util import load_dox_url, clean_state_dict 5 | 6 | MTDP_URLS = { 7 | "resnet50": ("https://dox.uliege.be/index.php/s/kvABLtVuMxW8iJy/download", "resnet50-mh-best-191205-141200.pth") 8 | } 9 | 10 | 11 | class NoHeadResNet(ResNet, FeaturesInterface): 12 | def forward(self, x): 13 | x = self.conv1(x) 14 | x = self.bn1(x) 15 | x = self.relu(x) 16 | x = self.maxpool(x) 17 | 18 | x = self.layer1(x) 19 | x = self.layer2(x) 20 | x = self.layer3(x) 21 | return self.layer4(x) 22 | 23 | def n_features(self): 24 | return [b for b in list(self.layer4[-1].children()) if hasattr(b, 'num_features')][-1].num_features 25 | 26 | 27 | def build_resnet(pretrained=None, arch="resnet50", model_class=NoHeadResNet, **kwargs): 28 | """Constructs a ResNet-18 model. 29 | 30 | Args: 31 | arch (str): Type of densenet (among: resnet18, resnet34, resnet50, resnet101 and resnet152) 32 | pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model 33 | pre-trained in multi-task on digital pathology data. Otherwise (None), random weights. 34 | model_class (nn.Module): Actual resnet module class 35 | """ 36 | params = { 37 | "resnet18": [BasicBlock, [2, 2, 2, 2]], 38 | "resnet34": [BasicBlock, [3, 4, 6, 3]], 39 | "resnet50": [Bottleneck, [3, 4, 6, 3]], 40 | "resnet101": [Bottleneck, [3, 4, 23, 3]], 41 | "resnet152": [Bottleneck, [3, 8, 36, 3]] 42 | } 43 | model = model_class(*params[arch], **kwargs) 44 | if isinstance(pretrained, str): 45 | if pretrained == "imagenet": 46 | url = resnet_urls[arch] # default imagenet 47 | state_dict = model_zoo.load_url(url) 48 | elif pretrained == "mtdp": 49 | if arch not in MTDP_URLS: 50 | raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch)) 51 | url, filename = MTDP_URLS[arch] 52 | state_dict = load_dox_url(url, filename, map_location="cpu") 53 | state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads.")) 54 | else: 55 | raise ValueError("Unknown pre-training source") 56 | model.load_state_dict(state_dict) 57 | return model -------------------------------------------------------------------------------- /mtdp/models/densenet.py: -------------------------------------------------------------------------------- 1 | import re 2 | 3 | from torch.utils import model_zoo 4 | from torchvision.models.densenet import DenseNet, model_urls as densenet_urls 5 | from mtdp.components import FeaturesInterface 6 | from mtdp.models._util import load_dox_url, clean_state_dict 7 | 8 | MTDP_URLS = { 9 | "densenet121": ("https://dox.uliege.be/index.php/s/G72InP4xmJvOrVp/download", "densenet121-mh-best-191205-141200.pth") 10 | } 11 | 12 | 13 | class NoHeadDenseNet(DenseNet, FeaturesInterface): 14 | def forward(self, x): 15 | return self.features(x) 16 | 17 | def n_features(self): 18 | return self.features[-1].num_features 19 | 20 | 21 | def build_densenet(pretrained=False, arch="densenet201", model_class=NoHeadDenseNet, **kwargs): 22 | r"""Densenet-XXX model from 23 | `"Densely Connected Convolutional Networks" `_ 24 | 25 | Args: 26 | arch (str): Type of densenet (among: densenet121, densenet169, densenet201 and densenet161) 27 | 28 | pretrained (str|None): If "imagenet", returns a model pre-trained on ImageNet. If "mtdp" returns a model pre-trained 29 | in multi-task on digital pathology data. Otherwise (None), random weights. 30 | model_class (nn.Module): Actual densenet module class 31 | """ 32 | params = { 33 | "densenet121": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 24, 16)}, 34 | "densenet169": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 32, 32)}, 35 | "densenet201": {"num_init_features": 64, "growth_rate": 32, "block_config": (6, 12, 48, 32)}, 36 | "densenet161": {"num_init_features": 96, "growth_rate": 48, "block_config": (6, 12, 36, 24)} 37 | } 38 | model = model_class(**(params[arch]), **kwargs) 39 | if isinstance(pretrained, str): 40 | # '.'s are no longer allowed in module names, but pervious _DenseLayer 41 | # has keys 'norm.1', 'relu.1', 'conv.1', 'norm.2', 'relu.2', 'conv.2'. 42 | # They are also in the checkpoints in model_urls. This pattern is used 43 | # to find such keys. 44 | if pretrained == "imagenet": 45 | pattern = re.compile( 46 | r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$') 47 | state_dict = model_zoo.load_url(densenet_urls[arch]) 48 | for key in list(state_dict.keys()): 49 | res = pattern.match(key) 50 | if res: 51 | new_key = res.group(1) + res.group(2) 52 | state_dict[new_key] = state_dict[key] 53 | del state_dict[key] 54 | elif pretrained == "mtdp": 55 | if arch not in MTDP_URLS: 56 | raise ValueError("No pretrained weights for multi task pretraining with architecture '{}'".format(arch)) 57 | url, filename = MTDP_URLS[arch] 58 | state_dict = load_dox_url(url, filename, map_location="cpu") 59 | state_dict = clean_state_dict(state_dict, prefix="features.", filter=lambda k: not k.startswith("heads.")) 60 | else: 61 | raise ValueError("Unknown pre-training source") 62 | model.load_state_dict(state_dict) 63 | return model 64 | -------------------------------------------------------------------------------- /mtdp/networks.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from mtdp.components import Head 4 | 5 | 6 | class SingleHead(nn.Module): 7 | """A single task network, similar to usual classification architectures. A network produces 8 | feature maps that are then process by a head composed of global average pooling and a linear 9 | layer producing the logits. 10 | """ 11 | 12 | def __init__(self, features, head): 13 | """ 14 | Parameters 15 | ---------- 16 | features: nn.Module, FeaturesInterface 17 | The module producing the features map 18 | head: nn.Module 19 | The head module 20 | """ 21 | super().__init__() 22 | self.features = features 23 | self.head = head 24 | 25 | def forward(self, x): 26 | x = self.features(x) 27 | return self.head(x) 28 | 29 | 30 | class MultiHead(nn.Module): 31 | """A multi-task network composed of a shared network producing features and several heads, 32 | one per task, producing the task-specific predictions. All samples of a given task are 33 | routed only through this task's specific head. 34 | """ 35 | 36 | def __init__(self, dataset, features): 37 | """ 38 | Parameters 39 | ---------- 40 | dataset: MultiTaskDataset 41 | The dataset for which the multi-head network must be built. 42 | features: nn.Module, FeaturesInterface 43 | The shared network module. Should have a `n_features()` function returning 44 | the number of feature maps it produces. 45 | """ 46 | super().__init__() 47 | self.features = features 48 | self._dataset = dataset 49 | self.heads = nn.ModuleDict({ 50 | name: Head(n_features=features.n_features(), n_classes=n_classes) 51 | for name, n_classes in dataset.n_classes_per_dataset.items() 52 | }) 53 | self._name_to_index = self._dataset.name_to_index 54 | 55 | def forward(self, x, sources): 56 | """ 57 | Parameters 58 | ---------- 59 | x: torch.Tensor 60 | Batch of images. 61 | sources: torch.Tensor 62 | A vector (same size as the batch) where `sources[i]` is the source index 63 | for sample `x[i]` of the dataset. Source index should be a unique identifier 64 | consistent with indexes defined by the dataset. 65 | 66 | Returns 67 | ------- 68 | results: dict 69 | A dictionary mapping task name with another dictionary. `results[task_name]["logits"]` contains 70 | the logits for the all the samples of the task `task_name` contained in `x` (same order as the 71 | order they appear in the batch). `results[task_name]["which"]` is a binary mask indicating which 72 | samples of the batch actually belonged to task `task_name` 73 | 74 | """ 75 | f = self.features(x) # extract features for all inputs 76 | results = {} 77 | for name, head in self.heads.items(): 78 | which = sources == self._name_to_index[name] 79 | if which.nonzero().size(0) > 0: 80 | results[name] = { 81 | "logits": head(f[which]), 82 | "which": which 83 | } 84 | return results 85 | 86 | def get_single_head(self, task_name): 87 | """Creates a single-head network by assembling the shared network and the given 88 | task's head. 89 | 90 | Parameters 91 | ---------- 92 | task_name: str 93 | Name of the dataset for which the single-head network should be extracted. 94 | 95 | Returns 96 | ------- 97 | singlehead: SingleHead 98 | """ 99 | return SingleHead(self.features, self.heads[task_name]) 100 | 101 | @property 102 | def dataset(self): 103 | return self._dataset -------------------------------------------------------------------------------- /mtdp/helpers.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.nn.modules.batchnorm import _BatchNorm 7 | 8 | 9 | def module_unfreeze(module): 10 | """Unfreezes the shared network of module. Its trainable parameters 11 | are made trainable (requires_grad=True) and the module is set to `train` mode. 12 | Parameters 13 | ---------- 14 | module: MultiHead 15 | Multi-head network module 16 | """ 17 | module.train() 18 | for param in module.parameters(): 19 | param.requires_grad = True 20 | 21 | 22 | def module_freeze(module): 23 | """Freezes the shared network of module. Its trainable parameters 24 | are fixed (requires_grad=False) and the module is set to `eval` mode. 25 | 26 | Parameters 27 | ---------- 28 | module: MultiHead 29 | Multi-head network module 30 | """ 31 | module.eval() 32 | for param in module.parameters(): 33 | param.requires_grad = False 34 | 35 | 36 | def get_batch_norm_layers(module, current_name): 37 | """Find and return a list of all batch norm modules in the given module""" 38 | bns = list() 39 | for i, (name, m) in enumerate(module.named_children()): 40 | iter_name = current_name + [name] 41 | if isinstance(m, _BatchNorm): 42 | m.bn_name = ".".join(iter_name) 43 | bns.append(m) 44 | else: 45 | bns.extend(get_batch_norm_layers(m, iter_name)) 46 | return bns 47 | 48 | 49 | def adapt_batch_norm(network, loader, device, n_iter=20, forward_params_fn=None): 50 | """ 51 | In place update of the batch norm weights and bias for preparing a pre-trained network for a domain-switch. 52 | 53 | Parameters 54 | ---------- 55 | forward_params_fn: callable 56 | A callable that is passed the batch and return the list of parameters to pass to the forward function of the 57 | 'network' module. 58 | """ 59 | if forward_params_fn is None: 60 | def default_forward_params(batch): 61 | return [batch[0].to(device)] 62 | forward_params_fn = default_forward_params 63 | 64 | bns = {bn.bn_name: bn for bn in get_batch_norm_layers(network, [])} 65 | 66 | if np.any([not bn.affine for bn in bns.values()]): 67 | warnings.warn("some layers have 'affine' disabled. They cannot be fixed by this approach, and will therefore be" 68 | "ignored") 69 | bns = {name: bn for name, bn in bns.items() if bn.affine} 70 | 71 | means = defaultdict(list) 72 | vars = defaultdict(list) 73 | 74 | def hook_fn(bn, _in): 75 | size = _in[0].size() 76 | if len(size) != 4: # support for NCHW only 77 | raise ValueError("Invalid shape {}".format(size)) 78 | n_features = size[1] 79 | d_in = _in[0].detach().permute(1, 0, 2, 3).contiguous().view(n_features, -1) 80 | means[bn.bn_name].append(torch.mean(d_in, dim=1, keepdim=False).cpu().numpy()) 81 | vars[bn.bn_name].append(torch.var(d_in, dim=1, keepdim=False, unbiased=True).cpu().numpy()) 82 | 83 | hooks = [bn.register_forward_pre_hook(hook_fn) for bn in bns.values()] 84 | 85 | # forward samples into the network 86 | network.eval() 87 | with torch.no_grad(): 88 | for i, batch in enumerate(loader): 89 | if i >= n_iter: 90 | break 91 | _ = network(*forward_params_fn(batch)) 92 | 93 | def eps_std(var, eps): 94 | return torch.sqrt(var + eps) 95 | 96 | for name, bn in bns.items(): 97 | mu_t = torch.tensor(np.mean(np.array(means[name]), axis=0)) 98 | var_t = torch.tensor(np.mean(np.array(vars[name]), axis=0)) 99 | mu_s = bn.running_mean.detach().cpu() 100 | var_s = bn.running_var.detach().cpu() 101 | gamma_s, beta_s = bn.weight.detach().cpu(), bn.bias.detach().cpu() 102 | 103 | gamma_t = gamma_s * eps_std(var_t, bn.eps) / eps_std(var_s, bn.eps) 104 | beta_t = beta_s + gamma_s * (mu_t - mu_s) / eps_std(var_s, bn.eps) 105 | 106 | # adapt/update old batch norm 107 | bn.weight = torch.nn.Parameter(gamma_t) 108 | bn.bias = torch.nn.Parameter(beta_t) 109 | bn.running_mean = mu_t 110 | bn.running_var = var_t 111 | bn.num_batches_tracked = torch.tensor(n_iter, dtype=torch.long) 112 | 113 | for hook in hooks: 114 | hook.remove() 115 | 116 | network.to(device) 117 | 118 | return network 119 | 120 | 121 | def forward(multihead, x, sources): 122 | """Forward samples through the multihead network""" 123 | results = multihead.forward(x, sources) 124 | return {source_name: source_results for source_name, source_results in results.items()} 125 | 126 | 127 | def compute_loss(multihead, x, y, sources, loss_fn, aggreg_fn=torch.mean, return_losses=False): 128 | """Forward samples into a multihead network and computes the loss 129 | Parameters 130 | ---------- 131 | multihead: MultiHead 132 | x: torch.Tensor 133 | y: torch.Tensor 134 | sources: torch.Tensor 135 | loss_fn: callable 136 | aggreg_fn: callable 137 | return_losses: bool 138 | :return: 139 | """ 140 | losses, losses_per_task = list(), dict() 141 | for source, results in forward(multihead, x, sources).items(): 142 | source_losses = loss_fn(results["logits"], y[results["which"]]) 143 | losses_per_task[source] = source_losses.detach().cpu().numpy() 144 | losses.append(source_losses) 145 | loss = aggreg_fn(torch.cat(losses)) 146 | if return_losses: 147 | return loss, losses_per_task 148 | else: 149 | return loss 150 | 151 | 152 | def rescale_head_grads(multihead, sources): 153 | """ 154 | Rescale the heads gradients based on the number of samples that passed through the 155 | head during this iteration 156 | Parameters 157 | ---------- 158 | multihead: MultiHead 159 | Multihead network 160 | sources: torch.tensor 161 | Batch sources strings. 162 | """ 163 | sources = sources.numpy() 164 | batch_size = sources.shape[0] 165 | values, counts = np.unique(sources, return_counts=True) 166 | for index, count in zip(values, counts): 167 | head = multihead.heads[multihead.dataset.name(index)] 168 | for p in head.parameters(): 169 | p.grad *= batch_size / count -------------------------------------------------------------------------------- /mtdp/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | from torch.utils.data import Dataset 4 | from torchvision.datasets import ImageFolder 5 | 6 | 7 | def datasets_size_cumsum(datasets): 8 | sizes = np.array([len(d) for d in datasets]) 9 | cumsum = np.concatenate([np.array([0]), np.cumsum(sizes[:-1], dtype=np.int)]) 10 | return sizes, cumsum 11 | 12 | 13 | def get_sample_indexes(index, cumsum): 14 | dataset_index = np.searchsorted(cumsum, index, side="right") - 1 15 | relative_index = index - cumsum[dataset_index] 16 | return dataset_index, relative_index 17 | 18 | 19 | def merge_dicts(dicts): 20 | out_dict = dict() 21 | for d in dicts: 22 | for k, v in d.items(): 23 | if k in out_dict and v != out_dict[k]: 24 | raise ValueError("Value mismatch for key {} and value {} (!= {})".format(k, v, out_dict[k])) 25 | out_dict[k] = v 26 | return out_dict 27 | 28 | 29 | def get_image_folders_if_not_empty(paths, **kwargs): 30 | folders = list() 31 | for path in paths: 32 | try: 33 | folders.append(ImageFolder(path, **kwargs)) 34 | except RuntimeError as e: 35 | print("skip dataset at '{}' because: \"{}\"".format(path, str(e))) 36 | pass 37 | return folders 38 | 39 | 40 | def add_group(dataset: ImageFolder, index, do_add_group=False): 41 | # this is what ImageFolder normally returns 42 | original_tuple = dataset[index] 43 | if not do_add_group: 44 | return original_tuple 45 | 46 | # the image file path 47 | path = dataset.imgs[index][0] 48 | # make a new tuple that includes original and the path 49 | group = os.path.basename(path).split("_", 1)[0] 50 | return original_tuple + (group,) 51 | 52 | 53 | class MultiSetImageFolder(Dataset): 54 | """A classification dataset splitted in several sets structured as follows: {base_path}/{set_name}/{cls}/* 55 | Image filename can be prefixed with a group identifier which can optionally be returned. 56 | """ 57 | 58 | def __init__(self, base_path, sets, do_add_group=False, **kwargs): 59 | """ 60 | Parameters 61 | ---------- 62 | base_path: str 63 | Base path of the dataset folder 64 | sets: list 65 | List of set folder names as strings. 66 | do_add_group: bool 67 | True to append group identifier (optional), default: `False`. 68 | kwargs: dict 69 | Parameters to be transferred to the actual `ImageFolder`. 70 | """ 71 | super().__init__() 72 | self._datasets = get_image_folders_if_not_empty([os.path.join(base_path, _set) for _set in sets], **kwargs) 73 | self._sizes, self._cumsum_sizes = datasets_size_cumsum(self._datasets) 74 | self.class_to_idx = merge_dicts([d.class_to_idx for d in self._datasets]) 75 | self.classes = list(self.class_to_idx.keys()) 76 | self.do_add_group = do_add_group 77 | 78 | def __getitem__(self, index): 79 | dataset_index, relative_index = get_sample_indexes(index, self._cumsum_sizes) 80 | return add_group(self._datasets[dataset_index], relative_index, do_add_group=self.do_add_group) 81 | 82 | def __len__(self): 83 | return self._cumsum_sizes[-1] + len(self._datasets[-1]) 84 | 85 | @property 86 | def n_classes(self): 87 | """Total number of classes in the dataset""" 88 | return len(self.classes) 89 | 90 | @property 91 | def root(self): 92 | """The name of the dataset folder""" 93 | return os.path.dirname(self._datasets[0].root) 94 | 95 | 96 | class MultiImageFolders(Dataset): 97 | """Multiple tasks, each being represented by a Dataset. Each dataset can be identified by its name (name of 98 | the dataset folder, should be unique) or an index. 99 | """ 100 | def __init__(self, datasets, indexes=None): 101 | """ 102 | Parameters 103 | ---------- 104 | datasets: iterable 105 | List of datasets, each of which represents a task. 106 | indexes: iterable 107 | List of indexes associated with every dataset (optional). Default: each dataset get its index 108 | in `datasets` as index. 109 | """ 110 | if indexes is not None and len(indexes) != len(datasets): 111 | raise ValueError("indexes should have the same size as datasets") 112 | self._datasets = datasets 113 | # array of actual indexes, maps internal with external ids 114 | self._indexes = list(range(len(datasets))) if indexes is None else indexes 115 | # maps external id with internal id 116 | self._index_to_dataset = {i: d for i, d in zip(self._indexes, self._datasets)} 117 | # maps dataset name with external id 118 | self._dataset_name_to_index = {self.name(i): i for i in self._indexes} 119 | self._sizes, self._cumsum_sizes = datasets_size_cumsum(self._datasets) 120 | self._check_name_unicity() 121 | 122 | def _check_name_unicity(self): 123 | """Check whether all datasets have different names""" 124 | names = set() 125 | for i, dataset in enumerate(self.datasets): 126 | name = self.name(i) 127 | if name in names: 128 | raise ValueError("several datasets in the MultiImageFolders have the same name '{}' (folder name)".format(name)) 129 | names.add(self.name(i)) 130 | 131 | def __getitem__(self, index): 132 | dataset_index, relative_index = get_sample_indexes(index, self._cumsum_sizes) 133 | sample = self._datasets[dataset_index][relative_index] 134 | sample = sample + (self._indexes[dataset_index],) # store the dataset index in the returned data 135 | return sample 136 | 137 | def __len__(self): 138 | return self._cumsum_sizes[-1] + len(self._datasets[-1]) 139 | 140 | @property 141 | def datasets(self): 142 | """List of datasets""" 143 | return self._datasets 144 | 145 | @property 146 | def name_to_index(self): 147 | """Get the map for dataset name to dataset index""" 148 | return self._dataset_name_to_index 149 | 150 | def dataset_by_name(self, name): 151 | """Get the dataset object from the name 152 | Parameters 153 | ---------- 154 | name: str 155 | Name of the dataset 156 | """ 157 | return self._index_to_dataset[self.name_to_index[name]] 158 | 159 | @property 160 | def weights(self): 161 | """Return a weight vector for the samples so that each dataset has the same probability '1 / len(datasets)' 162 | of being sampled 163 | 164 | Returns 165 | ------- 166 | weights: ndarray 167 | Dimensions (len(self),). Sample weights. 168 | """ 169 | return np.repeat([1 / (len(self._datasets) * self._sizes)], self._sizes) 170 | 171 | @property 172 | def n_classes_per_dataset(self): 173 | """Return the number of classes for each dataset""" 174 | return {name: len(v) for name, v in self.classes_per_dataset.items()} 175 | 176 | @property 177 | def classes_per_dataset(self): 178 | """Return the classes for each dataset""" 179 | return {self.name(i): list(d.class_to_idx.keys()) for i, d in enumerate(self._datasets)} 180 | 181 | @property 182 | def class_to_idx_per_dataset(self): 183 | """Return class indexes for each dataset""" 184 | return {self.name(i): d.class_to_idx for i, d in enumerate(self._datasets)} 185 | 186 | @property 187 | def n_classes(self): 188 | return sum([len(a) for a in self.classes_per_dataset.values()]) 189 | 190 | def report(self): 191 | print("Multi dataset with {} sub-dataset(s) and {} samples.".format(len(self._datasets), len(self))) 192 | for i, d in self._index_to_dataset.items(): 193 | print("> {} ({}): {} samples, classes {}".format(self.name(i), i, len(d), d.classes)) 194 | 195 | def name(self, index): 196 | return os.path.basename(self._index_to_dataset[index].root) 197 | 198 | @property 199 | def names(self): 200 | return [self.name(i) for i, _ in enumerate(self._datasets)] 201 | --------------------------------------------------------------------------------