├── src ├── __init__.py ├── fmow │ ├── __init__.py │ ├── fmow_aggregator.py │ ├── fmow_experts.py │ ├── fmow_train.py │ └── fmow_utils.py ├── rxrx1 │ ├── __init__.py │ ├── rxrx1_aggregator.py │ ├── rxrx1_experts.py │ ├── rxrx1_utils.py │ └── rxrx1_train.py ├── camelyon │ ├── __init__.py │ ├── camelyon_aggregator.py │ ├── camelyon_experts.py │ ├── camelyon_train.py │ └── camelyon_utils.py ├── iwildcam │ ├── __init__.py │ ├── iwildcam_aggregator.py │ ├── iwildcam_experts.py │ ├── iwildcam_utils.py │ └── iwildcam_train.py ├── transformer.py └── configs.py ├── assets ├── table1.png └── overview.png ├── requirements.txt ├── LICENSE ├── train_single_expert.py ├── README.md ├── run.py └── environment.yml /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/fmow/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/rxrx1/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/camelyon/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /src/iwildcam/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /assets/table1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n3il666/Meta-DMoE/HEAD/assets/table1.png -------------------------------------------------------------------------------- /assets/overview.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/n3il666/Meta-DMoE/HEAD/assets/overview.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | --find-links https://download.pytorch.org/whl/torch_stable.html 2 | torch==1.7.1 3 | torchvision==0.8.2 4 | 5 | numpy==1.21.3 6 | tqdm==4.62.3 7 | transformers==4.18.0 8 | learn2learn==0.1.6 9 | wilds==1.2.2 10 | pickle5==0.0.12 11 | scikit-learn==0.24.2 12 | scikit-image==0.18.3 -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Tao Zhong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /src/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | class PreNorm(nn.Module): 5 | def __init__(self, dim, fn): 6 | super().__init__() 7 | self.norm = nn.LayerNorm(dim) 8 | self.fn = fn 9 | def forward(self, x, **kwargs): 10 | return self.fn(self.norm(x), **kwargs) 11 | 12 | class FeedForward(nn.Module): 13 | def __init__(self, dim, hidden_dim, dropout = 0.): 14 | super().__init__() 15 | self.net = nn.Sequential( 16 | nn.Linear(dim, hidden_dim), 17 | nn.GELU(), 18 | nn.Dropout(dropout), 19 | nn.Linear(hidden_dim, dim), 20 | nn.Dropout(dropout) 21 | ) 22 | def forward(self, x): 23 | return self.net(x) 24 | 25 | class Attention(nn.Module): 26 | def __init__(self, dim, heads = 8, dropout = 0.): 27 | super().__init__() 28 | self.attend = nn.MultiheadAttention(dim, heads, dropout=dropout) 29 | 30 | def forward(self, x): 31 | q = x.permute((1,0,2)) 32 | k = x.permute((1,0,2)) 33 | v = x.permute((1,0,2)) 34 | out, _ = self.attend(q, k, v) 35 | out = out.permute((1,0,2)) 36 | return out 37 | 38 | class Transformer(nn.Module): 39 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0.): 40 | super().__init__() 41 | self.layers = nn.ModuleList([]) 42 | for _ in range(depth): 43 | self.layers.append(nn.ModuleList([ 44 | PreNorm(dim, Attention(dim, heads, dropout=dropout)), 45 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)) 46 | ])) 47 | def forward(self, x): 48 | for attn, ff in self.layers: 49 | x = attn(x) + x 50 | x = ff(x) + x 51 | return x 52 | -------------------------------------------------------------------------------- /src/configs.py: -------------------------------------------------------------------------------- 1 | default_param = { 2 | 'iwildcam':{ 3 | 'num_experts': 10, 4 | 'expert_batch_size': 16, 5 | 'expert_lr': 3e-5, 6 | 'expert_l2': 0, 7 | 'expert_epoch': 12, 8 | 'aggregator_depth': 1, 9 | 'aggregator_heads': 16, 10 | 'aggregator_dropout': 0.1, 11 | 'aggregator_pretrain_epoch': 12, 12 | 'student_pretrain_epoch': 12, 13 | 'batch_size': 40, 14 | 'sup_size': 24, 15 | 'tlr': 1e-6, 16 | 'slr': 3e-5, 17 | 'ilr': 3e-4, 18 | 'epoch': 20, 19 | 'seed': 1 20 | }, 21 | 'camelyon':{ 22 | 'num_experts': 5, 23 | 'expert_batch_size': 32, 24 | 'expert_lr': 1e-3, 25 | 'expert_l2': 1e-2, 26 | 'expert_epoch': 5, 27 | 'aggregator_depth': 1, 28 | 'aggregator_heads': 16, 29 | 'aggregator_dropout': 0.1, 30 | 'aggregator_pretrain_epoch': 5, 31 | 'student_pretrain_epoch': 5, 32 | 'batch_size': 96, 33 | 'sup_size': 64, 34 | 'tlr': 1e-6, 35 | 'slr': 1e-4, 36 | 'ilr': 1e-3, 37 | 'epoch': 10, 38 | 'seed': 3407 39 | }, 40 | 'rxrx1':{ 41 | 'num_experts': 3, 42 | 'expert_batch_size': 75, 43 | 'expert_lr': 1e-4, 44 | 'expert_l2': 1e-5, 45 | 'expert_epoch': 90, 46 | 'aggregator_depth': 1, 47 | 'aggregator_heads': 16, 48 | 'aggregator_dropout': 0.1, 49 | 'aggregator_pretrain_epoch': 90, 50 | 'student_pretrain_epoch': 90, 51 | 'batch_size': 123, 52 | 'sup_size': 75, 53 | 'tlr': 3e-6, 54 | 'slr': 1e-6, 55 | 'ilr': 1e-4, 56 | 'epoch': 10, 57 | 'seed': 1000 58 | }, 59 | 'fmow':{ 60 | 'num_experts': 4, 61 | 'expert_batch_size': 64, 62 | 'expert_lr': 1e-4, 63 | 'expert_l2': 0, 64 | 'expert_epoch': 30, 65 | 'aggregator_depth': 1, 66 | 'aggregator_heads': 16, 67 | 'aggregator_dropout': 0.1, 68 | 'aggregator_pretrain_epoch': 50, 69 | 'student_pretrain_epoch': 20, 70 | 'batch_size': 112, 71 | 'sup_size': 64, 72 | 'tlr': 1e-6, 73 | 'slr': 1e-5, 74 | 'ilr': 1e-4, 75 | 'epoch': 30, 76 | 'seed': 42 77 | } 78 | } -------------------------------------------------------------------------------- /train_single_expert.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import argparse 6 | import pickle 7 | from src.configs import default_param 8 | 9 | def get_parser(): 10 | parser = argparse.ArgumentParser() 11 | 12 | parser.add_argument('--gpu', type=str, default='0') # We only support single gpu training for now 13 | parser.add_argument('--threads', type=int, default=12) 14 | 15 | parser.add_argument('--dataset', type=str, default='iwildcam', 16 | choices=['iwildcam', 'fmow', 'camelyon', 'rxrx1', 'poverty']) 17 | parser.add_argument('--data_dir', type=str, default='data') 18 | 19 | parser.add_argument('--expert_idx', type=int) 20 | 21 | args = parser.parse_args() 22 | args_dict = args.__dict__ 23 | args_dict.update(default_param[args.dataset]) 24 | args = argparse.Namespace(**args_dict) 25 | return args 26 | 27 | def set_seed(seed): 28 | if torch.cuda.is_available(): 29 | torch.cuda.manual_seed(seed) 30 | torch.manual_seed(seed) 31 | np.random.seed(seed) 32 | random.seed(seed) 33 | torch.backends.cudnn.deterministic = True 34 | 35 | def train(args): 36 | 37 | if args.dataset == 'iwildcam': 38 | from src.iwildcam.iwildcam_utils import get_models_list 39 | from src.iwildcam.iwildcam_experts import train_model, get_expert_split 40 | elif args.dataset == 'camelyon': 41 | from src.camelyon.camelyon_utils import get_models_list 42 | from src.camelyon.camelyon_experts import train_model, get_expert_split 43 | elif args.dataset == 'rxrx1': 44 | from src.rxrx1.rxrx1_utils import get_models_list 45 | from src.rxrx1.rxrx1_experts import train_model, get_expert_split 46 | elif args.dataset == 'fmow': 47 | from src.fmow.fmow_utils import get_models_list 48 | from src.fmow.fmow_experts import train_model, get_expert_split 49 | else: 50 | raise NotImplementedError 51 | 52 | name = f"{args.dataset}_{str(args.num_experts)}experts_seed{str(args.seed)}" 53 | 54 | models_list = get_models_list(device=device, num_domains=0) 55 | 56 | try: 57 | with open(f"model/{args.dataset}/domain_split.pkl", "rb") as f: 58 | all_split, split_to_cluster = pickle.load(f) 59 | except FileNotFoundError: 60 | all_split, split_to_cluster = get_expert_split(args.num_experts, root_dir=args.data_dir) 61 | with open(f"model/{args.dataset}/domain_split.pkl", "wb") as f: 62 | pickle.dump((all_split, split_to_cluster), f) 63 | 64 | print(f"Training model {args.expert_idx} for domain ", *all_split[args.expert_idx]) 65 | train_model(models_list[0], name+'_'+str(args.expert_idx), device=device, 66 | domain=all_split[args.expert_idx], batch_size=args.expert_batch_size, 67 | lr=args.expert_lr, l2=args.expert_l2, num_epochs=args.expert_epoch, 68 | save=True, root_dir=args.data_dir) 69 | 70 | if __name__ == "__main__": 71 | args = get_parser() 72 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 73 | torch.set_num_threads(args.threads) 74 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 75 | set_seed(args.seed) 76 | train(args) 77 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts 2 | 3 | This repository is the official implementation of [Meta-DMoE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts](https://arxiv.org/abs/2210.03885). 4 | 5 |
6 | Method Overview 7 |
8 | 9 | ## Requirements 10 | 11 | The code was tested on python3.7 and CUDA10.1. 12 | 13 | We recommend using conda environment to setup all required dependencies: 14 | 15 | ```setup 16 | conda env create -f environment.yml 17 | conda activate dmoe 18 | ``` 19 | 20 | If you have any problem with the above command, you can also install them by `pip install -r requirements.txt`. 21 | 22 | Either of these commands will automatically install all required dependecies **except for the `torch-scatter` and `torch-geometric` packages** , which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries). 23 | 24 | ## Training 25 | 26 | We provide the training script for the following 4 datasets from the WILDS benchmark: `iwildcam`, `camelyon`, `rxrx1`, and `FMoW`. To train the models in the paper, run the following commands: 27 | 28 | ```Training 29 | python run.py --dataset 30 | ``` 31 | 32 | The data will be automatically downloaded to the data folder. 33 | 34 | ### Distributed Training for Expert Models 35 | 36 | Although we are not able to provide Multi-GPU support for meta-training at this point, you could still consider training the expert models in a distributed manner by opening multiple terminals and running: 37 | 38 | ```Train Experts 39 | python train_single_expert.py --dataset --data_dir --gpu --expert_idx 40 | ``` 41 | 42 | and add `--load_trained_experts` flag when running `run.py`. 43 | 44 | ## Evaluation 45 | 46 | To evaluate trained models, run: 47 | 48 | ```eval 49 | python run.py --dataset --data_dir --test 50 | ``` 51 | 52 | ## Pre-trained Models 53 | 54 | To reproduce the results reported in Table 1 in our paper, you can download pretrained models here and extract to `model/` folder. Note that due to the size limit of cloud storage, we only uploaded checkpoints from one random seed per dataset, while the results reported in the table are aggregated across several random seeds. 55 | 56 | - [iwildcam](https://drive.google.com/drive/folders/1mi-xInK5jXplmE4jo8oLSgHAYhsBe5sX?usp=sharing) 57 | - [camelyon](https://drive.google.com/drive/folders/1Wbuzv0DMxtfYjhF51KgdQQGV2_PRNEDw?usp=sharing) 58 | - [rxrx1](https://drive.google.com/drive/folders/1gsSFezrWbgKrU-nChC777DE3HFRnDkRx?usp=sharing) 59 | - [fmow](https://drive.google.com/drive/folders/1PAgzr7e2kcTQaYRq8gZnYtrQGq3PSxCh?usp=sharing) 60 | 61 |
62 | Table 1 63 |
64 | 65 | ## Citation 66 | If you find this codebase useful in your research, consider citing: 67 | ``` 68 | @inproceedings{ 69 | zhong2022metadmoe, 70 | title={Meta-{DM}oE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts}, 71 | author={Tao Zhong and Zhixiang Chi and Li Gu and Yang Wang and Yuanhao Yu and Jin Tang}, 72 | booktitle={Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS)}, 73 | year={2022} 74 | } 75 | ``` 76 | -------------------------------------------------------------------------------- /src/rxrx1/rxrx1_aggregator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | from tqdm import tqdm 8 | import os 9 | from src.rxrx1.rxrx1_utils import * 10 | 11 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True): 12 | selector.eval() 13 | correct = 0 14 | total = 0 15 | #mean_correct = 0 16 | if progress: 17 | data_loader = tqdm(data_loader) 18 | for x, y_true, metadata in data_loader: 19 | #z = grouper.metadata_to_group(metadata) 20 | #z = set(z.tolist()) 21 | #assert z.issubset(set(meta_indices)) 22 | 23 | x = x.to(device) 24 | y_true = y_true.to(device) 25 | 26 | with torch.no_grad(): 27 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 28 | features = features.permute((0,2,1)) 29 | out = selector(features) 30 | 31 | pred = out.max(1, keepdim=True)[1] 32 | correct += pred.eq(y_true.view_as(pred)).sum().item() 33 | total += x.shape[0] 34 | 35 | return correct/total 36 | 37 | def train_model_selector(selector, model_name, models_list, device, root_dir='data', 38 | batch_size=75, lr=1e-5, l2=0, 39 | num_epochs=90, decayRate=0.98, save=True, test_way='ood'): 40 | for model in models_list: 41 | model.eval() 42 | 43 | train_loader, val_loader, _, grouper = get_data_loader(root_dir=root_dir, 44 | batch_size=batch_size, domain=None, 45 | test_way=test_way, n_groups_per_batch=0) 46 | 47 | criterion = nn.CrossEntropyLoss() 48 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2) 49 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 50 | 51 | i = 0 52 | 53 | losses = [] 54 | acc_best = 0 55 | 56 | tot = len(train_loader) 57 | 58 | for epoch in range(num_epochs): 59 | 60 | print(f"Epoch:{epoch}|| Total:{tot}") 61 | 62 | for x, y_true, metadata in train_loader: 63 | selector.train() 64 | 65 | z = grouper.metadata_to_group(metadata) 66 | z = set(z.tolist()) 67 | #assert z.issubset(set(meta_indices)) 68 | 69 | x = x.to(device) 70 | y_true = y_true.to(device) 71 | 72 | with torch.no_grad(): 73 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 74 | features = features.permute((0,2,1)) 75 | out = selector(features) 76 | 77 | loss = criterion(out, y_true) 78 | loss.backward() 79 | optimizer.step() 80 | optimizer.zero_grad() 81 | losses.append(loss.item()/batch_size) 82 | 83 | if i % (tot//2) == 0 and i != 0: 84 | losses = np.mean(losses) 85 | acc = get_selector_accuracy(selector, models_list, val_loader, 86 | grouper, device, progress=False) 87 | 88 | print("Iter: {}/{} || Loss: {:.4f} || Acc:{:.4f}".format(i, tot, losses, acc)) 89 | losses = [] 90 | 91 | if acc > acc_best and save: 92 | print("Saving model ...") 93 | save_model(selector, model_name+"_selector", 0, test_way=test_way) 94 | acc_best = acc 95 | 96 | i += 1 97 | scheduler.step() 98 | -------------------------------------------------------------------------------- /src/camelyon/camelyon_aggregator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | from tqdm import tqdm 8 | import os 9 | from src.camelyon.camelyon_utils import * 10 | 11 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True): 12 | selector.eval() 13 | correct = 0 14 | total = 0 15 | #mean_correct = 0 16 | if progress: 17 | data_loader = tqdm(data_loader) 18 | for x, y_true, metadata in data_loader: 19 | #z = grouper.metadata_to_group(metadata) 20 | #z = set(z.tolist()) 21 | #assert z.issubset(set(meta_indices)) 22 | 23 | x = x.to(device) 24 | y_true = y_true.to(device) 25 | 26 | with torch.no_grad(): 27 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 28 | features = features.permute((0,2,1)) 29 | out = selector(features) 30 | 31 | pred = (out > 0.0).squeeze().long() 32 | correct += pred.eq(y_true.view_as(pred)).sum().item() 33 | total += x.shape[0] 34 | 35 | return correct/total 36 | 37 | def train_model_selector(selector, model_name, models_list, device, root_dir='data', 38 | batch_size=32, lr=3e-6, l2=0, 39 | num_epochs=5, decayRate=0.96, save=True, test_way='ood'): 40 | for model in models_list: 41 | model.eval() 42 | 43 | train_loader, val_loader, _, grouper = get_data_loader(root_dir=root_dir, 44 | batch_size=batch_size, domain=None, 45 | test_way=test_way, n_groups_per_batch=0, 46 | return_dataset=False) 47 | 48 | criterion = nn.BCEWithLogitsLoss() 49 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2) 50 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 51 | 52 | i = 0 53 | 54 | losses = [] 55 | acc_best = 0 56 | 57 | tot = len(train_loader) 58 | 59 | for epoch in range(num_epochs): 60 | 61 | print(f"Epoch:{epoch}|| Total:{tot}") 62 | 63 | for x, y_true, metadata in train_loader: 64 | selector.train() 65 | 66 | z = grouper.metadata_to_group(metadata) 67 | z = set(z.tolist()) 68 | #assert z.issubset(set(meta_indices)) 69 | 70 | x = x.to(device) 71 | y_true = y_true.to(device) 72 | 73 | with torch.no_grad(): 74 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 75 | features = features.permute((0,2,1)) 76 | out = selector(features) 77 | 78 | loss = criterion(out, y_true.unsqueeze(-1).float()) 79 | loss.backward() 80 | optimizer.step() 81 | optimizer.zero_grad() 82 | losses.append(loss.item()/batch_size) 83 | 84 | if i % (tot//2) == 0 and i != 0: 85 | losses = np.mean(losses) 86 | acc = get_selector_accuracy(selector, models_list, val_loader, 87 | grouper, device, progress=False) 88 | 89 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f}".format(i, losses, acc)) 90 | losses = [] 91 | 92 | if acc > acc_best and save: 93 | print("Saving model ...") 94 | save_model(selector, model_name+"_selector", 0, test_way=test_way) 95 | acc_best = acc 96 | 97 | i += 1 98 | scheduler.step() 99 | -------------------------------------------------------------------------------- /src/iwildcam/iwildcam_aggregator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | from tqdm import tqdm 8 | from sklearn.metrics import f1_score 9 | import os 10 | from src.iwildcam.iwildcam_utils import * 11 | 12 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True): 13 | selector.eval() 14 | correct = 0 15 | total = 0 16 | #mean_correct = 0 17 | if progress: 18 | data_loader = tqdm(data_loader) 19 | for x, y_true, metadata in data_loader: 20 | #z = grouper.metadata_to_group(metadata) 21 | #z = set(z.tolist()) 22 | #assert z.issubset(set(meta_indices)) 23 | 24 | x = x.to(device) 25 | y_true = y_true.to(device) 26 | 27 | with torch.no_grad(): 28 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 29 | features = features.permute((0,2,1)) 30 | out = selector(features) 31 | 32 | pred = out.max(1, keepdim=True)[1] 33 | correct += pred.eq(y_true.view_as(pred)).sum().item() 34 | total += x.shape[0] 35 | try: 36 | pred_all = torch.cat((pred_all, pred.view_as(y_true))) 37 | y_all = torch.cat((y_all, y_true)) 38 | except NameError: 39 | pred_all = pred.view_as(y_true) 40 | y_all = y_true 41 | 42 | y_all = y_all.detach().cpu() 43 | pred_all = pred_all.detach().cpu() 44 | f1 = f1_score(y_all,pred_all,average='macro', labels=torch.unique(y_all)) 45 | 46 | return correct/total, f1 47 | 48 | def train_model_selector(selector, model_name, models_list, device, root_dir='data', 49 | batch_size=32, lr=1e-6, l2=0, 50 | num_epochs=12, decayRate=0.96, save=True, test_way='ood'): 51 | for model in models_list: 52 | model.eval() 53 | 54 | train_loader, val_loader, _, grouper = get_data_loader(root_dir=root_dir, 55 | batch_size=batch_size, domain=None, 56 | test_way=test_way, n_groups_per_batch=0) 57 | 58 | criterion = nn.CrossEntropyLoss() 59 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2) 60 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 61 | 62 | i = 0 63 | 64 | losses = [] 65 | acc_best = 0 66 | 67 | tot = len(train_loader) 68 | 69 | for epoch in range(num_epochs): 70 | 71 | print(f"Epoch:{epoch}|| Total:{tot}") 72 | 73 | for x, y_true, metadata in train_loader: 74 | selector.train() 75 | 76 | z = grouper.metadata_to_group(metadata) 77 | z = set(z.tolist()) 78 | #assert z.issubset(set(meta_indices)) 79 | 80 | x = x.to(device) 81 | y_true = y_true.to(device) 82 | 83 | with torch.no_grad(): 84 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 85 | features = features.permute((0,2,1)) 86 | out = selector(features) 87 | 88 | loss = criterion(out, y_true) 89 | loss.backward() 90 | optimizer.step() 91 | optimizer.zero_grad() 92 | losses.append(loss.item()/batch_size) 93 | 94 | if i % (tot//2) == 0 and i != 0: 95 | losses = np.mean(losses) 96 | acc, f1 = get_selector_accuracy(selector, models_list, val_loader, 97 | grouper, device, progress=False) 98 | 99 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f} || F1:{:.4f}".format(i, losses, acc, f1)) 100 | losses = [] 101 | 102 | if f1 > acc_best and save: 103 | print("Saving model ...") 104 | save_model(selector, model_name+"_selector", 0, test_way=test_way) 105 | acc_best = f1 106 | 107 | i += 1 108 | scheduler.step() 109 | -------------------------------------------------------------------------------- /src/rxrx1/rxrx1_experts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import os 8 | from transformers import get_cosine_schedule_with_warmup 9 | from src.rxrx1.rxrx1_utils import * 10 | 11 | def get_model_accuracy(model, data_loader, grouper, device, domain=None): 12 | model.eval() 13 | correct = 0 14 | total = 0 15 | for x, y_true, metadata in iter(data_loader): 16 | 17 | z = grouper.metadata_to_group(metadata) 18 | z = set(z.tolist()) 19 | if domain is not None: 20 | assert z.issubset(set(domain)) 21 | 22 | x = x.to(device) 23 | y_true = y_true.to(device) 24 | 25 | out = model(x) 26 | pred = out.max(1, keepdim=True)[1] 27 | correct += pred.eq(y_true.view_as(pred)).sum().item() 28 | total += x.shape[0] 29 | return correct/total 30 | 31 | def train_model(model, model_name, device, domain=None, batch_size=75, lr=1e-4, l2=1e-5, 32 | num_epochs=90, decayRate=1., save=False, test_way='ood', root_dir='data'): 33 | 34 | train_loader, val_loader, test_loader, grouper = get_data_loader(root_dir=root_dir, 35 | batch_size=batch_size, domain=domain, 36 | test_way=test_way, n_groups_per_batch=0) 37 | 38 | criterion = nn.CrossEntropyLoss() 39 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2) 40 | 41 | i = 0 42 | 43 | losses = [] 44 | acc_best = 0 45 | 46 | tot = len(train_loader) 47 | scheduler = get_cosine_schedule_with_warmup( 48 | optimizer, 49 | num_warmup_steps=tot * 10, 50 | num_training_steps=tot*num_epochs) 51 | 52 | for epoch in range(num_epochs): 53 | 54 | print(f"Epoch:{epoch} || Total:{tot}") 55 | 56 | for x, y_true, metadata in iter(train_loader): 57 | model.train() 58 | 59 | z = grouper.metadata_to_group(metadata) 60 | z = set(z.tolist()) 61 | if domain is not None: 62 | assert z.issubset(set(domain)) 63 | 64 | x = x.to(device) 65 | y_true = y_true.to(device) 66 | 67 | pred = model(x) 68 | 69 | loss = criterion(pred, y_true) 70 | loss.backward() 71 | optimizer.step() 72 | optimizer.zero_grad() 73 | losses.append(loss.item()/batch_size) 74 | 75 | if i % (tot//2) == 0 and i != 0: 76 | losses = np.mean(losses) 77 | acc = get_model_accuracy(model, val_loader, grouper, device=device) 78 | 79 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f}".format(i, losses, acc)) 80 | losses = [] 81 | 82 | if acc > acc_best and save: 83 | print("Saving model ...") 84 | save_model(model, model_name+"_exp", 0, test_way=test_way) 85 | acc_best = acc 86 | 87 | i += 1 88 | scheduler.step() 89 | 90 | def train_exp(models_list, domain_specific_indices, device, batch_size=75, lr=1e-4, l2=1e-5, 91 | num_epochs=90, decayRate=1., save=False, test_way='ood', name="Resnet50_experts", 92 | root_dir='data'): 93 | 94 | assert len(models_list) == len(domain_specific_indices) 95 | for i in range(len(models_list)): 96 | print(f"Training model {i} for domain", *domain_specific_indices[i]) 97 | train_model(models_list[i], name+'_'+str(i), device=device, 98 | domain=domain_specific_indices[i], batch_size=batch_size, 99 | lr=lr, l2=l2, num_epochs=num_epochs, 100 | decayRate=decayRate, save=save, test_way=test_way, 101 | root_dir=root_dir) 102 | 103 | def get_expert_split(num_experts, root_dir='data'): 104 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir) 105 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]} 106 | return all_split, split_to_cluster 107 | -------------------------------------------------------------------------------- /src/camelyon/camelyon_experts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import os 8 | from src.camelyon.camelyon_utils import * 9 | 10 | def get_model_accuracy(model, data_loader, grouper, device, domain=None, progress=False): 11 | model.eval() 12 | correct = 0 13 | total = 0 14 | for x, y_true, metadata in iter(data_loader): 15 | 16 | z = grouper.metadata_to_group(metadata) 17 | z = set(z.tolist()) 18 | if domain is not None: 19 | assert z.issubset(set(domain)) 20 | 21 | x = x.to(device) 22 | y_true = y_true.to(device) 23 | 24 | out = model(x) 25 | pred = (out > 0.0).squeeze().long() 26 | correct += pred.eq(y_true.view_as(pred)).sum().item() 27 | total += x.shape[0] 28 | 29 | return correct/total 30 | 31 | def train_model(model, model_name, device, domain=None, batch_size=32, lr=1e-3, l2=1e-2, 32 | num_epochs=5, decayRate=1., save=False, test_way='ood', root_dir='data'): 33 | 34 | train_loader, val_loader, test_loader, grouper = get_data_loader(root_dir=root_dir, 35 | batch_size=batch_size, domain=domain, 36 | test_way=test_way, n_groups_per_batch=0, 37 | return_dataset=False) 38 | 39 | criterion = nn.BCEWithLogitsLoss() 40 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2) 41 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 42 | 43 | i = 0 44 | 45 | losses = [] 46 | acc_best = 0 47 | 48 | tot = len(train_loader) 49 | 50 | for epoch in range(num_epochs): 51 | 52 | print(f"Epoch:{epoch} || Total:{tot}") 53 | 54 | for x, y_true, metadata in iter(train_loader): 55 | model.train() 56 | 57 | z = grouper.metadata_to_group(metadata) 58 | z = set(z.tolist()) 59 | if domain is not None: 60 | assert z.issubset(set(domain)) 61 | 62 | x = x.to(device) 63 | y_true = y_true.to(device) 64 | 65 | pred = model(x) 66 | 67 | loss = criterion(pred, y_true.unsqueeze(-1).float()) 68 | loss.backward() 69 | optimizer.step() 70 | optimizer.zero_grad() 71 | losses.append(loss.item()/batch_size) 72 | 73 | if i % (tot//2) == 0 and i != 0: 74 | losses = np.mean(losses) 75 | acc = get_model_accuracy(model, val_loader, grouper, device=device) 76 | 77 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f}".format(i, losses, acc)) 78 | losses = [] 79 | 80 | if acc > acc_best and save: 81 | print("Saving model ...") 82 | save_model(model, model_name+"_exp", 0, test_way=test_way) 83 | acc_best = acc 84 | 85 | 86 | i += 1 87 | scheduler.step() 88 | 89 | def train_exp(models_list, domain_specific_indices, device, batch_size=32, lr=1e-3, l2=1e-2, 90 | num_epochs=5, decayRate=1., save=False, test_way='ood', name="Dense121_experts", 91 | root_dir='data'): 92 | 93 | assert len(models_list) == len(domain_specific_indices) 94 | for i in range(len(models_list)): 95 | print(f"Training model {i} for domain", *domain_specific_indices[i]) 96 | train_model(models_list[i], name+'_'+str(i), device=device, 97 | domain=domain_specific_indices[i], batch_size=batch_size, 98 | lr=lr, l2=l2, num_epochs=num_epochs, 99 | decayRate=decayRate, save=save, test_way=test_way, 100 | root_dir=root_dir) 101 | 102 | def get_expert_split(num_experts, root_dir='data'): 103 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir) 104 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]} 105 | return all_split, split_to_cluster 106 | -------------------------------------------------------------------------------- /src/fmow/fmow_aggregator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | from tqdm import tqdm 8 | import os 9 | from src.fmow.fmow_utils import * 10 | 11 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True, dataset=None): 12 | selector.eval() 13 | #mean_correct = 0 14 | if progress: 15 | data_loader = tqdm(data_loader) 16 | for x, y_true, metadata in data_loader: 17 | #z = grouper.metadata_to_group(metadata) 18 | #z = set(z.tolist()) 19 | #assert z.issubset(set(meta_indices)) 20 | 21 | x = x.to(device) 22 | y_true = y_true.to(device) 23 | 24 | with torch.no_grad(): 25 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 26 | features = features.permute((0,2,1)) 27 | out = selector(features) 28 | 29 | pred = out.max(1, keepdim=True)[1] 30 | try: 31 | pred_all = torch.cat((pred_all, pred.view_as(y_true))) 32 | y_all = torch.cat((y_all, y_true)) 33 | metadata_all = torch.cat((metadata_all, metadata)) 34 | except NameError: 35 | pred_all = pred.view_as(y_true) 36 | y_all = y_true 37 | metadata_all = metadata 38 | acc, worst_acc = get_fmow_metrics(pred_all, y_all, metadata_all, dataset) 39 | return acc, worst_acc 40 | 41 | def train_model_selector(selector, model_name, models_list, device, root_dir='data', 42 | batch_size=64, lr=3e-6, l2=0, 43 | num_epochs=20, decayRate=0.96, save=True, test_way='ood'): 44 | for model in models_list: 45 | model.eval() 46 | 47 | train_loader, val_loader, _, grouper, dataset = get_data_loader(root_dir=root_dir, 48 | batch_size=batch_size, domain=None, 49 | test_way=test_way, n_groups_per_batch=0, 50 | return_dataset=True) 51 | 52 | criterion = nn.CrossEntropyLoss() 53 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2) 54 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 55 | 56 | i = 0 57 | 58 | losses = [] 59 | acc_best = 0 60 | 61 | tot = len(train_loader) 62 | 63 | for epoch in range(num_epochs): 64 | 65 | print(f"Epoch:{epoch}|| Total:{tot}") 66 | 67 | for x, y_true, metadata in train_loader: 68 | selector.train() 69 | 70 | z = grouper.metadata_to_group(metadata) 71 | z = set(z.tolist()) 72 | #assert z.issubset(set(meta_indices)) 73 | 74 | x = x.to(device) 75 | y_true = y_true.to(device) 76 | 77 | with torch.no_grad(): 78 | features = torch.stack([model(x).detach() for model in models_list], dim=-1) 79 | features = features.permute((0,2,1)) 80 | out = selector(features) 81 | 82 | loss = criterion(out, y_true) 83 | loss.backward() 84 | optimizer.step() 85 | optimizer.zero_grad() 86 | losses.append(loss.item()/batch_size) 87 | 88 | if i % (tot//2) == 0 and i != 0: 89 | losses = np.mean(losses) 90 | acc, wc_acc = get_selector_accuracy(selector, models_list, val_loader, 91 | grouper, device, progress=False, dataset=dataset) 92 | 93 | print("Iter: {}/{} || Loss: {:.4f} || Acc:{:.4f} || WC Acc:{:.4f} ".format(i, tot, losses, 94 | acc, wc_acc)) 95 | losses = [] 96 | 97 | if wc_acc > acc_best and save: 98 | print("Saving model ...") 99 | save_model(selector, model_name+"_selector", 0, test_way=test_way) 100 | acc_best = wc_acc 101 | 102 | i += 1 103 | scheduler.step() 104 | -------------------------------------------------------------------------------- /src/iwildcam/iwildcam_experts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | from sklearn.metrics import f1_score 8 | import os 9 | from src.iwildcam.iwildcam_utils import * 10 | 11 | def get_model_accuracy(model, data_loader, grouper, device, domain=None): 12 | model.eval() 13 | correct = 0 14 | total = 0 15 | for x, y_true, metadata in iter(data_loader): 16 | 17 | z = grouper.metadata_to_group(metadata) 18 | z = set(z.tolist()) 19 | if domain is not None: 20 | assert z.issubset(set(domain)) 21 | 22 | x = x.to(device) 23 | y_true = y_true.to(device) 24 | 25 | out = model(x) 26 | pred = out.max(1, keepdim=True)[1] 27 | correct += pred.eq(y_true.view_as(pred)).sum().item() 28 | total += x.shape[0] 29 | try: 30 | pred_all = torch.cat((pred_all, pred.view_as(y_true))) 31 | y_all = torch.cat((y_all, y_true)) 32 | except NameError: 33 | pred_all = pred.view_as(y_true) 34 | y_all = y_true 35 | y_all = y_all.detach().cpu() 36 | pred_all = pred_all.detach().cpu() 37 | f1 = f1_score(y_all,pred_all,average='macro', labels=torch.unique(y_all)) 38 | return correct/total, f1 39 | 40 | def train_model(model, model_name, device, domain=None, batch_size=16, lr=3e-5, l2=0, 41 | num_epochs=12, decayRate=1., save=False, test_way='ood', root_dir='data'): 42 | 43 | train_loader, val_loader, test_loader, grouper = get_data_loader(root_dir=root_dir, 44 | batch_size=batch_size, domain=domain, 45 | test_way=test_way, n_groups_per_batch=0) 46 | 47 | criterion = nn.CrossEntropyLoss() 48 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2) 49 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 50 | 51 | i = 0 52 | 53 | losses = [] 54 | acc_best = 0 55 | 56 | tot = len(train_loader) 57 | 58 | for epoch in range(num_epochs): 59 | 60 | print(f"Epoch:{epoch} || Total:{tot}") 61 | 62 | for x, y_true, metadata in iter(train_loader): 63 | model.train() 64 | 65 | z = grouper.metadata_to_group(metadata) 66 | z = set(z.tolist()) 67 | if domain is not None: 68 | assert z.issubset(set(domain)) 69 | 70 | x = x.to(device) 71 | y_true = y_true.to(device) 72 | 73 | pred = model(x) 74 | 75 | loss = criterion(pred, y_true) 76 | loss.backward() 77 | optimizer.step() 78 | optimizer.zero_grad() 79 | losses.append(loss.item()/batch_size) 80 | 81 | if i % (tot//2) == 0 and i != 0: 82 | losses = np.mean(losses) 83 | acc, f1 = get_model_accuracy(model, val_loader, grouper, device=device) 84 | 85 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f} || F1:{:.4f} ".format(i, losses, acc, f1)) 86 | losses = [] 87 | 88 | if f1 > acc_best and save: 89 | print("Saving model ...") 90 | save_model(model, model_name+"_exp", 0, test_way=test_way) 91 | acc_best = f1 92 | 93 | 94 | i += 1 95 | scheduler.step() 96 | 97 | def train_exp(models_list, domain_specific_indices, device, batch_size=16, lr=1e-4, l2=0, 98 | num_epochs=30, decayRate=0.96, save=False, test_way='ood', name="Resnet50_experts", 99 | root_dir='data'): 100 | 101 | assert len(models_list) == len(domain_specific_indices) 102 | for i in range(len(models_list)): 103 | print(f"Training model {i} for domain", *domain_specific_indices[i]) 104 | train_model(models_list[i], name+'_'+str(i), device=device, 105 | domain=domain_specific_indices[i], batch_size=batch_size, 106 | lr=lr, l2=l2, num_epochs=num_epochs, 107 | decayRate=decayRate, save=save, test_way=test_way, 108 | root_dir=root_dir) 109 | 110 | def get_expert_split(num_experts, root_dir='data'): 111 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir) 112 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]} 113 | return all_split, split_to_cluster 114 | -------------------------------------------------------------------------------- /src/fmow/fmow_experts.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import os 8 | from src.fmow.fmow_utils import * 9 | 10 | def get_model_accuracy(model, data_loader, grouper, device, domain=None, dataset=None): 11 | model.eval() 12 | 13 | for x, y_true, metadata in iter(data_loader): 14 | 15 | z = grouper.metadata_to_group(metadata) 16 | z = set(z.tolist()) 17 | if domain is not None: 18 | assert z.issubset(set(domain)) 19 | 20 | x = x.to(device) 21 | y_true = y_true.to(device) 22 | 23 | out = model(x) 24 | pred = out.max(1, keepdim=True)[1] 25 | try: 26 | pred_all = torch.cat((pred_all, pred.view_as(y_true))) 27 | y_all = torch.cat((y_all, y_true)) 28 | metadata_all = torch.cat((metadata_all, metadata)) 29 | except NameError: 30 | pred_all = pred.view_as(y_true) 31 | y_all = y_true 32 | metadata_all = metadata 33 | acc, worst_acc = get_fmow_metrics(pred_all, y_all, metadata_all, dataset) 34 | return acc, worst_acc 35 | 36 | def train_model(model, model_name, device, domain=None, batch_size=64, lr=1e-4, l2=0, 37 | num_epochs=50, decayRate=0.96, save=False, test_way='ood', root_dir='data'): 38 | 39 | train_loader, val_loader, test_loader, grouper, dataset = get_data_loader(root_dir=root_dir, 40 | batch_size=batch_size, domain=domain, 41 | test_way=test_way, n_groups_per_batch=0, 42 | return_dataset=True) 43 | 44 | criterion = nn.CrossEntropyLoss() 45 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2) 46 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate) 47 | 48 | i = 0 49 | 50 | losses = [] 51 | acc_best = 0 52 | 53 | tot = len(train_loader) 54 | 55 | for epoch in range(num_epochs): 56 | 57 | print(f"Epoch:{epoch} || Total:{tot}") 58 | 59 | for x, y_true, metadata in iter(train_loader): 60 | model.train() 61 | 62 | z = grouper.metadata_to_group(metadata) 63 | z = set(z.tolist()) 64 | if domain is not None: 65 | assert z.issubset(set(domain)) 66 | 67 | x = x.to(device) 68 | y_true = y_true.to(device) 69 | 70 | pred = model(x) 71 | 72 | loss = criterion(pred, y_true) 73 | loss.backward() 74 | optimizer.step() 75 | optimizer.zero_grad() 76 | losses.append(loss.item()/batch_size) 77 | 78 | if i % (tot//2) == 0 and i != 0: 79 | losses = np.mean(losses) 80 | acc, worst_acc = get_model_accuracy(model, val_loader, grouper, device=device, dataset=dataset) 81 | 82 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f} ||Worst Acc:{:.4f} ".format(i, 83 | losses, acc, worst_acc)) 84 | losses = [] 85 | 86 | if worst_acc > acc_best and save: 87 | print("Saving model ...") 88 | save_model(model, model_name+"_exp", 0, test_way=test_way) 89 | acc_best = worst_acc 90 | 91 | i += 1 92 | scheduler.step() 93 | 94 | def train_exp(models_list, domain_specific_indices, device, batch_size=64, lr=1e-4, l2=0, 95 | num_epochs=50, decayRate=0.96, save=False, test_way='ood', name="Resnet50_experts", 96 | root_dir='data'): 97 | 98 | assert len(models_list) == len(domain_specific_indices) 99 | for i in range(len(models_list)): 100 | print(f"Training model {i} for domain", *domain_specific_indices[i]) 101 | train_model(models_list[i], name+'_'+str(i), device=device, 102 | domain=domain_specific_indices[i], batch_size=batch_size, 103 | lr=lr, l2=l2, num_epochs=num_epochs, 104 | decayRate=decayRate, save=save, test_way=test_way, 105 | root_dir=root_dir) 106 | 107 | def get_expert_split(num_experts, root_dir='data'): 108 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir) 109 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]} 110 | return all_split, split_to_cluster 111 | -------------------------------------------------------------------------------- /src/iwildcam/iwildcam_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import math 9 | import random 10 | import pickle5 as pickle 11 | from src.transformer import Transformer 12 | import copy 13 | import time 14 | import os 15 | import wilds 16 | from wilds import get_dataset 17 | from wilds.common.data_loaders import get_train_loader, get_eval_loader 18 | from wilds.common.grouper import CombinatorialGrouper 19 | 20 | # Utils 21 | 22 | def split_domains(num_experts, root_dir='data'): 23 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir) 24 | train_data = dataset.get_subset('train') 25 | locations = list(set(train_data.metadata_array[:,0].detach().numpy().tolist())) 26 | random.shuffle(locations) 27 | num_domains_per_super = len(locations) / float(num_experts) 28 | all_split = [[] for _ in range(num_experts)] 29 | for i in range(len(locations)): 30 | all_split[int(i//num_domains_per_super)].append(locations[i]) 31 | return all_split 32 | 33 | def get_subset_with_domain(dataset, split=None, domain=None, transform=None): 34 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset: 35 | subset = copy.deepcopy(dataset) 36 | else: 37 | subset = dataset.get_subset(split, transform=transform) 38 | if domain is not None: 39 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,0][subset.indices], domain)).ravel() 40 | subset.indices = subset.indices[idx] 41 | return subset 42 | 43 | def initialize_image_base_transform(dataset): 44 | transform_steps = [] 45 | if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution): 46 | crop_size = min(dataset.original_resolution) 47 | transform_steps.append(transforms.CenterCrop(crop_size)) 48 | transform_steps.append(transforms.Resize((448,448))) 49 | transform_steps += [ 50 | transforms.ToTensor(), 51 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 52 | ] 53 | transform = transforms.Compose(transform_steps) 54 | return transform 55 | 56 | def get_data_loader(batch_size=16, domain=None, test_way='id', 57 | n_groups_per_batch=0, uniform_over_groups=False, root_dir='data'): 58 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir) 59 | grouper = CombinatorialGrouper(dataset, ['location']) 60 | 61 | transform = initialize_image_base_transform(dataset) 62 | 63 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform) 64 | if test_way == 'ood': 65 | val_data = dataset.get_subset('val', transform=transform) 66 | test_data = dataset.get_subset('test', transform=transform) 67 | else: 68 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform) 69 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform) 70 | 71 | 72 | if n_groups_per_batch == 0: 73 | #0 identify standard loader 74 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size) 75 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size) 76 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size) 77 | 78 | else: 79 | # All use get_train_loader to enable grouper 80 | train_loader = get_train_loader('group', train_data, grouper=grouper, 81 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size) 82 | val_loader = get_train_loader('group', val_data, grouper=grouper, 83 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 84 | uniform_over_groups=uniform_over_groups) 85 | test_loader = get_train_loader('group', test_data, grouper=grouper, 86 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 87 | uniform_over_groups=uniform_over_groups) 88 | 89 | return train_loader, val_loader, test_loader, grouper 90 | 91 | def get_test_loader(batch_size=16, split='val', root_dir='data'): 92 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir) 93 | grouper = CombinatorialGrouper(dataset, ['location']) 94 | 95 | transform = initialize_image_base_transform(dataset) 96 | 97 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform) 98 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist())) 99 | test_loader = [] 100 | 101 | for domain in all_domains: 102 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform) 103 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size) 104 | test_loader.append((domain, domain_loader)) 105 | 106 | return test_loader, grouper 107 | 108 | def get_mask_grouper(root_dir='data'): 109 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir) 110 | grouper = CombinatorialGrouper(dataset, ['location']) 111 | return grouper 112 | 113 | def save_model(model, name, epoch, test_way='ood'): 114 | if not os.path.exists("model/iwildcam"): 115 | os.makedirs("model/iwildcam") 116 | path = "model/iwildcam/{0}_best.pth".format(name) 117 | torch.save(model.state_dict(), path) 118 | 119 | def mask_feat(feat, mask_index, num_experts, exclude=True): 120 | assert feat.shape[0] == mask_index.shape[0] 121 | if exclude: 122 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index] 123 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx] 124 | else: 125 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index]) 126 | feat = feat.permute((0,2,1)) 127 | return feat 128 | 129 | def l2_loss(input, target): 130 | loss = torch.square(target - input) 131 | loss = torch.mean(loss) 132 | return loss 133 | 134 | # Models 135 | class ResNetFeature(nn.Module): 136 | def __init__(self, original_model, layer=-1): 137 | super(ResNetFeature, self).__init__() 138 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 139 | 140 | def forward(self, x): 141 | x = self.features(x) 142 | x = x.view(x.shape[0], -1) 143 | return x 144 | 145 | class fa_selector(nn.Module): 146 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=182, pool='mean'): 147 | super(fa_selector, self).__init__() 148 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout) 149 | self.pool = pool 150 | self.mlp = nn.Sequential( 151 | nn.LayerNorm(dim), 152 | nn.Linear(dim, out_dim) 153 | ) 154 | 155 | def forward(self, x): 156 | x = self.transformer(x) 157 | if self.pool == 'mean': 158 | x = x.mean(dim=1) 159 | elif self.pool == 'max': 160 | x = x.max(dim=1) 161 | else: 162 | raise NotImplementedError 163 | x = self.mlp(x) 164 | return x 165 | 166 | def get_feat(self, x): 167 | x = self.transformer(x) 168 | if self.pool == 'mean': 169 | x = x.mean(dim=1) 170 | elif self.pool == 'max': 171 | x = x.max(dim=1) 172 | else: 173 | raise NotImplementedError 174 | return x 175 | 176 | class DivideModel(nn.Module): 177 | def __init__(self, original_model, layer=-1): 178 | super(DivideModel, self).__init__() 179 | self.num_ftrs = original_model.fc.in_features 180 | self.num_class = original_model.fc.out_features 181 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 182 | self.classifier = nn.Sequential(*list(original_model.children())[layer:]) 183 | 184 | def forward(self, x): 185 | x = self.features(x) 186 | x = x.view(-1, self.num_ftrs) 187 | x = self.classifier(x) 188 | x = x.view(-1, self.num_class) 189 | return x 190 | 191 | def StudentModel(device, num_classes=182, load_path=None): 192 | model = torchvision.models.resnet50(pretrained=True) 193 | num_ftrs = model.fc.in_features 194 | model.fc = nn.Linear(num_ftrs, num_classes) 195 | if load_path: 196 | model.load_state_dict(torch.load(load_path)) 197 | model = DivideModel(model) 198 | model = model.to(device) 199 | return model 200 | 201 | def get_feature_list(models_list, device): 202 | feature_list = [] 203 | for model in models_list: 204 | feature_list.append(ResNetFeature(model).to(device)) 205 | return feature_list 206 | 207 | def get_models_list(device, num_domains=9, num_classes=182, pretrained=False, bb='res50'): 208 | models_list = [] 209 | for _ in range(num_domains+1): 210 | model = torchvision.models.resnet50(pretrained=True) 211 | num_ftrs = model.fc.in_features 212 | model.fc = nn.Linear(num_ftrs, num_classes) 213 | model = model.to(device) 214 | models_list.append(model) 215 | return models_list 216 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import argparse 6 | import pickle 7 | from src.configs import default_param 8 | 9 | OUT_DIM = {'iwildcam':182, 10 | 'camelyon':1, 11 | 'rxrx1':1139, 12 | 'fmow':62, 13 | 'poverty':1} 14 | 15 | FEAT_DIM = {'iwildcam':2048, 16 | 'camelyon':1024, 17 | 'rxrx1':2048, 18 | 'fmow':1024, 19 | 'poverty':512} 20 | 21 | def get_parser(): 22 | parser = argparse.ArgumentParser() 23 | 24 | parser.add_argument('--gpu', type=str, default='0') # We only support single gpu training for now 25 | parser.add_argument('--threads', type=int, default=12) 26 | 27 | parser.add_argument('--dataset', type=str, default='iwildcam', 28 | choices=['iwildcam', 'fmow', 'camelyon', 'rxrx1', 'poverty']) 29 | parser.add_argument('--data_dir', type=str, default='data') 30 | 31 | parser.add_argument('--load_trained_experts', action='store_true') 32 | parser.add_argument('--load_pretrained_aggregator', action='store_true') 33 | parser.add_argument('--load_pretrained_student', action='store_true') 34 | 35 | parser.add_argument('--test', action='store_true') 36 | 37 | args = parser.parse_args() 38 | args_dict = args.__dict__ 39 | args_dict.update(default_param[args.dataset]) 40 | args = argparse.Namespace(**args_dict) 41 | return args 42 | 43 | def set_seed(seed): 44 | if torch.cuda.is_available(): 45 | torch.cuda.manual_seed(seed) 46 | torch.manual_seed(seed) 47 | np.random.seed(seed) 48 | random.seed(seed) 49 | torch.backends.cudnn.deterministic = True 50 | 51 | def train(args): 52 | 53 | if args.dataset == 'iwildcam': 54 | from src.iwildcam.iwildcam_utils import fa_selector, StudentModel, get_models_list, get_feature_list 55 | from src.iwildcam.iwildcam_experts import train_exp, train_model, get_expert_split 56 | from src.iwildcam.iwildcam_aggregator import train_model_selector 57 | from src.iwildcam.iwildcam_train import train_kd, eval 58 | elif args.dataset == 'camelyon': 59 | from src.camelyon.camelyon_utils import fa_selector, StudentModel, get_models_list, get_feature_list 60 | from src.camelyon.camelyon_experts import train_exp, train_model, get_expert_split 61 | from src.camelyon.camelyon_aggregator import train_model_selector 62 | from src.camelyon.camelyon_train import train_kd, eval 63 | elif args.dataset == 'rxrx1': 64 | from src.rxrx1.rxrx1_utils import fa_selector, StudentModel, get_models_list, get_feature_list 65 | from src.rxrx1.rxrx1_experts import train_exp, train_model, get_expert_split 66 | from src.rxrx1.rxrx1_aggregator import train_model_selector 67 | from src.rxrx1.rxrx1_train import train_kd, eval 68 | elif args.dataset == 'fmow': 69 | from src.fmow.fmow_utils import fa_selector, StudentModel, get_models_list, get_feature_list 70 | from src.fmow.fmow_experts import train_exp, train_model, get_expert_split 71 | from src.fmow.fmow_aggregator import train_model_selector 72 | from src.fmow.fmow_train import train_kd, eval 73 | else: 74 | raise NotImplementedError 75 | 76 | name = f"{args.dataset}_{str(args.num_experts)}experts_seed{str(args.seed)}" 77 | 78 | models_list = get_models_list(device=device, num_domains=args.num_experts-1) 79 | 80 | try: 81 | with open(f"model/{args.dataset}/domain_split.pkl", "rb") as f: 82 | all_split, split_to_cluster = pickle.load(f) 83 | except FileNotFoundError: 84 | all_split, split_to_cluster = get_expert_split(args.num_experts, root_dir=args.data_dir) 85 | with open(f"model/{args.dataset}/domain_split.pkl", "wb") as f: 86 | pickle.dump((all_split, split_to_cluster), f) 87 | 88 | if args.load_trained_experts: 89 | print("Skip training domain specific experts...") 90 | else: 91 | print("Training domain specific experts...") 92 | train_exp(models_list, all_split, device, batch_size=args.expert_batch_size, 93 | lr=args.expert_lr, l2=args.expert_l2, num_epochs=args.expert_epoch, 94 | save=True, name=name, root_dir=args.data_dir) 95 | 96 | for i,model in enumerate(models_list): 97 | model.load_state_dict(torch.load(f"model/{args.dataset}/{name}_{str(i)}_exp_best.pth")) 98 | models_list = get_feature_list(models_list, device=device) 99 | 100 | selector = fa_selector(dim=FEAT_DIM[args.dataset], depth=args.aggregator_depth, heads=args.aggregator_heads, 101 | mlp_dim=FEAT_DIM[args.dataset]*2, dropout=args.aggregator_dropout, 102 | out_dim=OUT_DIM[args.dataset]).to(device) 103 | if args.load_pretrained_aggregator: 104 | print("Skip pretraining knowledge aggregator...") 105 | else: 106 | print("Pretraining knowledge aggregator...") 107 | train_model_selector(selector, name+'_pretrained', models_list, device, root_dir=args.data_dir, 108 | num_epochs=args.aggregator_pretrain_epoch, save=True) 109 | 110 | selector.load_state_dict(torch.load(f"model/{args.dataset}/{name}_pretrained_selector_best.pth")) 111 | 112 | student = StudentModel(device=device, num_classes=OUT_DIM[args.dataset]) 113 | 114 | if args.load_pretrained_student: 115 | print("Skip pretraining student...") 116 | else: 117 | print("Pretraining student...") 118 | train_model(student, name+"_pretrained", device=device, 119 | num_epochs=args.student_pretrain_epoch, save=True, 120 | root_dir=args.data_dir) 121 | 122 | student.load_state_dict(torch.load(f"model/{args.dataset}/{name}_pretrained_exp_best.pth")) 123 | 124 | print("Start meta-training...") 125 | train_kd(selector, name+"_meta", models_list, student, name+"_meta", split_to_cluster, 126 | device=device, batch_size=args.batch_size, sup_size=args.sup_size, 127 | tlr=args.tlr, slr=args.slr, ilr=args.ilr, num_epochs=args.epoch, save=True, test_way='ood', 128 | root_dir=args.data_dir) 129 | 130 | def test(args): 131 | 132 | if args.dataset == 'iwildcam': 133 | from src.iwildcam.iwildcam_utils import fa_selector, StudentModel, get_models_list, get_feature_list 134 | from src.iwildcam.iwildcam_train import eval 135 | elif args.dataset == 'camelyon': 136 | from src.camelyon.camelyon_utils import fa_selector, StudentModel, get_models_list, get_feature_list 137 | from src.camelyon.camelyon_train import eval 138 | elif args.dataset == 'rxrx1': 139 | from src.rxrx1.rxrx1_utils import fa_selector, StudentModel, get_models_list, get_feature_list 140 | from src.rxrx1.rxrx1_train import eval 141 | elif args.dataset == 'fmow': 142 | from src.fmow.fmow_utils import fa_selector, StudentModel, get_models_list, get_feature_list 143 | from src.fmow.fmow_train import eval 144 | 145 | name = f"{args.dataset}_{str(args.num_experts)}experts_seed{str(args.seed)}" 146 | models_list = get_models_list(device=device, num_domains=args.num_experts-1) 147 | for i,model in enumerate(models_list): 148 | model.load_state_dict(torch.load(f"model/{args.dataset}/{name}_{str(i)}_exp_best.pth")) 149 | models_list = get_feature_list(models_list, device=device) 150 | selector = fa_selector(dim=FEAT_DIM[args.dataset], depth=args.aggregator_depth, heads=args.aggregator_heads, 151 | mlp_dim=FEAT_DIM[args.dataset]*2, dropout=args.aggregator_dropout, 152 | out_dim=OUT_DIM[args.dataset]).to(device) 153 | selector.load_state_dict(torch.load(f"model/{args.dataset}/{name}_meta_selector_best.pth")) 154 | 155 | student = StudentModel(device=device, num_classes=OUT_DIM[args.dataset]).to(device) 156 | student.load_state_dict(torch.load(f"model/{args.dataset}/{name}_meta_student_best.pth")) 157 | metrics = eval(selector, models_list, student, batch_size=args.sup_size, 158 | device=device, ilr=args.ilr, test=True, root_dir=args.data_dir) 159 | if args.dataset == 'iwildcam': 160 | print(f"Test Accuracy:{metrics[1]:.4f} Test Macro-F1:{metrics[2]:.4f}") 161 | with open(f'result/{args.dataset}/result.txt', 'a+') as f: 162 | f.write(f"Seed: {args.seed} || Test Accuracy:{metrics[1]:.4f} || Test Macro-F1:{metrics[2]:.4f}") 163 | elif args.dataset in ['camelyon', 'rxrx1']: 164 | print(f"Test Accuracy:{metrics:.4f}") 165 | with open(f'result/{args.dataset}/result.txt', 'a+') as f: 166 | f.write(f"Seed: {args.seed} || Test Accuracy:{metrics:.4f}") 167 | elif args.dataset == 'fmow': 168 | print(f"WC Accuracy:{metrics[0]:.4f} Acc:{metrics[1]:.4f}") 169 | with open(f'result/{args.dataset}/result.txt', 'a+') as f: 170 | f.write(f"Seed: {args.seed} || WC Accuracy:{metrics[0]:.4f} || Acc:{metrics[1]:.4f}") 171 | 172 | if __name__ == "__main__": 173 | args = get_parser() 174 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu 175 | torch.set_num_threads(args.threads) 176 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") 177 | set_seed(args.seed) 178 | if not os.path.exists(f"model/{args.dataset}"): 179 | os.makedirs(f"model/{args.dataset}") 180 | if not os.path.exists(f"log/{args.dataset}"): 181 | os.makedirs(f"log/{args.dataset}") 182 | if args.test: 183 | if not os.path.exists(f"result/{args.dataset}"): 184 | os.makedirs(f"result/{args.dataset}") 185 | test(args) 186 | else: 187 | train(args) 188 | -------------------------------------------------------------------------------- /src/rxrx1/rxrx1_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torchvision.transforms.functional as TF 6 | import torch.optim as optim 7 | import torchvision 8 | import torchvision.transforms as transforms 9 | import math 10 | import random 11 | import pickle5 as pickle 12 | from src.transformer import Transformer 13 | import copy 14 | import time 15 | import os 16 | import wilds 17 | from wilds import get_dataset 18 | from wilds.common.data_loaders import get_train_loader, get_eval_loader 19 | from wilds.common.grouper import CombinatorialGrouper 20 | 21 | # Utils 22 | 23 | def split_domains(num_experts, root_dir='data'): 24 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir) 25 | train_data = dataset.get_subset('train') 26 | experiment = list(set(train_data.metadata_array[:,1].detach().numpy().tolist())) 27 | random.shuffle(experiment) 28 | num_domains_per_super = len(experiment) / float(num_experts) 29 | all_split = [[] for _ in range(num_experts)] 30 | for i in range(len(experiment)): 31 | all_split[int(i//num_domains_per_super)].append(experiment[i]) 32 | return all_split 33 | 34 | def get_subset_with_domain(dataset, split, domain=None, transform=None): 35 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset: 36 | subset = copy.deepcopy(dataset) 37 | else: 38 | subset = dataset.get_subset(split, transform=transform) 39 | if domain is not None: 40 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,1][subset.indices], domain)).ravel() 41 | subset.indices = subset.indices[idx] 42 | return subset 43 | 44 | def initialize_image_base_transform(is_training): 45 | def standardize(x: torch.Tensor) -> torch.Tensor: 46 | mean = x.mean(dim=(1, 2)) 47 | std = x.std(dim=(1, 2)) 48 | std[std == 0.] = 1. 49 | return TF.normalize(x, mean, std) 50 | t_standardize = transforms.Lambda(lambda x: standardize(x)) 51 | 52 | angles = [0, 90, 180, 270] 53 | def random_rotation(x: torch.Tensor) -> torch.Tensor: 54 | angle = angles[torch.randint(low=0, high=len(angles), size=(1,))] 55 | if angle > 0: 56 | x = TF.rotate(x, angle) 57 | return x 58 | t_random_rotation = transforms.Lambda(lambda x: random_rotation(x)) 59 | 60 | if is_training: 61 | transforms_ls = [ 62 | t_random_rotation, 63 | transforms.RandomHorizontalFlip(), 64 | transforms.ToTensor(), 65 | t_standardize, 66 | ] 67 | else: 68 | transforms_ls = [ 69 | transforms.ToTensor(), 70 | t_standardize, 71 | ] 72 | transform = transforms.Compose(transforms_ls) 73 | return transform 74 | 75 | def get_data_loader(batch_size=16, domain=None, test_way='id', 76 | n_groups_per_batch=0, uniform_over_groups=False, root_dir='data'): 77 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir) 78 | grouper = CombinatorialGrouper(dataset, ['experiment']) 79 | 80 | transform_train = initialize_image_base_transform(True) 81 | transform_test = initialize_image_base_transform(False) 82 | 83 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform_train) 84 | if test_way == 'ood': 85 | val_data = dataset.get_subset('val', transform=transform_test) 86 | test_data = dataset.get_subset('test', transform=transform_test) 87 | else: 88 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform_test) 89 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform_test) 90 | 91 | 92 | if n_groups_per_batch == 0: 93 | #0 identify standard loader 94 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size) 95 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size) 96 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size) 97 | 98 | else: 99 | # All use get_train_loader to enable grouper 100 | train_loader = get_train_loader('group', train_data, grouper=grouper, 101 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size) 102 | val_loader = get_train_loader('group', val_data, grouper=grouper, 103 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 104 | uniform_over_groups=uniform_over_groups) 105 | test_loader = get_train_loader('group', test_data, grouper=grouper, 106 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 107 | uniform_over_groups=uniform_over_groups) 108 | 109 | return train_loader, val_loader, test_loader, grouper 110 | 111 | def get_mask_grouper(root_dir='data'): 112 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir) 113 | grouper = CombinatorialGrouper(dataset, ['experiment']) 114 | return grouper 115 | 116 | def get_test_loader(batch_size=16, split='val', root_dir='data'): 117 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir) 118 | grouper = CombinatorialGrouper(dataset, ['experiment']) 119 | 120 | transform = initialize_image_base_transform(dataset) 121 | 122 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform) 123 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist())) 124 | test_loader = [] 125 | 126 | for domain in all_domains: 127 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform) 128 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size) 129 | test_loader.append((domain, domain_loader)) 130 | 131 | return test_loader, grouper 132 | 133 | def save_model(model, name, epoch, test_way='ood'): 134 | if not os.path.exists("model/rxrx1"): 135 | os.makedirs("model/rxrx1") 136 | path = "model/rxrx1/{0}_best.pth".format(name) 137 | torch.save(model.state_dict(), path) 138 | 139 | def mask_feat(feat, mask_index, num_experts, exclude=True): 140 | assert feat.shape[0] == mask_index.shape[0] 141 | if exclude: 142 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index] 143 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx] 144 | else: 145 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index]) 146 | feat = feat.permute((0,2,1)) 147 | return feat 148 | 149 | def l2_loss(input, target): 150 | loss = torch.square(target - input) 151 | loss = torch.mean(loss) 152 | return loss 153 | 154 | # Models 155 | class ResNetFeature(nn.Module): 156 | def __init__(self, original_model, layer=-1): 157 | super(ResNetFeature, self).__init__() 158 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 159 | 160 | def forward(self, x): 161 | x = self.features(x) 162 | x = x.view(x.shape[0], -1) 163 | return x 164 | 165 | class fa_selector(nn.Module): 166 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=1139, pool='mean'): 167 | super(fa_selector, self).__init__() 168 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout) 169 | self.pool = pool 170 | self.mlp = nn.Sequential( 171 | nn.LayerNorm(dim), 172 | nn.Linear(dim, out_dim) 173 | ) 174 | 175 | def forward(self, x): 176 | x = self.transformer(x) 177 | if self.pool == 'mean': 178 | x = x.mean(dim=1) 179 | elif self.pool == 'max': 180 | x = x.max(dim=1) 181 | else: 182 | raise NotImplementedError 183 | x = self.mlp(x) 184 | return x 185 | 186 | def get_feat(self, x): 187 | x = self.transformer(x) 188 | if self.pool == 'mean': 189 | x = x.mean(dim=1) 190 | elif self.pool == 'max': 191 | x = x.max(dim=1) 192 | else: 193 | raise NotImplementedError 194 | return x 195 | 196 | class DivideModel(nn.Module): 197 | def __init__(self, original_model, layer=-1): 198 | super(DivideModel, self).__init__() 199 | self.num_ftrs = original_model.fc.in_features 200 | self.num_class = original_model.fc.out_features 201 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 202 | self.classifier = nn.Sequential(*list(original_model.children())[layer:]) 203 | 204 | def forward(self, x): 205 | x = self.features(x) 206 | x = x.view(-1, self.num_ftrs) 207 | x = self.classifier(x) 208 | x = x.view(-1, self.num_class) 209 | return x 210 | 211 | def StudentModel(device, num_classes=1139, load_path=None): 212 | model = torchvision.models.resnet50(pretrained=True) 213 | num_ftrs = model.fc.in_features 214 | model.fc = nn.Linear(num_ftrs, num_classes) 215 | if load_path: 216 | model.load_state_dict(torch.load(load_path)) 217 | model = DivideModel(model) 218 | model = model.to(device) 219 | return model 220 | 221 | def get_feature_list(models_list, device): 222 | feature_list = [] 223 | for model in models_list: 224 | feature_list.append(ResNetFeature(model).to(device)) 225 | return feature_list 226 | 227 | def get_models_list(device, num_domains=2, num_classes=1139, pretrained=False, bb='res50'): 228 | models_list = [] 229 | for _ in range(num_domains+1): 230 | model = torchvision.models.resnet50(pretrained=True) 231 | num_ftrs = model.fc.in_features 232 | model.fc = nn.Linear(num_ftrs, num_classes) 233 | model = model.to(device) 234 | models_list.append(model) 235 | return models_list 236 | -------------------------------------------------------------------------------- /src/camelyon/camelyon_train.py: -------------------------------------------------------------------------------- 1 | from scipy.stats import truncnorm 2 | import numpy as np 3 | import time 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | import torch.optim as optim 8 | import torchvision 9 | from tqdm import tqdm 10 | import copy 11 | import learn2learn as l2l 12 | import time 13 | from datetime import datetime 14 | import os 15 | import wilds 16 | from src.camelyon.camelyon_utils import * 17 | 18 | def train_epoch(selector, selector_name, models_list, student, student_name, 19 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster, 20 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3, 21 | batch_size=256, sup_size=24, test_way='id', save=False, 22 | root_dir='data'): 23 | for model in models_list: 24 | model.eval() 25 | 26 | student_ce = nn.BCEWithLogitsLoss() 27 | #teacher_ce = nn.CrossEntropyLoss() 28 | 29 | features = student.features 30 | head = student.classifier 31 | features = l2l.algorithms.MAML(features, lr=ilr) 32 | features.to(device) 33 | head.to(device) 34 | 35 | all_params = list(features.parameters()) + list(head.parameters()) 36 | optimizer_s = optim.Adam(all_params, lr=slr) 37 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr) 38 | 39 | i = 0 40 | 41 | losses = [] 42 | 43 | iter_per_epoch = len(train_loader) 44 | 45 | for x, y_true, metadata in train_loader: 46 | selector.eval() 47 | head.eval() 48 | features.eval() 49 | 50 | z = grouper.metadata_to_group(metadata) 51 | z = set(z.tolist()) 52 | assert len(z) == 1 53 | mask = mask_grouper.metadata_to_group(metadata) 54 | mask.apply_(lambda x: split_to_cluster[x]) 55 | 56 | #sup_size = x.shape[0]//2 57 | x_sup = x[:sup_size] 58 | y_sup = y_true[:sup_size] 59 | x_que = x[sup_size:] 60 | y_que = y_true[sup_size:] 61 | mask = mask[:sup_size] 62 | 63 | x_sup = x_sup.to(device) 64 | y_sup = y_sup.to(device) 65 | x_que = x_que.to(device) 66 | y_que = y_que.to(device) 67 | 68 | with torch.no_grad(): 69 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 70 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]]) 71 | #logits = logits.permute((0,2,1)) 72 | logits = mask_feat(logits, mask, len(models_list), exclude=True) 73 | 74 | t_out = selector.get_feat(logits) 75 | 76 | task_model = features.clone() 77 | task_model.module.eval() 78 | feat = task_model(x_que) 79 | feat = feat.view(feat.shape[0], -1) 80 | out = head(feat) 81 | with torch.no_grad(): 82 | loss_pre = student_ce(out, y_que.unsqueeze(-1).float()).item()/x_que.shape[0] 83 | 84 | feat = task_model(x_sup) 85 | feat = feat.view_as(t_out) 86 | 87 | inner_loss = l2_loss(feat, t_out) 88 | task_model.adapt(inner_loss) 89 | 90 | x_que = task_model(x_que) 91 | x_que = x_que.view(x_que.shape[0], -1) 92 | s_que_out = head(x_que) 93 | s_que_loss = student_ce(s_que_out, y_que.unsqueeze(-1).float()) 94 | #t_sup_loss = teacher_ce(t_out, y_sup) 95 | 96 | s_que_loss.backward() 97 | 98 | optimizer_s.step() 99 | optimizer_t.step() 100 | optimizer_s.zero_grad() 101 | optimizer_t.zero_grad() 102 | 103 | #print("Step:{}".format(time.time() - t_1)) 104 | #t_1 = time.time() 105 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 106 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch, 107 | loss_pre, 108 | s_que_loss.item()/x_que.shape[0])) 109 | losses.append(s_que_loss.item()/x_que.shape[0]) 110 | 111 | if i == iter_per_epoch//2: 112 | losses = np.mean(losses) 113 | acc = eval(selector, models_list, student, sup_size, device=device, 114 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 115 | root_dir=root_dir) 116 | 117 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 118 | f.write(f'Accuracy: {acc} \r\n') 119 | losses = [] 120 | 121 | if acc > acc_best and save: 122 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 123 | f.write("Saving model ...\r\n") 124 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 125 | save_model(student, student_name+"_student", 0, test_way=test_way) 126 | acc_best = acc 127 | 128 | i += 1 129 | return acc_best 130 | 131 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device, 132 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30, 133 | decayRate=0.96, save=False, test_way='ood', root_dir='data'): 134 | 135 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None, 136 | test_way=test_way, n_groups_per_batch=1) 137 | 138 | mask_grouper = get_mask_grouper(root_dir=root_dir) 139 | 140 | curr = str(datetime.now()) 141 | if not os.path.exists("log/camelyon"): 142 | os.makedirs("log/camelyon") 143 | #print(curr) 144 | print("Training log saved to log/camelyon/"+curr+".txt") 145 | 146 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 147 | f.write(selector_name+' '+student_name+'\r\n') 148 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n') 149 | 150 | accu_best = 0 151 | for epoch in range(num_epochs): 152 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 153 | f.write(f'Epoch: {epoch}\r\n') 154 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name, 155 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster, 156 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr, 157 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save, 158 | root_dir=root_dir) 159 | accu_best = max(accu_best, accu_epoch) 160 | accu = eval(selector, models_list, student, sup_size, device=device, 161 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 162 | root_dir=root_dir) 163 | 164 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 165 | f.write(f'Accuracy: {accu} \r\n') 166 | 167 | if accu > accu_best and save: 168 | with open('log/camelyon/'+curr+'.txt', 'a+') as f: 169 | f.write("Saving model ...\r\n") 170 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 171 | save_model(student, student_name+"_student", 0, test_way=test_way) 172 | accu_best = accu 173 | tlr = tlr*decayRate 174 | slr = slr*decayRate 175 | 176 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5, 177 | test=False, progress=True, uniform_over_groups=False, root_dir='data'): 178 | 179 | if test: 180 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir, 181 | return_dataset=True) 182 | else: 183 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir, 184 | return_dataset=True) 185 | '''if test: 186 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test') 187 | else: 188 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')''' 189 | 190 | features = student.features 191 | head = student.classifier 192 | head.to(device) 193 | 194 | student_maml = l2l.algorithms.MAML(features, lr=ilr) 195 | student_maml.to(device) 196 | 197 | correct = 0 198 | total = 0 199 | 200 | old_domain = {} 201 | if progress: 202 | loader = tqdm(iter(loader), total=len(loader)) 203 | 204 | for domain, domain_loader in loader: 205 | adapted = False 206 | for x_sup, y_sup, metadata in domain_loader: 207 | student_maml.module.eval() 208 | selector.eval() 209 | head.eval() 210 | 211 | z = grouper.metadata_to_group(metadata) 212 | z = set(z.tolist()) 213 | assert list(z)[0] == domain 214 | 215 | x_sup = x_sup.to(device) 216 | y_sup = y_sup.to(device) 217 | 218 | if not adapted: 219 | task_model = student_maml.clone() 220 | task_model.eval() 221 | 222 | with torch.no_grad(): 223 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 224 | logits = logits.permute((0,2,1)) 225 | t_out = selector.get_feat(logits) 226 | 227 | feat = task_model(x_sup) 228 | feat = feat.view_as(t_out) 229 | 230 | kl_loss = l2_loss(feat, t_out) 231 | task_model.adapt(kl_loss) 232 | adapted = True 233 | 234 | with torch.no_grad(): 235 | task_model.module.eval() 236 | x_sup = task_model(x_sup) 237 | x_sup = x_sup.view(x_sup.shape[0], -1) 238 | s_que_out = head(x_sup) 239 | pred = (s_que_out > 0.0).squeeze().long() 240 | correct += pred.eq(y_sup.view_as(pred)).sum().item() 241 | total += x_sup.shape[0] 242 | 243 | return correct/total 244 | -------------------------------------------------------------------------------- /src/camelyon/camelyon_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import math 9 | import random 10 | import pickle5 as pickle 11 | from src.transformer import Transformer 12 | import copy 13 | import time 14 | import os 15 | import wilds 16 | from wilds import get_dataset 17 | from wilds.common.data_loaders import get_train_loader, get_eval_loader 18 | from wilds.common.grouper import CombinatorialGrouper 19 | 20 | # Utils 21 | 22 | def split_domains(num_experts, root_dir='data'): 23 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir) 24 | train_data = dataset.get_subset('train') 25 | wsi = list(set(train_data.metadata_array[:,1].detach().numpy().tolist())) 26 | random.shuffle(wsi) 27 | num_domains_per_super = len(wsi) / float(num_experts) 28 | all_split = [[] for _ in range(num_experts)] 29 | for i in range(len(wsi)): 30 | all_split[int(i//num_domains_per_super)].append(wsi[i]) 31 | return all_split 32 | 33 | def get_subset_with_domain(dataset, split=None, domain=None, transform=None, grouper=None): 34 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset: 35 | subset = copy.deepcopy(dataset) 36 | else: 37 | subset = dataset.get_subset(split, transform=transform) 38 | if domain is not None: 39 | if grouper is not None: 40 | z = grouper.metadata_to_group(subset.dataset.metadata_array[subset.indices]) 41 | idx = np.argwhere(np.isin(z, domain)).ravel() 42 | subset.indices = subset.indices[idx] 43 | else: 44 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,1][subset.indices], domain)).ravel() 45 | subset.indices = subset.indices[idx] 46 | return subset 47 | 48 | def initialize_image_base_transform(dataset): 49 | transform_steps = [] 50 | if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution): 51 | crop_size = min(dataset.original_resolution) 52 | transform_steps.append(transforms.CenterCrop(crop_size)) 53 | transform_steps.append(transforms.Resize((96,96))) 54 | transform_steps += [ 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 57 | ] 58 | transform = transforms.Compose(transform_steps) 59 | return transform 60 | 61 | def get_data_loader(batch_size=16, domain=None, test_way='id', n_groups_per_batch=0, return_dataset=False, 62 | uniform_over_groups=True, root_dir='data'): 63 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir) 64 | grouper = CombinatorialGrouper(dataset, ['slide']) 65 | 66 | transform = initialize_image_base_transform(dataset) 67 | 68 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform) 69 | if test_way == 'ood': 70 | val_data = dataset.get_subset('val', transform=transform) 71 | test_data = dataset.get_subset('test', transform=transform) 72 | else: 73 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform) 74 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform) 75 | 76 | 77 | if n_groups_per_batch == 0: 78 | #0 identify standard loader 79 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size) 80 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size) 81 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size) 82 | 83 | else: 84 | # All use get_train_loader to enable grouper 85 | train_loader = get_train_loader('group', train_data, grouper=grouper, 86 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size) 87 | val_loader = get_train_loader('group', val_data, grouper=grouper, 88 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 89 | uniform_over_groups=uniform_over_groups) 90 | test_loader = get_train_loader('group', test_data, grouper=grouper, 91 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 92 | uniform_over_groups=uniform_over_groups) 93 | 94 | if return_dataset: 95 | return train_loader, val_loader, test_loader, grouper, dataset 96 | return train_loader, val_loader, test_loader, grouper 97 | 98 | def get_mask_grouper(root_dir='data'): 99 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir) 100 | grouper = CombinatorialGrouper(dataset, ['slide']) 101 | return grouper 102 | 103 | def get_test_loader(batch_size=16, split='val', root_dir='data', return_dataset=False): 104 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir) 105 | grouper = CombinatorialGrouper(dataset, ['slide']) 106 | 107 | transform = initialize_image_base_transform(dataset) 108 | 109 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform) 110 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist())) 111 | test_loader = [] 112 | 113 | for domain in all_domains: 114 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform, grouper=grouper) 115 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size) 116 | test_loader.append((domain, domain_loader)) 117 | 118 | if return_dataset: 119 | return test_loader, grouper, dataset 120 | return test_loader, grouper 121 | 122 | def save_model(model, name, epoch, test_way='ood'): 123 | if not os.path.exists("model/camelyon"): 124 | os.makedirs("model/camelyon") 125 | path = "model/camelyon/{0}_best.pth".format(name) 126 | torch.save(model.state_dict(), path) 127 | 128 | def mask_feat(feat, mask_index, num_experts, exclude=True): 129 | assert feat.shape[0] == mask_index.shape[0] 130 | if exclude: 131 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index] 132 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx] 133 | else: 134 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index]) 135 | feat = feat.permute((0,2,1)) 136 | return feat 137 | 138 | def l2_loss(input, target): 139 | loss = torch.square(target - input) 140 | loss = torch.mean(loss) 141 | return loss 142 | 143 | # Models 144 | class ResNetFeature(nn.Module): 145 | def __init__(self, original_model, layer=-1): 146 | super(ResNetFeature, self).__init__() 147 | self.num_ftrs = original_model.classifier.in_features 148 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 149 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 150 | 151 | def forward(self, x): 152 | x = self.features(x) 153 | x = self.pool(x) 154 | x = x.view(-1, self.num_ftrs) 155 | return x 156 | 157 | class fa_selector(nn.Module): 158 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=1, pool='mean'): 159 | super(fa_selector, self).__init__() 160 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout) 161 | self.pool = pool 162 | self.mlp = nn.Sequential( 163 | nn.LayerNorm(dim), 164 | nn.Linear(dim, out_dim) 165 | ) 166 | 167 | def forward(self, x): 168 | x = self.transformer(x) 169 | if self.pool == 'mean': 170 | x = x.mean(dim=1) 171 | elif self.pool == 'max': 172 | x = x.max(dim=1) 173 | else: 174 | raise NotImplementedError 175 | x = self.mlp(x) 176 | return x 177 | 178 | def get_feat(self, x): 179 | x = self.transformer(x) 180 | if self.pool == 'mean': 181 | x = x.mean(dim=1) 182 | elif self.pool == 'max': 183 | x = x.max(dim=1) 184 | else: 185 | raise NotImplementedError 186 | return x 187 | 188 | class DivideModel(nn.Module): 189 | def __init__(self, original_model, layer=-1): 190 | super(DivideModel, self).__init__() 191 | self.num_ftrs = original_model.classifier.in_features 192 | self.num_class = original_model.classifier.out_features 193 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 194 | self.features.add_module("avg_pool", nn.AdaptiveAvgPool2d((1,1))) 195 | self.classifier = nn.Sequential(*list(original_model.children())[layer:]) 196 | 197 | def forward(self, x): 198 | x = self.features(x) 199 | x = x.view(-1, self.num_ftrs) 200 | x = self.classifier(x) 201 | x = x.view(-1, self.num_class) 202 | return x 203 | 204 | def StudentModel(device, num_classes=1, load_path=None): 205 | model = torchvision.models.densenet121(pretrained=True) 206 | num_ftrs = model.classifier.in_features 207 | model.classifier = nn.Linear(num_ftrs, num_classes) 208 | if load_path: 209 | model.load_state_dict(torch.load(load_path)) 210 | model = DivideModel(model) 211 | model = model.to(device) 212 | return model 213 | 214 | def get_feature_list(models_list, device): 215 | feature_list = [] 216 | for model in models_list: 217 | feature_list.append(ResNetFeature(model).to(device)) 218 | return feature_list 219 | 220 | def get_models_list(device, num_domains=4, num_classes=1, pretrained=False, bb='d121'): 221 | models_list = [] 222 | for _ in range(num_domains+1): 223 | model = torchvision.models.densenet121(pretrained=True) 224 | num_ftrs = model.classifier.in_features 225 | model.classifier = nn.Linear(num_ftrs, num_classes) 226 | model = model.to(device) 227 | models_list.append(model) 228 | return models_list 229 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: dmoe 2 | channels: 3 | - pytorch 4 | - intel 5 | - conda-forge 6 | - defaults 7 | dependencies: 8 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0 9 | - _libgcc_mutex=0.1=main 10 | - _openmp_mutex=4.5=1_gnu 11 | - absl-py=0.13.0=py37h06a4308_0 12 | - aiohttp=3.7.4.post0=py37h7f8727e_2 13 | - anaconda-client=1.9.0=py37h06a4308_0 14 | - anaconda-navigator=2.1.0=py37h06a4308_0 15 | - anyio=2.2.0=py37h06a4308_1 16 | - argcomplete=1.12.3=pyhd3eb1b0_0 17 | - argon2-cffi=20.1.0=py37h27cfd23_1 18 | - async-timeout=3.0.1=py37h06a4308_0 19 | - async_generator=1.10=py37h28b3542_0 20 | - attrs=21.2.0=pyhd3eb1b0_0 21 | - babel=2.9.1=pyhd3eb1b0_0 22 | - backcall=0.2.0=pyhd3eb1b0_0 23 | - backports=1.0=pyhd3eb1b0_2 24 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0 25 | - backports.tempfile=1.0=pyhd3eb1b0_1 26 | - backports.weakref=1.0.post1=py_1 27 | - beautifulsoup4=4.10.0=pyh06a4308_0 28 | - blas=1.0=mkl 29 | - bleach=4.0.0=pyhd3eb1b0_0 30 | - blinker=1.4=py37h06a4308_0 31 | - brotlipy=0.7.0=py37h27cfd23_1003 32 | - bzip2=1.0.8=h7b6447c_0 33 | - c-ares=1.17.1=h27cfd23_0 34 | - ca-certificates=2022.4.26=h06a4308_0 35 | - cachetools=4.2.2=pyhd3eb1b0_0 36 | - cairo=1.16.0=hf32fb01_1 37 | - certifi=2022.5.18.1=py37h06a4308_0 38 | - cffi=1.14.6=py37h400218f_0 39 | - chardet=4.0.0=py37h06a4308_1003 40 | - charset-normalizer=2.0.4=pyhd3eb1b0_0 41 | - click=8.0.3=pyhd3eb1b0_0 42 | - clyent=1.2.2=py37_1 43 | - common_cmplr_lib_rt=2021.4.0=intel_3561 44 | - common_cmplr_lic_rt=2021.4.0=intel_3561 45 | - conda=4.10.3=py37hb2ebc77_1 46 | - conda-build=3.21.5=py37h06a4308_0 47 | - conda-content-trust=0.1.1=pyhd3eb1b0_0 48 | - conda-env=2.6.0=1 49 | - conda-package-handling=1.7.3=py37h27cfd23_1 50 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0 51 | - conda-token=0.3.0=pyhd3eb1b0_0 52 | - conda-verify=3.4.2=py_1 53 | - coverage=5.5=py37h27cfd23_2 54 | - cryptography=3.4.8=py37hd23ed53_0 55 | - cudatoolkit=10.1.243=h6bb024c_0 56 | - cudnn=7.6.5=cuda10.1_0 57 | - cupy=8.3.0=py37hcaf9a05_0 58 | - cython=0.29.24=py37h295c915_0 59 | - daal4py=2021.4.0=py37_intel_729 60 | - dal=2021.4.0=intel_729 61 | - dbus=1.13.18=hb2f20db_0 62 | - debugpy=1.4.1=py37h295c915_0 63 | - decorator=5.1.0=pyhd3eb1b0_0 64 | - defusedxml=0.7.1=pyhd3eb1b0_0 65 | - dpcpp_cpp_rt=2021.2.0=intel_610 66 | - easydict=1.9=py_0 67 | - entrypoints=0.3=py37_0 68 | - expat=2.4.1=h2531618_2 69 | - fastrlock=0.6=py37h2531618_0 70 | - ffmpeg=4.0=hcdf2ecd_0 71 | - filelock=3.3.1=pyhd3eb1b0_1 72 | - fontconfig=2.13.1=h6c09931_0 73 | - freeglut=3.0.0=hf484d3e_5 74 | - freetype=2.11.0=h70c0345_0 75 | - future=0.18.2=py37_1 76 | - giflib=5.2.1=h7b6447c_0 77 | - glib=2.69.1=h5202010_0 78 | - glob2=0.7=pyhd3eb1b0_0 79 | - google-auth=1.33.0=pyhd3eb1b0_0 80 | - google-auth-oauthlib=0.4.1=py_2 81 | - graphite2=1.3.14=h23475e2_0 82 | - grpcio=1.36.1=py37h2157cd5_1 83 | - gst-plugins-base=1.14.0=h8213a91_2 84 | - gstreamer=1.14.0=h28cd5cc_2 85 | - h5py=2.8.0=py37h989c5e5_3 86 | - harfbuzz=1.8.8=hffaf4a1_0 87 | - hdf5=1.10.2=hba1933b_1 88 | - icc_rt=2021.2.0=intel_610 89 | - icu=58.2=he6710b0_3 90 | - idna=3.2=pyhd3eb1b0_0 91 | - importlib-metadata=4.8.1=py37h06a4308_0 92 | - importlib_metadata=4.8.1=hd3eb1b0_0 93 | - intel-cmplr-lib-rt=2021.4.0=intel_3561 94 | - intel-cmplr-lic-rt=2021.4.0=intel_3561 95 | - intel-opencl-rt=2021.4.0=intel_3561 96 | - intel-openmp=2021.3.0=h06a4308_3350 97 | - intelpython=2021.4.0=0 98 | - ipykernel=6.4.1=py37h06a4308_1 99 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 100 | - ipywidgets=7.6.5=pyhd3eb1b0_1 101 | - jasper=2.0.14=hd8c5072_2 102 | - jedi=0.18.0=py37h06a4308_1 103 | - jinja2=2.11.3=pyhd3eb1b0_0 104 | - joblib=1.0.1=py37h3f38642_0 105 | - jpeg=9d=h7f8727e_0 106 | - json5=0.9.6=pyhd3eb1b0_0 107 | - jsonschema=3.2.0=pyhd3eb1b0_2 108 | - jupyter_client=7.0.1=pyhd3eb1b0_0 109 | - jupyter_core=4.8.1=py37h06a4308_0 110 | - jupyter_server=1.4.1=py37h06a4308_0 111 | - jupyterlab=3.2.1=pyhd3eb1b0_0 112 | - jupyterlab_pygments=0.1.2=py_0 113 | - jupyterlab_server=2.8.2=pyhd3eb1b0_0 114 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1 115 | - lcms2=2.12=h3be6417_0 116 | - ld_impl_linux-64=2.35.1=h7274673_9 117 | - libarchive=3.4.2=h62408e4_0 118 | - libffi=3.3=he6710b0_2 119 | - libgcc-ng=9.3.0=h5101ec6_17 120 | - libgfortran-ng=7.5.0=ha8ba4b0_17 121 | - libgfortran4=7.5.0=ha8ba4b0_17 122 | - libglu=9.0.0=hf484d3e_1 123 | - libgomp=9.3.0=h5101ec6_17 124 | - liblief=0.10.1=he6710b0_0 125 | - libopencv=3.4.2=hb342d67_1 126 | - libopus=1.3.1=h7b6447c_0 127 | - libpng=1.6.37=hbc83047_0 128 | - libprotobuf=3.17.2=h4ff587b_1 129 | - libsodium=1.0.18=h7b6447c_0 130 | - libstdcxx-ng=11.2.0=h1234567_0 131 | - libtiff=4.2.0=h85742a9_0 132 | - libuuid=1.0.3=h7f8727e_2 133 | - libuv=1.40.0=h7b6447c_0 134 | - libvpx=1.7.0=h439df22_0 135 | - libwebp=1.2.0=h89dd481_0 136 | - libwebp-base=1.2.0=h27cfd23_0 137 | - libxcb=1.14=h7b6447c_0 138 | - libxml2=2.9.12=h03d6c58_0 139 | - lz4-c=1.9.3=h295c915_1 140 | - markdown=3.3.4=py37h06a4308_0 141 | - markupsafe=2.0.1=py37h27cfd23_0 142 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2 143 | - mistune=0.8.4=py37h14c3975_1001 144 | - mkl=2021.3.0=h06a4308_520 145 | - mkl-service=2.4.0=py37h7f8727e_0 146 | - mkl_fft=1.3.1=py37hd3c417c_0 147 | - mkl_random=1.2.2=py37h51133e4_0 148 | - multidict=5.1.0=py37h27cfd23_2 149 | - navigator-updater=0.2.1=py37_0 150 | - nbclassic=0.2.6=pyhd3eb1b0_0 151 | - nbclient=0.5.3=pyhd3eb1b0_0 152 | - nbconvert=6.1.0=py37h06a4308_0 153 | - nbformat=5.1.3=pyhd3eb1b0_0 154 | - nccl=2.8.3.1=hcaf9a05_0 155 | - ncurses=6.2=he6710b0_1 156 | - nest-asyncio=1.5.1=pyhd3eb1b0_0 157 | - ninja=1.10.2=hff7bd54_1 158 | - notebook=6.4.5=py37h06a4308_0 159 | - oauthlib=3.1.1=pyhd3eb1b0_0 160 | - olefile=0.46=py37_0 161 | - opencl_rt=2021.4.0=intel_3561 162 | - opencv=3.4.2=py37h6fd60c2_1 163 | - openssl=1.1.1o=h7f8727e_0 164 | - packaging=21.0=pyhd3eb1b0_0 165 | - pandocfilters=1.4.3=py37h06a4308_1 166 | - parso=0.8.2=pyhd3eb1b0_0 167 | - patchelf=0.13=h295c915_0 168 | - pcre=8.45=h295c915_0 169 | - pexpect=4.8.0=pyhd3eb1b0_3 170 | - pickleshare=0.7.5=pyhd3eb1b0_1003 171 | - pixman=0.40.0=h7f8727e_1 172 | - pkginfo=1.7.1=py37h06a4308_0 173 | - prometheus_client=0.11.0=pyhd3eb1b0_0 174 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0 175 | - protobuf=3.17.2=py37h295c915_0 176 | - psutil=5.8.0=py37h27cfd23_1 177 | - ptyprocess=0.7.0=pyhd3eb1b0_2 178 | - py-lief=0.10.1=py37h403a769_0 179 | - py-opencv=3.4.2=py37hb342d67_1 180 | - pyasn1=0.4.8=pyhd3eb1b0_0 181 | - pyasn1-modules=0.2.8=py_0 182 | - pycosat=0.6.3=py37h27cfd23_0 183 | - pycparser=2.20=py_2 184 | - pygments=2.10.0=pyhd3eb1b0_0 185 | - pyjwt=2.1.0=py37h06a4308_0 186 | - pyopenssl=21.0.0=pyhd3eb1b0_1 187 | - pyparsing=2.4.7=pyhd3eb1b0_0 188 | - pyqt=5.9.2=py37h05f1152_2 189 | - pyrsistent=0.17.3=py37h7b6447c_0 190 | - pysocks=1.7.1=py37_1 191 | - python=3.7.11=h12debd9_0 192 | - python-dateutil=2.8.2=pyhd3eb1b0_0 193 | - python-libarchive-c=2.9=pyhd3eb1b0_1 194 | - pytorch=1.7.0=py3.7_cuda10.1.243_cudnn7.6.3_0 195 | - pytz=2021.3=pyhd3eb1b0_0 196 | - pyyaml=5.4.1=py37h27cfd23_1 197 | - pyzmq=22.2.1=py37h295c915_1 198 | - qt=5.9.7=h5867ecd_1 199 | - qtpy=1.10.0=pyhd3eb1b0_0 200 | - readline=8.1=h27cfd23_0 201 | - requests=2.26.0=pyhd3eb1b0_0 202 | - requests-oauthlib=1.3.0=py_0 203 | - ripgrep=12.1.1=0 204 | - rsa=4.7.2=pyhd3eb1b0_1 205 | - ruamel_yaml=0.15.100=py37h27cfd23_0 206 | - scikit-learn=0.24.2=py37h8411759_4 207 | - scikit-learn-intelex=2021.4.0=py37_intel_729 208 | - scipy=1.7.1=py37h292c36d_2 209 | - send2trash=1.8.0=pyhd3eb1b0_1 210 | - sip=4.19.8=py37hf484d3e_0 211 | - six=1.16.0=pyhd3eb1b0_0 212 | - sniffio=1.2.0=py37h06a4308_1 213 | - soupsieve=2.2.1=pyhd3eb1b0_0 214 | - sqlite=3.36.0=hc218d9a_0 215 | - tbb=2021.4.0=intel_643 216 | - tensorboard=2.6.0=py_1 217 | - tensorboard-plugin-wit=1.6.0=py_0 218 | - tensorboardx=2.2=pyhd3eb1b0_0 219 | - terminado=0.9.4=py37h06a4308_0 220 | - testpath=0.5.0=pyhd3eb1b0_0 221 | - threadpoolctl=2.1.0=py37h6447541_2 222 | - tk=8.6.11=h1ccaba5_0 223 | - tornado=6.1=py37h27cfd23_0 224 | - tqdm=4.62.3=pyhd3eb1b0_1 225 | - traitlets=5.1.0=pyhd3eb1b0_0 226 | - typing-extensions=3.10.0.2=hd3eb1b0_0 227 | - typing_extensions=3.10.0.2=pyh06a4308_0 228 | - urllib3=1.26.7=pyhd3eb1b0_0 229 | - wcwidth=0.2.5=pyhd3eb1b0_0 230 | - webencodings=0.5.1=py37_1 231 | - werkzeug=2.0.2=pyhd3eb1b0_0 232 | - widgetsnbextension=3.5.1=py37_0 233 | - xmltodict=0.12.0=pyhd3eb1b0_0 234 | - xz=5.2.5=h7b6447c_0 235 | - yaml=0.2.5=h7b6447c_0 236 | - yarl=1.6.3=py37h27cfd23_0 237 | - zeromq=4.3.4=h2531618_0 238 | - zipp=3.6.0=pyhd3eb1b0_0 239 | - zlib=1.2.11=h7b6447c_3 240 | - zstd=1.4.9=haebb681_0 241 | - pip: 242 | - albumentations==1.1.0 243 | - boto==2.49.0 244 | - cloudpickle==2.0.0 245 | - crcmod==1.7 246 | - cycler==0.11.0 247 | - dataclasses==0.6 248 | - dummy-test==0.1.3 249 | - fasteners==0.16.3 250 | - fonttools==4.33.3 251 | - gcs-oauth2-boto-plugin==3.0 252 | - google-apitools==0.5.32 253 | - google-reauth==0.1.1 254 | - gsutil==5.5 255 | - gym==0.21.0 256 | - higher==0.2.1 257 | - hnswlib==0.5.2 258 | - httplib2==0.20.2 259 | - huggingface-hub==0.5.1 260 | - imageio==2.19.2 261 | - ipdb==0.13.9 262 | - ipython==7.25.0 263 | - kiwisolver==1.4.2 264 | - learn2learn==0.1.6 265 | - littleutils==0.2.2 266 | - mat73==0.52 267 | - matplotlib==3.5.2 268 | - monotonic==1.6 269 | - networkx==2.6.3 270 | - numpy==1.21.3 271 | - oauth2client==4.1.3 272 | - ogb==1.3.2 273 | - opencv-python==4.5.4.58 274 | - opencv-python-headless==4.5.4.58 275 | - outdated==0.2.1 276 | - pandas==1.3.4 277 | - pathlib==1.0.1 278 | - pickle5==0.0.12 279 | - pillow==8.4.0 280 | - pip==21.3.1 281 | - ptflops==0.6.6 282 | - pyu2f==0.1.5 283 | - pywavelets==1.1.1 284 | - qpth==0.0.15 285 | - qudida==0.0.4 286 | - regex==2022.3.15 287 | - retry-decorator==1.1.1 288 | - sacremoses==0.0.49 289 | - scikit-image==0.18.3 290 | - setuptools==58.5.0 291 | - tabulate==0.8.9 292 | - tensorboard-data-server==0.6.1 293 | - tifffile==2021.10.12 294 | - timm==0.4.12 295 | - tokenizers==0.11.6 296 | - toml==0.10.2 297 | - torchvision==0.8.1 298 | - transformers==4.18.0 299 | - wheel==0.34.0 300 | - wilds==1.2.2 301 | -------------------------------------------------------------------------------- /src/fmow/fmow_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torchvision 8 | from tqdm import tqdm 9 | import copy 10 | import learn2learn as l2l 11 | import time 12 | from datetime import datetime 13 | import os 14 | import wilds 15 | from src.fmow.fmow_utils import * 16 | 17 | def train_epoch(selector, selector_name, models_list, student, student_name, 18 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster, 19 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3, 20 | batch_size=256, sup_size=24, test_way='id', save=False, 21 | root_dir='data'): 22 | for model in models_list: 23 | model.eval() 24 | 25 | student_ce = nn.CrossEntropyLoss() 26 | #teacher_ce = nn.CrossEntropyLoss() 27 | 28 | features = student.features 29 | head = student.classifier 30 | features = l2l.algorithms.MAML(features, lr=ilr) 31 | features.to(device) 32 | head.to(device) 33 | 34 | all_params = list(features.parameters()) + list(head.parameters()) 35 | optimizer_s = optim.Adam(all_params, lr=slr) 36 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr) 37 | 38 | i = 0 39 | 40 | losses = [] 41 | 42 | iter_per_epoch = len(train_loader) 43 | 44 | for x, y_true, metadata in train_loader: 45 | selector.eval() 46 | head.eval() 47 | features.eval() 48 | 49 | z = grouper.metadata_to_group(metadata) 50 | z = set(z.tolist()) 51 | assert len(z) == 1 52 | mask = mask_grouper.metadata_to_group(metadata) 53 | mask.apply_(lambda x: split_to_cluster[x]) 54 | 55 | #sup_size = x.shape[0]//2 56 | x_sup = x[:sup_size] 57 | y_sup = y_true[:sup_size] 58 | x_que = x[sup_size:] 59 | y_que = y_true[sup_size:] 60 | mask = mask[:sup_size] 61 | 62 | x_sup = x_sup.to(device) 63 | y_sup = y_sup.to(device) 64 | x_que = x_que.to(device) 65 | y_que = y_que.to(device) 66 | 67 | with torch.no_grad(): 68 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 69 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]]) 70 | #logits = logits.permute((0,2,1)) 71 | logits = mask_feat(logits, mask, len(models_list), exclude=True) 72 | 73 | t_out = selector.get_feat(logits) 74 | 75 | task_model = features.clone() 76 | task_model.module.eval() 77 | feat = task_model(x_que) 78 | feat = feat.view(feat.shape[0], -1) 79 | out = head(feat) 80 | with torch.no_grad(): 81 | loss_pre = student_ce(out, y_que).item()/x_que.shape[0] 82 | 83 | feat = task_model(x_sup) 84 | feat = feat.view_as(t_out) 85 | 86 | inner_loss = l2_loss(feat, t_out) 87 | task_model.adapt(inner_loss) 88 | 89 | x_que = task_model(x_que) 90 | x_que = x_que.view(x_que.shape[0], -1) 91 | s_que_out = head(x_que) 92 | s_que_loss = student_ce(s_que_out, y_que) 93 | #t_sup_loss = teacher_ce(t_out, y_sup) 94 | 95 | s_que_loss.backward() 96 | 97 | optimizer_s.step() 98 | optimizer_t.step() 99 | optimizer_s.zero_grad() 100 | optimizer_t.zero_grad() 101 | 102 | #print("Step:{}".format(time.time() - t_1)) 103 | #t_1 = time.time() 104 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 105 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch, 106 | loss_pre, 107 | s_que_loss.item()/x_que.shape[0])) 108 | losses.append(s_que_loss.item()/x_que.shape[0]) 109 | 110 | if i == iter_per_epoch//2: 111 | losses = np.mean(losses) 112 | worst_acc, acc = eval(selector, models_list, student, sup_size, device=device, 113 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 114 | root_dir=root_dir) 115 | 116 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 117 | f.write(f'Accuracy: {acc}; Worst acc: {worst_acc}\r\n') 118 | losses = [] 119 | 120 | if worst_acc > acc_best and save: 121 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 122 | f.write("Saving model ...\r\n") 123 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 124 | save_model(student, student_name+"_student", 0, test_way=test_way) 125 | acc_best = worst_acc 126 | 127 | i += 1 128 | return acc_best 129 | 130 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device, 131 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30, 132 | decayRate=0.96, save=False, test_way='ood', root_dir='data'): 133 | 134 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None, 135 | test_way=test_way, n_groups_per_batch=1, 136 | groupby_fields=['year', 'region']) 137 | 138 | mask_grouper = get_mask_grouper(root_dir=root_dir) 139 | 140 | curr = str(datetime.now()) 141 | if not os.path.exists("log/fmow"): 142 | os.makedirs("log/fmow") 143 | #print(curr) 144 | print("Training log saved to log/fmow/"+curr+".txt") 145 | 146 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 147 | f.write(selector_name+' '+student_name+'\r\n') 148 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n') 149 | 150 | accu_best = 0 151 | for epoch in range(num_epochs): 152 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 153 | f.write(f'Epoch: {epoch}\r\n') 154 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name, 155 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster, 156 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr, 157 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save, 158 | root_dir=root_dir) 159 | accu_best = max(accu_best, accu_epoch) 160 | worst_acc, accu = eval(selector, models_list, student, sup_size, device=device, 161 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 162 | root_dir=root_dir) 163 | 164 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 165 | f.write(f'Accuracy: {accu}; Worst acc: {worst_acc}\r\n') 166 | 167 | if worst_acc > accu_best and save: 168 | with open('log/fmow/'+curr+'.txt', 'a+') as f: 169 | f.write("Saving model ...\r\n") 170 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 171 | save_model(student, student_name+"_student", 0, test_way=test_way) 172 | accu_best = worst_acc 173 | tlr = tlr*decayRate 174 | slr = slr*decayRate 175 | 176 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5, 177 | test=False, progress=True, uniform_over_groups=False, root_dir='data'): 178 | 179 | if test: 180 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir, 181 | groupby_fields=['year', 'region'], return_dataset=True) 182 | else: 183 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir, 184 | groupby_fields=['year', 'region'], return_dataset=True) 185 | 186 | '''if test: 187 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test') 188 | else: 189 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')''' 190 | 191 | features = student.features 192 | head = student.classifier 193 | head.to(device) 194 | 195 | student_maml = l2l.algorithms.MAML(features, lr=ilr) 196 | student_maml.to(device) 197 | 198 | old_domain = {} 199 | if progress: 200 | loader = tqdm(iter(loader), total=len(loader)) 201 | 202 | for domain, domain_loader in loader: 203 | adapted = False 204 | for x_sup, y_sup, metadata in domain_loader: 205 | student_maml.module.eval() 206 | selector.eval() 207 | head.eval() 208 | 209 | z = grouper.metadata_to_group(metadata) 210 | z = set(z.tolist()) 211 | assert list(z)[0] == domain 212 | 213 | x_sup = x_sup.to(device) 214 | y_sup = y_sup.to(device) 215 | 216 | if not adapted: 217 | task_model = student_maml.clone() 218 | task_model.eval() 219 | with torch.no_grad(): 220 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 221 | logits = logits.permute((0,2,1)) 222 | t_out = selector.get_feat(logits) 223 | 224 | feat = task_model(x_sup) 225 | feat = feat.view_as(t_out) 226 | 227 | kl_loss = l2_loss(feat, t_out) 228 | torch.cuda.empty_cache() 229 | task_model.adapt(kl_loss) 230 | adapted = True 231 | 232 | with torch.no_grad(): 233 | task_model.module.eval() 234 | x_sup = task_model(x_sup) 235 | x_sup = x_sup.view(x_sup.shape[0], -1) 236 | s_que_out = head(x_sup) 237 | pred = s_que_out.max(1, keepdim=True)[1] 238 | try: 239 | pred_all = torch.cat((pred_all, pred.view_as(y_sup))) 240 | y_all = torch.cat((y_all, y_sup)) 241 | metadata_all = torch.cat((metadata_all, metadata)) 242 | except NameError: 243 | pred_all = pred.view_as(y_sup) 244 | y_all = y_sup 245 | metadata_all = metadata 246 | 247 | acc, worst_acc = get_fmow_metrics(pred_all, y_all, metadata_all, dataset) 248 | return worst_acc, acc 249 | -------------------------------------------------------------------------------- /src/rxrx1/rxrx1_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torchvision 8 | from tqdm import tqdm 9 | import copy 10 | import learn2learn as l2l 11 | import time 12 | from datetime import datetime 13 | from collections import defaultdict 14 | from transformers import get_cosine_schedule_with_warmup 15 | import os 16 | import wilds 17 | from src.rxrx1.rxrx1_utils import * 18 | 19 | def train_epoch(selector, selector_name, models_list, student, student_name, 20 | train_loader, grouper, num_epochs, curr, mask_grouper, split_to_cluster, 21 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3, 22 | batch_size=256, sup_size=24, test_way='id', save=False, 23 | root_dir='data'): 24 | for model in models_list: 25 | model.eval() 26 | 27 | student_ce = nn.CrossEntropyLoss() 28 | #teacher_ce = nn.CrossEntropyLoss() 29 | 30 | features = student.features 31 | head = student.classifier 32 | features = l2l.algorithms.MAML(features, lr=ilr) 33 | features.to(device) 34 | head.to(device) 35 | 36 | all_params = list(features.parameters()) + list(head.parameters()) 37 | optimizer_s = optim.Adam(all_params, lr=slr) 38 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr) 39 | 40 | i = 0 41 | 42 | losses = [] 43 | 44 | iter_per_epoch = len(train_loader) 45 | scheduler = get_cosine_schedule_with_warmup( 46 | optimizer_s, 47 | num_warmup_steps=iter_per_epoch, 48 | num_training_steps=iter_per_epoch*num_epochs) 49 | 50 | for epoch in range(num_epochs): 51 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f: 52 | f.write(f'Epoch: {epoch}\r\n') 53 | i = 0 54 | for x, y_true, metadata in train_loader: 55 | selector.eval() 56 | head.eval() 57 | features.eval() 58 | 59 | z = grouper.metadata_to_group(metadata) 60 | z = set(z.tolist()) 61 | assert len(z) == 1 62 | mask = mask_grouper.metadata_to_group(metadata) 63 | mask.apply_(lambda x: split_to_cluster[x]) 64 | 65 | #sup_size = x.shape[0]//2 66 | x_sup = x[:sup_size] 67 | y_sup = y_true[:sup_size] 68 | x_que = x[sup_size:] 69 | y_que = y_true[sup_size:] 70 | mask = mask[:sup_size] 71 | 72 | x_sup = x_sup.to(device) 73 | y_sup = y_sup.to(device) 74 | x_que = x_que.to(device) 75 | y_que = y_que.to(device) 76 | 77 | with torch.no_grad(): 78 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 79 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]]) 80 | #logits = logits.permute((0,2,1)) 81 | logits = mask_feat(logits, mask, len(models_list), exclude=True) 82 | 83 | t_out = selector.get_feat(logits) 84 | 85 | task_model = features.clone() 86 | task_model.module.eval() 87 | feat = task_model(x_que) 88 | feat = feat.view(feat.shape[0], -1) 89 | out = head(feat) 90 | with torch.no_grad(): 91 | loss_pre = student_ce(out, y_que).item()/x_que.shape[0] 92 | 93 | feat = task_model(x_sup) 94 | feat = feat.view_as(t_out) 95 | 96 | inner_loss = l2_loss(feat, t_out) 97 | task_model.adapt(inner_loss) 98 | 99 | x_que = task_model(x_que) 100 | x_que = x_que.view(x_que.shape[0], -1) 101 | s_que_out = head(x_que) 102 | s_que_loss = student_ce(s_que_out, y_que) 103 | #t_sup_loss = teacher_ce(t_out, y_sup) 104 | 105 | s_que_loss.backward() 106 | 107 | optimizer_s.step() 108 | optimizer_t.step() 109 | optimizer_s.zero_grad() 110 | optimizer_t.zero_grad() 111 | 112 | #print("Step:{}".format(time.time() - t_1)) 113 | #t_1 = time.time() 114 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f: 115 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch, 116 | loss_pre, 117 | s_que_loss.item()/x_que.shape[0])) 118 | losses.append(s_que_loss.item()/x_que.shape[0]) 119 | 120 | if i == iter_per_epoch//2 or i == iter_per_epoch-1: 121 | losses = np.mean(losses) 122 | acc = eval(selector, models_list, student, sup_size, device=device, 123 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 124 | root_dir=root_dir) 125 | 126 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f: 127 | f.write(f'Accuracy: {acc}\r\n') 128 | losses = [] 129 | 130 | if acc > acc_best and save: 131 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f: 132 | f.write("Saving model ...\r\n") 133 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 134 | save_model(student, student_name+"_student", 0, test_way=test_way) 135 | acc_best = acc 136 | 137 | i += 1 138 | scheduler.step() 139 | return acc_best 140 | 141 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device, 142 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30, 143 | decayRate=0.96, save=False, test_way='ood', root_dir='data'): 144 | 145 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None, 146 | test_way=test_way, n_groups_per_batch=1) 147 | 148 | mask_grouper = get_mask_grouper(root_dir=root_dir) 149 | 150 | curr = str(datetime.now()) 151 | if not os.path.exists("log/rxrx1"): 152 | os.makedirs("log/rxrx1") 153 | #print(curr) 154 | print("Training log saved to log/rxrx1/"+curr+".txt") 155 | 156 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f: 157 | f.write(selector_name+' '+student_name+'\r\n') 158 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n') 159 | 160 | accu_best = 0 161 | #for epoch in range(num_epochs): 162 | #with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 163 | #f.write(f'Epoch: {epoch}\r\n') 164 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name, 165 | train_loader, grouper, num_epochs, curr, mask_grouper, split_to_cluster, 166 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr, 167 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save, 168 | root_dir=root_dir) 169 | #accu_best = max(accu_best, accu_epoch) 170 | #_, accu, f1 = eval(selector, models_list, student, sup_size, device=device, 171 | #ilr=ilr, test=False, progress=False, uniform_over_groups=False, 172 | #root_dir=root_dir) 173 | 174 | #with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 175 | #f.write(f'Accuracy: {accu} || F1:{f1} \r\n') 176 | 177 | #if f1 > accu_best and save: 178 | #with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 179 | #f.write("Saving model ...\r\n") 180 | #save_model(selector, selector_name+"_selector", 0, test_way=test_way) 181 | #save_model(student, student_name+"_student", 0, test_way=test_way) 182 | #accu_best = f1 183 | #tlr = tlr*decayRate 184 | #slr = slr*decayRate 185 | 186 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5, 187 | test=False, progress=True, uniform_over_groups=False, root_dir='data'): 188 | 189 | if test: 190 | loader, grouper = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir) 191 | else: 192 | loader, grouper = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir) 193 | '''if test: 194 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test') 195 | else: 196 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')''' 197 | 198 | features = student.features 199 | head = student.classifier 200 | head.to(device) 201 | 202 | student_maml = l2l.algorithms.MAML(features, lr=ilr) 203 | student_maml.to(device) 204 | 205 | nor_correct = 0 206 | nor_total = 0 207 | old_domain = {} 208 | if progress: 209 | loader = tqdm(iter(loader), total=len(loader)) 210 | 211 | for domain, domain_loader in loader: 212 | adapted = False 213 | for x_sup, y_sup, metadata in domain_loader: 214 | student_maml.module.eval() 215 | selector.eval() 216 | head.eval() 217 | 218 | z = grouper.metadata_to_group(metadata) 219 | z = set(z.tolist()) 220 | assert list(z)[0] == domain 221 | 222 | x_sup = x_sup.to(device) 223 | y_sup = y_sup.to(device) 224 | 225 | if not adapted: 226 | task_model = student_maml.clone() 227 | task_model.eval() 228 | 229 | with torch.no_grad(): 230 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 231 | logits = logits.permute((0,2,1)) 232 | t_out = selector.get_feat(logits) 233 | 234 | feat = task_model(x_sup) 235 | feat = feat.view_as(t_out) 236 | 237 | kl_loss = l2_loss(feat, t_out) 238 | torch.cuda.empty_cache() 239 | task_model.adapt(kl_loss) 240 | adapted = True 241 | 242 | with torch.no_grad(): 243 | task_model.module.eval() 244 | x_sup = task_model(x_sup) 245 | x_sup = x_sup.view(x_sup.shape[0], -1) 246 | s_que_out = head(x_sup) 247 | pred = s_que_out.max(1, keepdim=True)[1] 248 | c = pred.eq(y_sup.view_as(pred)).sum().item() 249 | nor_correct += c 250 | nor_total += x_sup.shape[0] 251 | 252 | return nor_correct/nor_total 253 | -------------------------------------------------------------------------------- /src/fmow/fmow_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import torch.optim as optim 6 | import torchvision 7 | import torchvision.transforms as transforms 8 | import math 9 | import random 10 | import pickle5 as pickle 11 | from src.transformer import Transformer 12 | import copy 13 | import time 14 | import os 15 | import wilds 16 | from wilds import get_dataset 17 | from wilds.common.data_loaders import get_train_loader, get_eval_loader 18 | from wilds.common.grouper import CombinatorialGrouper 19 | from wilds.common.utils import split_into_groups 20 | 21 | # Utils 22 | 23 | def split_domains(num_experts, root_dir='data'): 24 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir) 25 | train_data = dataset.get_subset('train') 26 | year = list(set(train_data.metadata_array[:,1].detach().numpy().tolist())) 27 | random.shuffle(year) 28 | num_domains_per_super = len(year) / float(num_experts) 29 | all_split = [[] for _ in range(num_experts)] 30 | for i in range(len(year)): 31 | all_split[int(i//num_domains_per_super)].append(year[i]) 32 | return all_split 33 | 34 | def get_subset_with_domain(dataset, split=None, domain=None, transform=None, grouper=None): 35 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset: 36 | subset = copy.deepcopy(dataset) 37 | else: 38 | subset = dataset.get_subset(split, transform=transform) 39 | if domain is not None: 40 | if grouper is not None: 41 | z = grouper.metadata_to_group(subset.dataset.metadata_array[subset.indices]) 42 | idx = np.argwhere(np.isin(z, domain)).ravel() 43 | subset.indices = subset.indices[idx] 44 | else: 45 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,1][subset.indices], domain)).ravel() 46 | subset.indices = subset.indices[idx] 47 | return subset 48 | 49 | def initialize_image_base_transform(dataset): 50 | transform_steps = [] 51 | if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution): 52 | crop_size = min(dataset.original_resolution) 53 | transform_steps.append(transforms.CenterCrop(crop_size)) 54 | transform_steps += [ 55 | transforms.ToTensor(), 56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 57 | ] 58 | transform = transforms.Compose(transform_steps) 59 | return transform 60 | 61 | def get_data_loader(batch_size=16, domain=None, test_way='id', 62 | n_groups_per_batch=0, uniform_over_groups=False, root_dir='data', 63 | groupby_fields=['year'], return_dataset=False): 64 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir) 65 | grouper = CombinatorialGrouper(dataset, groupby_fields) 66 | 67 | transform = initialize_image_base_transform(dataset) 68 | 69 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform) 70 | if test_way == 'ood': 71 | val_data = dataset.get_subset('val', transform=transform) 72 | test_data = dataset.get_subset('test', transform=transform) 73 | else: 74 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform) 75 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform) 76 | 77 | 78 | if n_groups_per_batch == 0: 79 | #0 identify standard loader 80 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size) 81 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size) 82 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size) 83 | 84 | else: 85 | # All use get_train_loader to enable grouper 86 | train_loader = get_train_loader('group', train_data, grouper=grouper, 87 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size) 88 | val_loader = get_train_loader('group', val_data, grouper=grouper, 89 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 90 | uniform_over_groups=uniform_over_groups) 91 | test_loader = get_train_loader('group', test_data, grouper=grouper, 92 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size, 93 | uniform_over_groups=uniform_over_groups) 94 | if return_dataset: 95 | return train_loader, val_loader, test_loader, grouper, dataset 96 | return train_loader, val_loader, test_loader, grouper 97 | 98 | def get_test_loader(batch_size=16, split='val', root_dir='data', groupby_fields=['year'], return_dataset=False): 99 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir) 100 | grouper = CombinatorialGrouper(dataset, groupby_fields) 101 | 102 | transform = initialize_image_base_transform(dataset) 103 | 104 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform) 105 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist())) 106 | test_loader = [] 107 | 108 | for domain in all_domains: 109 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform, grouper=grouper) 110 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size) 111 | test_loader.append((domain, domain_loader)) 112 | 113 | if return_dataset: 114 | return test_loader, grouper, dataset 115 | return test_loader, grouper 116 | 117 | def get_mask_grouper(root_dir='data'): 118 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir) 119 | grouper = CombinatorialGrouper(dataset, ['year']) 120 | return grouper 121 | 122 | def save_model(model, name, epoch, test_way='ood'): 123 | if not os.path.exists("model/fmow"): 124 | os.makedirs("model/fmow") 125 | path = "model/fmow/{0}_best.pth".format(name) 126 | torch.save(model.state_dict(), path) 127 | 128 | def mask_feat(feat, mask_index, num_experts, exclude=True): 129 | assert feat.shape[0] == mask_index.shape[0] 130 | if exclude: 131 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index] 132 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx] 133 | else: 134 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index]) 135 | feat = feat.permute((0,2,1)) 136 | return feat 137 | 138 | def l2_loss(input, target): 139 | loss = torch.square(target - input) 140 | loss = torch.mean(loss) 141 | return loss 142 | 143 | # Models 144 | class ResNetFeature(nn.Module): 145 | def __init__(self, original_model, layer=-1): 146 | super(ResNetFeature, self).__init__() 147 | self.num_ftrs = original_model.classifier.in_features 148 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 149 | self.pool = nn.AdaptiveAvgPool2d((1,1)) 150 | 151 | def forward(self, x): 152 | x = self.features(x) 153 | x = self.pool(x) 154 | x = x.view(-1, self.num_ftrs) 155 | return x 156 | 157 | class fa_selector(nn.Module): 158 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=62, pool='mean'): 159 | super(fa_selector, self).__init__() 160 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout) 161 | self.pool = pool 162 | self.mlp = nn.Sequential( 163 | nn.LayerNorm(dim), 164 | nn.Linear(dim, out_dim) 165 | ) 166 | 167 | def forward(self, x): 168 | x = self.transformer(x) 169 | if self.pool == 'mean': 170 | x = x.mean(dim=1) 171 | elif self.pool == 'max': 172 | x = x.max(dim=1) 173 | else: 174 | raise NotImplementedError 175 | x = self.mlp(x) 176 | return x 177 | 178 | def get_feat(self, x): 179 | x = self.transformer(x) 180 | if self.pool == 'mean': 181 | x = x.mean(dim=1) 182 | elif self.pool == 'max': 183 | x = x.max(dim=1) 184 | else: 185 | raise NotImplementedError 186 | return x 187 | 188 | class DivideModel(nn.Module): 189 | def __init__(self, original_model, layer=-1): 190 | super(DivideModel, self).__init__() 191 | self.num_ftrs = original_model.classifier.in_features 192 | self.num_class = original_model.classifier.out_features 193 | self.features = nn.Sequential(*list(original_model.children())[:layer]) 194 | self.features.add_module("avg_pool", nn.AdaptiveAvgPool2d((1,1))) 195 | self.classifier = nn.Sequential(*list(original_model.children())[layer:]) 196 | 197 | def forward(self, x): 198 | x = self.features(x) 199 | x = x.view(-1, self.num_ftrs) 200 | x = self.classifier(x) 201 | x = x.view(-1, self.num_class) 202 | return x 203 | 204 | def StudentModel(device, num_classes=62, load_path=None): 205 | model = torchvision.models.densenet121(pretrained=True) 206 | num_ftrs = model.classifier.in_features 207 | model.classifier = nn.Linear(num_ftrs, num_classes) 208 | if load_path: 209 | model.load_state_dict(torch.load(load_path)) 210 | model = DivideModel(model) 211 | model = model.to(device) 212 | return model 213 | 214 | def get_feature_list(models_list, device): 215 | feature_list = [] 216 | for model in models_list: 217 | feature_list.append(ResNetFeature(model).to(device)) 218 | return feature_list 219 | 220 | def get_models_list(device, num_domains=3, num_classes=62, pretrained=False, bb='d121'): 221 | models_list = [] 222 | for _ in range(num_domains+1): 223 | model = torchvision.models.densenet121(pretrained=True) 224 | num_ftrs = model.classifier.in_features 225 | model.classifier = nn.Linear(num_ftrs, num_classes) 226 | model = model.to(device) 227 | models_list.append(model) 228 | return models_list 229 | 230 | def get_fmow_metrics(y_pred, y_true, metadata, dataset): 231 | worst_acc = float('inf') 232 | grouper = CombinatorialGrouper(dataset, ['region']) 233 | regions = grouper.metadata_to_group(metadata) 234 | unique_groups, group_indices, _ = split_into_groups(regions) 235 | for i_group in group_indices: 236 | correct = y_pred[i_group].eq(y_true[i_group].view_as(y_pred[i_group])).sum().item() 237 | total = y_pred[i_group].shape[0] 238 | worst_acc = min(correct/total, worst_acc) 239 | correct = y_pred.eq(y_true.view_as(y_pred)).sum().item() 240 | total = y_pred.shape[0] 241 | return correct/total, worst_acc 242 | -------------------------------------------------------------------------------- /src/iwildcam/iwildcam_train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.optim as optim 7 | import torchvision 8 | from tqdm import tqdm 9 | import copy 10 | import learn2learn as l2l 11 | from sklearn.metrics import f1_score 12 | import time 13 | from datetime import datetime 14 | from collections import defaultdict 15 | import os 16 | import wilds 17 | from src.iwildcam.iwildcam_utils import * 18 | 19 | def train_epoch(selector, selector_name, models_list, student, student_name, 20 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster, 21 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3, 22 | batch_size=256, sup_size=24, test_way='id', save=False, 23 | root_dir='data'): 24 | for model in models_list: 25 | model.eval() 26 | 27 | student_ce = nn.CrossEntropyLoss() 28 | #teacher_ce = nn.CrossEntropyLoss() 29 | 30 | features = student.features 31 | head = student.classifier 32 | features = l2l.algorithms.MAML(features, lr=ilr) 33 | features.to(device) 34 | head.to(device) 35 | 36 | all_params = list(features.parameters()) + list(head.parameters()) 37 | optimizer_s = optim.Adam(all_params, lr=slr) 38 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr) 39 | 40 | i = 0 41 | 42 | losses = [] 43 | 44 | iter_per_epoch = len(train_loader) 45 | 46 | for x, y_true, metadata in train_loader: 47 | selector.eval() 48 | head.eval() 49 | features.eval() 50 | 51 | z = grouper.metadata_to_group(metadata) 52 | z = set(z.tolist()) 53 | assert len(z) == 1 54 | mask = mask_grouper.metadata_to_group(metadata) 55 | mask.apply_(lambda x: split_to_cluster[x]) 56 | 57 | #sup_size = x.shape[0]//2 58 | x_sup = x[:sup_size] 59 | y_sup = y_true[:sup_size] 60 | x_que = x[sup_size:] 61 | y_que = y_true[sup_size:] 62 | mask = mask[:sup_size] 63 | 64 | x_sup = x_sup.to(device) 65 | y_sup = y_sup.to(device) 66 | x_que = x_que.to(device) 67 | y_que = y_que.to(device) 68 | 69 | with torch.no_grad(): 70 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 71 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]]) 72 | #logits = logits.permute((0,2,1)) 73 | logits = mask_feat(logits, mask, len(models_list), exclude=True) 74 | 75 | t_out = selector.get_feat(logits) 76 | 77 | task_model = features.clone() 78 | task_model.module.eval() 79 | feat = task_model(x_que) 80 | feat = feat.view(feat.shape[0], -1) 81 | out = head(feat) 82 | with torch.no_grad(): 83 | loss_pre = student_ce(out, y_que).item()/x_que.shape[0] 84 | 85 | feat = task_model(x_sup) 86 | feat = feat.view_as(t_out) 87 | 88 | inner_loss = l2_loss(feat, t_out) 89 | task_model.adapt(inner_loss) 90 | 91 | x_que = task_model(x_que) 92 | x_que = x_que.view(x_que.shape[0], -1) 93 | s_que_out = head(x_que) 94 | s_que_loss = student_ce(s_que_out, y_que) 95 | #t_sup_loss = teacher_ce(t_out, y_sup) 96 | 97 | s_que_loss.backward() 98 | 99 | optimizer_s.step() 100 | optimizer_t.step() 101 | optimizer_s.zero_grad() 102 | optimizer_t.zero_grad() 103 | 104 | #print("Step:{}".format(time.time() - t_1)) 105 | #t_1 = time.time() 106 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 107 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch, 108 | loss_pre, 109 | s_que_loss.item()/x_que.shape[0])) 110 | losses.append(s_que_loss.item()/x_que.shape[0]) 111 | 112 | if i == iter_per_epoch//2: 113 | losses = np.mean(losses) 114 | _, acc, f1 = eval(selector, models_list, student, sup_size, device=device, 115 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 116 | root_dir=root_dir) 117 | 118 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 119 | f.write(f'Accuracy: {acc} || F1:{f1} \r\n') 120 | losses = [] 121 | 122 | if f1 > acc_best and save: 123 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 124 | f.write("Saving model ...\r\n") 125 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 126 | save_model(student, student_name+"_student", 0, test_way=test_way) 127 | acc_best = f1 128 | 129 | i += 1 130 | return acc_best 131 | 132 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device, 133 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30, 134 | decayRate=0.96, save=False, test_way='ood', root_dir='data'): 135 | 136 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None, 137 | test_way=test_way, n_groups_per_batch=1) 138 | 139 | mask_grouper = get_mask_grouper(root_dir=root_dir) 140 | 141 | curr = str(datetime.now()) 142 | if not os.path.exists("log/iwildcam"): 143 | os.makedirs("log/iwildcam") 144 | #print(curr) 145 | print("Training log saved to log/iwildcam/"+curr+".txt") 146 | 147 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 148 | f.write(selector_name+' '+student_name+'\r\n') 149 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n') 150 | 151 | accu_best = 0 152 | for epoch in range(num_epochs): 153 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 154 | f.write(f'Epoch: {epoch}\r\n') 155 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name, 156 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster, 157 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr, 158 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save, 159 | root_dir=root_dir) 160 | accu_best = max(accu_best, accu_epoch) 161 | _, accu, f1 = eval(selector, models_list, student, sup_size, device=device, 162 | ilr=ilr, test=False, progress=False, uniform_over_groups=False, 163 | root_dir=root_dir) 164 | 165 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 166 | f.write(f'Accuracy: {accu} || F1:{f1} \r\n') 167 | 168 | if f1 > accu_best and save: 169 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f: 170 | f.write("Saving model ...\r\n") 171 | save_model(selector, selector_name+"_selector", 0, test_way=test_way) 172 | save_model(student, student_name+"_student", 0, test_way=test_way) 173 | accu_best = f1 174 | tlr = tlr*decayRate 175 | slr = slr*decayRate 176 | 177 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5, 178 | test=False, progress=True, uniform_over_groups=False, root_dir='data'): 179 | 180 | if test: 181 | loader, grouper = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir) 182 | else: 183 | loader, grouper = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir) 184 | '''if test: 185 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test') 186 | else: 187 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')''' 188 | 189 | 190 | features = student.features 191 | head = student.classifier 192 | head.to(device) 193 | 194 | student_maml = l2l.algorithms.MAML(features, lr=ilr) 195 | student_maml.to(device) 196 | 197 | correct = defaultdict(int) 198 | total = defaultdict(int) 199 | nor_correct = 0 200 | nor_total = 0 201 | old_domain = {} 202 | if progress: 203 | loader = tqdm(iter(loader), total=len(loader)) 204 | 205 | for domain, domain_loader in loader: 206 | adapted = False 207 | for x_sup, y_sup, metadata in domain_loader: 208 | student_maml.module.eval() 209 | selector.eval() 210 | head.eval() 211 | 212 | z = grouper.metadata_to_group(metadata) 213 | z = set(z.tolist()) 214 | assert list(z)[0] == domain 215 | 216 | x_sup = x_sup.to(device) 217 | y_sup = y_sup.to(device) 218 | 219 | if not adapted: 220 | task_model = student_maml.clone() 221 | task_model.eval() 222 | with torch.no_grad(): 223 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1) 224 | logits = logits.permute((0,2,1)) 225 | t_out = selector.get_feat(logits) 226 | 227 | feat = task_model(x_sup) 228 | feat = feat.view_as(t_out) 229 | 230 | kl_loss = l2_loss(feat, t_out) 231 | torch.cuda.empty_cache() 232 | task_model.adapt(kl_loss) 233 | adapted = True 234 | 235 | with torch.no_grad(): 236 | task_model.module.eval() 237 | x_sup = task_model(x_sup) 238 | x_sup = x_sup.view(x_sup.shape[0], -1) 239 | s_que_out = head(x_sup) 240 | pred = s_que_out.max(1, keepdim=True)[1] 241 | c = pred.eq(y_sup.view_as(pred)).sum().item() 242 | correct[list(z)[0]] += c 243 | total[list(z)[0]] += x_sup.shape[0] 244 | nor_correct += c 245 | nor_total += x_sup.shape[0] 246 | try: 247 | pred_all = torch.cat((pred_all, pred.view_as(y_sup))) 248 | y_all = torch.cat((y_all, y_sup)) 249 | except NameError: 250 | pred_all = pred.view_as(y_sup) 251 | y_all = y_sup 252 | 253 | y_all = y_all.detach().cpu() 254 | pred_all = pred_all.detach().cpu() 255 | f1 = f1_score(y_all,pred_all,average='macro', labels=torch.unique(y_all)) 256 | mean_acc = [] 257 | for key,value in correct.items(): 258 | mean_acc.append(value/total[key]) 259 | return np.mean(mean_acc).item(), nor_correct/nor_total, f1 260 | --------------------------------------------------------------------------------