├── README.md ├── fig_1.png ├── fig_2.png ├── routed_vgg.py └── taskrouting.py /README.md: -------------------------------------------------------------------------------- 1 | # Many Task Learning With Task Routing - [[ICCV'19](http://iccv2019.thecvf.com/) Oral] 2 | 3 | This is the official implementation repo for our 2019 ICCV paper [Many Task Learning With Task Routing](http://openaccess.thecvf.com/content_ICCV_2019/html/Strezoski_Many_Task_Learning_With_Task_Routing_ICCV_2019_paper.html): 4 | 5 | **Many Task Learning With Task Routing** 6 | [Gjorgji Strezoski](https://staff.fnwi.uva.nl/g.strezoski/), [Nanne van Noord](https://nanne.github.io/), [Marcel Worring](https://staff.fnwi.uva.nl/m.worring/) 7 | International Conference on Computer Vision ([ICCV](http://iccv2019.thecvf.com/)), 2019 [Oral] 8 | [[CVF](http://openaccess.thecvf.com/content_ICCV_2019/html/Strezoski_Many_Task_Learning_With_Task_Routing_ICCV_2019_paper.html)] [[ArXiv](https://arxiv.org/abs/1903.12117)] [[Web](https://staff.fnwi.uva.nl/g.strezoski/post/iccv/)] 9 | 10 | It contains the Task Routing Layer implentation, its integration in existing models and usage instructions. 11 | 12 | --- 13 | 14 | ![Figure 1](https://github.com/gstrezoski/taskrouting/blob/master/fig_2.png) 15 | 16 | **Abstract:** Typical multi-task learning (MTL) methods rely on architectural adjustments and a large trainable parameter set to jointly optimize over several tasks. However, when the number of tasks increases so do the complexity of the architectural adjustments and resource requirements. In this paper, we introduce a method which applies a conditional feature-wise transformation over the convolutional activations that enables a model to successfully perform a large number of tasks. To distinguish from regular MTL, we introduce Many Task Learning (MaTL) as a special case of MTL where more than 20 tasks are performed by a single model. Our method dubbed Task Routing (TR) is encapsulated in a layer we call the Task Routing Layer (TRL), which applied in an MaTL scenario successfully fits hundreds of classification tasks in one model. We evaluate on 5 datasets and the Visual Decathlon (VD) challenge against strong baselines and state-of-the-art approaches. 17 | 18 | --- 19 | 20 | ### Usage 21 | 22 | #### Task Routing Layer 23 | 24 | In `taskrouting.py` you can find the Task Routing Layer source. It is a standalone file containing the PyTorch layer class. It takes 3 input arguments for instantiation: 25 | 26 | - `unit_count (int)`: Number of input channels going into the Task Routing layer (TRL). 27 | - `task_count (int)`: Number of tasks. (In Single Task Learning it applies to number of output classes) 28 | - `sigma (float)`: Ratio for routed units per task. (0.5 is default) 29 | 30 | ### Sample Model 31 | 32 | In `routed_vgg.py` you can find an implementation of the stock PyTorch VGG-11 model with or without BatchNorm converted for brahnched MTL. With: 33 | 34 | ```python 35 | for ix in range(self.task_count): 36 | self.add_module("classifier_" + str(ix), nn.Sequential( 37 | nn.Linear(1024 * bottleneck_spatial[0] * bottleneck_spatial[1], 2) 38 | )) 39 | ``` 40 | 41 | we create as many task specific branches as there are tasks. Additionally, the forward function is designed to forward the data through the active task branch only. 42 | 43 | In the code snippet (lines 71 to 74 from `routed_vgg.py`) below we add the TRL to the VGG net: 44 | 45 | ```python 46 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 47 | router = TaskRouter(v, task_count, int(v * sigma), "taskrouter_"+str(ix)) 48 | if batch_norm: 49 | layers += [conv2d, nn.BatchNorm2d(v), router, nn.ReLU(inplace=True)] 50 | ``` 51 | 52 | For training a model with the Task Routing Layer, the active model task needs to be randomly changed over the training itterations within an epoch. For example: 53 | 54 | ```python 55 | def change_task(m): 56 | if hasattr(m, 'active_task'): 57 | m.set_active_task(active_task) 58 | 59 | 60 | def train(args, model, task_count, device, train_loader, optimizer, criterion, epoch, total_itts): 61 | 62 | train_start = time.time() 63 | model.train() 64 | 65 | correct, positives, true_positives, score_list = initialize_evaluation_vars() 66 | 67 | epoch_loss = 0 68 | individual_loss = [0 for i in range(task_count)] 69 | 70 | for enum_return in enumerate(train_loader): 71 | 72 | batch_idx = enum_return[0] 73 | data = enum_return[1][0] 74 | targets = enum_return[1][1:] 75 | 76 | data = data.to(device) 77 | 78 | for ix in sample(range(task_count), 1): 79 | target = targets[ix].to(device) 80 | global active_task 81 | active_task = ix 82 | 83 | model = model.apply(change_task) 84 | out = model(data) 85 | labels = target[:, ix] 86 | train_loss = criterion(out, labels) 87 | optimizer.zero_grad() 88 | train_loss.backward() 89 | optimizer.step() 90 | 91 | train_end = time.time() 92 | print("Execution time:", train_end - train_start, "s.") 93 | 94 | return total_itts 95 | ``` 96 | 97 | If you find this repository usefull, please cite this paper: 98 | 99 | ``` 100 | @article{strezoski2019taskrouting, 101 | title={Many Task Learning With Task Routing}, 102 | author={Strezoski, Gjorgji and van Noord, Nanne and Worring, Marcel}, 103 | booktitle = {International Conference on Computer Vision (ICCV)}, 104 | organization={IEEE} 105 | year={2019}, 106 | url={https://arxiv.org/abs/1903.12117} 107 | } 108 | ``` 109 | -------------------------------------------------------------------------------- /fig_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gstrezoski/TaskRouting/c13a743cf5965ee70f80a81edeaaab0ebabe6aed/fig_1.png -------------------------------------------------------------------------------- /fig_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gstrezoski/TaskRouting/c13a743cf5965ee70f80a81edeaaab0ebabe6aed/fig_2.png -------------------------------------------------------------------------------- /routed_vgg.py: -------------------------------------------------------------------------------- 1 | import torch, pdb 2 | import torch.nn as nn 3 | import torch.utils.model_zoo as model_zoo 4 | from . taskrouting import TaskRouter 5 | import torch.nn.functional as F 6 | 7 | __all__ = [ 8 | 'VGG', 'vgg11', 'vgg11_bn', 9 | ] 10 | 11 | 12 | model_urls = { 13 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth', 14 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth', 15 | } 16 | 17 | 18 | class VGG(nn.Module): 19 | 20 | def __init__(self, features, task_count=10, init_weights=True, active_task=0, bottleneck_spatial=[7,7]): 21 | super(VGG, self).__init__() 22 | self.features = features 23 | self.avgpool = nn.AdaptiveAvgPool2d((7, 7)) 24 | self.task_count = task_count 25 | self.active_task = active_task 26 | for ix in range(self.task_count): 27 | self.add_module("classifier_" + str(ix), nn.Sequential( 28 | nn.Linear(1024 * bottleneck_spatial[0] * bottleneck_spatial[1], 2) 29 | )) 30 | 31 | if init_weights: 32 | self._initialize_weights() 33 | 34 | def forward(self, x): 35 | x = self.features(x) 36 | x = self.avgpool(x) 37 | x = x.view(x.size(0), -1) 38 | 39 | output = self.get_layer("classifier_" + str(self.active_task)).forward(x) 40 | 41 | return output 42 | 43 | def set_active_task(self, active_task): 44 | self.active_task = active_task 45 | return active_task 46 | 47 | def get_layer(self, name): 48 | return getattr(self, name) 49 | 50 | def _initialize_weights(self): 51 | for m in self.modules(): 52 | if isinstance(m, nn.Conv2d): 53 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 54 | if m.bias is not None: 55 | nn.init.constant_(m.bias, 0) 56 | elif isinstance(m, nn.BatchNorm2d): 57 | nn.init.constant_(m.weight, 1) 58 | nn.init.constant_(m.bias, 0) 59 | elif isinstance(m, nn.Linear): 60 | nn.init.normal_(m.weight, 0, 0.01) 61 | nn.init.constant_(m.bias, 0) 62 | 63 | 64 | def make_layers(cfg, task_count, sigma, batch_norm=False): 65 | layers = [] 66 | in_channels = 3 67 | for ix, v in enumerate(cfg): 68 | if v == 'M': 69 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)] 70 | else: 71 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1) 72 | router = TaskRouter(v, task_count, int(v * sigma), "taskrouter_"+str(ix)) 73 | if batch_norm: 74 | layers += [conv2d, nn.BatchNorm2d(v), router, nn.ReLU(inplace=True)] 75 | else: 76 | layers += [conv2d, router, nn.ReLU(inplace=True)] 77 | in_channels = v 78 | return nn.Sequential(*layers) 79 | 80 | 81 | cfg = { 82 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 83 | } 84 | 85 | 86 | def vgg11(pretrained=False, task_count=10, sigma=0.5, **kwargs): 87 | """VGG 11-layer model (configuration "A") 88 | Args: 89 | pretrained (bool): If True, returns a model pre-trained on ImageNet 90 | """ 91 | if pretrained: 92 | kwargs['init_weights'] = False 93 | model = VGG(make_layers(cfg['A'], task_count, sigma), task_count, **kwargs) 94 | if pretrained: 95 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11'])) 96 | return model 97 | 98 | 99 | def vgg11_bn(pretrained=False, task_count=10, sigma=0.5, **kwargs): 100 | """VGG 11-layer model (configuration "A") with batch normalization 101 | Args: 102 | pretrained (bool): If True, returns a model pre-trained on ImageNet 103 | """ 104 | if pretrained: 105 | kwargs['init_weights'] = False 106 | model = VGG(make_layers(cfg['A'], task_count, sigma), task_count, **kwargs) 107 | if pretrained: 108 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn'])) 109 | return model 110 | -------------------------------------------------------------------------------- /taskrouting.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import numpy as np 3 | import torch 4 | 5 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 6 | 7 | class TaskRouter(nn.Module): 8 | 9 | r""" Applies task specific masking out individual units in a layer. 10 | 11 | Args: 12 | unit_count (int): Number of input channels going into the Task Routing layer. 13 | task_count (int): Number of tasks. (IN STL it applies to number of output classes) 14 | sigma (int): Ratio for routed units per task. 15 | """ 16 | 17 | def __init__(self, unit_count, task_count, sigma, name="TaskRouter"): 18 | 19 | super(TaskRouter, self).__init__() 20 | 21 | self.use_routing = True 22 | self.name = name 23 | self.unit_count = unit_count 24 | # Just initilize it with 0. This gets changed right after the model is loaded so the value is never used. 25 | # We store the active mask for the Task Routing Layer here. 26 | self.active_task = 0 27 | 28 | if sigma!=0: 29 | self._unit_mapping = torch.ones((task_count, unit_count)) 30 | self._unit_mapping[np.arange(task_count)[:, None], np.random.rand(task_count, unit_count).argsort(1)[:, :sigma]] = 0 31 | self._unit_mapping = torch.nn.Parameter(self._unit_mapping) 32 | else: 33 | self._unit_mapping = torch.ones((task_count, unit_count)) 34 | self.use_knockout = False 35 | print("Not using Routing! Sigma is 0") 36 | 37 | def get_unit_mapping(self): 38 | 39 | return self._unit_mapping 40 | 41 | def set_active_task(self, active_task): 42 | 43 | self.active_task = active_task 44 | return active_task 45 | 46 | def forward(self, input): 47 | 48 | if not self.use_routing: 49 | return input 50 | 51 | mask = torch.index_select(self._unit_mapping, 0, (torch.ones(input.shape[0])*self.active_task).long().to(device))\ 52 | .unsqueeze(2).unsqueeze(3) 53 | input.data.mul_(mask) 54 | 55 | return input 56 | --------------------------------------------------------------------------------