├── src
├── __init__.py
├── fmow
│ ├── __init__.py
│ ├── fmow_aggregator.py
│ ├── fmow_experts.py
│ ├── fmow_train.py
│ └── fmow_utils.py
├── rxrx1
│ ├── __init__.py
│ ├── rxrx1_aggregator.py
│ ├── rxrx1_experts.py
│ ├── rxrx1_utils.py
│ └── rxrx1_train.py
├── camelyon
│ ├── __init__.py
│ ├── camelyon_aggregator.py
│ ├── camelyon_experts.py
│ ├── camelyon_train.py
│ └── camelyon_utils.py
├── iwildcam
│ ├── __init__.py
│ ├── iwildcam_aggregator.py
│ ├── iwildcam_experts.py
│ ├── iwildcam_utils.py
│ └── iwildcam_train.py
├── transformer.py
└── configs.py
├── assets
├── table1.png
└── overview.png
├── requirements.txt
├── LICENSE
├── train_single_expert.py
├── README.md
├── run.py
└── environment.yml
/src/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/fmow/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/rxrx1/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/camelyon/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/src/iwildcam/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/assets/table1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n3il666/Meta-DMoE/HEAD/assets/table1.png
--------------------------------------------------------------------------------
/assets/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/n3il666/Meta-DMoE/HEAD/assets/overview.png
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | --find-links https://download.pytorch.org/whl/torch_stable.html
2 | torch==1.7.1
3 | torchvision==0.8.2
4 |
5 | numpy==1.21.3
6 | tqdm==4.62.3
7 | transformers==4.18.0
8 | learn2learn==0.1.6
9 | wilds==1.2.2
10 | pickle5==0.0.12
11 | scikit-learn==0.24.2
12 | scikit-image==0.18.3
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Tao Zhong
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/src/transformer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 | class PreNorm(nn.Module):
5 | def __init__(self, dim, fn):
6 | super().__init__()
7 | self.norm = nn.LayerNorm(dim)
8 | self.fn = fn
9 | def forward(self, x, **kwargs):
10 | return self.fn(self.norm(x), **kwargs)
11 |
12 | class FeedForward(nn.Module):
13 | def __init__(self, dim, hidden_dim, dropout = 0.):
14 | super().__init__()
15 | self.net = nn.Sequential(
16 | nn.Linear(dim, hidden_dim),
17 | nn.GELU(),
18 | nn.Dropout(dropout),
19 | nn.Linear(hidden_dim, dim),
20 | nn.Dropout(dropout)
21 | )
22 | def forward(self, x):
23 | return self.net(x)
24 |
25 | class Attention(nn.Module):
26 | def __init__(self, dim, heads = 8, dropout = 0.):
27 | super().__init__()
28 | self.attend = nn.MultiheadAttention(dim, heads, dropout=dropout)
29 |
30 | def forward(self, x):
31 | q = x.permute((1,0,2))
32 | k = x.permute((1,0,2))
33 | v = x.permute((1,0,2))
34 | out, _ = self.attend(q, k, v)
35 | out = out.permute((1,0,2))
36 | return out
37 |
38 | class Transformer(nn.Module):
39 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0.):
40 | super().__init__()
41 | self.layers = nn.ModuleList([])
42 | for _ in range(depth):
43 | self.layers.append(nn.ModuleList([
44 | PreNorm(dim, Attention(dim, heads, dropout=dropout)),
45 | PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout))
46 | ]))
47 | def forward(self, x):
48 | for attn, ff in self.layers:
49 | x = attn(x) + x
50 | x = ff(x) + x
51 | return x
52 |
--------------------------------------------------------------------------------
/src/configs.py:
--------------------------------------------------------------------------------
1 | default_param = {
2 | 'iwildcam':{
3 | 'num_experts': 10,
4 | 'expert_batch_size': 16,
5 | 'expert_lr': 3e-5,
6 | 'expert_l2': 0,
7 | 'expert_epoch': 12,
8 | 'aggregator_depth': 1,
9 | 'aggregator_heads': 16,
10 | 'aggregator_dropout': 0.1,
11 | 'aggregator_pretrain_epoch': 12,
12 | 'student_pretrain_epoch': 12,
13 | 'batch_size': 40,
14 | 'sup_size': 24,
15 | 'tlr': 1e-6,
16 | 'slr': 3e-5,
17 | 'ilr': 3e-4,
18 | 'epoch': 20,
19 | 'seed': 1
20 | },
21 | 'camelyon':{
22 | 'num_experts': 5,
23 | 'expert_batch_size': 32,
24 | 'expert_lr': 1e-3,
25 | 'expert_l2': 1e-2,
26 | 'expert_epoch': 5,
27 | 'aggregator_depth': 1,
28 | 'aggregator_heads': 16,
29 | 'aggregator_dropout': 0.1,
30 | 'aggregator_pretrain_epoch': 5,
31 | 'student_pretrain_epoch': 5,
32 | 'batch_size': 96,
33 | 'sup_size': 64,
34 | 'tlr': 1e-6,
35 | 'slr': 1e-4,
36 | 'ilr': 1e-3,
37 | 'epoch': 10,
38 | 'seed': 3407
39 | },
40 | 'rxrx1':{
41 | 'num_experts': 3,
42 | 'expert_batch_size': 75,
43 | 'expert_lr': 1e-4,
44 | 'expert_l2': 1e-5,
45 | 'expert_epoch': 90,
46 | 'aggregator_depth': 1,
47 | 'aggregator_heads': 16,
48 | 'aggregator_dropout': 0.1,
49 | 'aggregator_pretrain_epoch': 90,
50 | 'student_pretrain_epoch': 90,
51 | 'batch_size': 123,
52 | 'sup_size': 75,
53 | 'tlr': 3e-6,
54 | 'slr': 1e-6,
55 | 'ilr': 1e-4,
56 | 'epoch': 10,
57 | 'seed': 1000
58 | },
59 | 'fmow':{
60 | 'num_experts': 4,
61 | 'expert_batch_size': 64,
62 | 'expert_lr': 1e-4,
63 | 'expert_l2': 0,
64 | 'expert_epoch': 30,
65 | 'aggregator_depth': 1,
66 | 'aggregator_heads': 16,
67 | 'aggregator_dropout': 0.1,
68 | 'aggregator_pretrain_epoch': 50,
69 | 'student_pretrain_epoch': 20,
70 | 'batch_size': 112,
71 | 'sup_size': 64,
72 | 'tlr': 1e-6,
73 | 'slr': 1e-5,
74 | 'ilr': 1e-4,
75 | 'epoch': 30,
76 | 'seed': 42
77 | }
78 | }
--------------------------------------------------------------------------------
/train_single_expert.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import random
5 | import argparse
6 | import pickle
7 | from src.configs import default_param
8 |
9 | def get_parser():
10 | parser = argparse.ArgumentParser()
11 |
12 | parser.add_argument('--gpu', type=str, default='0') # We only support single gpu training for now
13 | parser.add_argument('--threads', type=int, default=12)
14 |
15 | parser.add_argument('--dataset', type=str, default='iwildcam',
16 | choices=['iwildcam', 'fmow', 'camelyon', 'rxrx1', 'poverty'])
17 | parser.add_argument('--data_dir', type=str, default='data')
18 |
19 | parser.add_argument('--expert_idx', type=int)
20 |
21 | args = parser.parse_args()
22 | args_dict = args.__dict__
23 | args_dict.update(default_param[args.dataset])
24 | args = argparse.Namespace(**args_dict)
25 | return args
26 |
27 | def set_seed(seed):
28 | if torch.cuda.is_available():
29 | torch.cuda.manual_seed(seed)
30 | torch.manual_seed(seed)
31 | np.random.seed(seed)
32 | random.seed(seed)
33 | torch.backends.cudnn.deterministic = True
34 |
35 | def train(args):
36 |
37 | if args.dataset == 'iwildcam':
38 | from src.iwildcam.iwildcam_utils import get_models_list
39 | from src.iwildcam.iwildcam_experts import train_model, get_expert_split
40 | elif args.dataset == 'camelyon':
41 | from src.camelyon.camelyon_utils import get_models_list
42 | from src.camelyon.camelyon_experts import train_model, get_expert_split
43 | elif args.dataset == 'rxrx1':
44 | from src.rxrx1.rxrx1_utils import get_models_list
45 | from src.rxrx1.rxrx1_experts import train_model, get_expert_split
46 | elif args.dataset == 'fmow':
47 | from src.fmow.fmow_utils import get_models_list
48 | from src.fmow.fmow_experts import train_model, get_expert_split
49 | else:
50 | raise NotImplementedError
51 |
52 | name = f"{args.dataset}_{str(args.num_experts)}experts_seed{str(args.seed)}"
53 |
54 | models_list = get_models_list(device=device, num_domains=0)
55 |
56 | try:
57 | with open(f"model/{args.dataset}/domain_split.pkl", "rb") as f:
58 | all_split, split_to_cluster = pickle.load(f)
59 | except FileNotFoundError:
60 | all_split, split_to_cluster = get_expert_split(args.num_experts, root_dir=args.data_dir)
61 | with open(f"model/{args.dataset}/domain_split.pkl", "wb") as f:
62 | pickle.dump((all_split, split_to_cluster), f)
63 |
64 | print(f"Training model {args.expert_idx} for domain ", *all_split[args.expert_idx])
65 | train_model(models_list[0], name+'_'+str(args.expert_idx), device=device,
66 | domain=all_split[args.expert_idx], batch_size=args.expert_batch_size,
67 | lr=args.expert_lr, l2=args.expert_l2, num_epochs=args.expert_epoch,
68 | save=True, root_dir=args.data_dir)
69 |
70 | if __name__ == "__main__":
71 | args = get_parser()
72 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
73 | torch.set_num_threads(args.threads)
74 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
75 | set_seed(args.seed)
76 | train(args)
77 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts
2 |
3 | This repository is the official implementation of [Meta-DMoE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts](https://arxiv.org/abs/2210.03885).
4 |
5 |
6 |

7 |
8 |
9 | ## Requirements
10 |
11 | The code was tested on python3.7 and CUDA10.1.
12 |
13 | We recommend using conda environment to setup all required dependencies:
14 |
15 | ```setup
16 | conda env create -f environment.yml
17 | conda activate dmoe
18 | ```
19 |
20 | If you have any problem with the above command, you can also install them by `pip install -r requirements.txt`.
21 |
22 | Either of these commands will automatically install all required dependecies **except for the `torch-scatter` and `torch-geometric` packages** , which require a [quick manual install](https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html#installation-via-binaries).
23 |
24 | ## Training
25 |
26 | We provide the training script for the following 4 datasets from the WILDS benchmark: `iwildcam`, `camelyon`, `rxrx1`, and `FMoW`. To train the models in the paper, run the following commands:
27 |
28 | ```Training
29 | python run.py --dataset
30 | ```
31 |
32 | The data will be automatically downloaded to the data folder.
33 |
34 | ### Distributed Training for Expert Models
35 |
36 | Although we are not able to provide Multi-GPU support for meta-training at this point, you could still consider training the expert models in a distributed manner by opening multiple terminals and running:
37 |
38 | ```Train Experts
39 | python train_single_expert.py --dataset --data_dir --gpu --expert_idx
40 | ```
41 |
42 | and add `--load_trained_experts` flag when running `run.py`.
43 |
44 | ## Evaluation
45 |
46 | To evaluate trained models, run:
47 |
48 | ```eval
49 | python run.py --dataset --data_dir --test
50 | ```
51 |
52 | ## Pre-trained Models
53 |
54 | To reproduce the results reported in Table 1 in our paper, you can download pretrained models here and extract to `model/` folder. Note that due to the size limit of cloud storage, we only uploaded checkpoints from one random seed per dataset, while the results reported in the table are aggregated across several random seeds.
55 |
56 | - [iwildcam](https://drive.google.com/drive/folders/1mi-xInK5jXplmE4jo8oLSgHAYhsBe5sX?usp=sharing)
57 | - [camelyon](https://drive.google.com/drive/folders/1Wbuzv0DMxtfYjhF51KgdQQGV2_PRNEDw?usp=sharing)
58 | - [rxrx1](https://drive.google.com/drive/folders/1gsSFezrWbgKrU-nChC777DE3HFRnDkRx?usp=sharing)
59 | - [fmow](https://drive.google.com/drive/folders/1PAgzr7e2kcTQaYRq8gZnYtrQGq3PSxCh?usp=sharing)
60 |
61 |
62 |

63 |
64 |
65 | ## Citation
66 | If you find this codebase useful in your research, consider citing:
67 | ```
68 | @inproceedings{
69 | zhong2022metadmoe,
70 | title={Meta-{DM}oE: Adapting to Domain Shift by Meta-Distillation from Mixture-of-Experts},
71 | author={Tao Zhong and Zhixiang Chi and Li Gu and Yang Wang and Yuanhao Yu and Jin Tang},
72 | booktitle={Thirty-Sixth Conference on Neural Information Processing Systems (NeurIPS)},
73 | year={2022}
74 | }
75 | ```
76 |
--------------------------------------------------------------------------------
/src/rxrx1/rxrx1_aggregator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | from tqdm import tqdm
8 | import os
9 | from src.rxrx1.rxrx1_utils import *
10 |
11 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True):
12 | selector.eval()
13 | correct = 0
14 | total = 0
15 | #mean_correct = 0
16 | if progress:
17 | data_loader = tqdm(data_loader)
18 | for x, y_true, metadata in data_loader:
19 | #z = grouper.metadata_to_group(metadata)
20 | #z = set(z.tolist())
21 | #assert z.issubset(set(meta_indices))
22 |
23 | x = x.to(device)
24 | y_true = y_true.to(device)
25 |
26 | with torch.no_grad():
27 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
28 | features = features.permute((0,2,1))
29 | out = selector(features)
30 |
31 | pred = out.max(1, keepdim=True)[1]
32 | correct += pred.eq(y_true.view_as(pred)).sum().item()
33 | total += x.shape[0]
34 |
35 | return correct/total
36 |
37 | def train_model_selector(selector, model_name, models_list, device, root_dir='data',
38 | batch_size=75, lr=1e-5, l2=0,
39 | num_epochs=90, decayRate=0.98, save=True, test_way='ood'):
40 | for model in models_list:
41 | model.eval()
42 |
43 | train_loader, val_loader, _, grouper = get_data_loader(root_dir=root_dir,
44 | batch_size=batch_size, domain=None,
45 | test_way=test_way, n_groups_per_batch=0)
46 |
47 | criterion = nn.CrossEntropyLoss()
48 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2)
49 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
50 |
51 | i = 0
52 |
53 | losses = []
54 | acc_best = 0
55 |
56 | tot = len(train_loader)
57 |
58 | for epoch in range(num_epochs):
59 |
60 | print(f"Epoch:{epoch}|| Total:{tot}")
61 |
62 | for x, y_true, metadata in train_loader:
63 | selector.train()
64 |
65 | z = grouper.metadata_to_group(metadata)
66 | z = set(z.tolist())
67 | #assert z.issubset(set(meta_indices))
68 |
69 | x = x.to(device)
70 | y_true = y_true.to(device)
71 |
72 | with torch.no_grad():
73 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
74 | features = features.permute((0,2,1))
75 | out = selector(features)
76 |
77 | loss = criterion(out, y_true)
78 | loss.backward()
79 | optimizer.step()
80 | optimizer.zero_grad()
81 | losses.append(loss.item()/batch_size)
82 |
83 | if i % (tot//2) == 0 and i != 0:
84 | losses = np.mean(losses)
85 | acc = get_selector_accuracy(selector, models_list, val_loader,
86 | grouper, device, progress=False)
87 |
88 | print("Iter: {}/{} || Loss: {:.4f} || Acc:{:.4f}".format(i, tot, losses, acc))
89 | losses = []
90 |
91 | if acc > acc_best and save:
92 | print("Saving model ...")
93 | save_model(selector, model_name+"_selector", 0, test_way=test_way)
94 | acc_best = acc
95 |
96 | i += 1
97 | scheduler.step()
98 |
--------------------------------------------------------------------------------
/src/camelyon/camelyon_aggregator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | from tqdm import tqdm
8 | import os
9 | from src.camelyon.camelyon_utils import *
10 |
11 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True):
12 | selector.eval()
13 | correct = 0
14 | total = 0
15 | #mean_correct = 0
16 | if progress:
17 | data_loader = tqdm(data_loader)
18 | for x, y_true, metadata in data_loader:
19 | #z = grouper.metadata_to_group(metadata)
20 | #z = set(z.tolist())
21 | #assert z.issubset(set(meta_indices))
22 |
23 | x = x.to(device)
24 | y_true = y_true.to(device)
25 |
26 | with torch.no_grad():
27 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
28 | features = features.permute((0,2,1))
29 | out = selector(features)
30 |
31 | pred = (out > 0.0).squeeze().long()
32 | correct += pred.eq(y_true.view_as(pred)).sum().item()
33 | total += x.shape[0]
34 |
35 | return correct/total
36 |
37 | def train_model_selector(selector, model_name, models_list, device, root_dir='data',
38 | batch_size=32, lr=3e-6, l2=0,
39 | num_epochs=5, decayRate=0.96, save=True, test_way='ood'):
40 | for model in models_list:
41 | model.eval()
42 |
43 | train_loader, val_loader, _, grouper = get_data_loader(root_dir=root_dir,
44 | batch_size=batch_size, domain=None,
45 | test_way=test_way, n_groups_per_batch=0,
46 | return_dataset=False)
47 |
48 | criterion = nn.BCEWithLogitsLoss()
49 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2)
50 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
51 |
52 | i = 0
53 |
54 | losses = []
55 | acc_best = 0
56 |
57 | tot = len(train_loader)
58 |
59 | for epoch in range(num_epochs):
60 |
61 | print(f"Epoch:{epoch}|| Total:{tot}")
62 |
63 | for x, y_true, metadata in train_loader:
64 | selector.train()
65 |
66 | z = grouper.metadata_to_group(metadata)
67 | z = set(z.tolist())
68 | #assert z.issubset(set(meta_indices))
69 |
70 | x = x.to(device)
71 | y_true = y_true.to(device)
72 |
73 | with torch.no_grad():
74 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
75 | features = features.permute((0,2,1))
76 | out = selector(features)
77 |
78 | loss = criterion(out, y_true.unsqueeze(-1).float())
79 | loss.backward()
80 | optimizer.step()
81 | optimizer.zero_grad()
82 | losses.append(loss.item()/batch_size)
83 |
84 | if i % (tot//2) == 0 and i != 0:
85 | losses = np.mean(losses)
86 | acc = get_selector_accuracy(selector, models_list, val_loader,
87 | grouper, device, progress=False)
88 |
89 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f}".format(i, losses, acc))
90 | losses = []
91 |
92 | if acc > acc_best and save:
93 | print("Saving model ...")
94 | save_model(selector, model_name+"_selector", 0, test_way=test_way)
95 | acc_best = acc
96 |
97 | i += 1
98 | scheduler.step()
99 |
--------------------------------------------------------------------------------
/src/iwildcam/iwildcam_aggregator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | from tqdm import tqdm
8 | from sklearn.metrics import f1_score
9 | import os
10 | from src.iwildcam.iwildcam_utils import *
11 |
12 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True):
13 | selector.eval()
14 | correct = 0
15 | total = 0
16 | #mean_correct = 0
17 | if progress:
18 | data_loader = tqdm(data_loader)
19 | for x, y_true, metadata in data_loader:
20 | #z = grouper.metadata_to_group(metadata)
21 | #z = set(z.tolist())
22 | #assert z.issubset(set(meta_indices))
23 |
24 | x = x.to(device)
25 | y_true = y_true.to(device)
26 |
27 | with torch.no_grad():
28 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
29 | features = features.permute((0,2,1))
30 | out = selector(features)
31 |
32 | pred = out.max(1, keepdim=True)[1]
33 | correct += pred.eq(y_true.view_as(pred)).sum().item()
34 | total += x.shape[0]
35 | try:
36 | pred_all = torch.cat((pred_all, pred.view_as(y_true)))
37 | y_all = torch.cat((y_all, y_true))
38 | except NameError:
39 | pred_all = pred.view_as(y_true)
40 | y_all = y_true
41 |
42 | y_all = y_all.detach().cpu()
43 | pred_all = pred_all.detach().cpu()
44 | f1 = f1_score(y_all,pred_all,average='macro', labels=torch.unique(y_all))
45 |
46 | return correct/total, f1
47 |
48 | def train_model_selector(selector, model_name, models_list, device, root_dir='data',
49 | batch_size=32, lr=1e-6, l2=0,
50 | num_epochs=12, decayRate=0.96, save=True, test_way='ood'):
51 | for model in models_list:
52 | model.eval()
53 |
54 | train_loader, val_loader, _, grouper = get_data_loader(root_dir=root_dir,
55 | batch_size=batch_size, domain=None,
56 | test_way=test_way, n_groups_per_batch=0)
57 |
58 | criterion = nn.CrossEntropyLoss()
59 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2)
60 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
61 |
62 | i = 0
63 |
64 | losses = []
65 | acc_best = 0
66 |
67 | tot = len(train_loader)
68 |
69 | for epoch in range(num_epochs):
70 |
71 | print(f"Epoch:{epoch}|| Total:{tot}")
72 |
73 | for x, y_true, metadata in train_loader:
74 | selector.train()
75 |
76 | z = grouper.metadata_to_group(metadata)
77 | z = set(z.tolist())
78 | #assert z.issubset(set(meta_indices))
79 |
80 | x = x.to(device)
81 | y_true = y_true.to(device)
82 |
83 | with torch.no_grad():
84 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
85 | features = features.permute((0,2,1))
86 | out = selector(features)
87 |
88 | loss = criterion(out, y_true)
89 | loss.backward()
90 | optimizer.step()
91 | optimizer.zero_grad()
92 | losses.append(loss.item()/batch_size)
93 |
94 | if i % (tot//2) == 0 and i != 0:
95 | losses = np.mean(losses)
96 | acc, f1 = get_selector_accuracy(selector, models_list, val_loader,
97 | grouper, device, progress=False)
98 |
99 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f} || F1:{:.4f}".format(i, losses, acc, f1))
100 | losses = []
101 |
102 | if f1 > acc_best and save:
103 | print("Saving model ...")
104 | save_model(selector, model_name+"_selector", 0, test_way=test_way)
105 | acc_best = f1
106 |
107 | i += 1
108 | scheduler.step()
109 |
--------------------------------------------------------------------------------
/src/rxrx1/rxrx1_experts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | import os
8 | from transformers import get_cosine_schedule_with_warmup
9 | from src.rxrx1.rxrx1_utils import *
10 |
11 | def get_model_accuracy(model, data_loader, grouper, device, domain=None):
12 | model.eval()
13 | correct = 0
14 | total = 0
15 | for x, y_true, metadata in iter(data_loader):
16 |
17 | z = grouper.metadata_to_group(metadata)
18 | z = set(z.tolist())
19 | if domain is not None:
20 | assert z.issubset(set(domain))
21 |
22 | x = x.to(device)
23 | y_true = y_true.to(device)
24 |
25 | out = model(x)
26 | pred = out.max(1, keepdim=True)[1]
27 | correct += pred.eq(y_true.view_as(pred)).sum().item()
28 | total += x.shape[0]
29 | return correct/total
30 |
31 | def train_model(model, model_name, device, domain=None, batch_size=75, lr=1e-4, l2=1e-5,
32 | num_epochs=90, decayRate=1., save=False, test_way='ood', root_dir='data'):
33 |
34 | train_loader, val_loader, test_loader, grouper = get_data_loader(root_dir=root_dir,
35 | batch_size=batch_size, domain=domain,
36 | test_way=test_way, n_groups_per_batch=0)
37 |
38 | criterion = nn.CrossEntropyLoss()
39 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2)
40 |
41 | i = 0
42 |
43 | losses = []
44 | acc_best = 0
45 |
46 | tot = len(train_loader)
47 | scheduler = get_cosine_schedule_with_warmup(
48 | optimizer,
49 | num_warmup_steps=tot * 10,
50 | num_training_steps=tot*num_epochs)
51 |
52 | for epoch in range(num_epochs):
53 |
54 | print(f"Epoch:{epoch} || Total:{tot}")
55 |
56 | for x, y_true, metadata in iter(train_loader):
57 | model.train()
58 |
59 | z = grouper.metadata_to_group(metadata)
60 | z = set(z.tolist())
61 | if domain is not None:
62 | assert z.issubset(set(domain))
63 |
64 | x = x.to(device)
65 | y_true = y_true.to(device)
66 |
67 | pred = model(x)
68 |
69 | loss = criterion(pred, y_true)
70 | loss.backward()
71 | optimizer.step()
72 | optimizer.zero_grad()
73 | losses.append(loss.item()/batch_size)
74 |
75 | if i % (tot//2) == 0 and i != 0:
76 | losses = np.mean(losses)
77 | acc = get_model_accuracy(model, val_loader, grouper, device=device)
78 |
79 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f}".format(i, losses, acc))
80 | losses = []
81 |
82 | if acc > acc_best and save:
83 | print("Saving model ...")
84 | save_model(model, model_name+"_exp", 0, test_way=test_way)
85 | acc_best = acc
86 |
87 | i += 1
88 | scheduler.step()
89 |
90 | def train_exp(models_list, domain_specific_indices, device, batch_size=75, lr=1e-4, l2=1e-5,
91 | num_epochs=90, decayRate=1., save=False, test_way='ood', name="Resnet50_experts",
92 | root_dir='data'):
93 |
94 | assert len(models_list) == len(domain_specific_indices)
95 | for i in range(len(models_list)):
96 | print(f"Training model {i} for domain", *domain_specific_indices[i])
97 | train_model(models_list[i], name+'_'+str(i), device=device,
98 | domain=domain_specific_indices[i], batch_size=batch_size,
99 | lr=lr, l2=l2, num_epochs=num_epochs,
100 | decayRate=decayRate, save=save, test_way=test_way,
101 | root_dir=root_dir)
102 |
103 | def get_expert_split(num_experts, root_dir='data'):
104 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir)
105 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]}
106 | return all_split, split_to_cluster
107 |
--------------------------------------------------------------------------------
/src/camelyon/camelyon_experts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | import os
8 | from src.camelyon.camelyon_utils import *
9 |
10 | def get_model_accuracy(model, data_loader, grouper, device, domain=None, progress=False):
11 | model.eval()
12 | correct = 0
13 | total = 0
14 | for x, y_true, metadata in iter(data_loader):
15 |
16 | z = grouper.metadata_to_group(metadata)
17 | z = set(z.tolist())
18 | if domain is not None:
19 | assert z.issubset(set(domain))
20 |
21 | x = x.to(device)
22 | y_true = y_true.to(device)
23 |
24 | out = model(x)
25 | pred = (out > 0.0).squeeze().long()
26 | correct += pred.eq(y_true.view_as(pred)).sum().item()
27 | total += x.shape[0]
28 |
29 | return correct/total
30 |
31 | def train_model(model, model_name, device, domain=None, batch_size=32, lr=1e-3, l2=1e-2,
32 | num_epochs=5, decayRate=1., save=False, test_way='ood', root_dir='data'):
33 |
34 | train_loader, val_loader, test_loader, grouper = get_data_loader(root_dir=root_dir,
35 | batch_size=batch_size, domain=domain,
36 | test_way=test_way, n_groups_per_batch=0,
37 | return_dataset=False)
38 |
39 | criterion = nn.BCEWithLogitsLoss()
40 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2)
41 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
42 |
43 | i = 0
44 |
45 | losses = []
46 | acc_best = 0
47 |
48 | tot = len(train_loader)
49 |
50 | for epoch in range(num_epochs):
51 |
52 | print(f"Epoch:{epoch} || Total:{tot}")
53 |
54 | for x, y_true, metadata in iter(train_loader):
55 | model.train()
56 |
57 | z = grouper.metadata_to_group(metadata)
58 | z = set(z.tolist())
59 | if domain is not None:
60 | assert z.issubset(set(domain))
61 |
62 | x = x.to(device)
63 | y_true = y_true.to(device)
64 |
65 | pred = model(x)
66 |
67 | loss = criterion(pred, y_true.unsqueeze(-1).float())
68 | loss.backward()
69 | optimizer.step()
70 | optimizer.zero_grad()
71 | losses.append(loss.item()/batch_size)
72 |
73 | if i % (tot//2) == 0 and i != 0:
74 | losses = np.mean(losses)
75 | acc = get_model_accuracy(model, val_loader, grouper, device=device)
76 |
77 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f}".format(i, losses, acc))
78 | losses = []
79 |
80 | if acc > acc_best and save:
81 | print("Saving model ...")
82 | save_model(model, model_name+"_exp", 0, test_way=test_way)
83 | acc_best = acc
84 |
85 |
86 | i += 1
87 | scheduler.step()
88 |
89 | def train_exp(models_list, domain_specific_indices, device, batch_size=32, lr=1e-3, l2=1e-2,
90 | num_epochs=5, decayRate=1., save=False, test_way='ood', name="Dense121_experts",
91 | root_dir='data'):
92 |
93 | assert len(models_list) == len(domain_specific_indices)
94 | for i in range(len(models_list)):
95 | print(f"Training model {i} for domain", *domain_specific_indices[i])
96 | train_model(models_list[i], name+'_'+str(i), device=device,
97 | domain=domain_specific_indices[i], batch_size=batch_size,
98 | lr=lr, l2=l2, num_epochs=num_epochs,
99 | decayRate=decayRate, save=save, test_way=test_way,
100 | root_dir=root_dir)
101 |
102 | def get_expert_split(num_experts, root_dir='data'):
103 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir)
104 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]}
105 | return all_split, split_to_cluster
106 |
--------------------------------------------------------------------------------
/src/fmow/fmow_aggregator.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | from tqdm import tqdm
8 | import os
9 | from src.fmow.fmow_utils import *
10 |
11 | def get_selector_accuracy(selector, models_list, data_loader, grouper, device, progress=True, dataset=None):
12 | selector.eval()
13 | #mean_correct = 0
14 | if progress:
15 | data_loader = tqdm(data_loader)
16 | for x, y_true, metadata in data_loader:
17 | #z = grouper.metadata_to_group(metadata)
18 | #z = set(z.tolist())
19 | #assert z.issubset(set(meta_indices))
20 |
21 | x = x.to(device)
22 | y_true = y_true.to(device)
23 |
24 | with torch.no_grad():
25 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
26 | features = features.permute((0,2,1))
27 | out = selector(features)
28 |
29 | pred = out.max(1, keepdim=True)[1]
30 | try:
31 | pred_all = torch.cat((pred_all, pred.view_as(y_true)))
32 | y_all = torch.cat((y_all, y_true))
33 | metadata_all = torch.cat((metadata_all, metadata))
34 | except NameError:
35 | pred_all = pred.view_as(y_true)
36 | y_all = y_true
37 | metadata_all = metadata
38 | acc, worst_acc = get_fmow_metrics(pred_all, y_all, metadata_all, dataset)
39 | return acc, worst_acc
40 |
41 | def train_model_selector(selector, model_name, models_list, device, root_dir='data',
42 | batch_size=64, lr=3e-6, l2=0,
43 | num_epochs=20, decayRate=0.96, save=True, test_way='ood'):
44 | for model in models_list:
45 | model.eval()
46 |
47 | train_loader, val_loader, _, grouper, dataset = get_data_loader(root_dir=root_dir,
48 | batch_size=batch_size, domain=None,
49 | test_way=test_way, n_groups_per_batch=0,
50 | return_dataset=True)
51 |
52 | criterion = nn.CrossEntropyLoss()
53 | optimizer = optim.Adam(selector.parameters(), lr=lr, weight_decay=l2)
54 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
55 |
56 | i = 0
57 |
58 | losses = []
59 | acc_best = 0
60 |
61 | tot = len(train_loader)
62 |
63 | for epoch in range(num_epochs):
64 |
65 | print(f"Epoch:{epoch}|| Total:{tot}")
66 |
67 | for x, y_true, metadata in train_loader:
68 | selector.train()
69 |
70 | z = grouper.metadata_to_group(metadata)
71 | z = set(z.tolist())
72 | #assert z.issubset(set(meta_indices))
73 |
74 | x = x.to(device)
75 | y_true = y_true.to(device)
76 |
77 | with torch.no_grad():
78 | features = torch.stack([model(x).detach() for model in models_list], dim=-1)
79 | features = features.permute((0,2,1))
80 | out = selector(features)
81 |
82 | loss = criterion(out, y_true)
83 | loss.backward()
84 | optimizer.step()
85 | optimizer.zero_grad()
86 | losses.append(loss.item()/batch_size)
87 |
88 | if i % (tot//2) == 0 and i != 0:
89 | losses = np.mean(losses)
90 | acc, wc_acc = get_selector_accuracy(selector, models_list, val_loader,
91 | grouper, device, progress=False, dataset=dataset)
92 |
93 | print("Iter: {}/{} || Loss: {:.4f} || Acc:{:.4f} || WC Acc:{:.4f} ".format(i, tot, losses,
94 | acc, wc_acc))
95 | losses = []
96 |
97 | if wc_acc > acc_best and save:
98 | print("Saving model ...")
99 | save_model(selector, model_name+"_selector", 0, test_way=test_way)
100 | acc_best = wc_acc
101 |
102 | i += 1
103 | scheduler.step()
104 |
--------------------------------------------------------------------------------
/src/iwildcam/iwildcam_experts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | from sklearn.metrics import f1_score
8 | import os
9 | from src.iwildcam.iwildcam_utils import *
10 |
11 | def get_model_accuracy(model, data_loader, grouper, device, domain=None):
12 | model.eval()
13 | correct = 0
14 | total = 0
15 | for x, y_true, metadata in iter(data_loader):
16 |
17 | z = grouper.metadata_to_group(metadata)
18 | z = set(z.tolist())
19 | if domain is not None:
20 | assert z.issubset(set(domain))
21 |
22 | x = x.to(device)
23 | y_true = y_true.to(device)
24 |
25 | out = model(x)
26 | pred = out.max(1, keepdim=True)[1]
27 | correct += pred.eq(y_true.view_as(pred)).sum().item()
28 | total += x.shape[0]
29 | try:
30 | pred_all = torch.cat((pred_all, pred.view_as(y_true)))
31 | y_all = torch.cat((y_all, y_true))
32 | except NameError:
33 | pred_all = pred.view_as(y_true)
34 | y_all = y_true
35 | y_all = y_all.detach().cpu()
36 | pred_all = pred_all.detach().cpu()
37 | f1 = f1_score(y_all,pred_all,average='macro', labels=torch.unique(y_all))
38 | return correct/total, f1
39 |
40 | def train_model(model, model_name, device, domain=None, batch_size=16, lr=3e-5, l2=0,
41 | num_epochs=12, decayRate=1., save=False, test_way='ood', root_dir='data'):
42 |
43 | train_loader, val_loader, test_loader, grouper = get_data_loader(root_dir=root_dir,
44 | batch_size=batch_size, domain=domain,
45 | test_way=test_way, n_groups_per_batch=0)
46 |
47 | criterion = nn.CrossEntropyLoss()
48 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2)
49 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
50 |
51 | i = 0
52 |
53 | losses = []
54 | acc_best = 0
55 |
56 | tot = len(train_loader)
57 |
58 | for epoch in range(num_epochs):
59 |
60 | print(f"Epoch:{epoch} || Total:{tot}")
61 |
62 | for x, y_true, metadata in iter(train_loader):
63 | model.train()
64 |
65 | z = grouper.metadata_to_group(metadata)
66 | z = set(z.tolist())
67 | if domain is not None:
68 | assert z.issubset(set(domain))
69 |
70 | x = x.to(device)
71 | y_true = y_true.to(device)
72 |
73 | pred = model(x)
74 |
75 | loss = criterion(pred, y_true)
76 | loss.backward()
77 | optimizer.step()
78 | optimizer.zero_grad()
79 | losses.append(loss.item()/batch_size)
80 |
81 | if i % (tot//2) == 0 and i != 0:
82 | losses = np.mean(losses)
83 | acc, f1 = get_model_accuracy(model, val_loader, grouper, device=device)
84 |
85 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f} || F1:{:.4f} ".format(i, losses, acc, f1))
86 | losses = []
87 |
88 | if f1 > acc_best and save:
89 | print("Saving model ...")
90 | save_model(model, model_name+"_exp", 0, test_way=test_way)
91 | acc_best = f1
92 |
93 |
94 | i += 1
95 | scheduler.step()
96 |
97 | def train_exp(models_list, domain_specific_indices, device, batch_size=16, lr=1e-4, l2=0,
98 | num_epochs=30, decayRate=0.96, save=False, test_way='ood', name="Resnet50_experts",
99 | root_dir='data'):
100 |
101 | assert len(models_list) == len(domain_specific_indices)
102 | for i in range(len(models_list)):
103 | print(f"Training model {i} for domain", *domain_specific_indices[i])
104 | train_model(models_list[i], name+'_'+str(i), device=device,
105 | domain=domain_specific_indices[i], batch_size=batch_size,
106 | lr=lr, l2=l2, num_epochs=num_epochs,
107 | decayRate=decayRate, save=save, test_way=test_way,
108 | root_dir=root_dir)
109 |
110 | def get_expert_split(num_experts, root_dir='data'):
111 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir)
112 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]}
113 | return all_split, split_to_cluster
114 |
--------------------------------------------------------------------------------
/src/fmow/fmow_experts.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | import os
8 | from src.fmow.fmow_utils import *
9 |
10 | def get_model_accuracy(model, data_loader, grouper, device, domain=None, dataset=None):
11 | model.eval()
12 |
13 | for x, y_true, metadata in iter(data_loader):
14 |
15 | z = grouper.metadata_to_group(metadata)
16 | z = set(z.tolist())
17 | if domain is not None:
18 | assert z.issubset(set(domain))
19 |
20 | x = x.to(device)
21 | y_true = y_true.to(device)
22 |
23 | out = model(x)
24 | pred = out.max(1, keepdim=True)[1]
25 | try:
26 | pred_all = torch.cat((pred_all, pred.view_as(y_true)))
27 | y_all = torch.cat((y_all, y_true))
28 | metadata_all = torch.cat((metadata_all, metadata))
29 | except NameError:
30 | pred_all = pred.view_as(y_true)
31 | y_all = y_true
32 | metadata_all = metadata
33 | acc, worst_acc = get_fmow_metrics(pred_all, y_all, metadata_all, dataset)
34 | return acc, worst_acc
35 |
36 | def train_model(model, model_name, device, domain=None, batch_size=64, lr=1e-4, l2=0,
37 | num_epochs=50, decayRate=0.96, save=False, test_way='ood', root_dir='data'):
38 |
39 | train_loader, val_loader, test_loader, grouper, dataset = get_data_loader(root_dir=root_dir,
40 | batch_size=batch_size, domain=domain,
41 | test_way=test_way, n_groups_per_batch=0,
42 | return_dataset=True)
43 |
44 | criterion = nn.CrossEntropyLoss()
45 | optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=l2)
46 | scheduler = optim.lr_scheduler.ExponentialLR(optimizer=optimizer, gamma=decayRate)
47 |
48 | i = 0
49 |
50 | losses = []
51 | acc_best = 0
52 |
53 | tot = len(train_loader)
54 |
55 | for epoch in range(num_epochs):
56 |
57 | print(f"Epoch:{epoch} || Total:{tot}")
58 |
59 | for x, y_true, metadata in iter(train_loader):
60 | model.train()
61 |
62 | z = grouper.metadata_to_group(metadata)
63 | z = set(z.tolist())
64 | if domain is not None:
65 | assert z.issubset(set(domain))
66 |
67 | x = x.to(device)
68 | y_true = y_true.to(device)
69 |
70 | pred = model(x)
71 |
72 | loss = criterion(pred, y_true)
73 | loss.backward()
74 | optimizer.step()
75 | optimizer.zero_grad()
76 | losses.append(loss.item()/batch_size)
77 |
78 | if i % (tot//2) == 0 and i != 0:
79 | losses = np.mean(losses)
80 | acc, worst_acc = get_model_accuracy(model, val_loader, grouper, device=device, dataset=dataset)
81 |
82 | print("Iter: {} || Loss: {:.4f} || Acc:{:.4f} ||Worst Acc:{:.4f} ".format(i,
83 | losses, acc, worst_acc))
84 | losses = []
85 |
86 | if worst_acc > acc_best and save:
87 | print("Saving model ...")
88 | save_model(model, model_name+"_exp", 0, test_way=test_way)
89 | acc_best = worst_acc
90 |
91 | i += 1
92 | scheduler.step()
93 |
94 | def train_exp(models_list, domain_specific_indices, device, batch_size=64, lr=1e-4, l2=0,
95 | num_epochs=50, decayRate=0.96, save=False, test_way='ood', name="Resnet50_experts",
96 | root_dir='data'):
97 |
98 | assert len(models_list) == len(domain_specific_indices)
99 | for i in range(len(models_list)):
100 | print(f"Training model {i} for domain", *domain_specific_indices[i])
101 | train_model(models_list[i], name+'_'+str(i), device=device,
102 | domain=domain_specific_indices[i], batch_size=batch_size,
103 | lr=lr, l2=l2, num_epochs=num_epochs,
104 | decayRate=decayRate, save=save, test_way=test_way,
105 | root_dir=root_dir)
106 |
107 | def get_expert_split(num_experts, root_dir='data'):
108 | all_split = split_domains(num_experts=num_experts, root_dir=root_dir)
109 | split_to_cluster = {d:i for i in range(len(all_split)) for d in all_split[i]}
110 | return all_split, split_to_cluster
111 |
--------------------------------------------------------------------------------
/src/iwildcam/iwildcam_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | import math
9 | import random
10 | import pickle5 as pickle
11 | from src.transformer import Transformer
12 | import copy
13 | import time
14 | import os
15 | import wilds
16 | from wilds import get_dataset
17 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
18 | from wilds.common.grouper import CombinatorialGrouper
19 |
20 | # Utils
21 |
22 | def split_domains(num_experts, root_dir='data'):
23 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir)
24 | train_data = dataset.get_subset('train')
25 | locations = list(set(train_data.metadata_array[:,0].detach().numpy().tolist()))
26 | random.shuffle(locations)
27 | num_domains_per_super = len(locations) / float(num_experts)
28 | all_split = [[] for _ in range(num_experts)]
29 | for i in range(len(locations)):
30 | all_split[int(i//num_domains_per_super)].append(locations[i])
31 | return all_split
32 |
33 | def get_subset_with_domain(dataset, split=None, domain=None, transform=None):
34 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset:
35 | subset = copy.deepcopy(dataset)
36 | else:
37 | subset = dataset.get_subset(split, transform=transform)
38 | if domain is not None:
39 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,0][subset.indices], domain)).ravel()
40 | subset.indices = subset.indices[idx]
41 | return subset
42 |
43 | def initialize_image_base_transform(dataset):
44 | transform_steps = []
45 | if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution):
46 | crop_size = min(dataset.original_resolution)
47 | transform_steps.append(transforms.CenterCrop(crop_size))
48 | transform_steps.append(transforms.Resize((448,448)))
49 | transform_steps += [
50 | transforms.ToTensor(),
51 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
52 | ]
53 | transform = transforms.Compose(transform_steps)
54 | return transform
55 |
56 | def get_data_loader(batch_size=16, domain=None, test_way='id',
57 | n_groups_per_batch=0, uniform_over_groups=False, root_dir='data'):
58 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir)
59 | grouper = CombinatorialGrouper(dataset, ['location'])
60 |
61 | transform = initialize_image_base_transform(dataset)
62 |
63 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform)
64 | if test_way == 'ood':
65 | val_data = dataset.get_subset('val', transform=transform)
66 | test_data = dataset.get_subset('test', transform=transform)
67 | else:
68 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform)
69 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform)
70 |
71 |
72 | if n_groups_per_batch == 0:
73 | #0 identify standard loader
74 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size)
75 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size)
76 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size)
77 |
78 | else:
79 | # All use get_train_loader to enable grouper
80 | train_loader = get_train_loader('group', train_data, grouper=grouper,
81 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size)
82 | val_loader = get_train_loader('group', val_data, grouper=grouper,
83 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
84 | uniform_over_groups=uniform_over_groups)
85 | test_loader = get_train_loader('group', test_data, grouper=grouper,
86 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
87 | uniform_over_groups=uniform_over_groups)
88 |
89 | return train_loader, val_loader, test_loader, grouper
90 |
91 | def get_test_loader(batch_size=16, split='val', root_dir='data'):
92 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir)
93 | grouper = CombinatorialGrouper(dataset, ['location'])
94 |
95 | transform = initialize_image_base_transform(dataset)
96 |
97 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform)
98 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist()))
99 | test_loader = []
100 |
101 | for domain in all_domains:
102 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform)
103 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size)
104 | test_loader.append((domain, domain_loader))
105 |
106 | return test_loader, grouper
107 |
108 | def get_mask_grouper(root_dir='data'):
109 | dataset = get_dataset(dataset='iwildcam', download=True, root_dir=root_dir)
110 | grouper = CombinatorialGrouper(dataset, ['location'])
111 | return grouper
112 |
113 | def save_model(model, name, epoch, test_way='ood'):
114 | if not os.path.exists("model/iwildcam"):
115 | os.makedirs("model/iwildcam")
116 | path = "model/iwildcam/{0}_best.pth".format(name)
117 | torch.save(model.state_dict(), path)
118 |
119 | def mask_feat(feat, mask_index, num_experts, exclude=True):
120 | assert feat.shape[0] == mask_index.shape[0]
121 | if exclude:
122 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index]
123 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx]
124 | else:
125 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index])
126 | feat = feat.permute((0,2,1))
127 | return feat
128 |
129 | def l2_loss(input, target):
130 | loss = torch.square(target - input)
131 | loss = torch.mean(loss)
132 | return loss
133 |
134 | # Models
135 | class ResNetFeature(nn.Module):
136 | def __init__(self, original_model, layer=-1):
137 | super(ResNetFeature, self).__init__()
138 | self.features = nn.Sequential(*list(original_model.children())[:layer])
139 |
140 | def forward(self, x):
141 | x = self.features(x)
142 | x = x.view(x.shape[0], -1)
143 | return x
144 |
145 | class fa_selector(nn.Module):
146 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=182, pool='mean'):
147 | super(fa_selector, self).__init__()
148 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout)
149 | self.pool = pool
150 | self.mlp = nn.Sequential(
151 | nn.LayerNorm(dim),
152 | nn.Linear(dim, out_dim)
153 | )
154 |
155 | def forward(self, x):
156 | x = self.transformer(x)
157 | if self.pool == 'mean':
158 | x = x.mean(dim=1)
159 | elif self.pool == 'max':
160 | x = x.max(dim=1)
161 | else:
162 | raise NotImplementedError
163 | x = self.mlp(x)
164 | return x
165 |
166 | def get_feat(self, x):
167 | x = self.transformer(x)
168 | if self.pool == 'mean':
169 | x = x.mean(dim=1)
170 | elif self.pool == 'max':
171 | x = x.max(dim=1)
172 | else:
173 | raise NotImplementedError
174 | return x
175 |
176 | class DivideModel(nn.Module):
177 | def __init__(self, original_model, layer=-1):
178 | super(DivideModel, self).__init__()
179 | self.num_ftrs = original_model.fc.in_features
180 | self.num_class = original_model.fc.out_features
181 | self.features = nn.Sequential(*list(original_model.children())[:layer])
182 | self.classifier = nn.Sequential(*list(original_model.children())[layer:])
183 |
184 | def forward(self, x):
185 | x = self.features(x)
186 | x = x.view(-1, self.num_ftrs)
187 | x = self.classifier(x)
188 | x = x.view(-1, self.num_class)
189 | return x
190 |
191 | def StudentModel(device, num_classes=182, load_path=None):
192 | model = torchvision.models.resnet50(pretrained=True)
193 | num_ftrs = model.fc.in_features
194 | model.fc = nn.Linear(num_ftrs, num_classes)
195 | if load_path:
196 | model.load_state_dict(torch.load(load_path))
197 | model = DivideModel(model)
198 | model = model.to(device)
199 | return model
200 |
201 | def get_feature_list(models_list, device):
202 | feature_list = []
203 | for model in models_list:
204 | feature_list.append(ResNetFeature(model).to(device))
205 | return feature_list
206 |
207 | def get_models_list(device, num_domains=9, num_classes=182, pretrained=False, bb='res50'):
208 | models_list = []
209 | for _ in range(num_domains+1):
210 | model = torchvision.models.resnet50(pretrained=True)
211 | num_ftrs = model.fc.in_features
212 | model.fc = nn.Linear(num_ftrs, num_classes)
213 | model = model.to(device)
214 | models_list.append(model)
215 | return models_list
216 |
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import numpy as np
4 | import random
5 | import argparse
6 | import pickle
7 | from src.configs import default_param
8 |
9 | OUT_DIM = {'iwildcam':182,
10 | 'camelyon':1,
11 | 'rxrx1':1139,
12 | 'fmow':62,
13 | 'poverty':1}
14 |
15 | FEAT_DIM = {'iwildcam':2048,
16 | 'camelyon':1024,
17 | 'rxrx1':2048,
18 | 'fmow':1024,
19 | 'poverty':512}
20 |
21 | def get_parser():
22 | parser = argparse.ArgumentParser()
23 |
24 | parser.add_argument('--gpu', type=str, default='0') # We only support single gpu training for now
25 | parser.add_argument('--threads', type=int, default=12)
26 |
27 | parser.add_argument('--dataset', type=str, default='iwildcam',
28 | choices=['iwildcam', 'fmow', 'camelyon', 'rxrx1', 'poverty'])
29 | parser.add_argument('--data_dir', type=str, default='data')
30 |
31 | parser.add_argument('--load_trained_experts', action='store_true')
32 | parser.add_argument('--load_pretrained_aggregator', action='store_true')
33 | parser.add_argument('--load_pretrained_student', action='store_true')
34 |
35 | parser.add_argument('--test', action='store_true')
36 |
37 | args = parser.parse_args()
38 | args_dict = args.__dict__
39 | args_dict.update(default_param[args.dataset])
40 | args = argparse.Namespace(**args_dict)
41 | return args
42 |
43 | def set_seed(seed):
44 | if torch.cuda.is_available():
45 | torch.cuda.manual_seed(seed)
46 | torch.manual_seed(seed)
47 | np.random.seed(seed)
48 | random.seed(seed)
49 | torch.backends.cudnn.deterministic = True
50 |
51 | def train(args):
52 |
53 | if args.dataset == 'iwildcam':
54 | from src.iwildcam.iwildcam_utils import fa_selector, StudentModel, get_models_list, get_feature_list
55 | from src.iwildcam.iwildcam_experts import train_exp, train_model, get_expert_split
56 | from src.iwildcam.iwildcam_aggregator import train_model_selector
57 | from src.iwildcam.iwildcam_train import train_kd, eval
58 | elif args.dataset == 'camelyon':
59 | from src.camelyon.camelyon_utils import fa_selector, StudentModel, get_models_list, get_feature_list
60 | from src.camelyon.camelyon_experts import train_exp, train_model, get_expert_split
61 | from src.camelyon.camelyon_aggregator import train_model_selector
62 | from src.camelyon.camelyon_train import train_kd, eval
63 | elif args.dataset == 'rxrx1':
64 | from src.rxrx1.rxrx1_utils import fa_selector, StudentModel, get_models_list, get_feature_list
65 | from src.rxrx1.rxrx1_experts import train_exp, train_model, get_expert_split
66 | from src.rxrx1.rxrx1_aggregator import train_model_selector
67 | from src.rxrx1.rxrx1_train import train_kd, eval
68 | elif args.dataset == 'fmow':
69 | from src.fmow.fmow_utils import fa_selector, StudentModel, get_models_list, get_feature_list
70 | from src.fmow.fmow_experts import train_exp, train_model, get_expert_split
71 | from src.fmow.fmow_aggregator import train_model_selector
72 | from src.fmow.fmow_train import train_kd, eval
73 | else:
74 | raise NotImplementedError
75 |
76 | name = f"{args.dataset}_{str(args.num_experts)}experts_seed{str(args.seed)}"
77 |
78 | models_list = get_models_list(device=device, num_domains=args.num_experts-1)
79 |
80 | try:
81 | with open(f"model/{args.dataset}/domain_split.pkl", "rb") as f:
82 | all_split, split_to_cluster = pickle.load(f)
83 | except FileNotFoundError:
84 | all_split, split_to_cluster = get_expert_split(args.num_experts, root_dir=args.data_dir)
85 | with open(f"model/{args.dataset}/domain_split.pkl", "wb") as f:
86 | pickle.dump((all_split, split_to_cluster), f)
87 |
88 | if args.load_trained_experts:
89 | print("Skip training domain specific experts...")
90 | else:
91 | print("Training domain specific experts...")
92 | train_exp(models_list, all_split, device, batch_size=args.expert_batch_size,
93 | lr=args.expert_lr, l2=args.expert_l2, num_epochs=args.expert_epoch,
94 | save=True, name=name, root_dir=args.data_dir)
95 |
96 | for i,model in enumerate(models_list):
97 | model.load_state_dict(torch.load(f"model/{args.dataset}/{name}_{str(i)}_exp_best.pth"))
98 | models_list = get_feature_list(models_list, device=device)
99 |
100 | selector = fa_selector(dim=FEAT_DIM[args.dataset], depth=args.aggregator_depth, heads=args.aggregator_heads,
101 | mlp_dim=FEAT_DIM[args.dataset]*2, dropout=args.aggregator_dropout,
102 | out_dim=OUT_DIM[args.dataset]).to(device)
103 | if args.load_pretrained_aggregator:
104 | print("Skip pretraining knowledge aggregator...")
105 | else:
106 | print("Pretraining knowledge aggregator...")
107 | train_model_selector(selector, name+'_pretrained', models_list, device, root_dir=args.data_dir,
108 | num_epochs=args.aggregator_pretrain_epoch, save=True)
109 |
110 | selector.load_state_dict(torch.load(f"model/{args.dataset}/{name}_pretrained_selector_best.pth"))
111 |
112 | student = StudentModel(device=device, num_classes=OUT_DIM[args.dataset])
113 |
114 | if args.load_pretrained_student:
115 | print("Skip pretraining student...")
116 | else:
117 | print("Pretraining student...")
118 | train_model(student, name+"_pretrained", device=device,
119 | num_epochs=args.student_pretrain_epoch, save=True,
120 | root_dir=args.data_dir)
121 |
122 | student.load_state_dict(torch.load(f"model/{args.dataset}/{name}_pretrained_exp_best.pth"))
123 |
124 | print("Start meta-training...")
125 | train_kd(selector, name+"_meta", models_list, student, name+"_meta", split_to_cluster,
126 | device=device, batch_size=args.batch_size, sup_size=args.sup_size,
127 | tlr=args.tlr, slr=args.slr, ilr=args.ilr, num_epochs=args.epoch, save=True, test_way='ood',
128 | root_dir=args.data_dir)
129 |
130 | def test(args):
131 |
132 | if args.dataset == 'iwildcam':
133 | from src.iwildcam.iwildcam_utils import fa_selector, StudentModel, get_models_list, get_feature_list
134 | from src.iwildcam.iwildcam_train import eval
135 | elif args.dataset == 'camelyon':
136 | from src.camelyon.camelyon_utils import fa_selector, StudentModel, get_models_list, get_feature_list
137 | from src.camelyon.camelyon_train import eval
138 | elif args.dataset == 'rxrx1':
139 | from src.rxrx1.rxrx1_utils import fa_selector, StudentModel, get_models_list, get_feature_list
140 | from src.rxrx1.rxrx1_train import eval
141 | elif args.dataset == 'fmow':
142 | from src.fmow.fmow_utils import fa_selector, StudentModel, get_models_list, get_feature_list
143 | from src.fmow.fmow_train import eval
144 |
145 | name = f"{args.dataset}_{str(args.num_experts)}experts_seed{str(args.seed)}"
146 | models_list = get_models_list(device=device, num_domains=args.num_experts-1)
147 | for i,model in enumerate(models_list):
148 | model.load_state_dict(torch.load(f"model/{args.dataset}/{name}_{str(i)}_exp_best.pth"))
149 | models_list = get_feature_list(models_list, device=device)
150 | selector = fa_selector(dim=FEAT_DIM[args.dataset], depth=args.aggregator_depth, heads=args.aggregator_heads,
151 | mlp_dim=FEAT_DIM[args.dataset]*2, dropout=args.aggregator_dropout,
152 | out_dim=OUT_DIM[args.dataset]).to(device)
153 | selector.load_state_dict(torch.load(f"model/{args.dataset}/{name}_meta_selector_best.pth"))
154 |
155 | student = StudentModel(device=device, num_classes=OUT_DIM[args.dataset]).to(device)
156 | student.load_state_dict(torch.load(f"model/{args.dataset}/{name}_meta_student_best.pth"))
157 | metrics = eval(selector, models_list, student, batch_size=args.sup_size,
158 | device=device, ilr=args.ilr, test=True, root_dir=args.data_dir)
159 | if args.dataset == 'iwildcam':
160 | print(f"Test Accuracy:{metrics[1]:.4f} Test Macro-F1:{metrics[2]:.4f}")
161 | with open(f'result/{args.dataset}/result.txt', 'a+') as f:
162 | f.write(f"Seed: {args.seed} || Test Accuracy:{metrics[1]:.4f} || Test Macro-F1:{metrics[2]:.4f}")
163 | elif args.dataset in ['camelyon', 'rxrx1']:
164 | print(f"Test Accuracy:{metrics:.4f}")
165 | with open(f'result/{args.dataset}/result.txt', 'a+') as f:
166 | f.write(f"Seed: {args.seed} || Test Accuracy:{metrics:.4f}")
167 | elif args.dataset == 'fmow':
168 | print(f"WC Accuracy:{metrics[0]:.4f} Acc:{metrics[1]:.4f}")
169 | with open(f'result/{args.dataset}/result.txt', 'a+') as f:
170 | f.write(f"Seed: {args.seed} || WC Accuracy:{metrics[0]:.4f} || Acc:{metrics[1]:.4f}")
171 |
172 | if __name__ == "__main__":
173 | args = get_parser()
174 | os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
175 | torch.set_num_threads(args.threads)
176 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
177 | set_seed(args.seed)
178 | if not os.path.exists(f"model/{args.dataset}"):
179 | os.makedirs(f"model/{args.dataset}")
180 | if not os.path.exists(f"log/{args.dataset}"):
181 | os.makedirs(f"log/{args.dataset}")
182 | if args.test:
183 | if not os.path.exists(f"result/{args.dataset}"):
184 | os.makedirs(f"result/{args.dataset}")
185 | test(args)
186 | else:
187 | train(args)
188 |
--------------------------------------------------------------------------------
/src/rxrx1/rxrx1_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torchvision.transforms.functional as TF
6 | import torch.optim as optim
7 | import torchvision
8 | import torchvision.transforms as transforms
9 | import math
10 | import random
11 | import pickle5 as pickle
12 | from src.transformer import Transformer
13 | import copy
14 | import time
15 | import os
16 | import wilds
17 | from wilds import get_dataset
18 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
19 | from wilds.common.grouper import CombinatorialGrouper
20 |
21 | # Utils
22 |
23 | def split_domains(num_experts, root_dir='data'):
24 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir)
25 | train_data = dataset.get_subset('train')
26 | experiment = list(set(train_data.metadata_array[:,1].detach().numpy().tolist()))
27 | random.shuffle(experiment)
28 | num_domains_per_super = len(experiment) / float(num_experts)
29 | all_split = [[] for _ in range(num_experts)]
30 | for i in range(len(experiment)):
31 | all_split[int(i//num_domains_per_super)].append(experiment[i])
32 | return all_split
33 |
34 | def get_subset_with_domain(dataset, split, domain=None, transform=None):
35 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset:
36 | subset = copy.deepcopy(dataset)
37 | else:
38 | subset = dataset.get_subset(split, transform=transform)
39 | if domain is not None:
40 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,1][subset.indices], domain)).ravel()
41 | subset.indices = subset.indices[idx]
42 | return subset
43 |
44 | def initialize_image_base_transform(is_training):
45 | def standardize(x: torch.Tensor) -> torch.Tensor:
46 | mean = x.mean(dim=(1, 2))
47 | std = x.std(dim=(1, 2))
48 | std[std == 0.] = 1.
49 | return TF.normalize(x, mean, std)
50 | t_standardize = transforms.Lambda(lambda x: standardize(x))
51 |
52 | angles = [0, 90, 180, 270]
53 | def random_rotation(x: torch.Tensor) -> torch.Tensor:
54 | angle = angles[torch.randint(low=0, high=len(angles), size=(1,))]
55 | if angle > 0:
56 | x = TF.rotate(x, angle)
57 | return x
58 | t_random_rotation = transforms.Lambda(lambda x: random_rotation(x))
59 |
60 | if is_training:
61 | transforms_ls = [
62 | t_random_rotation,
63 | transforms.RandomHorizontalFlip(),
64 | transforms.ToTensor(),
65 | t_standardize,
66 | ]
67 | else:
68 | transforms_ls = [
69 | transforms.ToTensor(),
70 | t_standardize,
71 | ]
72 | transform = transforms.Compose(transforms_ls)
73 | return transform
74 |
75 | def get_data_loader(batch_size=16, domain=None, test_way='id',
76 | n_groups_per_batch=0, uniform_over_groups=False, root_dir='data'):
77 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir)
78 | grouper = CombinatorialGrouper(dataset, ['experiment'])
79 |
80 | transform_train = initialize_image_base_transform(True)
81 | transform_test = initialize_image_base_transform(False)
82 |
83 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform_train)
84 | if test_way == 'ood':
85 | val_data = dataset.get_subset('val', transform=transform_test)
86 | test_data = dataset.get_subset('test', transform=transform_test)
87 | else:
88 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform_test)
89 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform_test)
90 |
91 |
92 | if n_groups_per_batch == 0:
93 | #0 identify standard loader
94 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size)
95 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size)
96 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size)
97 |
98 | else:
99 | # All use get_train_loader to enable grouper
100 | train_loader = get_train_loader('group', train_data, grouper=grouper,
101 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size)
102 | val_loader = get_train_loader('group', val_data, grouper=grouper,
103 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
104 | uniform_over_groups=uniform_over_groups)
105 | test_loader = get_train_loader('group', test_data, grouper=grouper,
106 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
107 | uniform_over_groups=uniform_over_groups)
108 |
109 | return train_loader, val_loader, test_loader, grouper
110 |
111 | def get_mask_grouper(root_dir='data'):
112 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir)
113 | grouper = CombinatorialGrouper(dataset, ['experiment'])
114 | return grouper
115 |
116 | def get_test_loader(batch_size=16, split='val', root_dir='data'):
117 | dataset = get_dataset(dataset='rxrx1', download=True, root_dir=root_dir)
118 | grouper = CombinatorialGrouper(dataset, ['experiment'])
119 |
120 | transform = initialize_image_base_transform(dataset)
121 |
122 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform)
123 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist()))
124 | test_loader = []
125 |
126 | for domain in all_domains:
127 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform)
128 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size)
129 | test_loader.append((domain, domain_loader))
130 |
131 | return test_loader, grouper
132 |
133 | def save_model(model, name, epoch, test_way='ood'):
134 | if not os.path.exists("model/rxrx1"):
135 | os.makedirs("model/rxrx1")
136 | path = "model/rxrx1/{0}_best.pth".format(name)
137 | torch.save(model.state_dict(), path)
138 |
139 | def mask_feat(feat, mask_index, num_experts, exclude=True):
140 | assert feat.shape[0] == mask_index.shape[0]
141 | if exclude:
142 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index]
143 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx]
144 | else:
145 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index])
146 | feat = feat.permute((0,2,1))
147 | return feat
148 |
149 | def l2_loss(input, target):
150 | loss = torch.square(target - input)
151 | loss = torch.mean(loss)
152 | return loss
153 |
154 | # Models
155 | class ResNetFeature(nn.Module):
156 | def __init__(self, original_model, layer=-1):
157 | super(ResNetFeature, self).__init__()
158 | self.features = nn.Sequential(*list(original_model.children())[:layer])
159 |
160 | def forward(self, x):
161 | x = self.features(x)
162 | x = x.view(x.shape[0], -1)
163 | return x
164 |
165 | class fa_selector(nn.Module):
166 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=1139, pool='mean'):
167 | super(fa_selector, self).__init__()
168 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout)
169 | self.pool = pool
170 | self.mlp = nn.Sequential(
171 | nn.LayerNorm(dim),
172 | nn.Linear(dim, out_dim)
173 | )
174 |
175 | def forward(self, x):
176 | x = self.transformer(x)
177 | if self.pool == 'mean':
178 | x = x.mean(dim=1)
179 | elif self.pool == 'max':
180 | x = x.max(dim=1)
181 | else:
182 | raise NotImplementedError
183 | x = self.mlp(x)
184 | return x
185 |
186 | def get_feat(self, x):
187 | x = self.transformer(x)
188 | if self.pool == 'mean':
189 | x = x.mean(dim=1)
190 | elif self.pool == 'max':
191 | x = x.max(dim=1)
192 | else:
193 | raise NotImplementedError
194 | return x
195 |
196 | class DivideModel(nn.Module):
197 | def __init__(self, original_model, layer=-1):
198 | super(DivideModel, self).__init__()
199 | self.num_ftrs = original_model.fc.in_features
200 | self.num_class = original_model.fc.out_features
201 | self.features = nn.Sequential(*list(original_model.children())[:layer])
202 | self.classifier = nn.Sequential(*list(original_model.children())[layer:])
203 |
204 | def forward(self, x):
205 | x = self.features(x)
206 | x = x.view(-1, self.num_ftrs)
207 | x = self.classifier(x)
208 | x = x.view(-1, self.num_class)
209 | return x
210 |
211 | def StudentModel(device, num_classes=1139, load_path=None):
212 | model = torchvision.models.resnet50(pretrained=True)
213 | num_ftrs = model.fc.in_features
214 | model.fc = nn.Linear(num_ftrs, num_classes)
215 | if load_path:
216 | model.load_state_dict(torch.load(load_path))
217 | model = DivideModel(model)
218 | model = model.to(device)
219 | return model
220 |
221 | def get_feature_list(models_list, device):
222 | feature_list = []
223 | for model in models_list:
224 | feature_list.append(ResNetFeature(model).to(device))
225 | return feature_list
226 |
227 | def get_models_list(device, num_domains=2, num_classes=1139, pretrained=False, bb='res50'):
228 | models_list = []
229 | for _ in range(num_domains+1):
230 | model = torchvision.models.resnet50(pretrained=True)
231 | num_ftrs = model.fc.in_features
232 | model.fc = nn.Linear(num_ftrs, num_classes)
233 | model = model.to(device)
234 | models_list.append(model)
235 | return models_list
236 |
--------------------------------------------------------------------------------
/src/camelyon/camelyon_train.py:
--------------------------------------------------------------------------------
1 | from scipy.stats import truncnorm
2 | import numpy as np
3 | import time
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torch.optim as optim
8 | import torchvision
9 | from tqdm import tqdm
10 | import copy
11 | import learn2learn as l2l
12 | import time
13 | from datetime import datetime
14 | import os
15 | import wilds
16 | from src.camelyon.camelyon_utils import *
17 |
18 | def train_epoch(selector, selector_name, models_list, student, student_name,
19 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
20 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3,
21 | batch_size=256, sup_size=24, test_way='id', save=False,
22 | root_dir='data'):
23 | for model in models_list:
24 | model.eval()
25 |
26 | student_ce = nn.BCEWithLogitsLoss()
27 | #teacher_ce = nn.CrossEntropyLoss()
28 |
29 | features = student.features
30 | head = student.classifier
31 | features = l2l.algorithms.MAML(features, lr=ilr)
32 | features.to(device)
33 | head.to(device)
34 |
35 | all_params = list(features.parameters()) + list(head.parameters())
36 | optimizer_s = optim.Adam(all_params, lr=slr)
37 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr)
38 |
39 | i = 0
40 |
41 | losses = []
42 |
43 | iter_per_epoch = len(train_loader)
44 |
45 | for x, y_true, metadata in train_loader:
46 | selector.eval()
47 | head.eval()
48 | features.eval()
49 |
50 | z = grouper.metadata_to_group(metadata)
51 | z = set(z.tolist())
52 | assert len(z) == 1
53 | mask = mask_grouper.metadata_to_group(metadata)
54 | mask.apply_(lambda x: split_to_cluster[x])
55 |
56 | #sup_size = x.shape[0]//2
57 | x_sup = x[:sup_size]
58 | y_sup = y_true[:sup_size]
59 | x_que = x[sup_size:]
60 | y_que = y_true[sup_size:]
61 | mask = mask[:sup_size]
62 |
63 | x_sup = x_sup.to(device)
64 | y_sup = y_sup.to(device)
65 | x_que = x_que.to(device)
66 | y_que = y_que.to(device)
67 |
68 | with torch.no_grad():
69 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
70 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]])
71 | #logits = logits.permute((0,2,1))
72 | logits = mask_feat(logits, mask, len(models_list), exclude=True)
73 |
74 | t_out = selector.get_feat(logits)
75 |
76 | task_model = features.clone()
77 | task_model.module.eval()
78 | feat = task_model(x_que)
79 | feat = feat.view(feat.shape[0], -1)
80 | out = head(feat)
81 | with torch.no_grad():
82 | loss_pre = student_ce(out, y_que.unsqueeze(-1).float()).item()/x_que.shape[0]
83 |
84 | feat = task_model(x_sup)
85 | feat = feat.view_as(t_out)
86 |
87 | inner_loss = l2_loss(feat, t_out)
88 | task_model.adapt(inner_loss)
89 |
90 | x_que = task_model(x_que)
91 | x_que = x_que.view(x_que.shape[0], -1)
92 | s_que_out = head(x_que)
93 | s_que_loss = student_ce(s_que_out, y_que.unsqueeze(-1).float())
94 | #t_sup_loss = teacher_ce(t_out, y_sup)
95 |
96 | s_que_loss.backward()
97 |
98 | optimizer_s.step()
99 | optimizer_t.step()
100 | optimizer_s.zero_grad()
101 | optimizer_t.zero_grad()
102 |
103 | #print("Step:{}".format(time.time() - t_1))
104 | #t_1 = time.time()
105 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
106 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch,
107 | loss_pre,
108 | s_que_loss.item()/x_que.shape[0]))
109 | losses.append(s_que_loss.item()/x_que.shape[0])
110 |
111 | if i == iter_per_epoch//2:
112 | losses = np.mean(losses)
113 | acc = eval(selector, models_list, student, sup_size, device=device,
114 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
115 | root_dir=root_dir)
116 |
117 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
118 | f.write(f'Accuracy: {acc} \r\n')
119 | losses = []
120 |
121 | if acc > acc_best and save:
122 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
123 | f.write("Saving model ...\r\n")
124 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
125 | save_model(student, student_name+"_student", 0, test_way=test_way)
126 | acc_best = acc
127 |
128 | i += 1
129 | return acc_best
130 |
131 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device,
132 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30,
133 | decayRate=0.96, save=False, test_way='ood', root_dir='data'):
134 |
135 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None,
136 | test_way=test_way, n_groups_per_batch=1)
137 |
138 | mask_grouper = get_mask_grouper(root_dir=root_dir)
139 |
140 | curr = str(datetime.now())
141 | if not os.path.exists("log/camelyon"):
142 | os.makedirs("log/camelyon")
143 | #print(curr)
144 | print("Training log saved to log/camelyon/"+curr+".txt")
145 |
146 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
147 | f.write(selector_name+' '+student_name+'\r\n')
148 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n')
149 |
150 | accu_best = 0
151 | for epoch in range(num_epochs):
152 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
153 | f.write(f'Epoch: {epoch}\r\n')
154 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name,
155 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
156 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr,
157 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save,
158 | root_dir=root_dir)
159 | accu_best = max(accu_best, accu_epoch)
160 | accu = eval(selector, models_list, student, sup_size, device=device,
161 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
162 | root_dir=root_dir)
163 |
164 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
165 | f.write(f'Accuracy: {accu} \r\n')
166 |
167 | if accu > accu_best and save:
168 | with open('log/camelyon/'+curr+'.txt', 'a+') as f:
169 | f.write("Saving model ...\r\n")
170 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
171 | save_model(student, student_name+"_student", 0, test_way=test_way)
172 | accu_best = accu
173 | tlr = tlr*decayRate
174 | slr = slr*decayRate
175 |
176 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5,
177 | test=False, progress=True, uniform_over_groups=False, root_dir='data'):
178 |
179 | if test:
180 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir,
181 | return_dataset=True)
182 | else:
183 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir,
184 | return_dataset=True)
185 | '''if test:
186 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test')
187 | else:
188 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')'''
189 |
190 | features = student.features
191 | head = student.classifier
192 | head.to(device)
193 |
194 | student_maml = l2l.algorithms.MAML(features, lr=ilr)
195 | student_maml.to(device)
196 |
197 | correct = 0
198 | total = 0
199 |
200 | old_domain = {}
201 | if progress:
202 | loader = tqdm(iter(loader), total=len(loader))
203 |
204 | for domain, domain_loader in loader:
205 | adapted = False
206 | for x_sup, y_sup, metadata in domain_loader:
207 | student_maml.module.eval()
208 | selector.eval()
209 | head.eval()
210 |
211 | z = grouper.metadata_to_group(metadata)
212 | z = set(z.tolist())
213 | assert list(z)[0] == domain
214 |
215 | x_sup = x_sup.to(device)
216 | y_sup = y_sup.to(device)
217 |
218 | if not adapted:
219 | task_model = student_maml.clone()
220 | task_model.eval()
221 |
222 | with torch.no_grad():
223 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
224 | logits = logits.permute((0,2,1))
225 | t_out = selector.get_feat(logits)
226 |
227 | feat = task_model(x_sup)
228 | feat = feat.view_as(t_out)
229 |
230 | kl_loss = l2_loss(feat, t_out)
231 | task_model.adapt(kl_loss)
232 | adapted = True
233 |
234 | with torch.no_grad():
235 | task_model.module.eval()
236 | x_sup = task_model(x_sup)
237 | x_sup = x_sup.view(x_sup.shape[0], -1)
238 | s_que_out = head(x_sup)
239 | pred = (s_que_out > 0.0).squeeze().long()
240 | correct += pred.eq(y_sup.view_as(pred)).sum().item()
241 | total += x_sup.shape[0]
242 |
243 | return correct/total
244 |
--------------------------------------------------------------------------------
/src/camelyon/camelyon_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | import math
9 | import random
10 | import pickle5 as pickle
11 | from src.transformer import Transformer
12 | import copy
13 | import time
14 | import os
15 | import wilds
16 | from wilds import get_dataset
17 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
18 | from wilds.common.grouper import CombinatorialGrouper
19 |
20 | # Utils
21 |
22 | def split_domains(num_experts, root_dir='data'):
23 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir)
24 | train_data = dataset.get_subset('train')
25 | wsi = list(set(train_data.metadata_array[:,1].detach().numpy().tolist()))
26 | random.shuffle(wsi)
27 | num_domains_per_super = len(wsi) / float(num_experts)
28 | all_split = [[] for _ in range(num_experts)]
29 | for i in range(len(wsi)):
30 | all_split[int(i//num_domains_per_super)].append(wsi[i])
31 | return all_split
32 |
33 | def get_subset_with_domain(dataset, split=None, domain=None, transform=None, grouper=None):
34 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset:
35 | subset = copy.deepcopy(dataset)
36 | else:
37 | subset = dataset.get_subset(split, transform=transform)
38 | if domain is not None:
39 | if grouper is not None:
40 | z = grouper.metadata_to_group(subset.dataset.metadata_array[subset.indices])
41 | idx = np.argwhere(np.isin(z, domain)).ravel()
42 | subset.indices = subset.indices[idx]
43 | else:
44 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,1][subset.indices], domain)).ravel()
45 | subset.indices = subset.indices[idx]
46 | return subset
47 |
48 | def initialize_image_base_transform(dataset):
49 | transform_steps = []
50 | if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution):
51 | crop_size = min(dataset.original_resolution)
52 | transform_steps.append(transforms.CenterCrop(crop_size))
53 | transform_steps.append(transforms.Resize((96,96)))
54 | transform_steps += [
55 | transforms.ToTensor(),
56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
57 | ]
58 | transform = transforms.Compose(transform_steps)
59 | return transform
60 |
61 | def get_data_loader(batch_size=16, domain=None, test_way='id', n_groups_per_batch=0, return_dataset=False,
62 | uniform_over_groups=True, root_dir='data'):
63 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir)
64 | grouper = CombinatorialGrouper(dataset, ['slide'])
65 |
66 | transform = initialize_image_base_transform(dataset)
67 |
68 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform)
69 | if test_way == 'ood':
70 | val_data = dataset.get_subset('val', transform=transform)
71 | test_data = dataset.get_subset('test', transform=transform)
72 | else:
73 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform)
74 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform)
75 |
76 |
77 | if n_groups_per_batch == 0:
78 | #0 identify standard loader
79 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size)
80 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size)
81 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size)
82 |
83 | else:
84 | # All use get_train_loader to enable grouper
85 | train_loader = get_train_loader('group', train_data, grouper=grouper,
86 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size)
87 | val_loader = get_train_loader('group', val_data, grouper=grouper,
88 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
89 | uniform_over_groups=uniform_over_groups)
90 | test_loader = get_train_loader('group', test_data, grouper=grouper,
91 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
92 | uniform_over_groups=uniform_over_groups)
93 |
94 | if return_dataset:
95 | return train_loader, val_loader, test_loader, grouper, dataset
96 | return train_loader, val_loader, test_loader, grouper
97 |
98 | def get_mask_grouper(root_dir='data'):
99 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir)
100 | grouper = CombinatorialGrouper(dataset, ['slide'])
101 | return grouper
102 |
103 | def get_test_loader(batch_size=16, split='val', root_dir='data', return_dataset=False):
104 | dataset = get_dataset(dataset='camelyon17', download=True, root_dir=root_dir)
105 | grouper = CombinatorialGrouper(dataset, ['slide'])
106 |
107 | transform = initialize_image_base_transform(dataset)
108 |
109 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform)
110 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist()))
111 | test_loader = []
112 |
113 | for domain in all_domains:
114 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform, grouper=grouper)
115 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size)
116 | test_loader.append((domain, domain_loader))
117 |
118 | if return_dataset:
119 | return test_loader, grouper, dataset
120 | return test_loader, grouper
121 |
122 | def save_model(model, name, epoch, test_way='ood'):
123 | if not os.path.exists("model/camelyon"):
124 | os.makedirs("model/camelyon")
125 | path = "model/camelyon/{0}_best.pth".format(name)
126 | torch.save(model.state_dict(), path)
127 |
128 | def mask_feat(feat, mask_index, num_experts, exclude=True):
129 | assert feat.shape[0] == mask_index.shape[0]
130 | if exclude:
131 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index]
132 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx]
133 | else:
134 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index])
135 | feat = feat.permute((0,2,1))
136 | return feat
137 |
138 | def l2_loss(input, target):
139 | loss = torch.square(target - input)
140 | loss = torch.mean(loss)
141 | return loss
142 |
143 | # Models
144 | class ResNetFeature(nn.Module):
145 | def __init__(self, original_model, layer=-1):
146 | super(ResNetFeature, self).__init__()
147 | self.num_ftrs = original_model.classifier.in_features
148 | self.features = nn.Sequential(*list(original_model.children())[:layer])
149 | self.pool = nn.AdaptiveAvgPool2d((1,1))
150 |
151 | def forward(self, x):
152 | x = self.features(x)
153 | x = self.pool(x)
154 | x = x.view(-1, self.num_ftrs)
155 | return x
156 |
157 | class fa_selector(nn.Module):
158 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=1, pool='mean'):
159 | super(fa_selector, self).__init__()
160 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout)
161 | self.pool = pool
162 | self.mlp = nn.Sequential(
163 | nn.LayerNorm(dim),
164 | nn.Linear(dim, out_dim)
165 | )
166 |
167 | def forward(self, x):
168 | x = self.transformer(x)
169 | if self.pool == 'mean':
170 | x = x.mean(dim=1)
171 | elif self.pool == 'max':
172 | x = x.max(dim=1)
173 | else:
174 | raise NotImplementedError
175 | x = self.mlp(x)
176 | return x
177 |
178 | def get_feat(self, x):
179 | x = self.transformer(x)
180 | if self.pool == 'mean':
181 | x = x.mean(dim=1)
182 | elif self.pool == 'max':
183 | x = x.max(dim=1)
184 | else:
185 | raise NotImplementedError
186 | return x
187 |
188 | class DivideModel(nn.Module):
189 | def __init__(self, original_model, layer=-1):
190 | super(DivideModel, self).__init__()
191 | self.num_ftrs = original_model.classifier.in_features
192 | self.num_class = original_model.classifier.out_features
193 | self.features = nn.Sequential(*list(original_model.children())[:layer])
194 | self.features.add_module("avg_pool", nn.AdaptiveAvgPool2d((1,1)))
195 | self.classifier = nn.Sequential(*list(original_model.children())[layer:])
196 |
197 | def forward(self, x):
198 | x = self.features(x)
199 | x = x.view(-1, self.num_ftrs)
200 | x = self.classifier(x)
201 | x = x.view(-1, self.num_class)
202 | return x
203 |
204 | def StudentModel(device, num_classes=1, load_path=None):
205 | model = torchvision.models.densenet121(pretrained=True)
206 | num_ftrs = model.classifier.in_features
207 | model.classifier = nn.Linear(num_ftrs, num_classes)
208 | if load_path:
209 | model.load_state_dict(torch.load(load_path))
210 | model = DivideModel(model)
211 | model = model.to(device)
212 | return model
213 |
214 | def get_feature_list(models_list, device):
215 | feature_list = []
216 | for model in models_list:
217 | feature_list.append(ResNetFeature(model).to(device))
218 | return feature_list
219 |
220 | def get_models_list(device, num_domains=4, num_classes=1, pretrained=False, bb='d121'):
221 | models_list = []
222 | for _ in range(num_domains+1):
223 | model = torchvision.models.densenet121(pretrained=True)
224 | num_ftrs = model.classifier.in_features
225 | model.classifier = nn.Linear(num_ftrs, num_classes)
226 | model = model.to(device)
227 | models_list.append(model)
228 | return models_list
229 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: dmoe
2 | channels:
3 | - pytorch
4 | - intel
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _ipyw_jlab_nb_ext_conf=0.1.0=py37_0
9 | - _libgcc_mutex=0.1=main
10 | - _openmp_mutex=4.5=1_gnu
11 | - absl-py=0.13.0=py37h06a4308_0
12 | - aiohttp=3.7.4.post0=py37h7f8727e_2
13 | - anaconda-client=1.9.0=py37h06a4308_0
14 | - anaconda-navigator=2.1.0=py37h06a4308_0
15 | - anyio=2.2.0=py37h06a4308_1
16 | - argcomplete=1.12.3=pyhd3eb1b0_0
17 | - argon2-cffi=20.1.0=py37h27cfd23_1
18 | - async-timeout=3.0.1=py37h06a4308_0
19 | - async_generator=1.10=py37h28b3542_0
20 | - attrs=21.2.0=pyhd3eb1b0_0
21 | - babel=2.9.1=pyhd3eb1b0_0
22 | - backcall=0.2.0=pyhd3eb1b0_0
23 | - backports=1.0=pyhd3eb1b0_2
24 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0
25 | - backports.tempfile=1.0=pyhd3eb1b0_1
26 | - backports.weakref=1.0.post1=py_1
27 | - beautifulsoup4=4.10.0=pyh06a4308_0
28 | - blas=1.0=mkl
29 | - bleach=4.0.0=pyhd3eb1b0_0
30 | - blinker=1.4=py37h06a4308_0
31 | - brotlipy=0.7.0=py37h27cfd23_1003
32 | - bzip2=1.0.8=h7b6447c_0
33 | - c-ares=1.17.1=h27cfd23_0
34 | - ca-certificates=2022.4.26=h06a4308_0
35 | - cachetools=4.2.2=pyhd3eb1b0_0
36 | - cairo=1.16.0=hf32fb01_1
37 | - certifi=2022.5.18.1=py37h06a4308_0
38 | - cffi=1.14.6=py37h400218f_0
39 | - chardet=4.0.0=py37h06a4308_1003
40 | - charset-normalizer=2.0.4=pyhd3eb1b0_0
41 | - click=8.0.3=pyhd3eb1b0_0
42 | - clyent=1.2.2=py37_1
43 | - common_cmplr_lib_rt=2021.4.0=intel_3561
44 | - common_cmplr_lic_rt=2021.4.0=intel_3561
45 | - conda=4.10.3=py37hb2ebc77_1
46 | - conda-build=3.21.5=py37h06a4308_0
47 | - conda-content-trust=0.1.1=pyhd3eb1b0_0
48 | - conda-env=2.6.0=1
49 | - conda-package-handling=1.7.3=py37h27cfd23_1
50 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0
51 | - conda-token=0.3.0=pyhd3eb1b0_0
52 | - conda-verify=3.4.2=py_1
53 | - coverage=5.5=py37h27cfd23_2
54 | - cryptography=3.4.8=py37hd23ed53_0
55 | - cudatoolkit=10.1.243=h6bb024c_0
56 | - cudnn=7.6.5=cuda10.1_0
57 | - cupy=8.3.0=py37hcaf9a05_0
58 | - cython=0.29.24=py37h295c915_0
59 | - daal4py=2021.4.0=py37_intel_729
60 | - dal=2021.4.0=intel_729
61 | - dbus=1.13.18=hb2f20db_0
62 | - debugpy=1.4.1=py37h295c915_0
63 | - decorator=5.1.0=pyhd3eb1b0_0
64 | - defusedxml=0.7.1=pyhd3eb1b0_0
65 | - dpcpp_cpp_rt=2021.2.0=intel_610
66 | - easydict=1.9=py_0
67 | - entrypoints=0.3=py37_0
68 | - expat=2.4.1=h2531618_2
69 | - fastrlock=0.6=py37h2531618_0
70 | - ffmpeg=4.0=hcdf2ecd_0
71 | - filelock=3.3.1=pyhd3eb1b0_1
72 | - fontconfig=2.13.1=h6c09931_0
73 | - freeglut=3.0.0=hf484d3e_5
74 | - freetype=2.11.0=h70c0345_0
75 | - future=0.18.2=py37_1
76 | - giflib=5.2.1=h7b6447c_0
77 | - glib=2.69.1=h5202010_0
78 | - glob2=0.7=pyhd3eb1b0_0
79 | - google-auth=1.33.0=pyhd3eb1b0_0
80 | - google-auth-oauthlib=0.4.1=py_2
81 | - graphite2=1.3.14=h23475e2_0
82 | - grpcio=1.36.1=py37h2157cd5_1
83 | - gst-plugins-base=1.14.0=h8213a91_2
84 | - gstreamer=1.14.0=h28cd5cc_2
85 | - h5py=2.8.0=py37h989c5e5_3
86 | - harfbuzz=1.8.8=hffaf4a1_0
87 | - hdf5=1.10.2=hba1933b_1
88 | - icc_rt=2021.2.0=intel_610
89 | - icu=58.2=he6710b0_3
90 | - idna=3.2=pyhd3eb1b0_0
91 | - importlib-metadata=4.8.1=py37h06a4308_0
92 | - importlib_metadata=4.8.1=hd3eb1b0_0
93 | - intel-cmplr-lib-rt=2021.4.0=intel_3561
94 | - intel-cmplr-lic-rt=2021.4.0=intel_3561
95 | - intel-opencl-rt=2021.4.0=intel_3561
96 | - intel-openmp=2021.3.0=h06a4308_3350
97 | - intelpython=2021.4.0=0
98 | - ipykernel=6.4.1=py37h06a4308_1
99 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
100 | - ipywidgets=7.6.5=pyhd3eb1b0_1
101 | - jasper=2.0.14=hd8c5072_2
102 | - jedi=0.18.0=py37h06a4308_1
103 | - jinja2=2.11.3=pyhd3eb1b0_0
104 | - joblib=1.0.1=py37h3f38642_0
105 | - jpeg=9d=h7f8727e_0
106 | - json5=0.9.6=pyhd3eb1b0_0
107 | - jsonschema=3.2.0=pyhd3eb1b0_2
108 | - jupyter_client=7.0.1=pyhd3eb1b0_0
109 | - jupyter_core=4.8.1=py37h06a4308_0
110 | - jupyter_server=1.4.1=py37h06a4308_0
111 | - jupyterlab=3.2.1=pyhd3eb1b0_0
112 | - jupyterlab_pygments=0.1.2=py_0
113 | - jupyterlab_server=2.8.2=pyhd3eb1b0_0
114 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
115 | - lcms2=2.12=h3be6417_0
116 | - ld_impl_linux-64=2.35.1=h7274673_9
117 | - libarchive=3.4.2=h62408e4_0
118 | - libffi=3.3=he6710b0_2
119 | - libgcc-ng=9.3.0=h5101ec6_17
120 | - libgfortran-ng=7.5.0=ha8ba4b0_17
121 | - libgfortran4=7.5.0=ha8ba4b0_17
122 | - libglu=9.0.0=hf484d3e_1
123 | - libgomp=9.3.0=h5101ec6_17
124 | - liblief=0.10.1=he6710b0_0
125 | - libopencv=3.4.2=hb342d67_1
126 | - libopus=1.3.1=h7b6447c_0
127 | - libpng=1.6.37=hbc83047_0
128 | - libprotobuf=3.17.2=h4ff587b_1
129 | - libsodium=1.0.18=h7b6447c_0
130 | - libstdcxx-ng=11.2.0=h1234567_0
131 | - libtiff=4.2.0=h85742a9_0
132 | - libuuid=1.0.3=h7f8727e_2
133 | - libuv=1.40.0=h7b6447c_0
134 | - libvpx=1.7.0=h439df22_0
135 | - libwebp=1.2.0=h89dd481_0
136 | - libwebp-base=1.2.0=h27cfd23_0
137 | - libxcb=1.14=h7b6447c_0
138 | - libxml2=2.9.12=h03d6c58_0
139 | - lz4-c=1.9.3=h295c915_1
140 | - markdown=3.3.4=py37h06a4308_0
141 | - markupsafe=2.0.1=py37h27cfd23_0
142 | - matplotlib-inline=0.1.2=pyhd3eb1b0_2
143 | - mistune=0.8.4=py37h14c3975_1001
144 | - mkl=2021.3.0=h06a4308_520
145 | - mkl-service=2.4.0=py37h7f8727e_0
146 | - mkl_fft=1.3.1=py37hd3c417c_0
147 | - mkl_random=1.2.2=py37h51133e4_0
148 | - multidict=5.1.0=py37h27cfd23_2
149 | - navigator-updater=0.2.1=py37_0
150 | - nbclassic=0.2.6=pyhd3eb1b0_0
151 | - nbclient=0.5.3=pyhd3eb1b0_0
152 | - nbconvert=6.1.0=py37h06a4308_0
153 | - nbformat=5.1.3=pyhd3eb1b0_0
154 | - nccl=2.8.3.1=hcaf9a05_0
155 | - ncurses=6.2=he6710b0_1
156 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
157 | - ninja=1.10.2=hff7bd54_1
158 | - notebook=6.4.5=py37h06a4308_0
159 | - oauthlib=3.1.1=pyhd3eb1b0_0
160 | - olefile=0.46=py37_0
161 | - opencl_rt=2021.4.0=intel_3561
162 | - opencv=3.4.2=py37h6fd60c2_1
163 | - openssl=1.1.1o=h7f8727e_0
164 | - packaging=21.0=pyhd3eb1b0_0
165 | - pandocfilters=1.4.3=py37h06a4308_1
166 | - parso=0.8.2=pyhd3eb1b0_0
167 | - patchelf=0.13=h295c915_0
168 | - pcre=8.45=h295c915_0
169 | - pexpect=4.8.0=pyhd3eb1b0_3
170 | - pickleshare=0.7.5=pyhd3eb1b0_1003
171 | - pixman=0.40.0=h7f8727e_1
172 | - pkginfo=1.7.1=py37h06a4308_0
173 | - prometheus_client=0.11.0=pyhd3eb1b0_0
174 | - prompt-toolkit=3.0.20=pyhd3eb1b0_0
175 | - protobuf=3.17.2=py37h295c915_0
176 | - psutil=5.8.0=py37h27cfd23_1
177 | - ptyprocess=0.7.0=pyhd3eb1b0_2
178 | - py-lief=0.10.1=py37h403a769_0
179 | - py-opencv=3.4.2=py37hb342d67_1
180 | - pyasn1=0.4.8=pyhd3eb1b0_0
181 | - pyasn1-modules=0.2.8=py_0
182 | - pycosat=0.6.3=py37h27cfd23_0
183 | - pycparser=2.20=py_2
184 | - pygments=2.10.0=pyhd3eb1b0_0
185 | - pyjwt=2.1.0=py37h06a4308_0
186 | - pyopenssl=21.0.0=pyhd3eb1b0_1
187 | - pyparsing=2.4.7=pyhd3eb1b0_0
188 | - pyqt=5.9.2=py37h05f1152_2
189 | - pyrsistent=0.17.3=py37h7b6447c_0
190 | - pysocks=1.7.1=py37_1
191 | - python=3.7.11=h12debd9_0
192 | - python-dateutil=2.8.2=pyhd3eb1b0_0
193 | - python-libarchive-c=2.9=pyhd3eb1b0_1
194 | - pytorch=1.7.0=py3.7_cuda10.1.243_cudnn7.6.3_0
195 | - pytz=2021.3=pyhd3eb1b0_0
196 | - pyyaml=5.4.1=py37h27cfd23_1
197 | - pyzmq=22.2.1=py37h295c915_1
198 | - qt=5.9.7=h5867ecd_1
199 | - qtpy=1.10.0=pyhd3eb1b0_0
200 | - readline=8.1=h27cfd23_0
201 | - requests=2.26.0=pyhd3eb1b0_0
202 | - requests-oauthlib=1.3.0=py_0
203 | - ripgrep=12.1.1=0
204 | - rsa=4.7.2=pyhd3eb1b0_1
205 | - ruamel_yaml=0.15.100=py37h27cfd23_0
206 | - scikit-learn=0.24.2=py37h8411759_4
207 | - scikit-learn-intelex=2021.4.0=py37_intel_729
208 | - scipy=1.7.1=py37h292c36d_2
209 | - send2trash=1.8.0=pyhd3eb1b0_1
210 | - sip=4.19.8=py37hf484d3e_0
211 | - six=1.16.0=pyhd3eb1b0_0
212 | - sniffio=1.2.0=py37h06a4308_1
213 | - soupsieve=2.2.1=pyhd3eb1b0_0
214 | - sqlite=3.36.0=hc218d9a_0
215 | - tbb=2021.4.0=intel_643
216 | - tensorboard=2.6.0=py_1
217 | - tensorboard-plugin-wit=1.6.0=py_0
218 | - tensorboardx=2.2=pyhd3eb1b0_0
219 | - terminado=0.9.4=py37h06a4308_0
220 | - testpath=0.5.0=pyhd3eb1b0_0
221 | - threadpoolctl=2.1.0=py37h6447541_2
222 | - tk=8.6.11=h1ccaba5_0
223 | - tornado=6.1=py37h27cfd23_0
224 | - tqdm=4.62.3=pyhd3eb1b0_1
225 | - traitlets=5.1.0=pyhd3eb1b0_0
226 | - typing-extensions=3.10.0.2=hd3eb1b0_0
227 | - typing_extensions=3.10.0.2=pyh06a4308_0
228 | - urllib3=1.26.7=pyhd3eb1b0_0
229 | - wcwidth=0.2.5=pyhd3eb1b0_0
230 | - webencodings=0.5.1=py37_1
231 | - werkzeug=2.0.2=pyhd3eb1b0_0
232 | - widgetsnbextension=3.5.1=py37_0
233 | - xmltodict=0.12.0=pyhd3eb1b0_0
234 | - xz=5.2.5=h7b6447c_0
235 | - yaml=0.2.5=h7b6447c_0
236 | - yarl=1.6.3=py37h27cfd23_0
237 | - zeromq=4.3.4=h2531618_0
238 | - zipp=3.6.0=pyhd3eb1b0_0
239 | - zlib=1.2.11=h7b6447c_3
240 | - zstd=1.4.9=haebb681_0
241 | - pip:
242 | - albumentations==1.1.0
243 | - boto==2.49.0
244 | - cloudpickle==2.0.0
245 | - crcmod==1.7
246 | - cycler==0.11.0
247 | - dataclasses==0.6
248 | - dummy-test==0.1.3
249 | - fasteners==0.16.3
250 | - fonttools==4.33.3
251 | - gcs-oauth2-boto-plugin==3.0
252 | - google-apitools==0.5.32
253 | - google-reauth==0.1.1
254 | - gsutil==5.5
255 | - gym==0.21.0
256 | - higher==0.2.1
257 | - hnswlib==0.5.2
258 | - httplib2==0.20.2
259 | - huggingface-hub==0.5.1
260 | - imageio==2.19.2
261 | - ipdb==0.13.9
262 | - ipython==7.25.0
263 | - kiwisolver==1.4.2
264 | - learn2learn==0.1.6
265 | - littleutils==0.2.2
266 | - mat73==0.52
267 | - matplotlib==3.5.2
268 | - monotonic==1.6
269 | - networkx==2.6.3
270 | - numpy==1.21.3
271 | - oauth2client==4.1.3
272 | - ogb==1.3.2
273 | - opencv-python==4.5.4.58
274 | - opencv-python-headless==4.5.4.58
275 | - outdated==0.2.1
276 | - pandas==1.3.4
277 | - pathlib==1.0.1
278 | - pickle5==0.0.12
279 | - pillow==8.4.0
280 | - pip==21.3.1
281 | - ptflops==0.6.6
282 | - pyu2f==0.1.5
283 | - pywavelets==1.1.1
284 | - qpth==0.0.15
285 | - qudida==0.0.4
286 | - regex==2022.3.15
287 | - retry-decorator==1.1.1
288 | - sacremoses==0.0.49
289 | - scikit-image==0.18.3
290 | - setuptools==58.5.0
291 | - tabulate==0.8.9
292 | - tensorboard-data-server==0.6.1
293 | - tifffile==2021.10.12
294 | - timm==0.4.12
295 | - tokenizers==0.11.6
296 | - toml==0.10.2
297 | - torchvision==0.8.1
298 | - transformers==4.18.0
299 | - wheel==0.34.0
300 | - wilds==1.2.2
301 |
--------------------------------------------------------------------------------
/src/fmow/fmow_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | import torchvision
8 | from tqdm import tqdm
9 | import copy
10 | import learn2learn as l2l
11 | import time
12 | from datetime import datetime
13 | import os
14 | import wilds
15 | from src.fmow.fmow_utils import *
16 |
17 | def train_epoch(selector, selector_name, models_list, student, student_name,
18 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
19 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3,
20 | batch_size=256, sup_size=24, test_way='id', save=False,
21 | root_dir='data'):
22 | for model in models_list:
23 | model.eval()
24 |
25 | student_ce = nn.CrossEntropyLoss()
26 | #teacher_ce = nn.CrossEntropyLoss()
27 |
28 | features = student.features
29 | head = student.classifier
30 | features = l2l.algorithms.MAML(features, lr=ilr)
31 | features.to(device)
32 | head.to(device)
33 |
34 | all_params = list(features.parameters()) + list(head.parameters())
35 | optimizer_s = optim.Adam(all_params, lr=slr)
36 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr)
37 |
38 | i = 0
39 |
40 | losses = []
41 |
42 | iter_per_epoch = len(train_loader)
43 |
44 | for x, y_true, metadata in train_loader:
45 | selector.eval()
46 | head.eval()
47 | features.eval()
48 |
49 | z = grouper.metadata_to_group(metadata)
50 | z = set(z.tolist())
51 | assert len(z) == 1
52 | mask = mask_grouper.metadata_to_group(metadata)
53 | mask.apply_(lambda x: split_to_cluster[x])
54 |
55 | #sup_size = x.shape[0]//2
56 | x_sup = x[:sup_size]
57 | y_sup = y_true[:sup_size]
58 | x_que = x[sup_size:]
59 | y_que = y_true[sup_size:]
60 | mask = mask[:sup_size]
61 |
62 | x_sup = x_sup.to(device)
63 | y_sup = y_sup.to(device)
64 | x_que = x_que.to(device)
65 | y_que = y_que.to(device)
66 |
67 | with torch.no_grad():
68 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
69 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]])
70 | #logits = logits.permute((0,2,1))
71 | logits = mask_feat(logits, mask, len(models_list), exclude=True)
72 |
73 | t_out = selector.get_feat(logits)
74 |
75 | task_model = features.clone()
76 | task_model.module.eval()
77 | feat = task_model(x_que)
78 | feat = feat.view(feat.shape[0], -1)
79 | out = head(feat)
80 | with torch.no_grad():
81 | loss_pre = student_ce(out, y_que).item()/x_que.shape[0]
82 |
83 | feat = task_model(x_sup)
84 | feat = feat.view_as(t_out)
85 |
86 | inner_loss = l2_loss(feat, t_out)
87 | task_model.adapt(inner_loss)
88 |
89 | x_que = task_model(x_que)
90 | x_que = x_que.view(x_que.shape[0], -1)
91 | s_que_out = head(x_que)
92 | s_que_loss = student_ce(s_que_out, y_que)
93 | #t_sup_loss = teacher_ce(t_out, y_sup)
94 |
95 | s_que_loss.backward()
96 |
97 | optimizer_s.step()
98 | optimizer_t.step()
99 | optimizer_s.zero_grad()
100 | optimizer_t.zero_grad()
101 |
102 | #print("Step:{}".format(time.time() - t_1))
103 | #t_1 = time.time()
104 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
105 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch,
106 | loss_pre,
107 | s_que_loss.item()/x_que.shape[0]))
108 | losses.append(s_que_loss.item()/x_que.shape[0])
109 |
110 | if i == iter_per_epoch//2:
111 | losses = np.mean(losses)
112 | worst_acc, acc = eval(selector, models_list, student, sup_size, device=device,
113 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
114 | root_dir=root_dir)
115 |
116 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
117 | f.write(f'Accuracy: {acc}; Worst acc: {worst_acc}\r\n')
118 | losses = []
119 |
120 | if worst_acc > acc_best and save:
121 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
122 | f.write("Saving model ...\r\n")
123 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
124 | save_model(student, student_name+"_student", 0, test_way=test_way)
125 | acc_best = worst_acc
126 |
127 | i += 1
128 | return acc_best
129 |
130 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device,
131 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30,
132 | decayRate=0.96, save=False, test_way='ood', root_dir='data'):
133 |
134 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None,
135 | test_way=test_way, n_groups_per_batch=1,
136 | groupby_fields=['year', 'region'])
137 |
138 | mask_grouper = get_mask_grouper(root_dir=root_dir)
139 |
140 | curr = str(datetime.now())
141 | if not os.path.exists("log/fmow"):
142 | os.makedirs("log/fmow")
143 | #print(curr)
144 | print("Training log saved to log/fmow/"+curr+".txt")
145 |
146 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
147 | f.write(selector_name+' '+student_name+'\r\n')
148 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n')
149 |
150 | accu_best = 0
151 | for epoch in range(num_epochs):
152 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
153 | f.write(f'Epoch: {epoch}\r\n')
154 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name,
155 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
156 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr,
157 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save,
158 | root_dir=root_dir)
159 | accu_best = max(accu_best, accu_epoch)
160 | worst_acc, accu = eval(selector, models_list, student, sup_size, device=device,
161 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
162 | root_dir=root_dir)
163 |
164 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
165 | f.write(f'Accuracy: {accu}; Worst acc: {worst_acc}\r\n')
166 |
167 | if worst_acc > accu_best and save:
168 | with open('log/fmow/'+curr+'.txt', 'a+') as f:
169 | f.write("Saving model ...\r\n")
170 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
171 | save_model(student, student_name+"_student", 0, test_way=test_way)
172 | accu_best = worst_acc
173 | tlr = tlr*decayRate
174 | slr = slr*decayRate
175 |
176 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5,
177 | test=False, progress=True, uniform_over_groups=False, root_dir='data'):
178 |
179 | if test:
180 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir,
181 | groupby_fields=['year', 'region'], return_dataset=True)
182 | else:
183 | loader, grouper, dataset = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir,
184 | groupby_fields=['year', 'region'], return_dataset=True)
185 |
186 | '''if test:
187 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test')
188 | else:
189 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')'''
190 |
191 | features = student.features
192 | head = student.classifier
193 | head.to(device)
194 |
195 | student_maml = l2l.algorithms.MAML(features, lr=ilr)
196 | student_maml.to(device)
197 |
198 | old_domain = {}
199 | if progress:
200 | loader = tqdm(iter(loader), total=len(loader))
201 |
202 | for domain, domain_loader in loader:
203 | adapted = False
204 | for x_sup, y_sup, metadata in domain_loader:
205 | student_maml.module.eval()
206 | selector.eval()
207 | head.eval()
208 |
209 | z = grouper.metadata_to_group(metadata)
210 | z = set(z.tolist())
211 | assert list(z)[0] == domain
212 |
213 | x_sup = x_sup.to(device)
214 | y_sup = y_sup.to(device)
215 |
216 | if not adapted:
217 | task_model = student_maml.clone()
218 | task_model.eval()
219 | with torch.no_grad():
220 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
221 | logits = logits.permute((0,2,1))
222 | t_out = selector.get_feat(logits)
223 |
224 | feat = task_model(x_sup)
225 | feat = feat.view_as(t_out)
226 |
227 | kl_loss = l2_loss(feat, t_out)
228 | torch.cuda.empty_cache()
229 | task_model.adapt(kl_loss)
230 | adapted = True
231 |
232 | with torch.no_grad():
233 | task_model.module.eval()
234 | x_sup = task_model(x_sup)
235 | x_sup = x_sup.view(x_sup.shape[0], -1)
236 | s_que_out = head(x_sup)
237 | pred = s_que_out.max(1, keepdim=True)[1]
238 | try:
239 | pred_all = torch.cat((pred_all, pred.view_as(y_sup)))
240 | y_all = torch.cat((y_all, y_sup))
241 | metadata_all = torch.cat((metadata_all, metadata))
242 | except NameError:
243 | pred_all = pred.view_as(y_sup)
244 | y_all = y_sup
245 | metadata_all = metadata
246 |
247 | acc, worst_acc = get_fmow_metrics(pred_all, y_all, metadata_all, dataset)
248 | return worst_acc, acc
249 |
--------------------------------------------------------------------------------
/src/rxrx1/rxrx1_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | import torchvision
8 | from tqdm import tqdm
9 | import copy
10 | import learn2learn as l2l
11 | import time
12 | from datetime import datetime
13 | from collections import defaultdict
14 | from transformers import get_cosine_schedule_with_warmup
15 | import os
16 | import wilds
17 | from src.rxrx1.rxrx1_utils import *
18 |
19 | def train_epoch(selector, selector_name, models_list, student, student_name,
20 | train_loader, grouper, num_epochs, curr, mask_grouper, split_to_cluster,
21 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3,
22 | batch_size=256, sup_size=24, test_way='id', save=False,
23 | root_dir='data'):
24 | for model in models_list:
25 | model.eval()
26 |
27 | student_ce = nn.CrossEntropyLoss()
28 | #teacher_ce = nn.CrossEntropyLoss()
29 |
30 | features = student.features
31 | head = student.classifier
32 | features = l2l.algorithms.MAML(features, lr=ilr)
33 | features.to(device)
34 | head.to(device)
35 |
36 | all_params = list(features.parameters()) + list(head.parameters())
37 | optimizer_s = optim.Adam(all_params, lr=slr)
38 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr)
39 |
40 | i = 0
41 |
42 | losses = []
43 |
44 | iter_per_epoch = len(train_loader)
45 | scheduler = get_cosine_schedule_with_warmup(
46 | optimizer_s,
47 | num_warmup_steps=iter_per_epoch,
48 | num_training_steps=iter_per_epoch*num_epochs)
49 |
50 | for epoch in range(num_epochs):
51 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f:
52 | f.write(f'Epoch: {epoch}\r\n')
53 | i = 0
54 | for x, y_true, metadata in train_loader:
55 | selector.eval()
56 | head.eval()
57 | features.eval()
58 |
59 | z = grouper.metadata_to_group(metadata)
60 | z = set(z.tolist())
61 | assert len(z) == 1
62 | mask = mask_grouper.metadata_to_group(metadata)
63 | mask.apply_(lambda x: split_to_cluster[x])
64 |
65 | #sup_size = x.shape[0]//2
66 | x_sup = x[:sup_size]
67 | y_sup = y_true[:sup_size]
68 | x_que = x[sup_size:]
69 | y_que = y_true[sup_size:]
70 | mask = mask[:sup_size]
71 |
72 | x_sup = x_sup.to(device)
73 | y_sup = y_sup.to(device)
74 | x_que = x_que.to(device)
75 | y_que = y_que.to(device)
76 |
77 | with torch.no_grad():
78 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
79 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]])
80 | #logits = logits.permute((0,2,1))
81 | logits = mask_feat(logits, mask, len(models_list), exclude=True)
82 |
83 | t_out = selector.get_feat(logits)
84 |
85 | task_model = features.clone()
86 | task_model.module.eval()
87 | feat = task_model(x_que)
88 | feat = feat.view(feat.shape[0], -1)
89 | out = head(feat)
90 | with torch.no_grad():
91 | loss_pre = student_ce(out, y_que).item()/x_que.shape[0]
92 |
93 | feat = task_model(x_sup)
94 | feat = feat.view_as(t_out)
95 |
96 | inner_loss = l2_loss(feat, t_out)
97 | task_model.adapt(inner_loss)
98 |
99 | x_que = task_model(x_que)
100 | x_que = x_que.view(x_que.shape[0], -1)
101 | s_que_out = head(x_que)
102 | s_que_loss = student_ce(s_que_out, y_que)
103 | #t_sup_loss = teacher_ce(t_out, y_sup)
104 |
105 | s_que_loss.backward()
106 |
107 | optimizer_s.step()
108 | optimizer_t.step()
109 | optimizer_s.zero_grad()
110 | optimizer_t.zero_grad()
111 |
112 | #print("Step:{}".format(time.time() - t_1))
113 | #t_1 = time.time()
114 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f:
115 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch,
116 | loss_pre,
117 | s_que_loss.item()/x_que.shape[0]))
118 | losses.append(s_que_loss.item()/x_que.shape[0])
119 |
120 | if i == iter_per_epoch//2 or i == iter_per_epoch-1:
121 | losses = np.mean(losses)
122 | acc = eval(selector, models_list, student, sup_size, device=device,
123 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
124 | root_dir=root_dir)
125 |
126 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f:
127 | f.write(f'Accuracy: {acc}\r\n')
128 | losses = []
129 |
130 | if acc > acc_best and save:
131 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f:
132 | f.write("Saving model ...\r\n")
133 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
134 | save_model(student, student_name+"_student", 0, test_way=test_way)
135 | acc_best = acc
136 |
137 | i += 1
138 | scheduler.step()
139 | return acc_best
140 |
141 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device,
142 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30,
143 | decayRate=0.96, save=False, test_way='ood', root_dir='data'):
144 |
145 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None,
146 | test_way=test_way, n_groups_per_batch=1)
147 |
148 | mask_grouper = get_mask_grouper(root_dir=root_dir)
149 |
150 | curr = str(datetime.now())
151 | if not os.path.exists("log/rxrx1"):
152 | os.makedirs("log/rxrx1")
153 | #print(curr)
154 | print("Training log saved to log/rxrx1/"+curr+".txt")
155 |
156 | with open('log/rxrx1/'+curr+'.txt', 'a+') as f:
157 | f.write(selector_name+' '+student_name+'\r\n')
158 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n')
159 |
160 | accu_best = 0
161 | #for epoch in range(num_epochs):
162 | #with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
163 | #f.write(f'Epoch: {epoch}\r\n')
164 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name,
165 | train_loader, grouper, num_epochs, curr, mask_grouper, split_to_cluster,
166 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr,
167 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save,
168 | root_dir=root_dir)
169 | #accu_best = max(accu_best, accu_epoch)
170 | #_, accu, f1 = eval(selector, models_list, student, sup_size, device=device,
171 | #ilr=ilr, test=False, progress=False, uniform_over_groups=False,
172 | #root_dir=root_dir)
173 |
174 | #with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
175 | #f.write(f'Accuracy: {accu} || F1:{f1} \r\n')
176 |
177 | #if f1 > accu_best and save:
178 | #with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
179 | #f.write("Saving model ...\r\n")
180 | #save_model(selector, selector_name+"_selector", 0, test_way=test_way)
181 | #save_model(student, student_name+"_student", 0, test_way=test_way)
182 | #accu_best = f1
183 | #tlr = tlr*decayRate
184 | #slr = slr*decayRate
185 |
186 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5,
187 | test=False, progress=True, uniform_over_groups=False, root_dir='data'):
188 |
189 | if test:
190 | loader, grouper = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir)
191 | else:
192 | loader, grouper = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir)
193 | '''if test:
194 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test')
195 | else:
196 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')'''
197 |
198 | features = student.features
199 | head = student.classifier
200 | head.to(device)
201 |
202 | student_maml = l2l.algorithms.MAML(features, lr=ilr)
203 | student_maml.to(device)
204 |
205 | nor_correct = 0
206 | nor_total = 0
207 | old_domain = {}
208 | if progress:
209 | loader = tqdm(iter(loader), total=len(loader))
210 |
211 | for domain, domain_loader in loader:
212 | adapted = False
213 | for x_sup, y_sup, metadata in domain_loader:
214 | student_maml.module.eval()
215 | selector.eval()
216 | head.eval()
217 |
218 | z = grouper.metadata_to_group(metadata)
219 | z = set(z.tolist())
220 | assert list(z)[0] == domain
221 |
222 | x_sup = x_sup.to(device)
223 | y_sup = y_sup.to(device)
224 |
225 | if not adapted:
226 | task_model = student_maml.clone()
227 | task_model.eval()
228 |
229 | with torch.no_grad():
230 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
231 | logits = logits.permute((0,2,1))
232 | t_out = selector.get_feat(logits)
233 |
234 | feat = task_model(x_sup)
235 | feat = feat.view_as(t_out)
236 |
237 | kl_loss = l2_loss(feat, t_out)
238 | torch.cuda.empty_cache()
239 | task_model.adapt(kl_loss)
240 | adapted = True
241 |
242 | with torch.no_grad():
243 | task_model.module.eval()
244 | x_sup = task_model(x_sup)
245 | x_sup = x_sup.view(x_sup.shape[0], -1)
246 | s_que_out = head(x_sup)
247 | pred = s_que_out.max(1, keepdim=True)[1]
248 | c = pred.eq(y_sup.view_as(pred)).sum().item()
249 | nor_correct += c
250 | nor_total += x_sup.shape[0]
251 |
252 | return nor_correct/nor_total
253 |
--------------------------------------------------------------------------------
/src/fmow/fmow_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | import torch.optim as optim
6 | import torchvision
7 | import torchvision.transforms as transforms
8 | import math
9 | import random
10 | import pickle5 as pickle
11 | from src.transformer import Transformer
12 | import copy
13 | import time
14 | import os
15 | import wilds
16 | from wilds import get_dataset
17 | from wilds.common.data_loaders import get_train_loader, get_eval_loader
18 | from wilds.common.grouper import CombinatorialGrouper
19 | from wilds.common.utils import split_into_groups
20 |
21 | # Utils
22 |
23 | def split_domains(num_experts, root_dir='data'):
24 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir)
25 | train_data = dataset.get_subset('train')
26 | year = list(set(train_data.metadata_array[:,1].detach().numpy().tolist()))
27 | random.shuffle(year)
28 | num_domains_per_super = len(year) / float(num_experts)
29 | all_split = [[] for _ in range(num_experts)]
30 | for i in range(len(year)):
31 | all_split[int(i//num_domains_per_super)].append(year[i])
32 | return all_split
33 |
34 | def get_subset_with_domain(dataset, split=None, domain=None, transform=None, grouper=None):
35 | if type(dataset) == wilds.datasets.wilds_dataset.WILDSSubset:
36 | subset = copy.deepcopy(dataset)
37 | else:
38 | subset = dataset.get_subset(split, transform=transform)
39 | if domain is not None:
40 | if grouper is not None:
41 | z = grouper.metadata_to_group(subset.dataset.metadata_array[subset.indices])
42 | idx = np.argwhere(np.isin(z, domain)).ravel()
43 | subset.indices = subset.indices[idx]
44 | else:
45 | idx = np.argwhere(np.isin(subset.dataset.metadata_array[:,1][subset.indices], domain)).ravel()
46 | subset.indices = subset.indices[idx]
47 | return subset
48 |
49 | def initialize_image_base_transform(dataset):
50 | transform_steps = []
51 | if dataset.original_resolution is not None and min(dataset.original_resolution)!=max(dataset.original_resolution):
52 | crop_size = min(dataset.original_resolution)
53 | transform_steps.append(transforms.CenterCrop(crop_size))
54 | transform_steps += [
55 | transforms.ToTensor(),
56 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
57 | ]
58 | transform = transforms.Compose(transform_steps)
59 | return transform
60 |
61 | def get_data_loader(batch_size=16, domain=None, test_way='id',
62 | n_groups_per_batch=0, uniform_over_groups=False, root_dir='data',
63 | groupby_fields=['year'], return_dataset=False):
64 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir)
65 | grouper = CombinatorialGrouper(dataset, groupby_fields)
66 |
67 | transform = initialize_image_base_transform(dataset)
68 |
69 | train_data = get_subset_with_domain(dataset, 'train', domain=domain, transform=transform)
70 | if test_way == 'ood':
71 | val_data = dataset.get_subset('val', transform=transform)
72 | test_data = dataset.get_subset('test', transform=transform)
73 | else:
74 | val_data = get_subset_with_domain(dataset, 'id_val', domain=domain, transform=transform)
75 | test_data = get_subset_with_domain(dataset, 'id_test', domain=domain, transform=transform)
76 |
77 |
78 | if n_groups_per_batch == 0:
79 | #0 identify standard loader
80 | train_loader = get_train_loader('standard', train_data, batch_size=batch_size)
81 | val_loader = get_train_loader('standard', val_data, batch_size=batch_size)
82 | test_loader = get_train_loader('standard', test_data, batch_size=batch_size)
83 |
84 | else:
85 | # All use get_train_loader to enable grouper
86 | train_loader = get_train_loader('group', train_data, grouper=grouper,
87 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size)
88 | val_loader = get_train_loader('group', val_data, grouper=grouper,
89 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
90 | uniform_over_groups=uniform_over_groups)
91 | test_loader = get_train_loader('group', test_data, grouper=grouper,
92 | n_groups_per_batch=n_groups_per_batch, batch_size=batch_size,
93 | uniform_over_groups=uniform_over_groups)
94 | if return_dataset:
95 | return train_loader, val_loader, test_loader, grouper, dataset
96 | return train_loader, val_loader, test_loader, grouper
97 |
98 | def get_test_loader(batch_size=16, split='val', root_dir='data', groupby_fields=['year'], return_dataset=False):
99 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir)
100 | grouper = CombinatorialGrouper(dataset, groupby_fields)
101 |
102 | transform = initialize_image_base_transform(dataset)
103 |
104 | eval_data = get_subset_with_domain(dataset, split, domain=None, transform=transform)
105 | all_domains = list(set(grouper.metadata_to_group(eval_data.dataset.metadata_array[eval_data.indices]).tolist()))
106 | test_loader = []
107 |
108 | for domain in all_domains:
109 | domain_data = get_subset_with_domain(eval_data, split, domain=domain, transform=transform, grouper=grouper)
110 | domain_loader = get_eval_loader('standard', domain_data, batch_size=batch_size)
111 | test_loader.append((domain, domain_loader))
112 |
113 | if return_dataset:
114 | return test_loader, grouper, dataset
115 | return test_loader, grouper
116 |
117 | def get_mask_grouper(root_dir='data'):
118 | dataset = get_dataset(dataset='fmow', download=True, root_dir=root_dir)
119 | grouper = CombinatorialGrouper(dataset, ['year'])
120 | return grouper
121 |
122 | def save_model(model, name, epoch, test_way='ood'):
123 | if not os.path.exists("model/fmow"):
124 | os.makedirs("model/fmow")
125 | path = "model/fmow/{0}_best.pth".format(name)
126 | torch.save(model.state_dict(), path)
127 |
128 | def mask_feat(feat, mask_index, num_experts, exclude=True):
129 | assert feat.shape[0] == mask_index.shape[0]
130 | if exclude:
131 | new_idx = [list(range(0, int(m.item()))) + list(range(int(m.item())+1, num_experts)) for m in mask_index]
132 | return feat[torch.arange(feat.shape[0]).unsqueeze(-1), :, new_idx]
133 | else:
134 | feat[list(range(feat.shape[0])), :, mask_index] = torch.zeros_like(feat[list(range(feat.shape[0])), :, mask_index])
135 | feat = feat.permute((0,2,1))
136 | return feat
137 |
138 | def l2_loss(input, target):
139 | loss = torch.square(target - input)
140 | loss = torch.mean(loss)
141 | return loss
142 |
143 | # Models
144 | class ResNetFeature(nn.Module):
145 | def __init__(self, original_model, layer=-1):
146 | super(ResNetFeature, self).__init__()
147 | self.num_ftrs = original_model.classifier.in_features
148 | self.features = nn.Sequential(*list(original_model.children())[:layer])
149 | self.pool = nn.AdaptiveAvgPool2d((1,1))
150 |
151 | def forward(self, x):
152 | x = self.features(x)
153 | x = self.pool(x)
154 | x = x.view(-1, self.num_ftrs)
155 | return x
156 |
157 | class fa_selector(nn.Module):
158 | def __init__(self, dim, depth, heads, mlp_dim, dropout = 0., out_dim=62, pool='mean'):
159 | super(fa_selector, self).__init__()
160 | self.transformer = Transformer(dim, depth, heads, mlp_dim, dropout=dropout)
161 | self.pool = pool
162 | self.mlp = nn.Sequential(
163 | nn.LayerNorm(dim),
164 | nn.Linear(dim, out_dim)
165 | )
166 |
167 | def forward(self, x):
168 | x = self.transformer(x)
169 | if self.pool == 'mean':
170 | x = x.mean(dim=1)
171 | elif self.pool == 'max':
172 | x = x.max(dim=1)
173 | else:
174 | raise NotImplementedError
175 | x = self.mlp(x)
176 | return x
177 |
178 | def get_feat(self, x):
179 | x = self.transformer(x)
180 | if self.pool == 'mean':
181 | x = x.mean(dim=1)
182 | elif self.pool == 'max':
183 | x = x.max(dim=1)
184 | else:
185 | raise NotImplementedError
186 | return x
187 |
188 | class DivideModel(nn.Module):
189 | def __init__(self, original_model, layer=-1):
190 | super(DivideModel, self).__init__()
191 | self.num_ftrs = original_model.classifier.in_features
192 | self.num_class = original_model.classifier.out_features
193 | self.features = nn.Sequential(*list(original_model.children())[:layer])
194 | self.features.add_module("avg_pool", nn.AdaptiveAvgPool2d((1,1)))
195 | self.classifier = nn.Sequential(*list(original_model.children())[layer:])
196 |
197 | def forward(self, x):
198 | x = self.features(x)
199 | x = x.view(-1, self.num_ftrs)
200 | x = self.classifier(x)
201 | x = x.view(-1, self.num_class)
202 | return x
203 |
204 | def StudentModel(device, num_classes=62, load_path=None):
205 | model = torchvision.models.densenet121(pretrained=True)
206 | num_ftrs = model.classifier.in_features
207 | model.classifier = nn.Linear(num_ftrs, num_classes)
208 | if load_path:
209 | model.load_state_dict(torch.load(load_path))
210 | model = DivideModel(model)
211 | model = model.to(device)
212 | return model
213 |
214 | def get_feature_list(models_list, device):
215 | feature_list = []
216 | for model in models_list:
217 | feature_list.append(ResNetFeature(model).to(device))
218 | return feature_list
219 |
220 | def get_models_list(device, num_domains=3, num_classes=62, pretrained=False, bb='d121'):
221 | models_list = []
222 | for _ in range(num_domains+1):
223 | model = torchvision.models.densenet121(pretrained=True)
224 | num_ftrs = model.classifier.in_features
225 | model.classifier = nn.Linear(num_ftrs, num_classes)
226 | model = model.to(device)
227 | models_list.append(model)
228 | return models_list
229 |
230 | def get_fmow_metrics(y_pred, y_true, metadata, dataset):
231 | worst_acc = float('inf')
232 | grouper = CombinatorialGrouper(dataset, ['region'])
233 | regions = grouper.metadata_to_group(metadata)
234 | unique_groups, group_indices, _ = split_into_groups(regions)
235 | for i_group in group_indices:
236 | correct = y_pred[i_group].eq(y_true[i_group].view_as(y_pred[i_group])).sum().item()
237 | total = y_pred[i_group].shape[0]
238 | worst_acc = min(correct/total, worst_acc)
239 | correct = y_pred.eq(y_true.view_as(y_pred)).sum().item()
240 | total = y_pred.shape[0]
241 | return correct/total, worst_acc
242 |
--------------------------------------------------------------------------------
/src/iwildcam/iwildcam_train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import time
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import torch.optim as optim
7 | import torchvision
8 | from tqdm import tqdm
9 | import copy
10 | import learn2learn as l2l
11 | from sklearn.metrics import f1_score
12 | import time
13 | from datetime import datetime
14 | from collections import defaultdict
15 | import os
16 | import wilds
17 | from src.iwildcam.iwildcam_utils import *
18 |
19 | def train_epoch(selector, selector_name, models_list, student, student_name,
20 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
21 | device, acc_best=0, tlr=1e-4, slr=1e-4, ilr=1e-3,
22 | batch_size=256, sup_size=24, test_way='id', save=False,
23 | root_dir='data'):
24 | for model in models_list:
25 | model.eval()
26 |
27 | student_ce = nn.CrossEntropyLoss()
28 | #teacher_ce = nn.CrossEntropyLoss()
29 |
30 | features = student.features
31 | head = student.classifier
32 | features = l2l.algorithms.MAML(features, lr=ilr)
33 | features.to(device)
34 | head.to(device)
35 |
36 | all_params = list(features.parameters()) + list(head.parameters())
37 | optimizer_s = optim.Adam(all_params, lr=slr)
38 | optimizer_t = optim.Adam(selector.parameters(), lr=tlr)
39 |
40 | i = 0
41 |
42 | losses = []
43 |
44 | iter_per_epoch = len(train_loader)
45 |
46 | for x, y_true, metadata in train_loader:
47 | selector.eval()
48 | head.eval()
49 | features.eval()
50 |
51 | z = grouper.metadata_to_group(metadata)
52 | z = set(z.tolist())
53 | assert len(z) == 1
54 | mask = mask_grouper.metadata_to_group(metadata)
55 | mask.apply_(lambda x: split_to_cluster[x])
56 |
57 | #sup_size = x.shape[0]//2
58 | x_sup = x[:sup_size]
59 | y_sup = y_true[:sup_size]
60 | x_que = x[sup_size:]
61 | y_que = y_true[sup_size:]
62 | mask = mask[:sup_size]
63 |
64 | x_sup = x_sup.to(device)
65 | y_sup = y_sup.to(device)
66 | x_que = x_que.to(device)
67 | y_que = y_que.to(device)
68 |
69 | with torch.no_grad():
70 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
71 | #logits[:, :, split_to_cluster[z]] = torch.zeros_like(logits[:, :, split_to_cluster[z]])
72 | #logits = logits.permute((0,2,1))
73 | logits = mask_feat(logits, mask, len(models_list), exclude=True)
74 |
75 | t_out = selector.get_feat(logits)
76 |
77 | task_model = features.clone()
78 | task_model.module.eval()
79 | feat = task_model(x_que)
80 | feat = feat.view(feat.shape[0], -1)
81 | out = head(feat)
82 | with torch.no_grad():
83 | loss_pre = student_ce(out, y_que).item()/x_que.shape[0]
84 |
85 | feat = task_model(x_sup)
86 | feat = feat.view_as(t_out)
87 |
88 | inner_loss = l2_loss(feat, t_out)
89 | task_model.adapt(inner_loss)
90 |
91 | x_que = task_model(x_que)
92 | x_que = x_que.view(x_que.shape[0], -1)
93 | s_que_out = head(x_que)
94 | s_que_loss = student_ce(s_que_out, y_que)
95 | #t_sup_loss = teacher_ce(t_out, y_sup)
96 |
97 | s_que_loss.backward()
98 |
99 | optimizer_s.step()
100 | optimizer_t.step()
101 | optimizer_s.zero_grad()
102 | optimizer_t.zero_grad()
103 |
104 | #print("Step:{}".format(time.time() - t_1))
105 | #t_1 = time.time()
106 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
107 | f.write('Iter: {}/{}, Loss Before:{:.4f}, Loss After:{:.4f}\r\n'.format(i,iter_per_epoch,
108 | loss_pre,
109 | s_que_loss.item()/x_que.shape[0]))
110 | losses.append(s_que_loss.item()/x_que.shape[0])
111 |
112 | if i == iter_per_epoch//2:
113 | losses = np.mean(losses)
114 | _, acc, f1 = eval(selector, models_list, student, sup_size, device=device,
115 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
116 | root_dir=root_dir)
117 |
118 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
119 | f.write(f'Accuracy: {acc} || F1:{f1} \r\n')
120 | losses = []
121 |
122 | if f1 > acc_best and save:
123 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
124 | f.write("Saving model ...\r\n")
125 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
126 | save_model(student, student_name+"_student", 0, test_way=test_way)
127 | acc_best = f1
128 |
129 | i += 1
130 | return acc_best
131 |
132 | def train_kd(selector, selector_name, models_list, student, student_name, split_to_cluster, device,
133 | batch_size=256, sup_size=24, tlr=1e-4, slr=1e-4, ilr=1e-5, num_epochs=30,
134 | decayRate=0.96, save=False, test_way='ood', root_dir='data'):
135 |
136 | train_loader, _, _, grouper = get_data_loader(root_dir=root_dir, batch_size=batch_size, domain=None,
137 | test_way=test_way, n_groups_per_batch=1)
138 |
139 | mask_grouper = get_mask_grouper(root_dir=root_dir)
140 |
141 | curr = str(datetime.now())
142 | if not os.path.exists("log/iwildcam"):
143 | os.makedirs("log/iwildcam")
144 | #print(curr)
145 | print("Training log saved to log/iwildcam/"+curr+".txt")
146 |
147 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
148 | f.write(selector_name+' '+student_name+'\r\n')
149 | f.write(f'tlr={tlr} slr={slr} ilr={ilr}\r\n')
150 |
151 | accu_best = 0
152 | for epoch in range(num_epochs):
153 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
154 | f.write(f'Epoch: {epoch}\r\n')
155 | accu_epoch = train_epoch(selector, selector_name, models_list, student, student_name,
156 | train_loader, grouper, epoch, curr, mask_grouper, split_to_cluster,
157 | device, acc_best=accu_best, tlr=tlr, slr=slr, ilr=ilr,
158 | batch_size=batch_size, sup_size=sup_size, test_way=test_way, save=save,
159 | root_dir=root_dir)
160 | accu_best = max(accu_best, accu_epoch)
161 | _, accu, f1 = eval(selector, models_list, student, sup_size, device=device,
162 | ilr=ilr, test=False, progress=False, uniform_over_groups=False,
163 | root_dir=root_dir)
164 |
165 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
166 | f.write(f'Accuracy: {accu} || F1:{f1} \r\n')
167 |
168 | if f1 > accu_best and save:
169 | with open('log/iwildcam/'+curr+'.txt', 'a+') as f:
170 | f.write("Saving model ...\r\n")
171 | save_model(selector, selector_name+"_selector", 0, test_way=test_way)
172 | save_model(student, student_name+"_student", 0, test_way=test_way)
173 | accu_best = f1
174 | tlr = tlr*decayRate
175 | slr = slr*decayRate
176 |
177 | def eval(selector, models_list, student, batch_size, device, ilr=1e-5,
178 | test=False, progress=True, uniform_over_groups=False, root_dir='data'):
179 |
180 | if test:
181 | loader, grouper = get_test_loader(batch_size=batch_size, split='test', root_dir=root_dir)
182 | else:
183 | loader, grouper = get_test_loader(batch_size=batch_size, split='val', root_dir=root_dir)
184 | '''if test:
185 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='test')
186 | else:
187 | loader, que_set, grouper = get_test_loader(batch_size=batch_size, test_way='val')'''
188 |
189 |
190 | features = student.features
191 | head = student.classifier
192 | head.to(device)
193 |
194 | student_maml = l2l.algorithms.MAML(features, lr=ilr)
195 | student_maml.to(device)
196 |
197 | correct = defaultdict(int)
198 | total = defaultdict(int)
199 | nor_correct = 0
200 | nor_total = 0
201 | old_domain = {}
202 | if progress:
203 | loader = tqdm(iter(loader), total=len(loader))
204 |
205 | for domain, domain_loader in loader:
206 | adapted = False
207 | for x_sup, y_sup, metadata in domain_loader:
208 | student_maml.module.eval()
209 | selector.eval()
210 | head.eval()
211 |
212 | z = grouper.metadata_to_group(metadata)
213 | z = set(z.tolist())
214 | assert list(z)[0] == domain
215 |
216 | x_sup = x_sup.to(device)
217 | y_sup = y_sup.to(device)
218 |
219 | if not adapted:
220 | task_model = student_maml.clone()
221 | task_model.eval()
222 | with torch.no_grad():
223 | logits = torch.stack([model(x_sup).detach() for model in models_list], dim=-1)
224 | logits = logits.permute((0,2,1))
225 | t_out = selector.get_feat(logits)
226 |
227 | feat = task_model(x_sup)
228 | feat = feat.view_as(t_out)
229 |
230 | kl_loss = l2_loss(feat, t_out)
231 | torch.cuda.empty_cache()
232 | task_model.adapt(kl_loss)
233 | adapted = True
234 |
235 | with torch.no_grad():
236 | task_model.module.eval()
237 | x_sup = task_model(x_sup)
238 | x_sup = x_sup.view(x_sup.shape[0], -1)
239 | s_que_out = head(x_sup)
240 | pred = s_que_out.max(1, keepdim=True)[1]
241 | c = pred.eq(y_sup.view_as(pred)).sum().item()
242 | correct[list(z)[0]] += c
243 | total[list(z)[0]] += x_sup.shape[0]
244 | nor_correct += c
245 | nor_total += x_sup.shape[0]
246 | try:
247 | pred_all = torch.cat((pred_all, pred.view_as(y_sup)))
248 | y_all = torch.cat((y_all, y_sup))
249 | except NameError:
250 | pred_all = pred.view_as(y_sup)
251 | y_all = y_sup
252 |
253 | y_all = y_all.detach().cpu()
254 | pred_all = pred_all.detach().cpu()
255 | f1 = f1_score(y_all,pred_all,average='macro', labels=torch.unique(y_all))
256 | mean_acc = []
257 | for key,value in correct.items():
258 | mean_acc.append(value/total[key])
259 | return np.mean(mean_acc).item(), nor_correct/nor_total, f1
260 |
--------------------------------------------------------------------------------