├── .gitignore
├── LICENSE
├── README.md
├── exps
├── cifar100c.py
├── cifar10_1.py
├── cifar10c.py
├── imagenet-c.py
├── officehome.py
├── pacs.py
└── yearbook.py
├── figs
└── overview of datasets.jpg
├── monitor
├── __init__.py
├── tmux_cluster
│ ├── __init__.py
│ ├── tmux.py
│ └── utils.py
└── tools
│ ├── __init__.py
│ ├── file_io.py
│ ├── plot.py
│ ├── plot_utils.py
│ ├── show_results.py
│ └── utils.py
├── notebooks
└── example.ipynb
├── parameters.py
├── pretrain
├── ssl_pretrain.py
└── third_party
│ ├── __init__.py
│ ├── augmentations.py
│ ├── datasets.py
│ └── utils.py
├── run_exp.py
├── run_exps.py
├── run_extract.py
└── ttab
├── __init__.py
├── api.py
├── benchmark.py
├── configs
├── __init__.py
├── algorithms.py
├── datasets.py
└── utils.py
├── loads
├── .DS_Store
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── cifar
│ │ ├── __init__.py
│ │ ├── data_aug_cifar.py
│ │ └── synthetic.py
│ ├── dataset_sampling.py
│ ├── dataset_shifts.py
│ ├── datasets.py
│ ├── imagenet
│ │ ├── __init__.py
│ │ ├── data_aug_imagenet.py
│ │ ├── synthetic_224.py
│ │ └── synthetic_64.py
│ ├── loaders.py
│ ├── mnist
│ │ ├── __init__.py
│ │ └── data_aug_mnist.py
│ ├── utils
│ │ ├── __init__.py
│ │ ├── c_resource
│ │ │ ├── frost1.png
│ │ │ ├── frost2.png
│ │ │ ├── frost3.png
│ │ │ ├── frost4.jpg
│ │ │ ├── frost5.jpg
│ │ │ └── frost6.jpg
│ │ ├── lmdb.py
│ │ ├── preprocess_toolkit.py
│ │ └── serialize.py
│ └── yearbook
│ │ └── data_aug_yearbook.py
├── define_dataset.py
├── define_model.py
└── models
│ ├── __init__.py
│ ├── cct.py
│ ├── resnet.py
│ ├── utils
│ ├── __init__.py
│ ├── embedder.py
│ ├── helpers.py
│ ├── stochastic_depth.py
│ ├── tokenizer.py
│ └── transformers.py
│ └── wideresnet.py
├── model_adaptation
├── __init__.py
├── base_adaptation.py
├── bn_adapt.py
├── conjugate_pl.py
├── cotta.py
├── eata.py
├── memo.py
├── no_adaptation.py
├── note.py
├── rotta.py
├── sar.py
├── shot.py
├── t3a.py
├── tent.py
├── ttt.py
├── ttt_plus_plus.py
└── utils.py
├── model_selection
├── __init__.py
├── base_selection.py
├── group_metrics.py
├── last_iterate.py
├── metrics.py
└── oracle_model_selection.py
├── scenarios
├── __init__.py
├── default_scenarios.py
└── define_scenario.py
└── utils
├── __init__.py
├── auxiliary.py
├── checkpoint.py
├── early_stopping.py
├── file_io.py
├── logging.py
├── mathdict.py
├── stat_tracker.py
├── tensor_buffer.py
└── timer.py
/.gitignore:
--------------------------------------------------------------------------------
1 | ### vscode ###
2 | .vscode/
3 | !.vscode/settings.json
4 | *.DS_Store
5 |
6 | ### Python ###
7 | __pycache__/
8 |
9 | ### generated files ###
10 | ckpt/
11 | data/
12 | logs/
13 | # pretrain/
14 | runs/
15 | todo/
16 | *.out
17 | *.pth
18 | *.sh
19 | *.ipynb
20 | *.txt
21 |
22 | datasets/
23 |
--------------------------------------------------------------------------------
/exps/cifar100c.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | "cifar100c_episodic_oracle_model_selection",
11 | # "cifar100c_online_last_iterate",
12 | ],
13 | base_data_name=[
14 | "cifar100",
15 | ],
16 | data_names=[
17 | "cifar100_c_deterministic-gaussian_noise-5",
18 | "cifar100_c_deterministic-shot_noise-5",
19 | "cifar100_c_deterministic-impulse_noise-5",
20 | "cifar100_c_deterministic-defocus_blur-5",
21 | "cifar100_c_deterministic-glass_blur-5",
22 | "cifar100_c_deterministic-motion_blur-5",
23 | "cifar100_c_deterministic-zoom_blur-5",
24 | "cifar100_c_deterministic-snow-5",
25 | "cifar100_c_deterministic-frost-5",
26 | "cifar100_c_deterministic-fog-5",
27 | "cifar100_c_deterministic-brightness-5",
28 | "cifar100_c_deterministic-contrast-5",
29 | "cifar100_c_deterministic-elastic_transform-5",
30 | "cifar100_c_deterministic-pixelate-5",
31 | "cifar100_c_deterministic-jpeg_compression-5",
32 | ],
33 | model_name=[
34 | "resnet26",
35 | ],
36 | model_adaptation_method=[
37 | # "no_adaptation",
38 | "tent",
39 | # "bn_adapt",
40 | # "t3a",
41 | # "memo",
42 | # "shot",
43 | # "ttt",
44 | # "note",
45 | # "sar",
46 | # "conjugate_pl",
47 | # "cotta",
48 | # "eata",
49 | ],
50 | model_selection_method=[
51 | "oracle_model_selection",
52 | # "last_iterate",
53 | ],
54 | offline_pre_adapt=[
55 | "false",
56 | ],
57 | data_wise=["batch_wise"],
58 | batch_size=[64],
59 | episodic=[
60 | # "false",
61 | "true",
62 | ],
63 | inter_domain=["HomogeneousNoMixture"],
64 | non_iid_ness=[0.1],
65 | non_iid_pattern=["class_wise_over_domain"],
66 | python_path=["/opt/conda/bin/python"],
67 | data_path=["/run/determined/workdir/data/"],
68 | ckpt_path=[
69 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar100/rn26_bn.pth",
70 | ],
71 | # oracle_model_selection
72 | lr_grid=[
73 | [1e-3],
74 | [5e-4],
75 | [1e-4],
76 | ],
77 | n_train_steps=[50],
78 | # last_iterate
79 | # lr=[
80 | # 5e-3,
81 | # 1e-3,
82 | # 5e-4,
83 | # ],
84 | # n_train_steps=[
85 | # 1,
86 | # 2,
87 | # 3,
88 | # ],
89 | intra_domain_shuffle=["true"],
90 | record_preadapted_perf=["true"],
91 | device=[
92 | "cuda:0",
93 | ],
94 | gradient_checkpoint=["false"],
95 | )
96 |
--------------------------------------------------------------------------------
/exps/cifar10_1.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | "cifar10_1_episodic_oracle_model_selection",
11 | # "cifar10_1_online_last_iterate",
12 | ],
13 | base_data_name=[
14 | "cifar10",
15 | ],
16 | data_names=[
17 | "cifar10_1",
18 | ],
19 | model_name=[
20 | "resnet26",
21 | ],
22 | model_adaptation_method=[
23 | # "no_adaptation",
24 | "tent",
25 | # "bn_adapt",
26 | # "t3a",
27 | # "memo",
28 | # "shot",
29 | # "ttt",
30 | # "note",
31 | # "sar",
32 | # "conjugate_pl",
33 | # "cotta",
34 | # "eata",
35 | ],
36 | model_selection_method=[
37 | "oracle_model_selection",
38 | # "last_iterate",
39 | ],
40 | offline_pre_adapt=[
41 | "false",
42 | ],
43 | data_wise=["batch_wise"],
44 | batch_size=[64],
45 | episodic=[
46 | # "false",
47 | "true",
48 | ],
49 | inter_domain=["HomogeneousNoMixture"],
50 | non_iid_ness=[0.1],
51 | non_iid_pattern=["class_wise_over_domain"],
52 | python_path=["/opt/conda/bin/python"],
53 | data_path=["/run/determined/workdir/data/"],
54 | ckpt_path=[
55 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth",
56 | ],
57 | # oracle_model_selection
58 | lr_grid=[
59 | [1e-3],
60 | [5e-4],
61 | [1e-4],
62 | ],
63 | n_train_steps=[50],
64 | # last_iterate
65 | # lr=[
66 | # 5e-3,
67 | # 1e-3,
68 | # 5e-4,
69 | # ],
70 | # n_train_steps=[
71 | # 1,
72 | # 2,
73 | # 3,
74 | # ],
75 | intra_domain_shuffle=["true"],
76 | record_preadapted_perf=["true"],
77 | device=[
78 | "cuda:0",
79 | ],
80 | gradient_checkpoint=["false"],
81 | )
82 |
--------------------------------------------------------------------------------
/exps/cifar10c.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | "cifar10c_episodic_oracle_model_selection",
11 | # "cifar10c_online_last_iterate",
12 | ],
13 | base_data_name=[
14 | "cifar10",
15 | ],
16 | data_names=[
17 | "cifar10_c_deterministic-gaussian_noise-5",
18 | "cifar10_c_deterministic-shot_noise-5",
19 | "cifar10_c_deterministic-impulse_noise-5",
20 | "cifar10_c_deterministic-defocus_blur-5",
21 | "cifar10_c_deterministic-glass_blur-5",
22 | "cifar10_c_deterministic-motion_blur-5",
23 | "cifar10_c_deterministic-zoom_blur-5",
24 | "cifar10_c_deterministic-snow-5",
25 | "cifar10_c_deterministic-frost-5",
26 | "cifar10_c_deterministic-fog-5",
27 | "cifar10_c_deterministic-brightness-5",
28 | "cifar10_c_deterministic-contrast-5",
29 | "cifar10_c_deterministic-elastic_transform-5",
30 | "cifar10_c_deterministic-pixelate-5",
31 | "cifar10_c_deterministic-jpeg_compression-5",
32 | ],
33 | model_name=[
34 | "resnet26",
35 | ],
36 | model_adaptation_method=[
37 | # "no_adaptation",
38 | "tent",
39 | # "bn_adapt",
40 | # "t3a",
41 | # "memo",
42 | # "shot",
43 | # "ttt",
44 | # "note",
45 | # "sar",
46 | # "conjugate_pl",
47 | # "cotta",
48 | # "eata",
49 | ],
50 | model_selection_method=[
51 | "oracle_model_selection",
52 | # "last_iterate",
53 | ],
54 | offline_pre_adapt=[
55 | "false",
56 | ],
57 | data_wise=["batch_wise"],
58 | batch_size=[64],
59 | episodic=[
60 | # "false",
61 | "true",
62 | ],
63 | inter_domain=["HomogeneousNoMixture"],
64 | non_iid_ness=[0.1],
65 | non_iid_pattern=["class_wise_over_domain"],
66 | python_path=["/home/ttab/anaconda3/envs/test_algo/bin/python"],
67 | data_path=["./datasets/"],
68 | ckpt_path=[
69 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth",
70 | ],
71 | # oracle_model_selection
72 | lr_grid=[
73 | [1e-3],
74 | [5e-4],
75 | [1e-4],
76 | ],
77 | n_train_steps=[10],
78 | # last_iterate
79 | # lr=[
80 | # 5e-3,
81 | # 1e-3,
82 | # 5e-4,
83 | # ],
84 | # n_train_steps=[
85 | # 1,
86 | # 2,
87 | # 3,
88 | # ],
89 | intra_domain_shuffle=["true"],
90 | record_preadapted_perf=["true"],
91 | device=[
92 | "cuda:0",
93 | ],
94 | gradient_checkpoint=["false"],
95 | )
96 |
--------------------------------------------------------------------------------
/exps/imagenet-c.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | "imagenet_c_episodic_oracle_model_selection",
11 | # "imagenet_c_online_last_iterate",
12 | ],
13 | base_data_name=[
14 | "imagenet",
15 | ],
16 | data_names=[
17 | "imagenet_c_deterministic-gaussian_noise-5",
18 | "imagenet_c_deterministic-shot_noise-5",
19 | "imagenet_c_deterministic-impulse_noise-5",
20 | "imagenet_c_deterministic-defocus_blur-5",
21 | "imagenet_c_deterministic-glass_blur-5",
22 | "imagenet_c_deterministic-motion_blur-5",
23 | "imagenet_c_deterministic-zoom_blur-5",
24 | "imagenet_c_deterministic-snow-5",
25 | "imagenet_c_deterministic-frost-5",
26 | "imagenet_c_deterministic-fog-5",
27 | "imagenet_c_deterministic-brightness-5",
28 | "imagenet_c_deterministic-contrast-5",
29 | "imagenet_c_deterministic-elastic_transform-5",
30 | "imagenet_c_deterministic-pixelate-5",
31 | "imagenet_c_deterministic-jpeg_compression-5",
32 | ],
33 | model_name=[
34 | "resnet50",
35 | ],
36 | model_adaptation_method=[
37 | # "no_adaptation",
38 | "tent",
39 | # "bn_adapt",
40 | # "t3a",
41 | # "memo",
42 | # "shot",
43 | # "ttt",
44 | # "note",
45 | # "sar",
46 | # "conjugate_pl",
47 | # "cotta",
48 | # "eata",
49 | ],
50 | model_selection_method=[
51 | "oracle_model_selection",
52 | # "last_iterate",
53 | ],
54 | offline_pre_adapt=[
55 | "false",
56 | ],
57 | data_wise=["batch_wise"],
58 | batch_size=[64],
59 | episodic=[
60 | # "false",
61 | "true",
62 | ],
63 | inter_domain=["HomogeneousNoMixture"],
64 | non_iid_ness=[0.1],
65 | non_iid_pattern=["class_wise_over_domain"],
66 | python_path=["/opt/conda/bin/python"],
67 | data_path=["/run/determined/workdir/data/"],
68 | ckpt_path=[
69 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth", # Since ttab will automatically download the pretrained model from torchvision or huggingface, what ckpt_path is here does not matter.
70 | ],
71 | # oracle_model_selection
72 | lr_grid=[
73 | [1e-3],
74 | [5e-4],
75 | [1e-4],
76 | ],
77 | n_train_steps=[10],
78 | # last_iterate
79 | # lr=[
80 | # 5e-3,
81 | # 1e-3,
82 | # 5e-4,
83 | # ],
84 | # n_train_steps=[
85 | # 1,
86 | # 2,
87 | # 3,
88 | # ],
89 | intra_domain_shuffle=["true"],
90 | record_preadapted_perf=["true"],
91 | device=[
92 | "cuda:0",
93 | "cuda:1",
94 | "cuda:2",
95 | "cuda:3",
96 | "cuda:4",
97 | "cuda:5",
98 | "cuda:6",
99 | "cuda:7",
100 | ],
101 | gradient_checkpoint=["false"],
102 | )
103 |
--------------------------------------------------------------------------------
/exps/officehome.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | "officehome_episodic_oracle_model_selection",
11 | # "officehome_online_last_iterate",
12 | ],
13 | base_data_name=[
14 | "officehome",
15 | ],
16 | data_names=[
17 | "officehome_clipart",
18 | "officehome_product",
19 | "officehome_realworld",
20 | "officehome_art",
21 | "officehome_product",
22 | "officehome_realworld",
23 | "officehome_art",
24 | "officehome_clipart",
25 | "officehome_realworld",
26 | "officehome_art",
27 | "officehome_clipart",
28 | "officehome_product",
29 | ],
30 | model_name=[
31 | "resnet50",
32 | ],
33 | model_adaptation_method=[
34 | # "no_adaptation",
35 | "tent",
36 | # "bn_adapt",
37 | # "t3a",
38 | # "memo",
39 | # "shot",
40 | # "ttt",
41 | # "sar",
42 | # "conjugate_pl",
43 | # "note",
44 | # "cotta",
45 | # "eata",
46 | ],
47 | model_selection_method=[
48 | "oracle_model_selection",
49 | # "last_iterate",
50 | ],
51 | offline_pre_adapt=[
52 | "false",
53 | ],
54 | data_wise=["batch_wise"],
55 | batch_size=[64],
56 | episodic=[
57 | # "false",
58 | "true",
59 | ],
60 | inter_domain=["HomogeneousNoMixture"],
61 | python_path=["/opt/conda/bin/python"],
62 | data_path=["/run/determined/workdir/data/"],
63 | ckpt_path=[
64 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_art.pth",
65 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_art.pth",
66 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_art.pth",
67 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_clipart.pth",
68 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_clipart.pth",
69 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_clipart.pth",
70 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_product.pth",
71 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_product.pth",
72 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_product.pth",
73 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_realworld.pth",
74 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_realworld.pth",
75 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_realworld.pth",
76 | ],
77 | # oracle_model_selection
78 | lr_grid=[
79 | [1e-3],
80 | [5e-4],
81 | [1e-4],
82 | ],
83 | n_train_steps=[25],
84 | # last_iterate
85 | # lr=[
86 | # 5e-3,
87 | # 1e-3,
88 | # 5e-4,
89 | # ],
90 | # n_train_steps=[
91 | # 1,
92 | # 2,
93 | # 3,
94 | # ],
95 | entry_of_shared_layers=["layer3"],
96 | intra_domain_shuffle=["true"],
97 | record_preadapted_perf=["true"],
98 | device=[
99 | "cuda:0",
100 | ],
101 | grad_checkpoint=["false"],
102 | coupled=[
103 | "data_names",
104 | "ckpt_path",
105 | ],
106 | )
107 |
--------------------------------------------------------------------------------
/exps/pacs.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | # "pacs_online_last_iterate",
11 | "pacs_episodic_oracle_model_selection"
12 | ],
13 | base_data_name=[
14 | "pacs",
15 | ],
16 | data_names=[
17 | "pacs_cartoon",
18 | "pacs_photo",
19 | "pacs_sketch",
20 | "pacs_art",
21 | "pacs_photo",
22 | "pacs_sketch",
23 | "pacs_art",
24 | "pacs_cartoon",
25 | "pacs_sketch",
26 | "pacs_art",
27 | "pacs_cartoon",
28 | "pacs_photo",
29 | ],
30 | model_name=[
31 | "resnet50",
32 | ],
33 | model_adaptation_method=[
34 | # "no_adaptation",
35 | "tent",
36 | # "bn_adapt",
37 | # "t3a",
38 | # "memo",
39 | # "shot",
40 | # "ttt",
41 | # "sar",
42 | # "conjugate_pl",
43 | # "note",
44 | # "cotta",
45 | # "eata",
46 | ],
47 | model_selection_method=[
48 | "oracle_model_selection",
49 | # "last_iterate",
50 | ],
51 | offline_pre_adapt=[
52 | "false",
53 | ],
54 | data_wise=["batch_wise"],
55 | batch_size=[64],
56 | episodic=[
57 | # "false",
58 | "true",
59 | ],
60 | inter_domain=["HomogeneousNoMixture"],
61 | python_path=["/opt/conda/bin/python"],
62 | data_path=["/run/determined/workdir/data/"],
63 | ckpt_path=[
64 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_art.pth",
65 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_art.pth",
66 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_art.pth",
67 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_cartoon.pth",
68 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_cartoon.pth",
69 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_cartoon.pth",
70 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_photo.pth",
71 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_photo.pth",
72 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_photo.pth",
73 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_sketch.pth",
74 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_sketch.pth",
75 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_sketch.pth",
76 | ],
77 | # oracle_model_selection
78 | lr_grid=[
79 | [1e-3],
80 | [5e-4],
81 | [1e-4],
82 | ],
83 | n_train_steps=[25],
84 | # last_iterate
85 | # lr=[
86 | # 5e-3,
87 | # 1e-3,
88 | # 5e-4,
89 | # ],
90 | # n_train_steps=[
91 | # 1,
92 | # 2,
93 | # 3,
94 | # ],
95 | entry_of_shared_layers=["layer3"],
96 | intra_domain_shuffle=["true"],
97 | record_preadapted_perf=["true"],
98 | device=[
99 | "cuda:0",
100 | ],
101 | grad_checkpoint=["false"],
102 | coupled=[
103 | "data_names",
104 | "ckpt_path",
105 | ],
106 | )
107 |
--------------------------------------------------------------------------------
/exps/yearbook.py:
--------------------------------------------------------------------------------
1 | class NewConf(object):
2 | # create the list of hyper-parameters to be replaced.
3 | to_be_replaced = dict(
4 | # general for world.
5 | seed=[2022, 2023, 2024],
6 | main_file=[
7 | "run_exp.py",
8 | ],
9 | job_name=[
10 | "yearbook_episodic_oracle_model_selection",
11 | # "yearbook_online_last_iterate",
12 | ],
13 | base_data_name=[
14 | "yearbook",
15 | ],
16 | data_names=[
17 | "yearbook",
18 | ],
19 | model_name=[
20 | "resnet18",
21 | ],
22 | model_adaptation_method=[
23 | # "no_adaptation",
24 | "tent",
25 | # "bn_adapt",
26 | # "t3a",
27 | # "memo",
28 | # "shot",
29 | # "ttt",
30 | # "sar",
31 | # "conjugate_pl",
32 | # "note",
33 | # "cotta",
34 | # "eata",
35 | ],
36 | model_selection_method=[
37 | "oracle_model_selection",
38 | # "last_iterate",
39 | ],
40 | offline_pre_adapt=[
41 | "false",
42 | ],
43 | data_wise=["batch_wise"],
44 | batch_size=[64],
45 | episodic=[
46 | # "false",
47 | "true",
48 | ],
49 | inter_domain=["HomogeneousNoMixture"],
50 | python_path=["/home/ttab/anaconda3/envs/test_algo/bin/python"],
51 | data_path=["./datasets"],
52 | ckpt_path=[
53 | "./pretrain/checkpoint/resnet18_with_head/yearbook/resnet18_bn.pth",
54 | ],
55 | # oracle_model_selection
56 | lr_grid=[
57 | [1e-3],
58 | [5e-4],
59 | [1e-4],
60 | ],
61 | n_train_steps=[50],
62 | # last_iterate
63 | # lr=[
64 | # 5e-3,
65 | # 1e-3,
66 | # 5e-4,
67 | # ],
68 | # n_train_steps=[
69 | # 1,
70 | # 2,
71 | # 3,
72 | # ],
73 | entry_of_shared_layers=["layer3"],
74 | intra_domain_shuffle=["false"],
75 | record_preadapted_perf=["true"],
76 | device=[
77 | "cuda:0",
78 | ],
79 | grad_checkpoint=["false"],
80 | )
81 |
--------------------------------------------------------------------------------
/figs/overview of datasets.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/figs/overview of datasets.jpg
--------------------------------------------------------------------------------
/monitor/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/monitor/tmux_cluster/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/monitor/tmux_cluster/__init__.py
--------------------------------------------------------------------------------
/monitor/tmux_cluster/tmux.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from monitor.tmux_cluster.utils import ossystem
3 | import shlex
4 |
5 | TASKDIR_PREFIX = "/tmp/tasklogs"
6 |
7 |
8 | def exec_on_node(cmds, host="localhost"):
9 | def _decide_node(cmd):
10 | return cmd if host == "localhost" else f"ssh {host} -t {shlex.quote(cmd)}"
11 |
12 | cmds = (
13 | [_decide_node(cmd) for cmd in cmds]
14 | if isinstance(cmds, list)
15 | else _decide_node(cmds)
16 | )
17 | ossystem(cmds)
18 |
19 |
20 | class Run(object):
21 | def __init__(self, name, job_node="localhost"):
22 | self.name = name
23 | self.jobs = []
24 | self.job_node = job_node
25 |
26 | def make_job(self, job_name, task_scripts, run=True, **kwargs):
27 | num_tasks = len(task_scripts)
28 | assert num_tasks > 0
29 |
30 | if kwargs:
31 | print("Warning: unused kwargs", kwargs)
32 |
33 | # Creating cmds
34 | cmds = []
35 | session_name = self.name + "-" + job_name # tmux can't use . in name
36 | cmds.append(f"tmux kill-session -t {session_name}")
37 |
38 | windows = []
39 | for task_id in range(num_tasks):
40 | if task_id == 0:
41 | cmds.append(f"tmux new-session -s {session_name} -n {task_id} -d")
42 | else:
43 | cmds.append(f"tmux new-window -t {session_name} -n {task_id}")
44 | windows.append(f"{session_name}:{task_id}")
45 |
46 | job = Job(self, job_name, windows, task_scripts, self.job_node)
47 | job.make_tasks()
48 | self.jobs.append(job)
49 | if run:
50 | for job in self.jobs:
51 | cmds += job.cmds
52 | exec_on_node(cmds, self.job_node)
53 | return job
54 |
55 | def attach_job(self):
56 | raise NotImplementedError
57 |
58 | def kill_jobs(self):
59 | cmds = []
60 | for job in self.jobs:
61 | session_name = self.name + "-" + job.name
62 | cmds.append(f"tmux kill-session -t {session_name}")
63 | exec_on_node(cmds, self.job_node)
64 |
65 |
66 | class Job(object):
67 | def __init__(self, run, name, windows, task_scripts, job_node):
68 | self._run = run
69 | self.name = name
70 | self.job_node = job_node
71 | self.windows = windows
72 | self.task_scripts = task_scripts
73 | self.tasks = []
74 |
75 | def make_tasks(self):
76 | for task_id, (window, script) in enumerate(
77 | zip(self.windows, self.task_scripts)
78 | ):
79 | self.tasks.append(
80 | Task(
81 | window,
82 | self,
83 | task_id,
84 | install_script=script,
85 | task_node=self.job_node,
86 | )
87 | )
88 |
89 | def attach_tasks(self):
90 | raise NotImplementedError
91 |
92 | @property
93 | def cmds(self):
94 | output = []
95 | for task in self.tasks:
96 | output += task.cmds
97 | return output
98 |
99 |
100 | class Task(object):
101 | """Local tasks interact with tmux session.
102 |
103 | * session name is derived from job name, and window names are task ids.
104 | * no pane is used.
105 |
106 | """
107 |
108 | def __init__(self, window, job, task_id, install_script, task_node):
109 | self.window = window
110 | self.job = job
111 | self.id = task_id
112 | self.install_script = install_script
113 | self.task_node = task_node
114 |
115 | # Path
116 | self.cmds = []
117 | self._run_counter = 0
118 |
119 | for line in install_script.split("\n"):
120 | self.run(line)
121 |
122 | def run(self, cmd):
123 | self._run_counter += 1
124 |
125 | cmd = cmd.strip()
126 | if not cmd or cmd.startswith("#"):
127 | # ignore empty command lines
128 | # ignore commented out lines
129 | return
130 |
131 | modified_cmd = cmd
132 | self.cmds.append(
133 | f"tmux send-keys -t {self.window} {shlex.quote(modified_cmd)} Enter"
134 | )
135 |
136 | def upload(self, source_fn, target_fn="."):
137 | raise NotImplementedError()
138 |
139 | def download(self, source_fn, target_fn="."):
140 | raise NotImplementedError()
141 |
--------------------------------------------------------------------------------
/monitor/tmux_cluster/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import yaml
3 | import os
4 | import time
5 | from tqdm import tqdm
6 |
7 |
8 | def ossystem(cmds):
9 | if isinstance(cmds, str):
10 | print(f"\n=> {cmds}")
11 | os.system(cmds)
12 | elif isinstance(cmds, list):
13 | for cmd in tqdm(cmds):
14 | ossystem(cmd)
15 | else:
16 | raise NotImplementedError(
17 | "Cmds should be string or list of str. Got {}.".format(cmds)
18 | )
19 |
20 |
21 | def environ(env):
22 | return os.getenv(env)
23 |
24 |
25 | def load_yaml(file):
26 | with open(file) as f:
27 | return yaml.safe_load(f)
28 |
29 |
30 | def wait_for_file(fn, max_wait_sec=600, check_interval=0.02):
31 | start_time = time.time()
32 | while True:
33 | if time.time() - start_time > max_wait_sec:
34 | assert False, "Timeout %s exceeded" % (max_wait_sec)
35 | if not os.path.exists(fn):
36 | time.sleep(check_interval)
37 | continue
38 | else:
39 | break
40 |
41 |
42 | if __name__ == "__main__":
43 | ossystem("ls")
44 |
--------------------------------------------------------------------------------
/monitor/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/monitor/tools/__init__.py
--------------------------------------------------------------------------------
/monitor/tools/file_io.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import shutil
4 | import json
5 | import pickle
6 |
7 |
8 | def read_text_withoutsplit(path):
9 | """read text file from path."""
10 | with open(path, "r") as f:
11 | return f.read()
12 |
13 |
14 | def read_txt(path):
15 | """read text file from path."""
16 | with open(path, "r") as f:
17 | return f.read().splitlines()
18 |
19 |
20 | def write_txt(data, out_path, type="w"):
21 | """write the data to the txt file."""
22 | with open(out_path, type) as f:
23 | f.write(data)
24 |
25 |
26 | def load_pickle(path):
27 | """load data by pickle."""
28 | with open(path, "rb") as handle:
29 | return pickle.load(handle)
30 |
31 |
32 | def write_pickle(data, path):
33 | """dump file to dir."""
34 | print("write --> data to path: {}\n".format(path))
35 | with open(path, "wb") as handle:
36 | pickle.dump(data, handle)
37 |
38 |
39 | """json related."""
40 |
41 |
42 | def read_json(path):
43 | """read json file from path."""
44 | with open(path, "r") as f:
45 | return json.load(f)
46 |
47 |
48 | def is_jsonable(x):
49 | try:
50 | json.dumps(x)
51 | return True
52 | except:
53 | return False
54 |
55 |
56 | """operate dir."""
57 |
58 |
59 | def build_dir(path, force):
60 | """build directory."""
61 | if os.path.exists(path) and force:
62 | shutil.rmtree(path)
63 | os.mkdir(path)
64 | elif not os.path.exists(path):
65 | os.mkdir(path)
66 | return path
67 |
68 |
69 | def build_dirs(path):
70 | try:
71 | os.makedirs(path)
72 | except Exception as e:
73 | print(" encounter error: {}".format(e))
74 |
75 |
76 | def remove_folder(path):
77 | try:
78 | shutil.rmtree(path)
79 | except Exception as e:
80 | print(" encounter error: {}".format(e))
81 |
82 |
83 | def list_files(root_path):
84 | dirs = os.listdir(root_path)
85 | return [os.path.join(root_path, path) for path in dirs]
86 |
--------------------------------------------------------------------------------
/monitor/tools/plot.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 |
4 | from monitor.tools.show_results import reorder_records
5 | from monitor.tools.plot_utils import (
6 | determine_color_and_lines,
7 | plot_one_case,
8 | smoothing_func,
9 | configure_figure,
10 | build_legend,
11 | )
12 |
13 |
14 | """plot the curve in terms of time."""
15 |
16 |
17 | def plot_curve_wrt_time(
18 | ax,
19 | records,
20 | x_wrt_sth,
21 | y_wrt_sth,
22 | xlabel,
23 | ylabel,
24 | title=None,
25 | markevery_list=None,
26 | is_smooth=True,
27 | smooth_space=100,
28 | l_subset=0.0,
29 | r_subset=1.0,
30 | reorder_record_item=None,
31 | remove_duplicate=True,
32 | n_by_line=1,
33 | n_by_color=1,
34 | font_shift=0,
35 | has_legend=True,
36 | legend=None,
37 | legend_loc="lower right",
38 | legend_ncol=2,
39 | bbox_to_anchor=[0, 0],
40 | ylimit_bottom=None,
41 | ylimit_top=None,
42 | use_log=False,
43 | ):
44 | """Each info consists of
45 | ['tr_loss', 'tr_top1', 'tr_time', 'te_top1', 'te_step', 'te_time'].
46 | """
47 | # parse a list of records.
48 | distinct_conf_set = set()
49 |
50 | # re-order the records.
51 | if reorder_record_item is not None:
52 | records = reorder_records(records, reorder_on=reorder_record_item)
53 |
54 | count = 0
55 | for ind, (args, info) in enumerate(records):
56 | y_wrt_sth_list = y_wrt_sth.split(",")
57 | # check.
58 | if len(y_wrt_sth_list) > 1:
59 | assert y_wrt_sth_list[-1].endswith("preadapt_accuracy"), "the second item of y_wrt_sth must end with preadapt_accuracy."
60 | assert len(y_wrt_sth_list) == 2, "only support two arguments in y_wrt_sth."
61 | assert len(records) <= 4, "only support # records <= 4 when having two arguments in y_wrt_sth" # otherwise, the color and line_style will be a mess.
62 |
63 | for i in range(len(y_wrt_sth_list)):
64 | y_wrt_sth_i = y_wrt_sth_list[i]
65 | # build legend.
66 | _legend = build_legend(args, legend)
67 | if len(y_wrt_sth_list)>1 and i==0:
68 | _legend = ", ".join([_legend, "after_adapt"])
69 | elif i==1:
70 | _legend = ", ".join([_legend, "before_adapt"])
71 |
72 | if _legend in distinct_conf_set and remove_duplicate:
73 | continue
74 | else:
75 | distinct_conf_set.add(_legend)
76 |
77 | # determine the style of line, color and marker.
78 | line_style, color_style, mark_style = determine_color_and_lines(
79 | n_by_line, n_by_color, index=count
80 | )
81 |
82 | if len(y_wrt_sth_list)>1 and i==0:
83 | line_style = "-"
84 | elif i==1:
85 | line_style = "--"
86 |
87 | if markevery_list is not None:
88 | mark_every = markevery_list[ind]
89 | else:
90 | mark_style = None
91 | mark_every = None
92 |
93 | # determine if we want to smooth the curve.
94 | if "train-step" in x_wrt_sth or "train-epoch" in x_wrt_sth:
95 | info["train-step"] = list(range(1, 1 + len(info["train-loss"])))
96 | if "train-epoch" == x_wrt_sth:
97 | x = info["train-step"]
98 | x = [1.0 * _x / args["num_batches_train_per_device_per_epoch"] for _x in x]
99 | else:
100 | x = info[x_wrt_sth]
101 | if "time" in x_wrt_sth:
102 | x = [(time - x[0]).seconds + 1 for time in x]
103 | y = info[y_wrt_sth_i]
104 |
105 | if is_smooth:
106 | x, y = smoothing_func(x, y, smooth_space)
107 |
108 | # only plot subtset.
109 | _l_subset, _r_subset = int(len(x) * l_subset), int(len(x) * r_subset)
110 | _x = x[_l_subset:_r_subset]
111 | _y = y[_l_subset:_r_subset]
112 |
113 | # use log scale for y
114 | if use_log:
115 | _y = np.log10(_y)
116 |
117 | # plot
118 | ax = plot_one_case(
119 | ax,
120 | x=_x,
121 | y=_y,
122 | label=_legend,
123 | line_style=line_style,
124 | color_style=color_style,
125 | mark_style=mark_style,
126 | mark_every=mark_every,
127 | remove_duplicate=remove_duplicate,
128 | )
129 | count += 1
130 |
131 | ax.set_ylim(bottom=ylimit_bottom, top=ylimit_top)
132 | ax = configure_figure(
133 | ax,
134 | xlabel=xlabel,
135 | ylabel=ylabel,
136 | title=title,
137 | has_legend=has_legend,
138 | legend_loc=legend_loc,
139 | legend_ncol=legend_ncol,
140 | bbox_to_anchor=bbox_to_anchor,
141 | font_shift=font_shift,
142 | )
143 | return ax
144 |
--------------------------------------------------------------------------------
/monitor/tools/plot_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from itertools import groupby
3 |
4 | import numpy as np
5 | from matplotlib.lines import Line2D
6 | import seaborn as sns
7 |
8 | """operate x and y."""
9 |
10 |
11 | def smoothing_func(x, y, smooth_length=10):
12 | def smoothing(end_index):
13 | # print(end_index)
14 | if end_index - smooth_length < 0:
15 | start_index = 0
16 | else:
17 | start_index = end_index - smooth_length
18 |
19 | data = y[start_index:end_index]
20 | if len(data) == 0:
21 | return y[start_index]
22 | else:
23 | return 1.0 * sum(data) / len(data)
24 |
25 | # smooth curve
26 | x_, y_ = [], []
27 |
28 | for end_ind in range(0, len(x)):
29 | x_.append(x[end_ind])
30 | y_.append(smoothing(end_ind))
31 | return x_, y_
32 |
33 |
34 | def reject_outliers(data, threshold=3):
35 | return data[abs(data - np.mean(data)) < threshold * np.std(data)]
36 |
37 |
38 | def groupby_indices(results, grouper):
39 | """group by indices and select the subset parameters"""
40 | out = []
41 | for key, group in groupby(sorted(results, key=grouper), grouper):
42 | group_item = list(group)
43 | out += [(key, group_item)]
44 | return out
45 |
46 |
47 | def find_same_num_sync(num_update_steps_and_local_step):
48 | list_of_num_sync = [
49 | num_update_steps // local_step
50 | for num_update_steps, local_step in num_update_steps_and_local_step
51 | ]
52 | return min(list_of_num_sync)
53 |
54 |
55 | def sample_from_records(x, y, local_step, max_same_num_sync):
56 | # cut the records.
57 | if max_same_num_sync is not None:
58 | x = x[: local_step * max_same_num_sync]
59 | y = y[: local_step * max_same_num_sync]
60 | return x[::local_step], y[::local_step]
61 |
62 |
63 | def drop_first_few(x, y, num_drop):
64 | return x[num_drop:], y[num_drop:]
65 |
66 |
67 | def rebuild_runtime_record(times):
68 | times = [(time - times[0]).seconds + 1 for time in times]
69 | return times
70 |
71 |
72 | def add_communication_delay(times, local_step, delay_factor):
73 | """add communication delay to original time."""
74 | return [
75 | time + delay_factor * ((ind + 1) // local_step)
76 | for ind, time in enumerate(times)
77 | ]
78 |
79 |
80 | """plot style related."""
81 |
82 |
83 | def determine_color_and_lines(n_by_line, n_by_color, index):
84 | line_styles = ["-", "--", "-.", ":"]
85 | color_styles = ["b", "g", "r", "k"]
86 |
87 | # safety check
88 | total_num_combs = len(line_styles) * len(color_styles)
89 | assert index + 1 <= total_num_combs
90 |
91 | if n_by_line >= n_by_color:
92 | line = index // n_by_line
93 | color = (index % n_by_line) // n_by_color
94 | else:
95 | color = index // n_by_color
96 | line = (index % n_by_color) // n_by_line
97 | return line_styles[line], color_styles[color], Line2D.filled_markers[index]
98 |
99 |
100 | def configure_figure(
101 | ax,
102 | xlabel,
103 | ylabel,
104 | title=None,
105 | has_legend=True,
106 | legend_loc="lower right",
107 | legend_ncol=2,
108 | bbox_to_anchor=[0, 0],
109 | font_shift=0,
110 | ):
111 | if has_legend:
112 | ax.legend(
113 | loc=legend_loc,
114 | bbox_to_anchor=bbox_to_anchor,
115 | ncol=legend_ncol,
116 | shadow=True,
117 | fancybox=True,
118 | fontsize=20 + font_shift,
119 | )
120 |
121 | ax.set_xlabel(xlabel, fontsize=24 + font_shift, labelpad=18 + font_shift)
122 | ax.set_ylabel(ylabel, fontsize=24 + font_shift, labelpad=18 + font_shift)
123 |
124 | if title is not None:
125 | ax.set_title(title, fontsize=24 + font_shift)
126 | ax.xaxis.set_tick_params(labelsize=22 + font_shift)
127 | ax.yaxis.set_tick_params(labelsize=22 + font_shift)
128 | return ax
129 |
130 |
131 | def plot_one_case(
132 | ax,
133 | label,
134 | line_style,
135 | color_style,
136 | mark_style,
137 | line_width=2.0,
138 | mark_every=5000,
139 | x=None,
140 | y=None,
141 | sns_plot=None,
142 | remove_duplicate=False,
143 | ):
144 | if sns_plot is not None and not remove_duplicate:
145 | ax = sns.lineplot(
146 | x="x",
147 | y="y",
148 | data=sns_plot,
149 | label=label,
150 | linewidth=line_width,
151 | linestyle=line_style,
152 | color=color_style,
153 | marker=mark_style,
154 | markevery=mark_every,
155 | markersize=16,
156 | ax=ax,
157 | )
158 | elif sns_plot is not None and remove_duplicate:
159 | ax = sns.lineplot(
160 | x="x",
161 | y="y",
162 | data=sns_plot,
163 | label=label,
164 | linewidth=line_width,
165 | linestyle=line_style,
166 | color=color_style,
167 | marker=mark_style,
168 | markevery=mark_every,
169 | markersize=16,
170 | ax=ax,
171 | estimator=None,
172 | )
173 | else:
174 | ax.plot(
175 | x,
176 | y,
177 | label=label,
178 | linewidth=line_width,
179 | linestyle=line_style,
180 | color=color_style,
181 | marker=mark_style,
182 | markevery=mark_every,
183 | markersize=16,
184 | )
185 | return ax
186 |
187 |
188 | def build_legend(args, legend):
189 | legend = legend.split(",")
190 |
191 | my_legend = []
192 | for _legend in legend:
193 | _legend_content = args[_legend] if _legend in args else -1
194 | my_legend += [
195 | "{}={}".format(
196 | _legend,
197 | list(_legend_content)[0]
198 | if "pandas" in str(type(_legend_content))
199 | else _legend_content,
200 | )
201 | ]
202 | return ", ".join(my_legend)
203 |
--------------------------------------------------------------------------------
/monitor/tools/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from datetime import datetime
3 |
4 |
5 | def str2time(string, pattern):
6 | """convert the string to the datetime."""
7 | return datetime.strptime(string, pattern)
8 |
9 |
10 | def str2bool(v):
11 | if v.lower() in ("yes", "true", "t", "y", "1"):
12 | return True
13 | elif v.lower() in ("no", "false", "f", "n", "0"):
14 | return False
15 | else:
16 | raise ValueError("Boolean value expected.")
17 |
18 |
19 | def is_float(value):
20 | try:
21 | float(value)
22 | return True
23 | except:
24 | return False
25 |
26 |
27 | def dict_parser(values):
28 | local_dict = {}
29 | if values is None:
30 | return local_dict
31 | for kv in values.split(",,"):
32 | k, v = kv.split("=")
33 | try:
34 | local_dict[k] = float(v)
35 | except ValueError:
36 | try:
37 | local_dict[k] = str2bool(v)
38 | except ValueError:
39 | local_dict[k] = v
40 | return local_dict
41 |
--------------------------------------------------------------------------------
/notebooks/example.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": null,
6 | "metadata": {},
7 | "outputs": [],
8 | "source": [
9 | "# Useful starting lines \n",
10 | "%matplotlib inline\n",
11 | "import numpy as np\n",
12 | "import matplotlib.pyplot as plt\n",
13 | "from matplotlib.lines import Line2D\n",
14 | "%load_ext autoreload\n",
15 | "%autoreload 2"
16 | ]
17 | },
18 | {
19 | "cell_type": "code",
20 | "execution_count": null,
21 | "metadata": {},
22 | "outputs": [],
23 | "source": [
24 | "import os \n",
25 | "import sys\n",
26 | "import time\n",
27 | "import copy\n",
28 | "from copy import deepcopy\n",
29 | "import pickle\n",
30 | "import math\n",
31 | "import functools \n",
32 | "from IPython.display import display, HTML\n",
33 | "import operator\n",
34 | "from operator import itemgetter\n",
35 | "\n",
36 | "import pandas as pd\n",
37 | "import seaborn as sns\n",
38 | "from matplotlib.lines import Line2D\n",
39 | "\n",
40 | "sns.set(style=\"darkgrid\")\n",
41 | "sns.set_context(\"paper\")"
42 | ]
43 | },
44 | {
45 | "cell_type": "code",
46 | "execution_count": null,
47 | "metadata": {},
48 | "outputs": [],
49 | "source": [
50 | "root_path = '/home/ttab/sp-ttabed/codes/code'\n",
51 | "sys.path.append(root_path)"
52 | ]
53 | },
54 | {
55 | "cell_type": "code",
56 | "execution_count": null,
57 | "metadata": {},
58 | "outputs": [],
59 | "source": [
60 | "from monitor.tools.show_results import extract_list_of_records, reorder_records, get_pickle_info, summarize_info\n",
61 | "from monitor.tools.plot import plot_curve_wrt_time\n",
62 | "import monitor.tools.plot_utils as plot_utils\n",
63 | "\n",
64 | "from monitor.tools.utils import dict_parser\n",
65 | "from monitor.tools.file_io import load_pickle"
66 | ]
67 | },
68 | {
69 | "cell_type": "code",
70 | "execution_count": null,
71 | "metadata": {},
72 | "outputs": [],
73 | "source": [
74 | "root_data_path = os.path.join(root_path, 'data', 'logs', 'resnet26')\n",
75 | "experiments = ['cifar10_label_shift_episodic_oracle_model_selection']\n",
76 | "raw_records = get_pickle_info(root_data_path, experiments)"
77 | ]
78 | },
79 | {
80 | "cell_type": "code",
81 | "execution_count": null,
82 | "metadata": {},
83 | "outputs": [],
84 | "source": [
85 | "# Have a glimpse of experimental results.\n",
86 | "conditions = {\n",
87 | " \"model_adaptation_method\": [\"tent\"],\n",
88 | " \"seed\": [2022],\n",
89 | " \"batch_size\": [64],\n",
90 | " \"episodic\": [False],\n",
91 | " \"n_train_steps\": [1],\n",
92 | " # \"lr\": [0.005],\n",
93 | " # \"data_names\": [\"cifar10_c_deterministic-gaussian_noise-5\"],\n",
94 | "}\n",
95 | "attributes = ['model_adaptation_method', 'step_ratio', 'label_shift_param', 'ckpt_path', 'episodic', 'model_selection_method', 'seed', 'data_names', 'status']\n",
96 | "records = extract_list_of_records(list_of_records=raw_records, conditions=conditions)\n",
97 | "aggregated_results, averaged_records_overall = summarize_info(records, attributes, reorder_on='model_adaptation_method', groupby_on='test-overall-accuracy', larger_is_better=True)\n",
98 | "display(HTML(averaged_records_overall.to_html()))\n",
99 | "\n",
100 | "# display test accuracy per test step.\n",
101 | "aggregated_results, averaged_records_step_nonepisodic_optimal = summarize_info(records, attributes, reorder_on='model_adaptation_method', groupby_on='test-step-accuracy', larger_is_better=True)\n",
102 | "\n",
103 | "fig = plt.figure(num=1, figsize=(18, 9))\n",
104 | "ax1 = fig.add_subplot(111)\n",
105 | "plot_curve_wrt_time(\n",
106 | " ax1, records,\n",
107 | " x_wrt_sth='test-step-step', y_wrt_sth='test-step-accuracy', is_smooth=True,\n",
108 | " xlabel='batch index', ylabel='Test accuracy', l_subset=0.0, r_subset=1, markevery_list=None,\n",
109 | " n_by_line=4, has_legend=True, legend='model_selection_method,step_ratio', legend_loc='lower right', legend_ncol=1, bbox_to_anchor=[1, 0],\n",
110 | " ylimit_bottom=0, ylimit_top=100, use_log=False)\n",
111 | "fig.tight_layout()\n",
112 | "plt.show()"
113 | ]
114 | }
115 | ],
116 | "metadata": {
117 | "kernelspec": {
118 | "display_name": "test_algo",
119 | "language": "python",
120 | "name": "python3"
121 | },
122 | "language_info": {
123 | "codemirror_mode": {
124 | "name": "ipython",
125 | "version": 3
126 | },
127 | "file_extension": ".py",
128 | "mimetype": "text/x-python",
129 | "name": "python",
130 | "nbconvert_exporter": "python",
131 | "pygments_lexer": "ipython3",
132 | "version": "3.7.11"
133 | },
134 | "orig_nbformat": 4,
135 | "vscode": {
136 | "interpreter": {
137 | "hash": "dc7a203c487a4c1b41bd3d170020b3757b8af76b16b2c4bd8127396815ac049f"
138 | }
139 | }
140 | },
141 | "nbformat": 4,
142 | "nbformat_minor": 2
143 | }
144 |
--------------------------------------------------------------------------------
/parameters.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import argparse
3 |
4 |
5 | def get_args():
6 | parser = argparse.ArgumentParser()
7 | # define meta info (this part can be ignored if `monitor` package is unused).
8 | parser.add_argument("--job_name", default="tta", type=str)
9 | parser.add_argument("--job_id", default=None, type=str)
10 | parser.add_argument("--timestamp", default=None, type=str)
11 | parser.add_argument("--python_path", default="python", type=str)
12 | parser.add_argument("--main_file", default="run_exp.py", type=str)
13 |
14 | parser.add_argument("--script_path", default=None, type=str)
15 | parser.add_argument("--script_class_name", default=None, type=str)
16 | parser.add_argument("--num_jobs_per_node", default=2, type=int)
17 | parser.add_argument("--num_jobs_per_script", default=1, type=int)
18 | parser.add_argument("--wait_in_seconds_per_job", default=15, type=float)
19 |
20 | # define test evaluation info.
21 | parser.add_argument("--root_path", default="./logs", type=str)
22 | parser.add_argument("--data_path", default="./datasets", type=str)
23 | parser.add_argument(
24 | "--ckpt_path",
25 | default="./pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth",
26 | type=str,
27 | )
28 | parser.add_argument("--seed", default=2022, type=int)
29 | parser.add_argument("--device", default="cuda:0", type=str)
30 | parser.add_argument("--num_cpus", default=2, type=int)
31 |
32 | # define the task & model & adaptation & selection method.
33 | parser.add_argument("--model_name", default="resnet26", type=str)
34 | parser.add_argument("--group_norm_num_groups", default=None, type=int)
35 | parser.add_argument(
36 | "--model_adaptation_method",
37 | default="tent",
38 | choices=[
39 | "no_adaptation",
40 | "tent",
41 | "bn_adapt",
42 | "memo",
43 | "shot",
44 | "t3a",
45 | "ttt",
46 | "note",
47 | "sar",
48 | "conjugate_pl",
49 | "cotta",
50 | "eata",
51 | ],
52 | type=str,
53 | )
54 | parser.add_argument(
55 | "--model_selection_method",
56 | default="last_iterate",
57 | choices=["last_iterate", "oracle_model_selection"],
58 | type=str,
59 | )
60 | parser.add_argument("--task", default="classification", type=str)
61 |
62 | # define the test scenario.
63 | parser.add_argument("--test_scenario", default=None, type=str)
64 | parser.add_argument(
65 | "--base_data_name",
66 | default="cifar10",
67 | choices=[
68 | "cifar10",
69 | "cifar100",
70 | "imagenet",
71 | "officehome",
72 | "pacs",
73 | "coloredmnist",
74 | "waterbirds",
75 | "yearbook",
76 | ],
77 | type=str,
78 | )
79 | parser.add_argument("--src_data_name", default="cifar10", type=str)
80 | parser.add_argument(
81 | "--data_names", default="cifar10_c_deterministic-gaussian_noise-5", type=str
82 | )
83 | parser.add_argument(
84 | "--data_wise",
85 | default="batch_wise",
86 | choices=["batch_wise", "sample_wise"],
87 | type=str,
88 | )
89 | parser.add_argument("--batch_size", default=64, type=int)
90 | parser.add_argument("--lr", default=1e-3, type=float)
91 | parser.add_argument("--n_train_steps", default=1, type=int)
92 | parser.add_argument("--offline_pre_adapt", default=False, type=str2bool)
93 | parser.add_argument("--episodic", default=False, type=str2bool)
94 | parser.add_argument("--intra_domain_shuffle", default=True, type=str2bool)
95 | parser.add_argument(
96 | "--inter_domain",
97 | default="HomogeneousNoMixture",
98 | choices=[
99 | "HomogeneousNoMixture",
100 | "HeterogeneousNoMixture",
101 | "InOutMixture",
102 | "CrossMixture",
103 | ],
104 | type=str,
105 | )
106 | # Test domain
107 | parser.add_argument("--domain_sampling_name", default="uniform", type=str)
108 | parser.add_argument("--domain_sampling_ratio", default=1.0, type=float)
109 | # HeterogeneousNoMixture
110 | parser.add_argument("--non_iid_pattern", default="class_wise_over_domain", type=str)
111 | parser.add_argument("--non_iid_ness", default=0.1, type=float)
112 | # for evaluation.
113 | # label shift
114 | parser.add_argument(
115 | "--label_shift_param",
116 | help="parameter to control the severity of label shift",
117 | default=None,
118 | type=float,
119 | )
120 | parser.add_argument(
121 | "--data_size",
122 | help="parameter to control the size of dataset",
123 | default=None,
124 | type=int,
125 | )
126 | # optimal model selection
127 | parser.add_argument(
128 | "--step_ratios",
129 | nargs="+",
130 | default=[0.1, 0.3, 0.5, 0.75],
131 | help="ratios used to control adaptation step length",
132 | type=float,
133 | )
134 | parser.add_argument("--step_ratio", default=None, type=float)
135 | # time-varying
136 | parser.add_argument("--stochastic_restore_model", default=False, type=str2bool)
137 | parser.add_argument("--restore_prob", default=0.01, type=float)
138 | parser.add_argument("--fishers", default=False, type=str2bool)
139 | parser.add_argument(
140 | "--fisher_size",
141 | default=5000,
142 | type=int,
143 | help="number of samples to compute fisher information matrix.",
144 | )
145 | parser.add_argument(
146 | "--fisher_alpha",
147 | type=float,
148 | default=1.5,
149 | help="the trade-off between entropy and regularization loss",
150 | )
151 | # method-wise hparams
152 | parser.add_argument(
153 | "--aug_size",
154 | default=32,
155 | help="number of per-image augmentation operations in memo and ttt",
156 | type=int,
157 | )
158 | parser.add_argument(
159 | "--entry_of_shared_layers",
160 | default=None,
161 | help="the split position of auxiliary head. Only used in TTT.",
162 | )
163 | # metrics
164 | parser.add_argument(
165 | "--record_preadapted_perf",
166 | default=False,
167 | help="record performance on the local batch prior to implementing test-time adaptation.",
168 | type=str2bool,
169 | )
170 | # misc
171 | parser.add_argument(
172 | "--grad_checkpoint",
173 | default=False,
174 | help="Trade computation for gpu space.",
175 | type=str2bool,
176 | )
177 | parser.add_argument("--debug", default=False, help="Display logs.", type=str2bool)
178 |
179 | # parse conf.
180 | conf = parser.parse_args()
181 | return conf
182 |
183 |
184 | def str2bool(v):
185 | if v.lower() in ("yes", "true", "t", "y", "1"):
186 | return True
187 | elif v.lower() in ("no", "false", "f", "n", "0"):
188 | return False
189 | else:
190 | raise ValueError("Boolean value expected.")
191 |
192 |
193 | if __name__ == "__main__":
194 | args = get_args()
195 |
--------------------------------------------------------------------------------
/pretrain/third_party/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/pretrain/third_party/__init__.py
--------------------------------------------------------------------------------
/pretrain/third_party/augmentations.py:
--------------------------------------------------------------------------------
1 | # https://github.com/google-research/augmix/blob/master/augmentations.py
2 | """Base augmentations operators."""
3 |
4 | import numpy as np
5 | from PIL import Image, ImageOps, ImageEnhance
6 |
7 | # ImageNet code should change this value
8 | IMAGE_SIZE = 32
9 |
10 |
11 | def int_parameter(level, maxval):
12 | """Helper function to scale `val` between 0 and maxval .
13 |
14 | Args:
15 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
16 | maxval: Maximum value that the operation can have. This will be scaled to
17 | level/PARAMETER_MAX.
18 |
19 | Returns:
20 | An int that results from scaling `maxval` according to `level`.
21 | """
22 | return int(level * maxval / 10)
23 |
24 |
25 | def float_parameter(level, maxval):
26 | """Helper function to scale `val` between 0 and maxval.
27 |
28 | Args:
29 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
30 | maxval: Maximum value that the operation can have. This will be scaled to
31 | level/PARAMETER_MAX.
32 |
33 | Returns:
34 | A float that results from scaling `maxval` according to `level`.
35 | """
36 | return float(level) * maxval / 10.0
37 |
38 |
39 | def sample_level(n):
40 | return np.random.uniform(low=0.1, high=n)
41 |
42 |
43 | def autocontrast(pil_img, _):
44 | return ImageOps.autocontrast(pil_img)
45 |
46 |
47 | def equalize(pil_img, _):
48 | return ImageOps.equalize(pil_img)
49 |
50 |
51 | def posterize(pil_img, level):
52 | level = int_parameter(sample_level(level), 4)
53 | return ImageOps.posterize(pil_img, 4 - level)
54 |
55 |
56 | def rotate(pil_img, level):
57 | degrees = int_parameter(sample_level(level), 30)
58 | if np.random.uniform() > 0.5:
59 | degrees = -degrees
60 | return pil_img.rotate(degrees, resample=Image.BILINEAR)
61 |
62 |
63 | def solarize(pil_img, level):
64 | level = int_parameter(sample_level(level), 256)
65 | return ImageOps.solarize(pil_img, 256 - level)
66 |
67 |
68 | def shear_x(pil_img, level):
69 | level = float_parameter(sample_level(level), 0.3)
70 | if np.random.uniform() > 0.5:
71 | level = -level
72 | return pil_img.transform(
73 | (IMAGE_SIZE, IMAGE_SIZE),
74 | Image.AFFINE,
75 | (1, level, 0, 0, 1, 0),
76 | resample=Image.BILINEAR,
77 | )
78 |
79 |
80 | def shear_y(pil_img, level):
81 | level = float_parameter(sample_level(level), 0.3)
82 | if np.random.uniform() > 0.5:
83 | level = -level
84 | return pil_img.transform(
85 | (IMAGE_SIZE, IMAGE_SIZE),
86 | Image.AFFINE,
87 | (1, 0, 0, level, 1, 0),
88 | resample=Image.BILINEAR,
89 | )
90 |
91 |
92 | def translate_x(pil_img, level):
93 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
94 | if np.random.random() > 0.5:
95 | level = -level
96 | return pil_img.transform(
97 | (IMAGE_SIZE, IMAGE_SIZE),
98 | Image.AFFINE,
99 | (1, 0, level, 0, 1, 0),
100 | resample=Image.BILINEAR,
101 | )
102 |
103 |
104 | def translate_y(pil_img, level):
105 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3)
106 | if np.random.random() > 0.5:
107 | level = -level
108 | return pil_img.transform(
109 | (IMAGE_SIZE, IMAGE_SIZE),
110 | Image.AFFINE,
111 | (1, 0, 0, 0, 1, level),
112 | resample=Image.BILINEAR,
113 | )
114 |
115 |
116 | # operation that overlaps with ImageNet-C's test set
117 | def color(pil_img, level):
118 | level = float_parameter(sample_level(level), 1.8) + 0.1
119 | return ImageEnhance.Color(pil_img).enhance(level)
120 |
121 |
122 | # operation that overlaps with ImageNet-C's test set
123 | def contrast(pil_img, level):
124 | level = float_parameter(sample_level(level), 1.8) + 0.1
125 | return ImageEnhance.Contrast(pil_img).enhance(level)
126 |
127 |
128 | # operation that overlaps with ImageNet-C's test set
129 | def brightness(pil_img, level):
130 | level = float_parameter(sample_level(level), 1.8) + 0.1
131 | return ImageEnhance.Brightness(pil_img).enhance(level)
132 |
133 |
134 | # operation that overlaps with ImageNet-C's test set
135 | def sharpness(pil_img, level):
136 | level = float_parameter(sample_level(level), 1.8) + 0.1
137 | return ImageEnhance.Sharpness(pil_img).enhance(level)
138 |
139 |
140 | augmentations = [
141 | autocontrast,
142 | equalize,
143 | posterize,
144 | rotate,
145 | solarize,
146 | shear_x,
147 | shear_y,
148 | translate_x,
149 | translate_y,
150 | ]
151 |
152 | augmentations_all = [
153 | autocontrast,
154 | equalize,
155 | posterize,
156 | rotate,
157 | solarize,
158 | shear_x,
159 | shear_y,
160 | translate_x,
161 | translate_y,
162 | color,
163 | contrast,
164 | brightness,
165 | sharpness,
166 | ]
167 |
--------------------------------------------------------------------------------
/pretrain/third_party/datasets.py:
--------------------------------------------------------------------------------
1 | # A collection of datasets class used in pretraining models.
2 |
3 | # imports.
4 | import os
5 | import functools
6 | import torch
7 | import sys
8 | sys.path.append("..")
9 |
10 | from ttab.loads.datasets.dataset_shifts import NoShiftedData, SyntheticShiftedData
11 | from ttab.loads.datasets.mnist import ColoredSyntheticShift
12 | from ttab.loads.datasets.loaders import BaseLoader
13 | from ttab.loads.datasets.datasets import OfficeHomeDataset, PACSDataset, CIFARDataset, WBirdsDataset, ColoredMNIST
14 |
15 | def get_train_dataset(config) -> BaseLoader:
16 | """Get the training dataset from `config`."""
17 | data_shift_class = functools.partial(NoShiftedData, data_name=config.data_name)
18 | if "cifar" in config.data_name:
19 | train_dataset = CIFARDataset(
20 | root=os.path.join(config.data_path, config.data_name),
21 | data_name=config.data_name,
22 | split="train",
23 | device=config.device,
24 | data_augment=True,
25 | data_shift_class=data_shift_class,
26 | )
27 | val_dataset = CIFARDataset(
28 | root=os.path.join(config.data_path, config.data_name),
29 | data_name=config.data_name,
30 | split="test",
31 | device=config.device,
32 | data_augment=False,
33 | data_shift_class=data_shift_class,
34 | )
35 | elif "officehome" in config.data_name:
36 | _data_names = config.data_name.split("_")
37 | dataset = OfficeHomeDataset(
38 | root=os.path.join(config.data_path, _data_names[0], _data_names[1]),
39 | device=config.device,
40 | data_augment=True,
41 | data_shift_class=data_shift_class,
42 | ).split_data(fractions=[0.9, 0.1], augment=[True, False], seed=config.seed)
43 | train_dataset, val_dataset = dataset[0], dataset[1]
44 | elif "pacs" in config.data_name:
45 | _data_names = config.data_name.split("_")
46 | dataset = PACSDataset(
47 | root=os.path.join(config.data_path, _data_names[0], _data_names[1]),
48 | device=config.device,
49 | data_augment=True,
50 | data_shift_class=data_shift_class,
51 | ).split_data(fractions=[0.9, 0.1], augment=[True, False], seed=config.seed)
52 | train_dataset, val_dataset = dataset[0], dataset[1]
53 | elif config.data_name == "waterbirds":
54 | train_dataset = WBirdsDataset(
55 | root=os.path.join(config.data_path, config.data_name),
56 | split="train",
57 | device=config.device,
58 | data_augment=True,
59 | data_shift_class=data_shift_class,
60 | )
61 | val_dataset = WBirdsDataset(
62 | root=os.path.join(config.data_path, config.data_name),
63 | split="val",
64 | device=config.device,
65 | data_augment=False,
66 | )
67 | elif config.data_name == "coloredmnist":
68 | data_shift_class = functools.partial(
69 | SyntheticShiftedData,
70 | data_name=config.data_name,
71 | seed=config.seed,
72 | synthetic_class=ColoredSyntheticShift(
73 | data_name=config.data_name, seed=config.seed
74 | ),
75 | version="stochastic",
76 | )
77 | train_dataset = ColoredMNIST(
78 | root=os.path.join(config.data_path, "mnist"),
79 | data_name=config.data_name,
80 | split="train",
81 | device=config.device,
82 | data_shift_class=data_shift_class,
83 | )
84 | val_dataset = ColoredMNIST(
85 | root=os.path.join(config.data_path, "mnist"),
86 | data_name=config.data_name,
87 | split="val",
88 | device=config.device,
89 | data_shift_class=data_shift_class,
90 | )
91 | else:
92 | raise RuntimeError(f"Unknown dataset: {config.data_name}")
93 |
94 | return BaseLoader(train_dataset), BaseLoader(val_dataset)
95 |
96 |
97 | class AugMixDataset(torch.utils.data.Dataset):
98 | """Dataset wrapper to perform AugMix augmentation."""
99 |
100 | def __init__(self, dataset, preprocess, aug, no_jsd=False):
101 | self.dataset = dataset
102 | self.preprocess = preprocess
103 | self.no_jsd = no_jsd
104 | self.aug = aug
105 |
106 | def __getitem__(self, i):
107 | x, y = self.dataset[i]
108 | if self.no_jsd:
109 | return self.aug(x, self.preprocess), y
110 | else:
111 | im_tuple = (
112 | self.preprocess(x),
113 | self.aug(x, self.preprocess),
114 | self.aug(x, self.preprocess),
115 | )
116 | return im_tuple, y
117 |
118 | def __len__(self):
119 | return len(self.dataset)
120 |
--------------------------------------------------------------------------------
/pretrain/third_party/utils.py:
--------------------------------------------------------------------------------
1 | # utils functions.
2 | import copy
3 | import numpy as np
4 | import sys
5 | sys.path.append("..")
6 |
7 | import torch
8 | import torch.nn as nn
9 | import torch.nn.functional as F
10 | import torchvision.models as models
11 |
12 | import ttab.model_adaptation.utils as adaptation_utils
13 | from ttab.loads.models import resnet, WideResNet
14 | from ttab.configs.datasets import dataset_defaults
15 |
16 |
17 | def convert_iabn(module: nn.Module, config, **kwargs) -> nn.Module:
18 | module_output = module
19 | if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d):
20 | IABN = (
21 | adaptation_utils.InstanceAwareBatchNorm2d
22 | if isinstance(module, nn.BatchNorm2d)
23 | else adaptation_utils.InstanceAwareBatchNorm1d
24 | )
25 | module_output = IABN(
26 | num_channels=module.num_features,
27 | k=config.iabn_k,
28 | eps=module.eps,
29 | momentum=module.momentum,
30 | threshold=config.threshold_note,
31 | affine=module.affine,
32 | )
33 |
34 | module_output._bn = copy.deepcopy(module)
35 |
36 | for name, child in module.named_children():
37 | module_output.add_module(name, convert_iabn(child, **kwargs))
38 | del module
39 | return module_output
40 |
41 | def update_pytorch_model_key(state_dict: dict) -> dict:
42 | """This function is used to modify the state dict key of pretrained model from pytorch."""
43 | new_state_dict = {}
44 | for key, value in state_dict.items():
45 | if "downsample" in key:
46 | name_split = key.split(".")
47 | if name_split[-2] == "0":
48 | name_split[-2] = "conv"
49 | new_key = ".".join(name_split)
50 | new_state_dict[new_key] = value
51 | elif name_split[-2] == "1":
52 | name_split[-2] = "bn"
53 | new_key = ".".join(name_split)
54 | new_state_dict[new_key] = value
55 | elif "fc" in key:
56 | name_split = key.split(".")
57 | if name_split[0] == "fc":
58 | name_split[0] = "classifier"
59 | new_key = ".".join(name_split)
60 | new_state_dict[new_key] = value
61 | else:
62 | new_state_dict[key] = value
63 |
64 | return new_state_dict
65 |
66 | def build_model(config) -> nn.Module:
67 | """Build model from `config`"""
68 | num_classes = dataset_defaults[config.base_data_name]["statistics"]["n_classes"]
69 | if config.base_data_name in ["officehome", "pacs", "waterbirds"]:
70 | pretrained_model = models.resnet50(pretrained=True)
71 | pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, out_features=num_classes, bias=False)
72 | pretrained_model.fc.weight.data.normal_(mean=0, std=0.01)
73 |
74 | model = resnet(config.base_data_name, depth=50, split_point=config.entry_of_shared_layers, grad_checkpoint=True).to(config.device)
75 | model_dict = model.state_dict()
76 | pretrained_dict = pretrained_model.state_dict()
77 | new_pretrained_dict = update_pytorch_model_key(pretrained_dict)
78 | new_pretrained_dict = {k: v for k, v in new_pretrained_dict.items() if k in model_dict}
79 | model.load_state_dict(new_pretrained_dict)
80 | model.to(config.device)
81 | del pretrained_dict, pretrained_model, new_pretrained_dict
82 | if config.use_iabn:
83 | model = convert_iabn(model, config)
84 | else:
85 | if "wideresnet" in config.model_name:
86 | components = config.model_name.split("_")
87 | depth = int(components[0].replace("wideresnet", ""))
88 | widen_factor = int(components[1])
89 |
90 | model = WideResNet(
91 | depth,
92 | widen_factor,
93 | num_classes,
94 | split_point=config.entry_of_shared_layers,
95 | dropout_rate=0.3,
96 | )
97 | elif "resnet" in config.model_name:
98 | depth = int(config.model_name.replace("resnet", ""))
99 | model = resnet(
100 | config.base_data_name,
101 | depth,
102 | split_point=config.entry_of_shared_layers,
103 | group_norm_num_groups=config.group_norm,
104 | ).to(config.device)
105 | if config.use_iabn:
106 | assert config.group_norm is None, "IABN cannot be used with group norm."
107 | model = convert_iabn(model, config)
108 | return model
109 |
110 | def get_train_params(model: nn.Module, config) -> list:
111 | """Define the trainable parameters for a model using `config`"""
112 | if config.base_data_name in ["officehome", "pacs"]:
113 | params = []
114 | learning_rate = config.lr
115 |
116 | for name_module, module in model.main_model.named_children():
117 | if name_module != "classifier":
118 | for _, param in module.named_parameters():
119 | params += [{"params": param, "lr": learning_rate*0.1}]
120 | else:
121 | for _, param in module.named_parameters():
122 | params += [{"params": param, "lr": learning_rate}]
123 |
124 | for name_module, module in model.ssh.head.named_children():
125 | if isinstance(module, nn.Linear):
126 | for _, param in module.named_parameters():
127 | params += [{"params": param, "lr": learning_rate}]
128 | else:
129 | for _, param in module.named_parameters():
130 | params += [{"params": param, "lr": learning_rate*0.1}]
131 | else:
132 | params = list(model.main_model.parameters()) + list(model.ssh.head.parameters())
133 | return params
134 |
135 | def get_lr(step, total_steps, lr_max, lr_min):
136 | """Compute learning rate according to cosine annealing schedule."""
137 | return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))
138 |
139 |
140 | def mixup_data(x, y, alpha=1.0, device=None):
141 | """Returns mixed inputs, pairs of targets, and lambda"""
142 | if alpha > 0:
143 | lam = np.random.beta(alpha, alpha)
144 | else:
145 | lam = 1
146 |
147 | batch_size = x.size()[0]
148 | if device is None:
149 | index = torch.randperm(batch_size)
150 | else:
151 | index = torch.randperm(batch_size).to(device)
152 |
153 | mixed_x = lam * x + (1 - lam) * x[index, :]
154 | y_a, y_b = y, y[index]
155 | return mixed_x, y_a, y_b, lam
156 |
157 |
158 | def mixup_criterion(criterion, pred, y_a, y_b, lam):
159 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)
160 |
--------------------------------------------------------------------------------
/run_exp.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import parameters
3 | import ttab.configs.utils as configs_utils
4 | import ttab.loads.define_dataset as define_dataset
5 | from ttab.benchmark import Benchmark
6 | from ttab.loads.define_model import define_model, load_pretrained_model
7 | from ttab.model_adaptation import get_model_adaptation_method
8 | from ttab.model_selection import get_model_selection_method
9 |
10 |
11 | def main(init_config):
12 | # Required auguments.
13 | config, scenario = configs_utils.config_hparams(config=init_config)
14 |
15 | test_data_cls = define_dataset.ConstructTestDataset(config=config)
16 | test_loader = test_data_cls.construct_test_loader(scenario=scenario)
17 |
18 | # Model.
19 | model = define_model(config=config)
20 | load_pretrained_model(config=config, model=model)
21 |
22 | # Algorithms.
23 | model_adaptation_cls = get_model_adaptation_method(
24 | adaptation_name=scenario.model_adaptation_method
25 | )(meta_conf=config, model=model)
26 | model_selection_cls = get_model_selection_method(selection_name=scenario.model_selection_method)(
27 | meta_conf=config, model_adaptation_method=model_adaptation_cls
28 | )
29 |
30 | # Evaluate.
31 | benchmark = Benchmark(
32 | scenario=scenario,
33 | model_adaptation_cls=model_adaptation_cls,
34 | model_selection_cls=model_selection_cls,
35 | test_loader=test_loader,
36 | meta_conf=config,
37 | )
38 | benchmark.eval()
39 |
40 |
41 | if __name__ == "__main__":
42 | conf = parameters.get_args()
43 | main(init_config=conf)
44 |
--------------------------------------------------------------------------------
/run_extract.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import argparse
4 |
5 | import monitor.tools.file_io as file_io
6 | from monitor.tools.show_results import load_raw_info_from_experiments
7 | from parameters import str2bool
8 |
9 | """parse and define arguments for different tasks."""
10 |
11 |
12 | def get_args():
13 | # feed them to the parser.
14 | parser = argparse.ArgumentParser(description="Extract results.")
15 |
16 | # add arguments.
17 | parser.add_argument("--in_dir", type=str)
18 | parser.add_argument("--out_name", type=str, default="summary.pickle")
19 |
20 | # parse aˇˇrgs.
21 | args = parser.parse_args()
22 |
23 | # an argument safety check.
24 | check_args(args)
25 | return args
26 |
27 |
28 | def check_args(args):
29 | assert args.in_dir is not None
30 |
31 | # define out path.
32 | args.out_path = os.path.join(args.in_dir, args.out_name)
33 |
34 |
35 | """write the results to path."""
36 |
37 |
38 | def main(args):
39 | # save the parsed results to path.
40 | file_io.write_pickle(
41 | load_raw_info_from_experiments(args.in_dir),
42 | args.out_path,
43 | )
44 |
45 |
46 | if __name__ == "__main__":
47 | args = get_args()
48 |
49 | main(args)
50 |
--------------------------------------------------------------------------------
/ttab/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/__init__.py
--------------------------------------------------------------------------------
/ttab/api.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import itertools
3 | from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union
4 |
5 | import numpy as np
6 | import torch
7 | from torch.utils.data.dataset import random_split
8 |
9 | State = List[torch.Tensor]
10 | Gradient = List[torch.Tensor]
11 | Parameters = List[torch.Tensor]
12 | Loss = float
13 | Quality = Mapping[str, float]
14 |
15 |
16 | class Batch(object):
17 | def __init__(self, x, y):
18 | self._x = x
19 | self._y = y
20 |
21 | def __len__(self) -> int:
22 | return len(self._x)
23 |
24 | def to(self, device) -> "Batch":
25 | return Batch(self._x.to(device), self._y.to(device))
26 |
27 | def __getitem__(self, index):
28 | return self._x[index], self._y[index]
29 |
30 |
31 | class GroupBatch(object):
32 | def __init__(self, x, y, g):
33 | self._x = x
34 | self._y = y
35 | self._g = g
36 |
37 | def __len__(self) -> int:
38 | return len(self._x)
39 |
40 | def to(self, device) -> "Batch":
41 | return GroupBatch(self._x.to(device), self._y.to(device), self._g.to(device))
42 |
43 | def __getitem__(self, index):
44 | return self._x[index], self._y[index], self._g[index]
45 |
46 |
47 | class Dataset:
48 | def random_split(self, fractions: List[float]) -> List["Dataset"]:
49 | pass
50 |
51 | def iterator(
52 | self, batch_size: int, shuffle: bool, repeat=True
53 | ) -> Iterable[Tuple[float, Batch]]:
54 | pass
55 |
56 | def __len__(self) -> int:
57 | pass
58 |
59 |
60 | class PyTorchDataset(object):
61 | def __init__(
62 | self,
63 | dataset: torch.utils.data.Dataset,
64 | device: str,
65 | prepare_batch: Callable,
66 | num_classes: int,
67 | ):
68 | self._set = dataset
69 | self._device = device
70 | self._prepare_batch = prepare_batch
71 | self._num_classes = num_classes
72 |
73 | def __len__(self):
74 | return len(self._set)
75 |
76 | def replace_indices(
77 | self,
78 | indices_pattern: str = "original",
79 | new_indices: List[int] = None,
80 | random_seed: int = None,
81 | ) -> None:
82 | """Change the order of dataset indices in a particular pattern."""
83 | if indices_pattern == "original":
84 | pass
85 | elif indices_pattern == "random_shuffle":
86 | rng = np.random.default_rng(random_seed)
87 | rng.shuffle(self.dataset.indices)
88 | elif indices_pattern == "new":
89 | if new_indices is None:
90 | raise ValueError("new_indices should be specified.")
91 | self.dataset.update_indices(new_indices=new_indices)
92 | else:
93 | raise NotImplementedError
94 |
95 | def query_dataset_attr(self, attr_name: str) -> Any:
96 | return getattr(self._set, attr_name, None)
97 |
98 | @property
99 | def dataset(self):
100 | return self._set
101 |
102 | @property
103 | def num_classes(self):
104 | return self._num_classes
105 |
106 | def no_split(self) -> List[Dataset]:
107 | return [
108 | PyTorchDataset(
109 | dataset=self._set,
110 | device=self._device,
111 | prepare_batch=self._prepare_batch,
112 | num_classes=self._num_classes,
113 | )
114 | ]
115 |
116 | def random_split(self, fractions: List[float], seed: int = 0) -> List[Dataset]:
117 | lengths = [int(f * len(self._set)) for f in fractions]
118 | lengths[0] += len(self._set) - sum(lengths)
119 | return [
120 | PyTorchDataset(
121 | dataset=split,
122 | device=self._device,
123 | prepare_batch=self._prepare_batch,
124 | num_classes=self._num_classes,
125 | )
126 | for split in random_split(
127 | self._set, lengths, torch.Generator().manual_seed(seed)
128 | )
129 | ]
130 |
131 | def iterator(
132 | self,
133 | batch_size: int,
134 | shuffle: bool = True,
135 | repeat: bool = False,
136 | ref_num_data: Optional[int] = None,
137 | num_workers: int = 1,
138 | sampler: Optional[torch.utils.data.Sampler] = None,
139 | generator: Optional[torch.Generator] = None,
140 | pin_memory: bool = True,
141 | drop_last: bool = True,
142 | ) -> Iterable[Tuple[int, float, Batch]]:
143 | _num_batch = 1 if not drop_last else 0
144 | if ref_num_data is None:
145 | num_batches = int(len(self) / batch_size + _num_batch)
146 | else:
147 | num_batches = int(ref_num_data / batch_size + _num_batch)
148 | if sampler is not None:
149 | shuffle = False
150 |
151 | loader = torch.utils.data.DataLoader(
152 | self._set,
153 | batch_size=batch_size,
154 | shuffle=shuffle,
155 | pin_memory=pin_memory,
156 | drop_last=drop_last,
157 | num_workers=num_workers,
158 | sampler=sampler,
159 | generator=generator,
160 | )
161 |
162 | step = 0
163 | for _ in itertools.count() if repeat else [0]:
164 | for i, batch in enumerate(loader):
165 | step += 1
166 | epoch_fractional = float(step) / num_batches
167 | yield step, epoch_fractional, self._prepare_batch(batch, self._device)
168 |
169 | def record_class_distribution(
170 | self,
171 | targets: Union[List, np.ndarray],
172 | indices: Union[List, np.ndarray],
173 | print_fn: Callable = print,
174 | is_train: bool = True,
175 | display: bool = True,
176 | ):
177 | targets_np = np.array(targets)
178 | unique_elements, counts_elements = np.unique(
179 | targets_np[indices] if indices is not None else targets_np,
180 | return_counts=True,
181 | )
182 | element_counts = list(zip(unique_elements, counts_elements))
183 |
184 | if display:
185 | print_fn(
186 | f"\tThe histogram of the targets in {'train' if is_train else 'test'}: {element_counts}"
187 | )
188 | return element_counts
189 |
--------------------------------------------------------------------------------
/ttab/configs/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/configs/__init__.py
--------------------------------------------------------------------------------
/ttab/configs/algorithms.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # 1. This file collects significant hyperparameters for the configuration of TTA methods.
4 | # 2. We are only concerned about method-related hyperparameters here.
5 | # 3. We provide default hyperparameters from the paper or official repo if users have no idea how to set up reasonable values.
6 | import math
7 |
8 | algorithm_defaults = {
9 | "no_adaptation": {"model_selection_method": "last_iterate"},
10 | #
11 | "bn_adapt": {
12 | "adapt_prior": 0, # the ratio of training set statistics.
13 | },
14 | "shot": {
15 | "optimizer": "SGD", # Adam for officehome
16 | "auxiliary_batch_size": 32,
17 | "threshold_shot": 0.9, # confidence threshold for online shot.
18 | "ent_par": 1.0,
19 | "cls_par": 0.3, # 0.1 for officehome.
20 | "offline_nepoch": 10,
21 | },
22 | "ttt": {
23 | "optimizer": "SGD",
24 | "entry_of_shared_layers": "layer2",
25 | "aug_size": 32,
26 | "threshold_ttt": 1.0,
27 | "dim_out": 4, # For rotation prediction self-supervision task.
28 | "rotation_type": "rand",
29 | },
30 | "tent": {
31 | "optimizer": "SGD",
32 | },
33 | "t3a": {"top_M": 100},
34 | "cotta": {
35 | "optimizer": "SGD",
36 | "alpha_teacher": 0.999, # weight of moving average for updating the teacher model.
37 | "restore_prob": 0.01, # the probability of restoring model parameters.
38 | "threshold_cotta": 0.92, # Threshold choice discussed in supplementary
39 | "aug_size": 32,
40 | },
41 | "eata": {
42 | "optimizer": "SGD",
43 | "eata_margin_e0": math.log(1000)
44 | * 0.40, # The threshold for reliable minimization in EATA.
45 | "eata_margin_d0": 0.05, # for filtering redundant samples.
46 | "fishers": True, # whether to use fisher regularizer.
47 | "fisher_size": 2000, # number of samples to compute fisher information matrix.
48 | "fisher_alpha": 50, # the trade-off between entropy and regularization loss.
49 | },
50 | "memo": {
51 | "optimizer": "SGD",
52 | "aug_size": 32,
53 | "bn_prior_strength": 16,
54 | },
55 | # "ttt_plus_plus": {
56 | # "optimizer": "SGD",
57 | # "entry_of_shared_layers": None,
58 | # "batch_size_align": 256,
59 | # "queue_size": 256,
60 | # "offline_nepoch": 500,
61 | # "bnepoch": 2, # first few epochs to update bn stat.
62 | # "delayepoch": 0, # In first few epochs after bnepoch, we dont do both ssl and align (only ssl actually).
63 | # "stopepoch": 25,
64 | # "scale_ext": 0.5,
65 | # "scale_ssh": 0.2,
66 | # "align_ext": True,
67 | # "align_ssh": True,
68 | # "fix_ssh": False,
69 | # "method": "align", # choices = ['ssl', 'align', 'both']
70 | # "divergence": "all", # choices = ['all', 'coral', 'mmd']
71 | # },
72 | "note": {
73 | "optimizer": "SGD", # use Adam in the paper
74 | "memory_size": 64,
75 | "update_every_x": 64, # This param may change in our codebase.
76 | "memory_type": "PBRS",
77 | "bn_momentum": 0.01,
78 | "temperature": 1.0,
79 | "iabn": False, # replace bn with iabn layer
80 | "iabn_k": 4,
81 | "threshold_note": 1, # skip threshold to discard adjustment.
82 | "use_learned_stats": True,
83 | },
84 | "conjugate_pl": {
85 | "optimizer": "SGD",
86 | "temperature_scaling": 1.0,
87 | "model_eps": 0.0, # this should be added for Polyloss model.
88 | },
89 | "sar": {
90 | "optimizer": "SGD",
91 | "sar_margin_e0": math.log(1000)
92 | * 0.40, # The threshold for reliable minimization in SAR.
93 | "reset_constant_em": 0.2, # threshold e_m for model recovery scheme
94 | },
95 | "rotta":{
96 | "optimizer": "Adam",
97 | "nu": 0.001,
98 | "memory_size": 64,
99 | "update_frequency": 64,
100 | "lambda_t": 1.0,
101 | "lambda_u": 1.0,
102 | "alpha": 0.05,
103 | }
104 | }
105 |
--------------------------------------------------------------------------------
/ttab/configs/datasets.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # 1. This file collects hyperparameters significant for pretraining and test-time adaptation.
4 | # 2. We are only concerned about dataset-related hyperparameters here, e.g., lr, dataset statistics, and type of corruptions.
5 | # 3. We provide default hyperparameters if users have no idea how to set up reasonable values.
6 |
7 | dataset_defaults = {
8 | "cifar10": {
9 | "statistics": {
10 | "mean": (0.4914, 0.4822, 0.4465),
11 | "std": (0.2023, 0.1994, 0.2010),
12 | "n_classes": 10,
13 | },
14 | "version": "deterministic",
15 | "img_shape": (32, 32, 3),
16 | },
17 | "cifar100": {
18 | "statistics": {
19 | "mean": (0.5070751592371323, 0.48654887331495095, 0.4409178433670343),
20 | "std": (0.2673342858792401, 0.2564384629170883, 0.27615047132568404),
21 | "n_classes": 100,
22 | },
23 | "version": "deterministic",
24 | "img_shape": (32, 32, 3),
25 | },
26 | "officehome": {
27 | "statistics": {
28 | "mean": (0.485, 0.456, 0.406),
29 | "std": (0.229, 0.224, 0.225),
30 | "n_classes": 65,
31 | },
32 | "img_shape": (224, 224, 3),
33 | },
34 | "pacs": {
35 | "statistics": {
36 | "mean": (0.485, 0.456, 0.406),
37 | "std": (0.229, 0.224, 0.225),
38 | "n_classes": 7,
39 | },
40 | "img_shape": (224, 224, 3),
41 | },
42 | "coloredmnist": {
43 | "statistics": {
44 | "mean": (0.1307, 0.1307, 0.0),
45 | "std": (0.3081, 0.3081, 0.3081),
46 | "n_classes": 2,
47 | },
48 | "img_shape": (28, 28, 3),
49 | },
50 | "waterbirds": {
51 | "statistics": {
52 | "mean": (0.485, 0.456, 0.406),
53 | "std": (0.229, 0.224, 0.225),
54 | "n_classes": 2,
55 | },
56 | "group_counts": [3498, 184, 56, 1057], # used to compute group ratio.
57 | "img_shape": (224, 224, 3),
58 | },
59 | "imagenet": {
60 | "statistics": {
61 | "mean": (0.485, 0.456, 0.406),
62 | "std": (0.229, 0.224, 0.225),
63 | "n_classes": 1000,
64 | },
65 | "img_shape": (224, 224, 3),
66 | },
67 | "yearbook": {
68 | "statistics": {"n_classes": 2,},
69 | "img_shape": (32, 32, 3),
70 | }
71 | }
72 |
--------------------------------------------------------------------------------
/ttab/configs/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | # Functions in this file are used to deal with any work related with the configuration of datasets, models and algorithms.
4 | import ttab.scenarios.define_scenario as define_scenario
5 | from ttab.configs.algorithms import algorithm_defaults
6 | from ttab.configs.datasets import dataset_defaults
7 |
8 |
9 | def config_hparams(config):
10 | """
11 | Populates hyperparameters with defaults implied by choices of other hyperparameters.
12 |
13 | Args:
14 | - config: namespace
15 | Returns:
16 | - config: namespace
17 | - scenario: NamedTuple
18 | """
19 | # prior safety check
20 | assert (
21 | config.model_adaptation_method is not None
22 | ), "model adaptation method must be specified"
23 |
24 | assert (
25 | config.base_data_name is not None
26 | ), "base_data_name must be specified, either from default scenario, or from user-provided inputs."
27 |
28 | # register default arguments from default scenarios if specified by the user.
29 | scenario = define_scenario.get_scenario(config)
30 | config = define_scenario.scenario_registry(config, scenario)
31 |
32 | # register default arguments based on data_name.
33 | config = defaults_registry(config, template=dataset_defaults[config.base_data_name])
34 |
35 | # register default arguments based on model_adaptation_method
36 | config = defaults_registry(
37 | config, template=algorithm_defaults[config.model_adaptation_method]
38 | )
39 |
40 | # TODO: register default arguments for different kinds of base models
41 | # add default path for provided ckpts, otherwise provide a link to other sources.
42 |
43 | return config, scenario
44 |
45 |
46 | def defaults_registry(config, template: dict, display_compatibility=False):
47 | """
48 | Populates missing (key, val) pairs in config with (key, val) in template.
49 |
50 | Args:
51 | - config: namespace
52 | - template: dict
53 | - display_compatibility: option to raise errors if config.key != template[key]
54 | """
55 | if template is None:
56 | return config
57 |
58 | dict_config = vars(config)
59 | for key, val in template.items():
60 | if not isinstance(val, dict): # template[key] is non-index-able
61 | if key not in dict_config or dict_config[key] is None:
62 | dict_config[key] = val
63 | elif dict_config[key] != val and display_compatibility:
64 | raise ValueError(f"Argument {key} must be set to {val}")
65 |
66 | else:
67 | if key not in dict_config.keys():
68 | dict_config[key] = {}
69 | for kwargs_key, kwargs_val in val.items():
70 | if (
71 | kwargs_key not in dict_config[key]
72 | or dict_config[key][kwargs_key] is None
73 | ):
74 | dict_config[key][kwargs_key] = kwargs_val
75 | elif (
76 | dict_config[key][kwargs_key] != kwargs_val and display_compatibility
77 | ):
78 | raise ValueError(
79 | f"Argument {key}[{kwargs_key}] must be set to {kwargs_val}"
80 | )
81 | return config
82 |
83 |
84 | def build_dict_from_config(arg_names, config):
85 | """
86 | Build a dictionary from config based on arg_names.
87 |
88 | Args:
89 | - arg_names: list of strings
90 | - config: namespace
91 | Returns:
92 | - dict: dictionary
93 | """
94 | dict_config = vars(config)
95 | return dict(
96 | (arg_name, dict_config[arg_name])
97 | for arg_name in arg_names
98 | if (arg_name in dict_config) and (dict_config[arg_name] is not None)
99 | )
100 |
--------------------------------------------------------------------------------
/ttab/loads/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/.DS_Store
--------------------------------------------------------------------------------
/ttab/loads/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/__init__.py
--------------------------------------------------------------------------------
/ttab/loads/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/cifar/data_aug_cifar.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Callable, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image, ImageOps
7 | from torchvision import transforms
8 | from ttab.configs.datasets import dataset_defaults
9 |
10 | """Base augmentations operators."""
11 |
12 | ## https://github.com/google-research/augmix
13 |
14 |
15 | def _augmix_aug_cifar(x_orig: torch.Tensor, data_name: str) -> torch.Tensor:
16 | input_size = x_orig.shape[-1]
17 | scale_size = 36 if input_size == 32 else 256 # input size is either 32 or 224
18 | padding = int((scale_size - input_size) / 2)
19 | tensor_to_image, preprocess = get_ops(data_name)
20 |
21 | x_orig = tensor_to_image(x_orig.squeeze(0))
22 | preaugment = transforms.Compose(
23 | [
24 | transforms.RandomCrop(input_size, padding=padding),
25 | transforms.RandomHorizontalFlip(),
26 | ]
27 | )
28 | x_orig = preaugment(x_orig)
29 | x_processed = preprocess(x_orig)
30 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
31 | m = np.float32(np.random.beta(1.0, 1.0))
32 |
33 | mix = torch.zeros_like(x_processed)
34 | for i in range(3):
35 | x_aug = x_orig.copy()
36 | for _ in range(np.random.randint(1, 4)):
37 | x_aug = np.random.choice(augmentations)(x_aug, input_size)
38 | mix += w[i] * preprocess(x_aug)
39 | mix = m * x_processed + (1 - m) * mix
40 | return mix
41 |
42 |
43 | aug_cifar = _augmix_aug_cifar
44 |
45 |
46 | def autocontrast(pil_img, input_size, level=None):
47 | return ImageOps.autocontrast(pil_img)
48 |
49 |
50 | def equalize(pil_img, input_size, level=None):
51 | return ImageOps.equalize(pil_img)
52 |
53 |
54 | def rotate(pil_img, input_size, level):
55 | degrees = int_parameter(rand_lvl(level), 30)
56 | if np.random.uniform() > 0.5:
57 | degrees = -degrees
58 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128)
59 |
60 |
61 | def solarize(pil_img, input_size, level):
62 | level = int_parameter(rand_lvl(level), 256)
63 | return ImageOps.solarize(pil_img, 256 - level)
64 |
65 |
66 | def shear_x(pil_img, input_size, level):
67 | level = float_parameter(rand_lvl(level), 0.3)
68 | if np.random.uniform() > 0.5:
69 | level = -level
70 | return pil_img.transform(
71 | (input_size, input_size),
72 | Image.AFFINE,
73 | (1, level, 0, 0, 1, 0),
74 | resample=Image.BILINEAR,
75 | fillcolor=128,
76 | )
77 |
78 |
79 | def shear_y(pil_img, input_size, level):
80 | level = float_parameter(rand_lvl(level), 0.3)
81 | if np.random.uniform() > 0.5:
82 | level = -level
83 | return pil_img.transform(
84 | (input_size, input_size),
85 | Image.AFFINE,
86 | (1, 0, 0, level, 1, 0),
87 | resample=Image.BILINEAR,
88 | fillcolor=128,
89 | )
90 |
91 |
92 | def translate_x(pil_img, input_size, level):
93 | level = int_parameter(rand_lvl(level), input_size / 3)
94 | if np.random.random() > 0.5:
95 | level = -level
96 | return pil_img.transform(
97 | (input_size, input_size),
98 | Image.AFFINE,
99 | (1, 0, level, 0, 1, 0),
100 | resample=Image.BILINEAR,
101 | fillcolor=128,
102 | )
103 |
104 |
105 | def translate_y(pil_img, input_size, level):
106 | level = int_parameter(rand_lvl(level), input_size / 3)
107 | if np.random.random() > 0.5:
108 | level = -level
109 | return pil_img.transform(
110 | (input_size, input_size),
111 | Image.AFFINE,
112 | (1, 0, 0, 0, 1, level),
113 | resample=Image.BILINEAR,
114 | fillcolor=128,
115 | )
116 |
117 |
118 | def posterize(pil_img, input_size, level):
119 | level = int_parameter(rand_lvl(level), 4)
120 | return ImageOps.posterize(pil_img, 4 - level)
121 |
122 |
123 | def int_parameter(level, maxval):
124 | """Helper function to scale `val` between 0 and maxval .
125 | Args:
126 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
127 | maxval: Maximum value that the operation can have. This will be scaled
128 | to level/PARAMETER_MAX.
129 | Returns:
130 | An int that results from scaling `maxval` according to `level`.
131 | """
132 | return int(level * maxval / 10)
133 |
134 |
135 | def float_parameter(level, maxval):
136 | """Helper function to scale `val` between 0 and maxval .
137 | Args:
138 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
139 | maxval: Maximum value that the operation can have. This will be scaled
140 | to level/PARAMETER_MAX.
141 | Returns:
142 | A float that results from scaling `maxval` according to `level`.
143 | """
144 | return float(level) * maxval / 10.0
145 |
146 |
147 | def rand_lvl(n):
148 | return np.random.uniform(low=0.1, high=n)
149 |
150 |
151 | augmentations = [
152 | autocontrast,
153 | equalize,
154 | lambda x, y: rotate(x, y, 1),
155 | lambda x, y: solarize(x, y, 1),
156 | lambda x, y: shear_x(x, y, 1),
157 | lambda x, y: shear_y(x, y, 1),
158 | lambda x, y: translate_x(x, y, 1),
159 | lambda x, y: translate_y(x, y, 1),
160 | lambda x, y: posterize(x, y, 1),
161 | ]
162 |
163 | def get_ops(data_name: str) -> Tuple[Callable, Callable]:
164 | """Get the operations to be applied when defining transforms."""
165 | unnormalize = transforms.Compose(
166 | [
167 | transforms.Normalize(
168 | mean=[0.0, 0.0, 0.0],
169 | std=[1.0 / v for v in dataset_defaults[data_name]["statistics"]["std"]],
170 | ),
171 | transforms.Normalize(
172 | mean=[-v for v in dataset_defaults[data_name]["statistics"]["mean"]],
173 | std=[1.0, 1.0, 1.0],
174 | ),
175 | ]
176 | )
177 |
178 | tensor_to_image = transforms.Compose([unnormalize, transforms.ToPILImage()])
179 | preprocess = transforms.Compose(
180 | [
181 | transforms.ToTensor(),
182 | transforms.Normalize(
183 | dataset_defaults[data_name]["statistics"]["mean"], dataset_defaults[data_name]["statistics"]["std"]
184 | ),
185 | ]
186 | )
187 | return tensor_to_image, preprocess
188 |
189 |
190 | def tr_transforms_cifar(image: torch.Tensor, data_name: str) -> torch.Tensor:
191 | """
192 | Data augmentation for input images.
193 | args:
194 | inputs:
195 | image: tensor [n_channel, H, W]
196 | outputs:
197 | augment_image: tensor [1, n_channel, H, W]
198 | """
199 | input_size = image.shape[-1]
200 | scale_size = 36 if input_size == 32 else 256 # input size is either 32 or 224
201 | padding = int(scale_size - input_size)
202 | tensor_to_image, preprocess = get_ops(data_name)
203 |
204 | image = tensor_to_image(image)
205 | preaugment = transforms.Compose(
206 | [
207 | transforms.RandomCrop(input_size, padding=padding),
208 | transforms.RandomHorizontalFlip(),
209 | ]
210 | )
211 | augment_image = preaugment(image)
212 | augment_image = preprocess(augment_image)
213 |
214 | return augment_image
215 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/dataset_sampling.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import random
3 |
4 | from ttab.api import PyTorchDataset
5 | from ttab.scenarios import TestDomain
6 |
7 |
8 | class DatasetSampling(object):
9 | def __init__(self, test_domain: TestDomain):
10 | self.domain_sampling_name = test_domain.domain_sampling_name
11 | self.domain_sampling_value = test_domain.domain_sampling_value
12 | self.domain_sampling_ratio = test_domain.domain_sampling_ratio
13 |
14 | def sample(
15 | self, dataset: PyTorchDataset, random_seed: int = None
16 | ) -> PyTorchDataset:
17 | if self.domain_sampling_name == "uniform":
18 | return self.uniform_sample(
19 | dataset=dataset,
20 | ratio=self.domain_sampling_ratio,
21 | random_seed=random_seed,
22 | )
23 | else:
24 | raise NotImplementedError
25 |
26 | @staticmethod
27 | def uniform_sample(
28 | dataset: PyTorchDataset, ratio: float, random_seed: int = None
29 | ) -> PyTorchDataset:
30 | """This function uniformly samples data from the original dataset without replacement."""
31 | random.seed(random_seed)
32 | indices = dataset.query_dataset_attr("indices")
33 | sampled_list = random.sample(
34 | indices,
35 | int(ratio * len(indices)),
36 | )
37 | sampled_list.sort()
38 | dataset.replace_indices(
39 | indices_pattern="new", new_indices=sampled_list, random_seed=random_seed
40 | )
41 | return dataset
42 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/dataset_shifts.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Callable, List, NamedTuple, Optional
3 |
4 | import numpy as np
5 | import torch
6 | from PIL import Image
7 |
8 | # map data_name to a distribution shift.
9 | # Let's say the data_name could be in the form of 1) , 2), _, and 3) __
10 |
11 | data2shift = dict(
12 | cifar10="no_shift",
13 | cifar100="no_shift",
14 | cifar10_c="synthetic",
15 | cifar100_c="synthetic",
16 | cifar10_1="natural",
17 | cifar10_shiftedlabel="natural",
18 | cifar100_shiftedlabel="natural",
19 | cifar10_temporal="temporal",
20 | imagenet="no_shift",
21 | imagenet_c="synthetic",
22 | imagenet_a="natural",
23 | imagenet_r="natural",
24 | # imagenet_v2_
25 | imagenet_v2="natural",
26 | officehome_art="natural",
27 | officehome_clipart="natural",
28 | officehome_product="natural",
29 | officehome_realworld="natural",
30 | pacs_art="natural",
31 | pacs_cartoon="natural",
32 | pacs_photo="natural",
33 | pacs_sketch="natural",
34 | mnist="no_shift",
35 | coloredmnist="synthetic",
36 | waterbirds="natural",
37 | yearbook="natural",
38 | )
39 |
40 |
41 | class SyntheticShiftProperty(NamedTuple):
42 | shift_degree: int
43 | shift_name: str
44 |
45 | version: str = "stochastic"
46 | has_shift: bool = True
47 |
48 |
49 | class NaturalShiftProperty(NamedTuple):
50 | version: str = None
51 | has_shift: bool = True
52 |
53 |
54 | class NoShiftProperty(NamedTuple):
55 | has_shift: bool = False
56 |
57 |
58 | class ShiftedData(object):
59 | def __init__(self, dataset: torch.utils.data.Dataset):
60 | self.dataset = dataset
61 |
62 | def __getitem__(self, index):
63 | return self.dataset[index]
64 |
65 | def __len__(self):
66 | return len(self.dataset)
67 |
68 | def update_indices(self, new_indices: List[int]) -> None:
69 | """Update the indices of the dataset after applying shift."""
70 | self.dataset.indices = new_indices
71 | self.dataset.data_size = len(self.indices)
72 |
73 | @property
74 | def data(self):
75 | return self.dataset.data
76 |
77 | @property
78 | def targets(self):
79 | return self.dataset.targets
80 |
81 | @property
82 | def indices(self):
83 | return self.dataset.indices
84 |
85 | @property
86 | def data_size(self):
87 | return self.dataset.data_size
88 |
89 | @property
90 | def classes(self):
91 | return self.dataset.classes
92 |
93 | @property
94 | def class_to_index(self):
95 | return self.dataset.class_to_index
96 |
97 | @property
98 | def group_array(self):
99 | return getattr(self.dataset, "group_array", None)
100 |
101 |
102 | class NoShiftedData(ShiftedData):
103 | """
104 | Dataset-like object, but only access a subset of it.
105 | And it applies NO shift to the data when __getitem__.
106 | """
107 |
108 | def __init__(self, data_name, dataset: torch.utils.data.Dataset):
109 | super().__init__(dataset)
110 | # initialize corruption class
111 | self.data_name = data_name
112 |
113 |
114 | class NaturalShiftedData(ShiftedData):
115 | """
116 | Dataset-like object, but only access a subset of it.
117 | And it will reload the data with natural shift. It will apply NO shift to the data when __getitem__.
118 | """
119 |
120 | def __init__(
121 | self,
122 | data_name,
123 | dataset: torch.utils.data.Dataset,
124 | new_data: torch.utils.data.Dataset,
125 | ):
126 | super().__init__(dataset)
127 | # replace original data/targets with new data/targets
128 | self.dataset.data = new_data.data
129 | self.dataset.targets = new_data.targets
130 | self.dataset.data_size = len(self.dataset.data)
131 | self.dataset.indices = list([x for x in range(0, self.dataset.data_size)])
132 | self.data_name = data_name
133 |
134 |
135 | class SyntheticShiftedData(ShiftedData):
136 | """
137 | Dataset-like object, but only access a subset of it.
138 | And it applies corruptions to the data when __getitem__.
139 | """
140 |
141 | def __init__(
142 | self,
143 | data_name: str,
144 | dataset: torch.utils.data.Dataset,
145 | seed: int,
146 | synthetic_class: Callable,
147 | version: str,
148 | **kwargs,
149 | ):
150 | super().__init__(dataset)
151 | self.data_name = data_name
152 | self.version = version # either stochastic or determinstic
153 |
154 | # initialize corruption class
155 | if any([name in data_name for name in ["cifar", "imagenet"]]):
156 | self.synthetic_ops = synthetic_class(data_name, seed, kwargs["severity"])
157 | elif "mnist" in data_name:
158 | self.synthetic_ops = synthetic_class
159 | else:
160 | NotImplementedError(
161 | f"synthetic shift for {data_name} is not supported in TTAB."
162 | )
163 |
164 | def apply_corruption(self):
165 | """Apply corruption to the clean dataset."""
166 | corrupted_imgs = []
167 |
168 | for index in range(self.dataset.data_size):
169 | img_array = self.dataset.data[index]
170 | img = Image.fromarray(img_array)
171 | img = self.synthetic_ops.apply(img)
172 | corrupted_imgs.append(img)
173 |
174 | # replace data.
175 | self.dataset.data = np.stack(corrupted_imgs, axis=0)
176 |
177 | def prepare_colored_mnist(
178 | self,
179 | transform: Optional[Callable] = None,
180 | target_transform: Optional[Callable] = None,
181 | ):
182 | return self.synthetic_ops.apply(self.dataset, transform, target_transform)
183 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/imagenet/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import numpy as np
4 | import tarfile
5 |
6 | from torchvision.datasets.utils import download_url
7 | from torchvision.datasets import ImageFolder
8 |
9 |
10 | class ImageNetValNaturalShift(object):
11 | """Borrowed from
12 | (1) https://github.com/hendrycks/imagenet-r/,
13 | (2) https://github.com/hendrycks/natural-adv-examples,
14 | (3) https://github.com/modestyachts/ImageNetV2.
15 | """
16 |
17 | stats = {
18 | "imagenet_r": {
19 | "data_and_labels": "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar",
20 | "folder_name": "imagenet-r",
21 | },
22 | "imagenet_a": {
23 | "data_and_labels": "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar",
24 | "folder_name": "imagenet-a",
25 | },
26 | "imagenet_v2_matched-frequency": {
27 | "data_and_labels": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz",
28 | "folder_name": "imagenetv2-matched-frequency-format-val",
29 | },
30 | "imagenet_v2_threshold0.7": {
31 | "data_and_labels": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-threshold0.7.tar.gz",
32 | "folder_name": "imagenetv2-threshold0.7-format-val",
33 | },
34 | "imagenet_v2_topimages": {
35 | "data_and_labels": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-top-images.tar.gz",
36 | "folder_name": "imagenetv2-topimages-format-val",
37 | },
38 | }
39 |
40 | def __init__(self, root, data_name, version=None):
41 | self.data_name = data_name
42 | self.path_data_and_labels_tar = os.path.join(
43 | root, self.stats[data_name]["data_and_labels"].split("/")[-1]
44 | )
45 | self.path_data_and_labels = os.path.join(
46 | root, self.stats[data_name]["folder_name"]
47 | )
48 |
49 | self._download(root)
50 |
51 | self.image_folder = ImageFolder(self.path_data_and_labels)
52 | self.data = self.image_folder.samples
53 | self.targets = self.image_folder.targets
54 |
55 | def _download(self, root):
56 | download_url(url=self.stats[self.data_name]["data_and_labels"], root=root)
57 |
58 | if self._check_integrity():
59 | print("Files already downloaded, verified, and uncompressed.")
60 | return
61 | self._uncompress(root)
62 |
63 | def _uncompress(self, root):
64 | with tarfile.open(self.path_data_and_labels_tar) as file:
65 | file.extractall(root)
66 |
67 | def _check_integrity(self) -> bool:
68 | if os.path.exists(self.path_data_and_labels):
69 | return True
70 | else:
71 | return False
72 |
73 | def __getitem__(self, index):
74 | path, target = self.data[index]
75 | img = self.image_folder.loader(path)
76 | return img, target
77 |
78 | def __len__(self):
79 | return len(self.data)
80 |
81 |
82 | """Some corruptions are referred to https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py"""
83 |
84 |
85 | class ImageNetSyntheticShift(object):
86 | """The class of synthetic corruptions/shifts introduced in ImageNet_C."""
87 |
88 | def __init__(
89 | self, data_name, seed, severity=5, corruption_type=None, img_resolution=224
90 | ):
91 | assert "imagenet" in data_name
92 |
93 | if img_resolution == 224:
94 | from ttab.loads.datasets.imagenet.synthetic_224 import (
95 | gaussian_noise,
96 | shot_noise,
97 | impulse_noise,
98 | defocus_blur,
99 | glass_blur,
100 | motion_blur,
101 | zoom_blur,
102 | snow,
103 | frost,
104 | fog,
105 | brightness,
106 | contrast,
107 | elastic_transform,
108 | pixelate,
109 | jpeg_compression,
110 | # for validation.
111 | speckle_noise,
112 | gaussian_blur,
113 | spatter,
114 | saturate,
115 | )
116 | elif img_resolution == 64:
117 | from ttab.loads.datasets.imagenet.synthetic_64 import (
118 | gaussian_noise,
119 | shot_noise,
120 | impulse_noise,
121 | defocus_blur,
122 | glass_blur,
123 | motion_blur,
124 | zoom_blur,
125 | snow,
126 | frost,
127 | fog,
128 | brightness,
129 | contrast,
130 | elastic_transform,
131 | pixelate,
132 | jpeg_compression,
133 | # for validation.
134 | speckle_noise,
135 | gaussian_blur,
136 | spatter,
137 | saturate,
138 | )
139 | else:
140 | raise NotImplementedError(
141 | f"Invalid img_resolution for ImageNet: {img_resolution}"
142 | )
143 |
144 | self.data_name = data_name
145 | self.base_data_name = data_name.split("_")[0]
146 | self.seed = seed
147 | self.severity = severity
148 | self.corruption_type = corruption_type
149 | self.dict_corruption = {
150 | "gaussian_noise": gaussian_noise,
151 | "shot_noise": shot_noise,
152 | "impulse_noise": impulse_noise,
153 | "defocus_blur": defocus_blur,
154 | "glass_blur": glass_blur,
155 | "motion_blur": motion_blur,
156 | "zoom_blur": zoom_blur,
157 | "snow": snow,
158 | "frost": frost,
159 | "fog": fog,
160 | "brightness": brightness,
161 | "contrast": contrast,
162 | "elastic_transform": elastic_transform,
163 | "pixelate": pixelate,
164 | "jpeg_compression": jpeg_compression,
165 | "speckle_noise": speckle_noise,
166 | "gaussian_blur": gaussian_blur,
167 | "spatter": spatter,
168 | "saturate": saturate,
169 | }
170 | if corruption_type is not None:
171 | assert (
172 | corruption_type in self.dict_corruption.keys()
173 | ), f"{corruption_type} is out of range"
174 | self.random_state = np.random.RandomState(self.seed)
175 |
176 | def _apply_corruption(self, pil_img):
177 | if self.corruption_index is None or self.corruption_type == "all":
178 | corruption = self.random_state.choice(self.dict_corruption.values())
179 | else:
180 | corruption = self.dict_corruption[self.corruption_type]
181 |
182 | return np.uint8(
183 | corruption(pil_img, random_state=self.random_state, severity=self.severity)
184 | )
185 |
186 | def apply(self, pil_img):
187 | return self._apply_corruption(pil_img)
188 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/imagenet/data_aug_imagenet.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image, ImageOps
6 | from torchvision import transforms
7 | from ttab.configs.datasets import dataset_defaults
8 |
9 | ## https://github.com/google-research/augmix
10 |
11 |
12 | def _augmix_aug(x_orig: torch.Tensor, data_name: str) -> torch.Tensor:
13 | tensor_to_image, preprocess = get_ops(data_name)
14 | x_orig = tensor_to_image(x_orig.squeeze(0))
15 | preaugment = transforms.Compose(
16 | [
17 | transforms.RandomResizedCrop(224),
18 | transforms.RandomHorizontalFlip(),
19 | ]
20 | )
21 | x_orig = preaugment(x_orig)
22 | x_processed = preprocess(x_orig)
23 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
24 | m = np.float32(np.random.beta(1.0, 1.0))
25 |
26 | mix = torch.zeros_like(x_processed)
27 | for i in range(3):
28 | x_aug = x_orig.copy()
29 | for _ in range(np.random.randint(1, 4)):
30 | x_aug = np.random.choice(augmentations)(x_aug)
31 | mix += w[i] * preprocess(x_aug)
32 | mix = m * x_processed + (1 - m) * mix
33 | return mix
34 |
35 |
36 | aug_imagenet = _augmix_aug
37 |
38 |
39 | def autocontrast(pil_img, level=None):
40 | return ImageOps.autocontrast(pil_img)
41 |
42 |
43 | def equalize(pil_img, level=None):
44 | return ImageOps.equalize(pil_img)
45 |
46 |
47 | def rotate(pil_img, level):
48 | degrees = int_parameter(rand_lvl(level), 30)
49 | if np.random.uniform() > 0.5:
50 | degrees = -degrees
51 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128)
52 |
53 |
54 | def solarize(pil_img, level):
55 | level = int_parameter(rand_lvl(level), 256)
56 | return ImageOps.solarize(pil_img, 256 - level)
57 |
58 |
59 | def shear_x(pil_img, level):
60 | level = float_parameter(rand_lvl(level), 0.3)
61 | if np.random.uniform() > 0.5:
62 | level = -level
63 | return pil_img.transform(
64 | (224, 224),
65 | Image.AFFINE,
66 | (1, level, 0, 0, 1, 0),
67 | resample=Image.BILINEAR,
68 | fillcolor=128,
69 | )
70 |
71 |
72 | def shear_y(pil_img, level):
73 | level = float_parameter(rand_lvl(level), 0.3)
74 | if np.random.uniform() > 0.5:
75 | level = -level
76 | return pil_img.transform(
77 | (224, 224),
78 | Image.AFFINE,
79 | (1, 0, 0, level, 1, 0),
80 | resample=Image.BILINEAR,
81 | fillcolor=128,
82 | )
83 |
84 |
85 | def translate_x(pil_img, level):
86 | level = int_parameter(rand_lvl(level), 224 / 3)
87 | if np.random.random() > 0.5:
88 | level = -level
89 | return pil_img.transform(
90 | (224, 224),
91 | Image.AFFINE,
92 | (1, 0, level, 0, 1, 0),
93 | resample=Image.BILINEAR,
94 | fillcolor=128,
95 | )
96 |
97 |
98 | def translate_y(pil_img, level):
99 | level = int_parameter(rand_lvl(level), 224 / 3)
100 | if np.random.random() > 0.5:
101 | level = -level
102 | return pil_img.transform(
103 | (224, 224),
104 | Image.AFFINE,
105 | (1, 0, 0, 0, 1, level),
106 | resample=Image.BILINEAR,
107 | fillcolor=128,
108 | )
109 |
110 |
111 | def posterize(pil_img, level):
112 | level = int_parameter(rand_lvl(level), 4)
113 | return ImageOps.posterize(pil_img, 4 - level)
114 |
115 |
116 | def int_parameter(level, maxval):
117 | """Helper function to scale `val` between 0 and maxval .
118 | Args:
119 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
120 | maxval: Maximum value that the operation can have. This will be scaled
121 | to level/PARAMETER_MAX.
122 | Returns:
123 | An int that results from scaling `maxval` according to `level`.
124 | """
125 | return int(level * maxval / 10)
126 |
127 |
128 | def float_parameter(level, maxval):
129 | """Helper function to scale `val` between 0 and maxval .
130 | Args:
131 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
132 | maxval: Maximum value that the operation can have. This will be scaled
133 | to level/PARAMETER_MAX.
134 | Returns:
135 | A float that results from scaling `maxval` according to `level`.
136 | """
137 | return float(level) * maxval / 10.0
138 |
139 |
140 | def rand_lvl(n):
141 | return np.random.uniform(low=0.1, high=n)
142 |
143 |
144 | augmentations = [
145 | autocontrast,
146 | equalize,
147 | lambda x: rotate(x, 1),
148 | lambda x: solarize(x, 1),
149 | lambda x: shear_x(x, 1),
150 | lambda x: shear_y(x, 1),
151 | lambda x: translate_x(x, 1),
152 | lambda x: translate_y(x, 1),
153 | lambda x: posterize(x, 1),
154 | ]
155 |
156 | def get_ops(data_name: str) -> Tuple[Callable, Callable]:
157 | """Get the operations to be applied when defining transforms."""
158 | unnormalize = transforms.Compose(
159 | [
160 | transforms.Normalize(
161 | mean=[0.0, 0.0, 0.0],
162 | std=[1.0 / v for v in dataset_defaults[data_name]["statistics"]["std"]],
163 | ),
164 | transforms.Normalize(
165 | mean=[-v for v in dataset_defaults[data_name]["statistics"]["mean"]],
166 | std=[1.0, 1.0, 1.0],
167 | ),
168 | ]
169 | )
170 |
171 | tensor_to_image = transforms.Compose([unnormalize, transforms.ToPILImage()])
172 | preprocess = transforms.Compose(
173 | [
174 | transforms.ToTensor(),
175 | transforms.Normalize(
176 | dataset_defaults[data_name]["statistics"]["mean"], dataset_defaults[data_name]["statistics"]["std"]
177 | ),
178 | ]
179 | )
180 | return tensor_to_image, preprocess
181 |
182 |
183 | def tr_transforms_imagenet(image: torch.Tensor, data_name: str) -> torch.Tensor:
184 | """
185 | Data augmentation for input images.
186 | args:
187 | inputs:
188 | image: tensor [n_channel, H, W]
189 | outputs:
190 | augment_image: tensor [1, n_channel, H, W]
191 | """
192 | tensor_to_image, preprocess = get_ops(data_name)
193 | image = tensor_to_image(image)
194 |
195 | preaugment = transforms.Compose(
196 | [
197 | transforms.RandomResizedCrop(224),
198 | transforms.RandomHorizontalFlip(),
199 | ]
200 | )
201 | augment_image = preaugment(image)
202 | augment_image = preprocess(augment_image)
203 |
204 | return augment_image
205 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/loaders.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Iterable, Optional, Tuple, Type, Union
3 |
4 | import torch
5 | from ttab.api import Batch, PyTorchDataset
6 | from ttab.loads.datasets.datasets import WrapperDataset
7 | from ttab.scenarios import Scenario
8 |
9 | D = Union[torch.utils.data.Dataset, PyTorchDataset]
10 |
11 |
12 | class BaseLoader(object):
13 | def __init__(self, dataset: PyTorchDataset):
14 | self.dataset = dataset
15 |
16 | def iterator(
17 | self,
18 | batch_size: int,
19 | shuffle: bool = True,
20 | repeat: bool = False,
21 | ref_num_data: Optional[int] = None,
22 | num_workers: int = 1,
23 | sampler: Optional[torch.utils.data.Sampler] = None,
24 | generator: Optional[torch.Generator] = None,
25 | pin_memory: bool = True,
26 | drop_last: bool = True,
27 | ) -> Iterable[Tuple[int, float, Batch]]:
28 | yield from self.dataset.iterator(
29 | batch_size,
30 | shuffle,
31 | repeat,
32 | ref_num_data,
33 | num_workers,
34 | sampler,
35 | generator,
36 | pin_memory,
37 | drop_last,
38 | )
39 |
40 |
41 | def _init_dataset(dataset: D, device: str) -> PyTorchDataset:
42 | if isinstance(dataset, torch.utils.data.Dataset):
43 | return WrapperDataset(dataset, device)
44 | else:
45 | return dataset
46 |
47 |
48 | def get_test_loader(dataset: D, device: str) -> Type[BaseLoader]:
49 | dataset: PyTorchDataset = _init_dataset(dataset, device)
50 | return BaseLoader(dataset)
51 |
52 |
53 | def get_auxiliary_loader(dataset: D, device: str) -> Type[BaseLoader]:
54 | dataset: PyTorchDataset = _init_dataset(dataset, device)
55 | return BaseLoader(dataset)
56 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/mnist/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import numpy as np
3 | from ttab.loads.datasets.datasets import ImageArrayDataset
4 |
5 | """ColoredMNIST is borrowed from https://github.com/facebookresearch/DomainBed/blob/main/domainbed/datasets.py"""
6 |
7 |
8 | class ColoredSyntheticShift(object):
9 | """The class of synthetic colored shifts introduced in ColoredMNIST."""
10 |
11 | def __init__(self, data_name, seed, color_flip_prob: float = 0.25) -> None:
12 | self.data_name = data_name
13 | self.base_data_name = data_name.split("_")[0]
14 | self.seed = seed
15 | self.color_flip_prob = color_flip_prob
16 | assert (
17 | 0 <= self.color_flip_prob <= 1
18 | ), f"{self.color_flip_prob} is out of range."
19 |
20 | def _color_grayscale_arr(self, arr, red=True):
21 | """Converts grayscale image to either red or green"""
22 | assert arr.ndim == 2
23 | dtype = arr.dtype
24 | h, w = arr.shape
25 | arr = np.reshape(arr, [h, w, 1])
26 | if red:
27 | arr = np.concatenate([arr, np.zeros((h, w, 2), dtype=dtype)], axis=2)
28 | else:
29 | arr = np.concatenate(
30 | [
31 | np.zeros((h, w, 1), dtype=dtype),
32 | arr,
33 | np.zeros((h, w, 1), dtype=dtype),
34 | ],
35 | axis=2,
36 | )
37 | return arr
38 |
39 | def apply(self, dataset, transform, target_transform):
40 | return self._apply_color(
41 | dataset, self.color_flip_prob, transform, target_transform
42 | )
43 |
44 | def _apply_color(self, dataset, color_flip_prob, transform, target_transform):
45 |
46 | train_data = []
47 | train_targets = []
48 | val_data = []
49 | val_targets = []
50 | test_data = []
51 | test_targets = []
52 | for idx, (im, label) in enumerate(dataset):
53 | im_array = np.array(im)
54 |
55 | # Assign a binary label y to the image based on the digit
56 | binary_label = 0 if label < 5 else 1
57 |
58 | # Flip label with probability of `color_flip_prob`
59 | if np.random.uniform() < color_flip_prob:
60 | binary_label = binary_label ^ 1
61 |
62 | # Color the image either red or green according to its possibly flipped label
63 | color_red = binary_label == 0
64 |
65 | # Flip the color with a probability e that depends on the domain
66 | if idx < 30000:
67 | # 10% in the training environment
68 | if np.random.uniform() < 0.1:
69 | color_red = not color_red
70 | elif idx < 40000:
71 | # 10% in the in-distribution eval environment
72 | # val set should have the same distribution as the source domain
73 | if np.random.uniform() < 0.1:
74 | color_red = not color_red
75 | else:
76 | # 90% in the ood test environment
77 | if np.random.uniform() < 0.9:
78 | color_red = not color_red
79 |
80 | colored_arr = self._color_grayscale_arr(im_array, red=color_red)
81 |
82 | if idx < 30000:
83 | train_data.append(colored_arr)
84 | train_targets.append(binary_label)
85 | elif idx < 40000:
86 | val_data.append(colored_arr)
87 | val_targets.append(binary_label)
88 | else:
89 | test_data.append(colored_arr)
90 | test_targets.append(binary_label)
91 |
92 | classes = ["0-4", "5-9"]
93 | class_to_index = {"0-4": 0, "5-9": 1}
94 | train_dataset = ImageArrayDataset(
95 | data=train_data,
96 | targets=train_targets,
97 | classes=classes,
98 | class_to_index=class_to_index,
99 | transform=transform,
100 | target_transform=target_transform,
101 | )
102 | val_dataset = ImageArrayDataset(
103 | data=val_data,
104 | targets=val_targets,
105 | classes=classes,
106 | class_to_index=class_to_index,
107 | transform=transform,
108 | target_transform=target_transform,
109 | )
110 | test_dataset = ImageArrayDataset(
111 | data=test_data,
112 | targets=test_targets,
113 | classes=classes,
114 | class_to_index=class_to_index,
115 | transform=transform,
116 | target_transform=target_transform,
117 | )
118 |
119 | return {
120 | "train": train_dataset,
121 | "val": val_dataset,
122 | "test": test_dataset,
123 | }
124 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/mnist/data_aug_mnist.py:
--------------------------------------------------------------------------------
1 | from typing import Callable, Tuple
2 |
3 | import numpy as np
4 | import torch
5 | from PIL import Image, ImageOps
6 | from torchvision import transforms
7 | from ttab.configs.datasets import dataset_defaults
8 |
9 | ## https://github.com/google-research/augmix
10 |
11 |
12 | def _augmix_aug(x_orig: torch.Tensor, data_name: str) -> torch.Tensor:
13 | input_size = x_orig.shape[-1]
14 | scale_size = 32 if input_size == 28 else 256 # input size is either 28 or 224
15 | padding = int((scale_size - input_size) / 2)
16 | tensor_to_image, preprocess = get_ops(data_name)
17 |
18 | x_orig = tensor_to_image(x_orig.squeeze(0))
19 | preaugment = transforms.Compose(
20 | [
21 | transforms.RandomCrop(input_size, padding=padding),
22 | transforms.RandomHorizontalFlip(),
23 | ]
24 | )
25 | x_orig = preaugment(x_orig)
26 | x_processed = preprocess(x_orig)
27 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
28 | m = np.float32(np.random.beta(1.0, 1.0))
29 |
30 | mix = torch.zeros_like(x_processed)
31 | for i in range(3):
32 | x_aug = x_orig.copy()
33 | for _ in range(np.random.randint(1, 4)):
34 | x_aug = np.random.choice(augmentations)(x_aug, input_size)
35 | mix += w[i] * preprocess(x_aug)
36 | mix = m * x_processed + (1 - m) * mix
37 | return mix
38 |
39 |
40 | aug_mnist = _augmix_aug
41 |
42 |
43 | def autocontrast(pil_img, input_size, level=None):
44 | return ImageOps.autocontrast(pil_img)
45 |
46 |
47 | def equalize(pil_img, input_size, level=None):
48 | return ImageOps.equalize(pil_img)
49 |
50 |
51 | def rotate(pil_img, input_size, level):
52 | degrees = int_parameter(rand_lvl(level), 30)
53 | if np.random.uniform() > 0.5:
54 | degrees = -degrees
55 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128)
56 |
57 |
58 | def solarize(pil_img, input_size, level):
59 | level = int_parameter(rand_lvl(level), 256)
60 | return ImageOps.solarize(pil_img, 256 - level)
61 |
62 |
63 | def shear_x(pil_img, input_size, level):
64 | level = float_parameter(rand_lvl(level), 0.3)
65 | if np.random.uniform() > 0.5:
66 | level = -level
67 | return pil_img.transform(
68 | (input_size, input_size),
69 | Image.AFFINE,
70 | (1, level, 0, 0, 1, 0),
71 | resample=Image.BILINEAR,
72 | fillcolor=128,
73 | )
74 |
75 |
76 | def shear_y(pil_img, input_size, level):
77 | level = float_parameter(rand_lvl(level), 0.3)
78 | if np.random.uniform() > 0.5:
79 | level = -level
80 | return pil_img.transform(
81 | (input_size, input_size),
82 | Image.AFFINE,
83 | (1, 0, 0, level, 1, 0),
84 | resample=Image.BILINEAR,
85 | fillcolor=128,
86 | )
87 |
88 |
89 | def translate_x(pil_img, input_size, level):
90 | level = int_parameter(rand_lvl(level), input_size / 3)
91 | if np.random.random() > 0.5:
92 | level = -level
93 | return pil_img.transform(
94 | (input_size, input_size),
95 | Image.AFFINE,
96 | (1, 0, level, 0, 1, 0),
97 | resample=Image.BILINEAR,
98 | fillcolor=128,
99 | )
100 |
101 |
102 | def translate_y(pil_img, input_size, level):
103 | level = int_parameter(rand_lvl(level), input_size / 3)
104 | if np.random.random() > 0.5:
105 | level = -level
106 | return pil_img.transform(
107 | (input_size, input_size),
108 | Image.AFFINE,
109 | (1, 0, 0, 0, 1, level),
110 | resample=Image.BILINEAR,
111 | fillcolor=128,
112 | )
113 |
114 |
115 | def posterize(pil_img, input_size, level):
116 | level = int_parameter(rand_lvl(level), 4)
117 | return ImageOps.posterize(pil_img, 4 - level)
118 |
119 |
120 | def int_parameter(level, maxval):
121 | """Helper function to scale `val` between 0 and maxval .
122 | Args:
123 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
124 | maxval: Maximum value that the operation can have. This will be scaled
125 | to level/PARAMETER_MAX.
126 | Returns:
127 | An int that results from scaling `maxval` according to `level`.
128 | """
129 | return int(level * maxval / 10)
130 |
131 |
132 | def float_parameter(level, maxval):
133 | """Helper function to scale `val` between 0 and maxval .
134 | Args:
135 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
136 | maxval: Maximum value that the operation can have. This will be scaled
137 | to level/PARAMETER_MAX.
138 | Returns:
139 | A float that results from scaling `maxval` according to `level`.
140 | """
141 | return float(level) * maxval / 10.0
142 |
143 |
144 | def rand_lvl(n):
145 | return np.random.uniform(low=0.1, high=n)
146 |
147 |
148 | augmentations = [
149 | autocontrast,
150 | equalize,
151 | lambda x, y: rotate(x, y, 1),
152 | lambda x, y: solarize(x, y, 1),
153 | lambda x, y: shear_x(x, y, 1),
154 | lambda x, y: shear_y(x, y, 1),
155 | lambda x, y: translate_x(x, y, 1),
156 | lambda x, y: translate_y(x, y, 1),
157 | lambda x, y: posterize(x, y, 1),
158 | ]
159 |
160 | def get_ops(data_name: str) -> Tuple[Callable, Callable]:
161 | """Get the operations to be applied when defining transforms."""
162 | unnormalize = transforms.Compose(
163 | [
164 | transforms.Normalize(
165 | mean=[0.0, 0.0, 0.0],
166 | std=[1.0 / v for v in dataset_defaults[data_name]["statistics"]["std"]],
167 | ),
168 | transforms.Normalize(
169 | mean=[-v for v in dataset_defaults[data_name]["statistics"]["mean"]],
170 | std=[1.0, 1.0, 1.0],
171 | ),
172 | ]
173 | )
174 |
175 | tensor_to_image = transforms.Compose([unnormalize, transforms.ToPILImage()])
176 | preprocess = transforms.Compose(
177 | [
178 | transforms.ToTensor(),
179 | transforms.Normalize(
180 | dataset_defaults[data_name]["statistics"]["mean"], dataset_defaults[data_name]["statistics"]["std"]
181 | ),
182 | ]
183 | )
184 | return tensor_to_image, preprocess
185 |
186 |
187 | def tr_transforms_mnist(image: torch.Tensor, data_name: str) -> torch.Tensor:
188 | """
189 | Data augmentation for input images.
190 | args:
191 | inputs:
192 | image: tensor [n_channel, H, W]
193 | outputs:
194 | augment_image: tensor [1, n_channel, H, W]
195 | """
196 | input_size = image.shape[-1]
197 | scale_size = 32 if input_size == 28 else 256 # input size is either 28 or 224
198 | padding = int((scale_size - input_size) / 2)
199 | tensor_to_image, preprocess = get_ops(data_name)
200 |
201 | image = tensor_to_image(image)
202 | preaugment = transforms.Compose(
203 | [
204 | transforms.RandomCrop(input_size, padding=padding),
205 | transforms.RandomHorizontalFlip(),
206 | ]
207 | )
208 | augment_image = preaugment(image)
209 | augment_image = preprocess(augment_image)
210 |
211 | return augment_image
212 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/__init__.py
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/c_resource/frost1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost1.png
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/c_resource/frost2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost2.png
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/c_resource/frost3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost3.png
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/c_resource/frost4.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost4.jpg
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/c_resource/frost5.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost5.jpg
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/c_resource/frost6.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost6.jpg
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/lmdb.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import sys
4 |
5 | import cv2
6 | import lmdb
7 | import numpy as np
8 | import torch.utils.data as data
9 | from PIL import Image
10 |
11 | from .serialize import loads
12 |
13 | if sys.version_info[0] == 2:
14 | import cPickle as pickle
15 | else:
16 | import pickle
17 |
18 |
19 | def be_ncwh_pt(x):
20 | return x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h)
21 |
22 |
23 | def uint8_to_float(x):
24 | x = x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h)
25 | return x.float() / 128.0 - 1.0
26 |
27 |
28 | class LMDBPT(data.Dataset):
29 | """A class to load the LMDB file for extreme large datasets.
30 | Args:
31 | root (string): Either root directory for the database files,
32 | or a absolute path pointing to the file.
33 | classes (string or list): One of {'train', 'val', 'test'} or a list of
34 | categories to load. e,g. ['bedroom_train', 'church_train'].
35 | transform (callable, optional): A function/transform that
36 | takes in an PIL image and returns a transformed version.
37 | E.g, ``transforms.RandomCrop``
38 | target_transform (callable, optional):
39 | A function/transform that takes in the target and transforms it.
40 | """
41 |
42 | def __init__(self, root, transform=None, target_transform=None, is_image=True):
43 | self.root = os.path.expanduser(root)
44 | self.transform = transform
45 | self.target_transform = target_transform
46 | self.lmdb_files = self._get_valid_lmdb_files()
47 |
48 | # for each class, create an LSUNClassDataset
49 | self.dbs = []
50 | for lmdb_file in self.lmdb_files:
51 | self.dbs.append(
52 | LMDBPTClass(
53 | root=lmdb_file,
54 | transform=transform,
55 | target_transform=target_transform,
56 | is_image=is_image,
57 | )
58 | )
59 |
60 | # build up indices.
61 | self.indices = np.cumsum([len(db) for db in self.dbs])
62 | self.length = self.indices[-1]
63 | self._build_indices()
64 | self._prepare_target()
65 |
66 | def _get_valid_lmdb_files(self):
67 | """get valid lmdb based on given root."""
68 | if not self.root.endswith(".lmdb"):
69 | files = []
70 | for l in os.listdir(self.root):
71 | if "_" in l and "-lock" not in l:
72 | files.append(os.path.join(self.root, l))
73 | return files
74 | else:
75 | return [self.root]
76 |
77 | def _build_indices(self):
78 | self.from_to_indices = enumerate(zip(self.indices[:-1], self.indices[1:]))
79 |
80 | def _get_matched_index(self, index):
81 | if len(list(self.from_to_indices)) == 0:
82 | return 0, index
83 |
84 | for ind, (from_index, to_index) in self.from_to_indices:
85 | if from_index <= index and index < to_index:
86 | return ind, index - from_index
87 |
88 | def __getitem__(self, index, apply_transform=True):
89 | block_index, item_index = self._get_matched_index(index)
90 | image, target = self.dbs[block_index].__getitem__(item_index, apply_transform)
91 | return image, target
92 |
93 | def __len__(self):
94 | return self.length
95 |
96 | def __repr__(self):
97 | fmt_str = "Dataset " + self.__class__.__name__ + "\n"
98 | fmt_str += " Number of datapoints: {}\n".format(self.__len__())
99 | fmt_str += " Root Location: {}\n".format(self.root)
100 | tmp = " Transforms (if any): "
101 | fmt_str += "{0}{1}\n".format(
102 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp))
103 | )
104 | tmp = " Target Transforms (if any): "
105 | fmt_str += "{0}{1}".format(
106 | tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp))
107 | )
108 | return fmt_str
109 |
110 | def _prepare_target(self):
111 | cache_file = self.root + "_targets_cache_"
112 | if os.path.isfile(cache_file):
113 | self.targets = pickle.load(open(cache_file, "rb"))
114 | else:
115 | self.targets = [
116 | self.__getitem__(idx, apply_transform=False)[1]
117 | for idx in range(self.length)
118 | ]
119 | pickle.dump(self.targets, open(cache_file, "wb"))
120 |
121 |
122 | class LMDBPTClass(data.Dataset):
123 | def __init__(self, root, transform=None, target_transform=None, is_image=True):
124 | self.root = os.path.expanduser(root)
125 | self.transform = transform
126 | self.target_transform = target_transform
127 | self.is_image = is_image
128 |
129 | # init the placeholder for env and length.
130 | self.env = None
131 | self.length = self._get_tmp_length()
132 |
133 | def _open_lmdb(self):
134 | return lmdb.open(
135 | self.root,
136 | subdir=os.path.isdir(self.root),
137 | readonly=True,
138 | lock=False,
139 | readahead=False,
140 | # map_size=1099511627776 * 2,
141 | max_readers=1,
142 | meminit=False,
143 | )
144 |
145 | def _get_tmp_length(self):
146 | env = lmdb.open(
147 | self.root,
148 | subdir=os.path.isdir(self.root),
149 | readonly=True,
150 | lock=False,
151 | readahead=False,
152 | # map_size=1099511627776 * 2,
153 | max_readers=1,
154 | meminit=False,
155 | )
156 | with env.begin(write=False) as txn:
157 | length = txn.stat()["entries"]
158 |
159 | if txn.get(b"__keys__") is not None:
160 | length -= 1
161 | # clean everything.
162 | del env
163 | return length
164 |
165 | def _get_length(self):
166 | with self.env.begin(write=False) as txn:
167 | self.length = txn.stat()["entries"]
168 |
169 | if txn.get(b"__keys__") is not None:
170 | self.length -= 1
171 |
172 | def _prepare_cache(self):
173 | cache_file = self.root + "_cache_"
174 | if os.path.isfile(cache_file):
175 | self.keys = pickle.load(open(cache_file, "rb"))
176 | else:
177 | with self.env.begin(write=False) as txn:
178 | self.keys = [key for key, _ in txn.cursor() if key != b"__keys__"]
179 | pickle.dump(self.keys, open(cache_file, "wb"))
180 |
181 | def _decode_from_image(self, x):
182 | image = cv2.imdecode(x, cv2.IMREAD_COLOR).astype("uint8")
183 | return Image.fromarray(image, "RGB")
184 |
185 | def _decode_from_array(self, x):
186 | return Image.fromarray(x.reshape(3, 32, 32).transpose((1, 2, 0)), "RGB")
187 |
188 | def __getitem__(self, index, apply_transform=True):
189 | if self.env is None:
190 | # # open lmdb env.
191 | self.env = self._open_lmdb()
192 | # # get file stats.
193 | # self._get_length()
194 | # # prepare cache_file
195 | self._prepare_cache()
196 |
197 | # setup.
198 | env = self.env
199 | with env.begin(write=False) as txn:
200 | bin_file = txn.get(self.keys[index])
201 |
202 | image, target = loads(bin_file)
203 |
204 | if apply_transform:
205 | if self.is_image:
206 | image = self._decode_from_image(image)
207 | else:
208 | image = self._decode_from_array(image)
209 |
210 | if self.transform is not None:
211 | image = self.transform(image)
212 | if self.target_transform is not None:
213 | target = self.target_transform(target)
214 | return image, target
215 |
216 | def __len__(self):
217 | return self.length
218 |
219 | def __repr__(self):
220 | return self.__class__.__name__ + " (" + self.root + ")"
221 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/utils/serialize.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os
4 |
5 | __all__ = ["loads", "dumps"]
6 |
7 |
8 | def create_dummy_func(func, dependency):
9 | """
10 | When a dependency of a function is not available,
11 | create a dummy function which throws ImportError when used.
12 | Args:
13 | func (str): name of the function.
14 | dependency (str or list[str]): name(s) of the dependency.
15 | Returns:
16 | function: a function object
17 | """
18 | if isinstance(dependency, (list, tuple)):
19 | dependency = ",".join(dependency)
20 |
21 | def _dummy(*args, **kwargs):
22 | raise ImportError(
23 | "Cannot import '{}', therefore '{}' is not available".format(
24 | dependency, func
25 | )
26 | )
27 |
28 | return _dummy
29 |
30 |
31 | def dumps_msgpack(obj):
32 | """
33 | Serialize an object.
34 | Returns:
35 | Implementation-dependent bytes-like object
36 | """
37 | return msgpack.dumps(obj, use_bin_type=True)
38 |
39 |
40 | def loads_msgpack(buf):
41 | """
42 | Args:
43 | buf: the output of `dumps`.
44 | """
45 | return msgpack.loads(buf, raw=False)
46 |
47 |
48 | def dumps_pyarrow(obj):
49 | """
50 | Serialize an object.
51 |
52 | Returns:
53 | Implementation-dependent bytes-like object
54 | """
55 | return pa.serialize(obj).to_buffer()
56 |
57 |
58 | def loads_pyarrow(buf):
59 | """
60 | Args:
61 | buf: the output of `dumps`.
62 | """
63 | return pa.deserialize(buf)
64 |
65 |
66 | try:
67 | # fixed in pyarrow 0.9: https://github.com/apache/arrow/pull/1223#issuecomment-359895666
68 | import pyarrow as pa
69 | except ImportError:
70 | pa = None
71 | dumps_pyarrow = create_dummy_func("dumps_pyarrow", ["pyarrow"]) # noqa
72 | loads_pyarrow = create_dummy_func("loads_pyarrow", ["pyarrow"]) # noqa
73 |
74 | try:
75 | import msgpack
76 | import msgpack_numpy
77 |
78 | msgpack_numpy.patch()
79 | except ImportError:
80 | assert pa is not None, "pyarrow is a dependency of tensorpack!"
81 | loads_msgpack = create_dummy_func( # noqa
82 | "loads_msgpack", ["msgpack", "msgpack_numpy"]
83 | )
84 | dumps_msgpack = create_dummy_func( # noqa
85 | "dumps_msgpack", ["msgpack", "msgpack_numpy"]
86 | )
87 |
88 | if os.environ.get("TENSORPACK_SERIALIZE", "msgpack") == "msgpack":
89 | loads = loads_msgpack
90 | dumps = dumps_msgpack
91 | else:
92 | loads = loads_pyarrow
93 | dumps = dumps_pyarrow
94 |
--------------------------------------------------------------------------------
/ttab/loads/datasets/yearbook/data_aug_yearbook.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from PIL import Image, ImageOps
4 | from torchvision import transforms
5 |
6 | from ttab.configs.datasets import dataset_defaults
7 |
8 | ## https://github.com/google-research/augmix
9 |
10 |
11 | def _augmix_aug(x_orig: torch.Tensor, data_name: str) -> torch.Tensor:
12 | input_size = x_orig.shape[-1]
13 | scale_size = 36 if input_size == 32 else 256 # input size is either 32 or 224
14 | padding = int((scale_size - input_size) / 2)
15 | tensor_to_image = transforms.Compose([transforms.ToPILImage()])
16 | preprocess = transforms.Compose([transforms.ToTensor()])
17 |
18 | x_orig = tensor_to_image(x_orig.squeeze(0))
19 | preaugment = transforms.Compose(
20 | [
21 | transforms.RandomCrop(input_size, padding=padding),
22 | transforms.RandomHorizontalFlip(),
23 | ]
24 | )
25 | x_orig = preaugment(x_orig)
26 | x_processed = preprocess(x_orig)
27 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
28 | m = np.float32(np.random.beta(1.0, 1.0))
29 |
30 | mix = torch.zeros_like(x_processed)
31 | for i in range(3):
32 | x_aug = x_orig.copy()
33 | for _ in range(np.random.randint(1, 4)):
34 | x_aug = np.random.choice(augmentations)(x_aug, input_size)
35 | mix += w[i] * preprocess(x_aug)
36 | mix = m * x_processed + (1 - m) * mix
37 | return mix
38 |
39 |
40 | aug_yearbook = _augmix_aug
41 |
42 |
43 | def autocontrast(pil_img, input_size, level=None):
44 | return ImageOps.autocontrast(pil_img)
45 |
46 |
47 | def equalize(pil_img, input_size, level=None):
48 | return ImageOps.equalize(pil_img)
49 |
50 |
51 | def rotate(pil_img, input_size, level):
52 | degrees = int_parameter(rand_lvl(level), 30)
53 | if np.random.uniform() > 0.5:
54 | degrees = -degrees
55 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128)
56 |
57 |
58 | def solarize(pil_img, input_size, level):
59 | level = int_parameter(rand_lvl(level), 256)
60 | return ImageOps.solarize(pil_img, 256 - level)
61 |
62 |
63 | def shear_x(pil_img, input_size, level):
64 | level = float_parameter(rand_lvl(level), 0.3)
65 | if np.random.uniform() > 0.5:
66 | level = -level
67 | return pil_img.transform(
68 | (input_size, input_size),
69 | Image.AFFINE,
70 | (1, level, 0, 0, 1, 0),
71 | resample=Image.BILINEAR,
72 | fillcolor=128,
73 | )
74 |
75 |
76 | def shear_y(pil_img, input_size, level):
77 | level = float_parameter(rand_lvl(level), 0.3)
78 | if np.random.uniform() > 0.5:
79 | level = -level
80 | return pil_img.transform(
81 | (input_size, input_size),
82 | Image.AFFINE,
83 | (1, 0, 0, level, 1, 0),
84 | resample=Image.BILINEAR,
85 | fillcolor=128,
86 | )
87 |
88 |
89 | def translate_x(pil_img, input_size, level):
90 | level = int_parameter(rand_lvl(level), input_size / 3)
91 | if np.random.random() > 0.5:
92 | level = -level
93 | return pil_img.transform(
94 | (input_size, input_size),
95 | Image.AFFINE,
96 | (1, 0, level, 0, 1, 0),
97 | resample=Image.BILINEAR,
98 | fillcolor=128,
99 | )
100 |
101 |
102 | def translate_y(pil_img, input_size, level):
103 | level = int_parameter(rand_lvl(level), input_size / 3)
104 | if np.random.random() > 0.5:
105 | level = -level
106 | return pil_img.transform(
107 | (input_size, input_size),
108 | Image.AFFINE,
109 | (1, 0, 0, 0, 1, level),
110 | resample=Image.BILINEAR,
111 | fillcolor=128,
112 | )
113 |
114 |
115 | def posterize(pil_img, input_size, level):
116 | level = int_parameter(rand_lvl(level), 4)
117 | return ImageOps.posterize(pil_img, 4 - level)
118 |
119 |
120 | def int_parameter(level, maxval):
121 | """Helper function to scale `val` between 0 and maxval .
122 | Args:
123 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
124 | maxval: Maximum value that the operation can have. This will be scaled
125 | to level/PARAMETER_MAX.
126 | Returns:
127 | An int that results from scaling `maxval` according to `level`.
128 | """
129 | return int(level * maxval / 10)
130 |
131 |
132 | def float_parameter(level, maxval):
133 | """Helper function to scale `val` between 0 and maxval .
134 | Args:
135 | level: Level of the operation that will be between [0, `PARAMETER_MAX`].
136 | maxval: Maximum value that the operation can have. This will be scaled
137 | to level/PARAMETER_MAX.
138 | Returns:
139 | A float that results from scaling `maxval` according to `level`.
140 | """
141 | return float(level) * maxval / 10.0
142 |
143 |
144 | def rand_lvl(n):
145 | return np.random.uniform(low=0.1, high=n)
146 |
147 |
148 | augmentations = [
149 | autocontrast,
150 | equalize,
151 | lambda x, y: rotate(x, y, 1),
152 | lambda x, y: solarize(x, y, 1),
153 | lambda x, y: shear_x(x, y, 1),
154 | lambda x, y: shear_y(x, y, 1),
155 | lambda x, y: translate_x(x, y, 1),
156 | lambda x, y: translate_y(x, y, 1),
157 | lambda x, y: posterize(x, y, 1),
158 | ]
159 |
160 |
161 | def tr_transforms_yearbook(image: torch.Tensor, data_name: str) -> torch.Tensor:
162 | """
163 | Data augmentation for input images.
164 | args:
165 | inputs:
166 | image: tensor [n_channel, H, W]
167 | outputs:
168 | augment_image: tensor [1, n_channel, H, W]
169 | """
170 | input_size = image.shape[-1]
171 | scale_size = 36 if input_size == 32 else 256 # input size is either 28 or 224
172 | padding = int((scale_size - input_size) / 2)
173 | tensor_to_image = transforms.Compose([transforms.ToPILImage()])
174 | preprocess = transforms.Compose([transforms.ToTensor()])
175 |
176 | image = tensor_to_image(image)
177 | preaugment = transforms.Compose(
178 | [
179 | transforms.RandomCrop(input_size, padding=padding),
180 | transforms.RandomHorizontalFlip(),
181 | ]
182 | )
183 | augment_image = preaugment(image)
184 | augment_image = preprocess(augment_image)
185 |
186 | return augment_image
187 |
--------------------------------------------------------------------------------
/ttab/loads/models/__init__.py:
--------------------------------------------------------------------------------
1 | from .cct import *
2 | from .resnet import *
3 | from .wideresnet import *
4 |
--------------------------------------------------------------------------------
/ttab/loads/models/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/models/utils/__init__.py
--------------------------------------------------------------------------------
/ttab/loads/models/utils/embedder.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 |
3 |
4 | class Embedder(nn.Module):
5 | def __init__(
6 | self,
7 | word_embedding_dim=300,
8 | vocab_size=100000,
9 | padding_idx=1,
10 | pretrained_weight=None,
11 | embed_freeze=False,
12 | *args,
13 | **kwargs
14 | ):
15 | super(Embedder, self).__init__()
16 | self.embeddings = (
17 | nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze)
18 | if pretrained_weight is not None
19 | else nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx)
20 | )
21 | self.embeddings.weight.requires_grad = not embed_freeze
22 |
23 | def forward_mask(self, mask):
24 | bsz, seq_len = mask.shape
25 | new_mask = mask.view(bsz, seq_len, 1)
26 | new_mask = new_mask.sum(-1)
27 | new_mask = new_mask > 0
28 | return new_mask
29 |
30 | def forward(self, x, mask=None):
31 | embed = self.embeddings(x)
32 | embed = (
33 | embed
34 | if mask is None
35 | else embed * self.forward_mask(mask).unsqueeze(-1).float()
36 | )
37 | return embed, mask
38 |
39 | @staticmethod
40 | def init_weight(m):
41 | if isinstance(m, nn.Linear):
42 | nn.init.trunc_normal_(m.weight, std=0.02)
43 | if isinstance(m, nn.Linear) and m.bias is not None:
44 | nn.init.constant_(m.bias, 0)
45 | else:
46 | nn.init.normal_(m.weight)
47 |
--------------------------------------------------------------------------------
/ttab/loads/models/utils/helpers.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.nn.functional as F
4 | import logging
5 |
6 | _logger = logging.getLogger("train")
7 |
8 |
9 | def resize_pos_embed(posemb, posemb_new, num_tokens=1):
10 | # Copied from `timm` by Ross Wightman:
11 | # github.com/rwightman/pytorch-image-models
12 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from
13 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
14 | ntok_new = posemb_new.shape[1]
15 | if num_tokens:
16 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
17 | ntok_new -= num_tokens
18 | else:
19 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
20 | gs_old = int(math.sqrt(len(posemb_grid)))
21 | gs_new = int(math.sqrt(ntok_new))
22 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
23 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode="bilinear")
24 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1)
25 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
26 | return posemb
27 |
28 |
29 | def pe_check(model, state_dict, pe_key="classifier.positional_emb"):
30 | if (
31 | pe_key is not None
32 | and pe_key in state_dict.keys()
33 | and pe_key in model.state_dict().keys()
34 | ):
35 | if model.state_dict()[pe_key].shape != state_dict[pe_key].shape:
36 | state_dict[pe_key] = resize_pos_embed(
37 | state_dict[pe_key],
38 | model.state_dict()[pe_key],
39 | num_tokens=model.classifier.num_tokens,
40 | )
41 | return state_dict
42 |
43 |
44 | def fc_check(model, state_dict, fc_key="classifier.fc"):
45 | for key in [f"{fc_key}.weight", f"{fc_key}.bias"]:
46 | if (
47 | key is not None
48 | and key in state_dict.keys()
49 | and key in model.state_dict().keys()
50 | ):
51 | if model.state_dict()[key].shape != state_dict[key].shape:
52 | _logger.warning(f"Removing {key}, number of classes has changed.")
53 | state_dict[key] = model.state_dict()[key]
54 | return state_dict
55 |
--------------------------------------------------------------------------------
/ttab/loads/models/utils/stochastic_depth.py:
--------------------------------------------------------------------------------
1 | # Thanks to rwightman's timm package
2 | # github.com:rwightman/pytorch-image-models
3 |
4 | import torch
5 | import torch.nn as nn
6 |
7 |
8 | def drop_path(x, drop_prob: float = 0.0, training: bool = False):
9 | """
10 | Obtained from: github.com:rwightman/pytorch-image-models
11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
16 | 'survival rate' as the argument.
17 | """
18 | if drop_prob == 0.0 or not training:
19 | return x
20 | keep_prob = 1 - drop_prob
21 | shape = (x.shape[0],) + (1,) * (
22 | x.ndim - 1
23 | ) # work with diff dim tensors, not just 2D ConvNets
24 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
25 | random_tensor.floor_() # binarize
26 | output = x.div(keep_prob) * random_tensor
27 | return output
28 |
29 |
30 | class DropPath(nn.Module):
31 | """
32 | Obtained from: github.com:rwightman/pytorch-image-models
33 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
34 | """
35 |
36 | def __init__(self, drop_prob=None):
37 | super(DropPath, self).__init__()
38 | self.drop_prob = drop_prob
39 |
40 | def forward(self, x):
41 | return drop_path(x, self.drop_prob, self.training)
42 |
--------------------------------------------------------------------------------
/ttab/loads/models/utils/tokenizer.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class Tokenizer(nn.Module):
7 | def __init__(
8 | self,
9 | kernel_size,
10 | stride,
11 | padding,
12 | pooling_kernel_size=3,
13 | pooling_stride=2,
14 | pooling_padding=1,
15 | n_conv_layers=1,
16 | n_input_channels=3,
17 | n_output_channels=64,
18 | in_planes=64,
19 | activation=None,
20 | max_pool=True,
21 | conv_bias=False,
22 | ):
23 | super(Tokenizer, self).__init__()
24 |
25 | n_filter_list = (
26 | [n_input_channels]
27 | + [in_planes for _ in range(n_conv_layers - 1)]
28 | + [n_output_channels]
29 | )
30 |
31 | self.conv_layers = nn.Sequential(
32 | *[
33 | nn.Sequential(
34 | nn.Conv2d(
35 | n_filter_list[i],
36 | n_filter_list[i + 1],
37 | kernel_size=(kernel_size, kernel_size),
38 | stride=(stride, stride),
39 | padding=(padding, padding),
40 | bias=conv_bias,
41 | ),
42 | nn.Identity() if activation is None else activation(),
43 | nn.MaxPool2d(
44 | kernel_size=pooling_kernel_size,
45 | stride=pooling_stride,
46 | padding=pooling_padding,
47 | )
48 | if max_pool
49 | else nn.Identity(),
50 | )
51 | for i in range(n_conv_layers)
52 | ]
53 | )
54 |
55 | self.flattener = nn.Flatten(2, 3)
56 | self.apply(self.init_weight)
57 |
58 | def sequence_length(self, n_channels=3, height=224, width=224):
59 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1]
60 |
61 | def forward(self, x):
62 | return self.flattener(self.conv_layers(x)).transpose(-2, -1)
63 |
64 | @staticmethod
65 | def init_weight(m):
66 | if isinstance(m, nn.Conv2d):
67 | nn.init.kaiming_normal_(m.weight)
68 |
69 |
70 | class TextTokenizer(nn.Module):
71 | def __init__(
72 | self,
73 | kernel_size,
74 | stride,
75 | padding,
76 | pooling_kernel_size=3,
77 | pooling_stride=2,
78 | pooling_padding=1,
79 | embedding_dim=300,
80 | n_output_channels=128,
81 | activation=None,
82 | max_pool=True,
83 | *args,
84 | **kwargs
85 | ):
86 | super(TextTokenizer, self).__init__()
87 |
88 | self.max_pool = max_pool
89 | self.conv_layers = nn.Sequential(
90 | nn.Conv2d(
91 | 1,
92 | n_output_channels,
93 | kernel_size=(kernel_size, embedding_dim),
94 | stride=(stride, 1),
95 | padding=(padding, 0),
96 | bias=False,
97 | ),
98 | nn.Identity() if activation is None else activation(),
99 | nn.MaxPool2d(
100 | kernel_size=(pooling_kernel_size, 1),
101 | stride=(pooling_stride, 1),
102 | padding=(pooling_padding, 0),
103 | )
104 | if max_pool
105 | else nn.Identity(),
106 | )
107 |
108 | self.apply(self.init_weight)
109 |
110 | def seq_len(self, seq_len=32, embed_dim=300):
111 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1]
112 |
113 | def forward_mask(self, mask):
114 | new_mask = mask.unsqueeze(1).float()
115 | cnn_weight = torch.ones(
116 | (1, 1, self.conv_layers[0].kernel_size[0]),
117 | device=mask.device,
118 | dtype=torch.float,
119 | )
120 | new_mask = F.conv1d(
121 | new_mask,
122 | cnn_weight,
123 | None,
124 | self.conv_layers[0].stride[0],
125 | self.conv_layers[0].padding[0],
126 | 1,
127 | 1,
128 | )
129 | if self.max_pool:
130 | new_mask = F.max_pool1d(
131 | new_mask,
132 | self.conv_layers[2].kernel_size[0],
133 | self.conv_layers[2].stride[0],
134 | self.conv_layers[2].padding[0],
135 | 1,
136 | False,
137 | False,
138 | )
139 | new_mask = new_mask.squeeze(1)
140 | new_mask = new_mask > 0
141 | return new_mask
142 |
143 | def forward(self, x, mask=None):
144 | x = x.unsqueeze(1)
145 | x = self.conv_layers(x)
146 | x = x.transpose(1, 3).squeeze(1)
147 | if mask is not None:
148 | mask = self.forward_mask(mask).unsqueeze(-1).float()
149 | x = x * mask
150 | return x, mask
151 |
152 | @staticmethod
153 | def init_weight(m):
154 | if isinstance(m, nn.Conv2d):
155 | nn.init.kaiming_normal_(m.weight)
156 |
--------------------------------------------------------------------------------
/ttab/loads/models/wideresnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import torch.nn.init as init
5 |
6 |
7 | def conv3x3(in_planes, out_planes, stride=1):
8 | return nn.Conv2d(
9 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True
10 | )
11 |
12 |
13 | def conv_init(m):
14 | classname = m.__class__.__name__
15 | if classname.find("Conv") != -1:
16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2))
17 | init.constant_(m.bias, 0)
18 | elif classname.find("BatchNorm") != -1:
19 | init.constant_(m.weight, 1)
20 | init.constant_(m.bias, 0)
21 |
22 |
23 | class wide_basic(nn.Module):
24 | def __init__(self, in_planes, planes, dropout_rate, stride=1):
25 | super(wide_basic, self).__init__()
26 | self.bn1 = nn.BatchNorm2d(in_planes)
27 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True)
28 | self.dropout = nn.Dropout(p=dropout_rate)
29 | self.bn2 = nn.BatchNorm2d(planes)
30 | self.conv2 = nn.Conv2d(
31 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True
32 | )
33 |
34 | self.shortcut = nn.Sequential()
35 | if stride != 1 or in_planes != planes:
36 | self.shortcut = nn.Sequential(
37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True),
38 | )
39 |
40 | def forward(self, x):
41 | out = self.dropout(self.conv1(F.relu(self.bn1(x))))
42 | out = self.conv2(F.relu(self.bn2(out)))
43 | out += self.shortcut(x)
44 |
45 | return out
46 |
47 |
48 | class WideResNet(nn.Module):
49 | def __init__(
50 | self, depth, widen_factor, num_classes, split_point="layer3", dropout_rate=0.3
51 | ):
52 | super(WideResNet, self).__init__()
53 | self.in_planes = 16
54 | assert split_point in ["layer2", "layer3", None], "invalid split position."
55 | self.split_point = split_point
56 |
57 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4"
58 | n = (depth - 4) / 6
59 | k = widen_factor
60 |
61 | print("| Wide-Resnet %dx%d" % (depth, k))
62 | nStages = [16, 16 * k, 32 * k, 64 * k]
63 |
64 | self.conv1 = conv3x3(3, nStages[0])
65 | self.relu = nn.ReLU(inplace=True)
66 | self.layer1 = self._wide_layer(
67 | wide_basic, nStages[1], n, dropout_rate, stride=1
68 | )
69 | self.layer2 = self._wide_layer(
70 | wide_basic, nStages[2], n, dropout_rate, stride=2
71 | )
72 | self.layer3 = self._wide_layer(
73 | wide_basic, nStages[3], n, dropout_rate, stride=2
74 | )
75 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9)
76 | self.avgpool = nn.AvgPool2d(kernel_size=8)
77 | self.classifier = nn.Linear(nStages[3], num_classes, bias=False)
78 |
79 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride):
80 | strides = [stride] + [1] * (int(num_blocks) - 1)
81 | layers = []
82 |
83 | for stride in strides:
84 | layers.append(block(self.in_planes, planes, dropout_rate, stride))
85 | self.in_planes = planes
86 |
87 | return nn.Sequential(*layers)
88 |
89 | def forward_features(self, x):
90 | """Forward function without classifier. Use gradient checkpointing to save memory."""
91 | x = self.conv1(x)
92 | x = self.layer1(x)
93 | x = self.layer2(x)
94 | if self.split_point in ["layer3", None]:
95 | x = self.layer3(x)
96 | x = self.bn1(x)
97 | x = self.relu(x)
98 | x = self.avgpool(x)
99 | x = x.view(x.size(0), -1)
100 |
101 | return x
102 |
103 | def forward_head(self, x, pre_logits: bool = False):
104 | """Forward function for classifier. Use gridient checkpointing to save memory."""
105 | if self.split_point == "layer2":
106 | x = self.layer3(x)
107 | x = self.bn1(x)
108 | x = self.relu(x)
109 | x = self.avgpool(x)
110 | x = x.view(x.size(0), -1)
111 |
112 | return x if pre_logits else self.classifier(x)
113 |
114 | def forward(self, x):
115 | x = self.forward_features(x)
116 | x = self.forward_head(x)
117 | return x
118 |
--------------------------------------------------------------------------------
/ttab/model_adaptation/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from .bn_adapt import BNAdapt
3 | from .conjugate_pl import ConjugatePL
4 | from .cotta import CoTTA
5 | from .eata import EATA
6 | from .memo import MEMO
7 | from .no_adaptation import NoAdaptation
8 | from .note import NOTE
9 | from .sar import SAR
10 | from .shot import SHOT
11 | from .t3a import T3A
12 | from .tent import TENT
13 | from .ttt import TTT
14 | from .ttt_plus_plus import TTTPlusPlus
15 | from .rotta import Rotta
16 |
17 |
18 | def get_model_adaptation_method(adaptation_name):
19 | return {
20 | "no_adaptation": NoAdaptation,
21 | "tent": TENT,
22 | "bn_adapt": BNAdapt,
23 | "memo": MEMO,
24 | "shot": SHOT,
25 | "t3a": T3A,
26 | "ttt": TTT,
27 | "ttt_plus_plus": TTTPlusPlus,
28 | "note": NOTE,
29 | "sar": SAR,
30 | "conjugate_pl": ConjugatePL,
31 | "cotta": CoTTA,
32 | "eata": EATA,
33 | "rotta": Rotta,
34 | }[adaptation_name]
35 |
--------------------------------------------------------------------------------
/ttab/model_adaptation/bn_adapt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import copy
3 | import functools
4 | from typing import List
5 |
6 | import torch
7 | import torch.nn as nn
8 | import ttab.loads.define_dataset as define_dataset
9 | from torch.nn import functional as F
10 | from ttab.api import Batch
11 | from ttab.model_adaptation.base_adaptation import BaseAdaptation
12 | from ttab.model_selection.base_selection import BaseSelection
13 | from ttab.model_selection.metrics import Metrics
14 | from ttab.utils.auxiliary import fork_rng_with_seed
15 | from ttab.utils.logging import Logger
16 | from ttab.utils.timer import Timer
17 |
18 |
19 | class BNAdapt(BaseAdaptation):
20 | """
21 | Improving robustness against common corruptions by covariate shift adaptation,
22 | https://arxiv.org/abs/2006.16971,
23 | https://github.com/bethgelab/robustness
24 | """
25 |
26 | def __init__(self, meta_conf, model: nn.Module):
27 | super().__init__(meta_conf, model)
28 |
29 | def _prior_safety_check(self):
30 |
31 | assert (
32 | self._meta_conf.adapt_prior is not None
33 | ), "the ratio of training set statistics is required"
34 | assert (
35 | self._meta_conf.debug is not None
36 | ), "The state of debug should be specified"
37 | assert self._meta_conf.n_train_steps > 0, "adaptation steps requires >= 1."
38 |
39 | def _initialize_model(self, model: nn.Module):
40 | """Configure model for use with adaptation method."""
41 | # disable grad.
42 | model.eval()
43 | model.requires_grad_(False)
44 |
45 | return model.to(self._meta_conf.device)
46 |
47 | def _initialize_trainable_parameters(self):
48 | """select target params for adaptation methods."""
49 | self._adapt_module_names = []
50 |
51 | for name_module, module in self._base_model.named_children():
52 | if isinstance(module, nn.BatchNorm2d):
53 | self._adapt_module_names.append(name_module)
54 |
55 | def _initialize_optimizer(self, params) -> torch.optim.Optimizer:
56 | """No optimizer is used in BNAdapt."""
57 | pass
58 |
59 | def _post_safety_check(self):
60 |
61 | param_grads = [p.requires_grad for p in (self._model.parameters())]
62 | has_any_params = any(param_grads)
63 | assert not has_any_params, "BNAdapt doesn't need adaptable params."
64 |
65 | has_bn = any(
66 | [
67 | (name_module in self._adapt_module_names)
68 | for name_module, _ in (self._model.named_modules())
69 | ]
70 | )
71 | assert has_bn, "BNAdapt needs batch normalization layers."
72 |
73 | def initialize(self, seed: int):
74 | """Initialize the algorithm."""
75 | self._initialize_trainable_parameters()
76 | self._model = self._initialize_model(model=copy.deepcopy(self._base_model))
77 | self.adapt_setup()
78 | self._auxiliary_data_cls = define_dataset.ConstructAuxiliaryDataset(
79 | config=self._meta_conf
80 | )
81 | self.criterion = nn.CrossEntropyLoss()
82 | if (
83 | self._meta_conf.n_train_steps > 1
84 | ): # no need to do multiple-step adaptation in bn_adapt.
85 | self._meta_conf.n_train_steps = 1
86 |
87 | def _bn_swap(self, model: nn.Module, prior: float):
88 | """
89 | replace the original BN layers in the model with new defined BN layer (AdaptiveBatchNorm).
90 | modifying BN forward pass.
91 | """
92 | return AdaptiveBatchNorm.adapt_model(
93 | model, prior=prior, device=self._meta_conf.device
94 | )
95 |
96 | def one_adapt_step(
97 | self,
98 | model: torch.nn.Module,
99 | timer: Timer,
100 | batch: Batch,
101 | random_seed: int = None,
102 | ):
103 | """adapt the model in one step."""
104 | with timer("forward"):
105 | with fork_rng_with_seed(random_seed):
106 | y_hat = model(batch._x)
107 | loss = self.criterion(y_hat, batch._y)
108 |
109 | return {"loss": loss.item(), "yhat": y_hat}
110 |
111 | def adapt_setup(self):
112 | """adjust batch normalization layers."""
113 | self._bn_swap(self._model, self._meta_conf.adapt_prior)
114 |
115 | def adapt_and_eval(
116 | self,
117 | episodic: bool,
118 | metrics: Metrics,
119 | model_selection_method: BaseSelection,
120 | current_batch,
121 | previous_batches: List[Batch],
122 | logger: Logger,
123 | timer: Timer,
124 | ):
125 | """The key entry of test-time adaptation."""
126 | # some simple initialization.
127 | log = functools.partial(logger.log, display=self._meta_conf.debug)
128 | nbsteps = self._meta_conf.n_train_steps
129 | with timer("test_time_adaptation"):
130 | log(f"\tadapt the model for {nbsteps} steps.")
131 | for _ in range(nbsteps):
132 | adaptation_result = self.one_adapt_step(
133 | self._model, timer, current_batch, random_seed=self._meta_conf.seed
134 | )
135 |
136 | with timer("evaluate_adaptation_result"):
137 | metrics.eval(current_batch._y, adaptation_result["yhat"])
138 | if self._meta_conf.base_data_name in ["waterbirds"]:
139 | self.tta_loss_computer.loss(
140 | adaptation_result["yhat"],
141 | current_batch._y,
142 | current_batch._g,
143 | is_training=False,
144 | )
145 |
146 | @property
147 | def name(self):
148 | return "bn_adapt"
149 |
150 |
151 | class AdaptiveBatchNorm(nn.Module):
152 | """Use the source statistics as a prior on the target statistics"""
153 |
154 | @staticmethod
155 | def find_bns(parent, prior, device):
156 | replace_mods = []
157 | if parent is None:
158 | return []
159 | for name, child in parent.named_children():
160 | child.requires_grad_(False)
161 | if isinstance(child, nn.BatchNorm2d):
162 | module = AdaptiveBatchNorm(child, prior, device)
163 | replace_mods.append((parent, name, module))
164 | else:
165 | replace_mods.extend(AdaptiveBatchNorm.find_bns(child, prior, device))
166 |
167 | return replace_mods
168 |
169 | @staticmethod
170 | def adapt_model(model, prior, device):
171 | replace_mods = AdaptiveBatchNorm.find_bns(model, prior, device)
172 | print(f"| Found {len(replace_mods)} modules to be replaced.")
173 | for (parent, name, child) in replace_mods:
174 | setattr(parent, name, child)
175 | return model
176 |
177 | def __init__(self, layer, prior, device):
178 | assert prior >= 0 and prior <= 1
179 |
180 | super().__init__()
181 | self.layer = layer
182 | self.layer.eval()
183 |
184 | self.norm = nn.BatchNorm2d(
185 | self.layer.num_features, affine=False, momentum=1.0
186 | ).to(device)
187 |
188 | self.prior = prior
189 |
190 | def forward(self, input):
191 | self.norm(input)
192 |
193 | running_mean = (
194 | self.prior * self.layer.running_mean
195 | + (1 - self.prior) * self.norm.running_mean
196 | )
197 | running_var = (
198 | self.prior * self.layer.running_var
199 | + (1 - self.prior) * self.norm.running_var
200 | )
201 |
202 | return F.batch_norm(
203 | input,
204 | running_mean,
205 | running_var,
206 | self.layer.weight,
207 | self.layer.bias,
208 | False,
209 | 0,
210 | self.layer.eps,
211 | )
212 |
--------------------------------------------------------------------------------
/ttab/model_adaptation/no_adaptation.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import copy
3 | import warnings
4 | from typing import List
5 |
6 | import torch
7 | import torch.nn as nn
8 | import ttab.model_adaptation.utils as adaptation_utils
9 | from ttab.api import Batch
10 | from ttab.loads.define_model import load_pretrained_model
11 | from ttab.model_adaptation.base_adaptation import BaseAdaptation
12 | from ttab.model_selection.base_selection import BaseSelection
13 | from ttab.model_selection.metrics import Metrics
14 | from ttab.utils.logging import Logger
15 | from ttab.utils.timer import Timer
16 |
17 |
18 | class NoAdaptation(BaseAdaptation):
19 | """Standard test-time evaluation (no adaptation)."""
20 |
21 | def __init__(self, meta_conf, model: nn.Module):
22 | super().__init__(meta_conf, model)
23 |
24 | def convert_iabn(self, module: nn.Module, **kwargs):
25 | """
26 | Recursively convert all BatchNorm to InstanceAwareBatchNorm.
27 | """
28 | module_output = module
29 | if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
30 | IABN = (
31 | adaptation_utils.InstanceAwareBatchNorm2d
32 | if isinstance(module, nn.BatchNorm2d)
33 | else adaptation_utils.InstanceAwareBatchNorm1d
34 | )
35 | module_output = IABN(
36 | num_channels=module.num_features,
37 | k=self._meta_conf.iabn_k,
38 | eps=module.eps,
39 | momentum=module.momentum,
40 | threshold=self._meta_conf.threshold_note,
41 | affine=module.affine,
42 | )
43 |
44 | module_output._bn = copy.deepcopy(module)
45 |
46 | for name, child in module.named_children():
47 | module_output.add_module(name, self.convert_iabn(child, **kwargs))
48 | del module
49 | return module_output
50 |
51 | def _initialize_model(self, model: nn.Module):
52 | """Configure model for adaptation."""
53 | if hasattr(self._meta_conf, "iabn") and self._meta_conf.iabn:
54 | # check BN layers
55 | bn_flag = False
56 | for name_module, module in model.named_modules():
57 | if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)):
58 | bn_flag = True
59 | if not bn_flag:
60 | warnings.warn(
61 | "IABN needs bn layers, while there is no bn in the base model."
62 | )
63 | self.convert_iabn(model)
64 | load_pretrained_model(self._meta_conf, model)
65 | model.eval()
66 | return model.to(self._meta_conf.device)
67 |
68 | def _post_safety_check(self):
69 | pass
70 |
71 | def initialize(self, seed: int):
72 | """Initialize the algorithm."""
73 | self._model = self._initialize_model(model=copy.deepcopy(self._base_model))
74 |
75 | def adapt_and_eval(
76 | self,
77 | episodic: bool,
78 | metrics: Metrics,
79 | model_selection_method: BaseSelection,
80 | current_batch: Batch,
81 | previous_batches: List[Batch],
82 | logger: Logger,
83 | timer: Timer,
84 | ):
85 | """The key entry of test-time adaptation."""
86 | # some simple initialization.
87 | with timer("test_time_adaptation"):
88 | with torch.no_grad():
89 | y_hat = self._model(current_batch._x)
90 |
91 | with timer("evaluate_adaptation_result"):
92 | metrics.eval(current_batch._y, y_hat)
93 | if self._meta_conf.base_data_name in ["waterbirds"]:
94 | self.tta_loss_computer.loss(
95 | y_hat, current_batch._y, current_batch._g, is_training=False
96 | )
97 |
98 | @property
99 | def name(self):
100 | return "no_adaptation"
101 |
--------------------------------------------------------------------------------
/ttab/model_selection/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from .last_iterate import LastIterate
4 | from .oracle_model_selection import OracleModelSelection
5 |
6 |
7 | def get_model_selection_method(selection_name):
8 | return {
9 | "last_iterate": LastIterate,
10 | "oracle_model_selection": OracleModelSelection,
11 | }[selection_name]
12 |
--------------------------------------------------------------------------------
/ttab/model_selection/base_selection.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Any, Dict, List
3 |
4 | from ttab.api import Batch
5 |
6 |
7 | class BaseSelection(object):
8 | def __init__(self, meta_conf, model_adaptation_method):
9 | self.meta_conf = meta_conf
10 | self.model = model_adaptation_method.copy_model()
11 | self.model.to(self.meta_conf.device)
12 |
13 | self.initialize()
14 |
15 | def initialize(self):
16 | pass
17 |
18 | def clean_up(self):
19 | pass
20 |
21 | def save_state(self):
22 | pass
23 |
24 | def select_state(
25 | self,
26 | current_batch: Batch,
27 | previous_batches: List[Batch],
28 | ) -> Dict[str, Any]:
29 | pass
30 |
--------------------------------------------------------------------------------
/ttab/model_selection/last_iterate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import copy
3 | from typing import Any, Dict
4 |
5 | from ttab.model_selection.base_selection import BaseSelection
6 |
7 |
8 | class LastIterate(BaseSelection):
9 | """Naively return the model generated from the last iterate of adaptation."""
10 |
11 | def __init__(self, meta_conf, model_adaptation_method):
12 | super().__init__(meta_conf, model_adaptation_method)
13 |
14 | def initialize(self):
15 | if hasattr(self.model, "ssh"):
16 | self.model.ssh.eval()
17 | self.model.main_model.eval()
18 | else:
19 | self.model.eval()
20 |
21 | self.optimal_state = None
22 |
23 | def clean_up(self):
24 | self.optimal_state = None
25 |
26 | def save_state(self, state, current_batch):
27 | self.optimal_state = state
28 |
29 | def select_state(self) -> Dict[str, Any]:
30 | """return the optimal state and sync the model defined in the model selection method."""
31 | return self.optimal_state
32 |
33 | @property
34 | def name(self):
35 | return "last_iterate"
36 |
--------------------------------------------------------------------------------
/ttab/model_selection/metrics.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import torch
3 | from ttab.utils.stat_tracker import RuntimeTracker
4 |
5 | task2metrics = {"classification": ["cross_entropy", "accuracy_top1"]}
6 | auxiliary_metrics_dict = {
7 | "preadapted_cross_entropy": "cross_entropy",
8 | "preadapted_accuracy_top1": "accuracy_top1",
9 | }
10 |
11 |
12 | class Metrics(object):
13 | def __init__(self, scenario) -> None:
14 | self._conf = scenario
15 | self._init_metrics()
16 |
17 | def _init_metrics(self) -> None:
18 | self._metrics = task2metrics[self._conf.task]
19 | self.tracker = RuntimeTracker(metrics_to_track=self._metrics)
20 | self._primary_metrics = self._metrics[0]
21 |
22 | def init_auxiliary_metric(self, metric_name: str):
23 | self._metrics.append(metric_name)
24 | self.tracker.add_stat(metric_name)
25 |
26 | @torch.no_grad()
27 | def eval(self, y: torch.Tensor, y_hat: torch.Tensor) -> None:
28 | results = dict()
29 | for metric_name in self._metrics:
30 | if not metric_name in auxiliary_metrics_dict.keys():
31 | results[metric_name] = eval(metric_name)(y, y_hat)
32 | else:
33 | continue
34 | self.tracker.update_metrics(results, n_samples=y.size(0))
35 | return results
36 |
37 | @torch.no_grad()
38 | def eval_auxiliary_metric(
39 | self, y: torch.Tensor, y_hat: torch.Tensor, metric_name: str
40 | ):
41 | assert (
42 | metric_name in self._metrics
43 | ), "The target metric must be in the list of metrics."
44 | results = dict()
45 | results[metric_name] = eval(auxiliary_metrics_dict[metric_name])(y, y_hat)
46 | self.tracker.update_metrics(results, n_samples=y.size(0))
47 | return results
48 |
49 |
50 | """list some common metrics."""
51 |
52 |
53 | def _accuracy(target, output, topk):
54 | """Computes the precision@k for the specified values of k"""
55 | batch_size = target.size(0)
56 |
57 | _, pred = output.topk(topk, 1, True, True)
58 | pred = pred.t()
59 | correct = pred.eq(target.view(1, -1).expand_as(pred))
60 |
61 | correct_k = correct[:topk].reshape(-1).float().sum(0, keepdim=True)
62 | return correct_k.mul_(100.0 / batch_size).item()
63 |
64 |
65 | def accuracy_top1(target, output, topk=1):
66 | """Computes the precision@k for the specified values of k"""
67 | return _accuracy(target, output, topk)
68 |
69 |
70 | def accuracy_top5(target, output, topk=5):
71 | """Computes the precision@k for the specified values of k"""
72 | return _accuracy(target, output, topk)
73 |
74 |
75 | cross_entropy_loss = torch.nn.CrossEntropyLoss()
76 |
77 |
78 | def cross_entropy(target, output):
79 | """Cross entropy loss"""
80 | return cross_entropy_loss(output, target).item()
81 |
--------------------------------------------------------------------------------
/ttab/model_selection/oracle_model_selection.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import copy
3 | from typing import Any, Dict
4 |
5 | import torch
6 | from ttab.model_selection.base_selection import BaseSelection
7 | from ttab.model_selection.metrics import accuracy_top1, cross_entropy
8 |
9 |
10 | class OracleModelSelection(BaseSelection):
11 | """grid-search the best adaptation result per batch (given a sufficiently long adaptation
12 | steps and a single learning rate in each step, save the best checkpoint
13 | and its optimizer states after iterating over all adaptation steps)"""
14 |
15 | def __init__(self, meta_conf, model_adaptation_method):
16 | super().__init__(meta_conf, model_adaptation_method)
17 |
18 | def initialize(self):
19 | if hasattr(self.model, "ssh"):
20 | self.model.ssh.eval()
21 | self.model.main_model.eval()
22 | else:
23 | self.model.eval()
24 |
25 | self.optimal_state = None
26 | self.current_batch_best_acc = 0
27 | self.current_batch_coupled_ent = None
28 |
29 | def clean_up(self):
30 | self.optimal_state = None
31 | self.current_batch_best_acc = 0
32 | self.current_batch_coupled_ent = None
33 |
34 | def save_state(self, state, current_batch):
35 | """Selectively save state for current batch of data."""
36 | batch_best_acc = self.current_batch_best_acc
37 | coupled_ent = self.current_batch_coupled_ent
38 |
39 | if not hasattr(self.model, "ssh"):
40 | self.model.load_state_dict(state["model"])
41 | with torch.no_grad():
42 | outputs = self.model(current_batch._x)
43 | else:
44 | self.model.main_model.load_state_dict(state["main_model"])
45 | with torch.no_grad():
46 | outputs = self.model.main_model(current_batch._x)
47 |
48 | current_acc = self.cal_acc(current_batch._y, outputs)
49 | if (self.optimal_state is None) or (current_acc > batch_best_acc):
50 | self.current_batch_best_acc = current_acc
51 | self.current_batch_coupled_ent = self.cal_ent(current_batch._y, outputs)
52 | state["yhat"] = outputs
53 | self.optimal_state = state
54 | elif current_acc == batch_best_acc:
55 | # compare cross entropy
56 | assert coupled_ent is not None, "Cross entropy value cannot be none."
57 | current_ent = self.cal_ent(current_batch._y, outputs)
58 | if current_ent < coupled_ent:
59 | self.current_batch_coupled_ent = current_ent
60 | state["yhat"] = outputs
61 | self.optimal_state = state
62 |
63 | def cal_acc(self, targets, outputs):
64 | return accuracy_top1(targets, outputs)
65 |
66 | def cal_ent(self, targets, outputs):
67 | return cross_entropy(targets, outputs)
68 |
69 | def select_state(self) -> Dict[str, Any]:
70 | """return the optimal state and sync the model defined in the model selection method."""
71 | if not hasattr(self.model, "ssh"):
72 | self.model.load_state_dict(self.optimal_state["model"])
73 | else:
74 | self.model.main_model.load_state_dict(self.optimal_state["main_model"])
75 | self.model.ssh.load_state_dict(self.optimal_state["ssh"])
76 | return self.optimal_state
77 |
78 | @property
79 | def name(self):
80 | return "oracle_model_selection"
81 |
--------------------------------------------------------------------------------
/ttab/scenarios/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import List, NamedTuple, Union
3 |
4 | from ttab.loads.datasets.dataset_shifts import (
5 | NaturalShiftProperty,
6 | NoShiftProperty,
7 | SyntheticShiftProperty,
8 | )
9 |
10 |
11 | class TestDomain(NamedTuple):
12 | """
13 | The definition of TestDomain: shift property (i.e., P(a^{1:K})) for each domain and its sampling strategy.
14 |
15 | Each data_name follow the pattern of either 1) , or 2), _, or 3) __,
16 | where is the base_data_name, is the shift_name, and is the shift version.
17 | """
18 |
19 | base_data_name: str
20 | data_name: str
21 | shift_type: str # shift between data_name and base_data_name
22 | shift_property: Union[SyntheticShiftProperty, NaturalShiftProperty, NoShiftProperty]
23 |
24 | domain_sampling_name: str = "uniform" # ['uniform', 'label_skew']
25 | domain_sampling_value: float = None # hyper-parameter for domain_sampling_name
26 | domain_sampling_ratio: float = 1.0
27 |
28 |
29 | class HomogeneousNoMixture(NamedTuple):
30 | """
31 | Only consider the shift on P(a^{1:K}) in Figure 6 of the paper, but no label shift.
32 | """
33 |
34 | # no mixture
35 | has_mixture: bool = False
36 |
37 |
38 | class HeterogeneousNoMixture(NamedTuple):
39 | """
40 | Only consider the shift on P(a^{1:K}) with label shift.
41 |
42 | We use this setting to evaluate TTA methods under the continual distribution shift setting in Table 4 of the paper.
43 | """
44 |
45 | # no mixture
46 | has_mixture: bool = False
47 | non_iid_pattern: str = "class_wise_over_domain"
48 | non_iid_ness: float = 100
49 |
50 |
51 | class InOutMixture(NamedTuple):
52 | """
53 | Mix the source domain (left) with the target domain (right).
54 | """
55 |
56 | # mix one in-domain (left) with one out out-domain (right)
57 | has_mixture: bool = True
58 | ratio: float = 0.5 # for left domain
59 |
60 |
61 | class CrossMixture(NamedTuple):
62 | """
63 | Mix multiple target domains (right). Consider shuffle data across domains.
64 | """
65 |
66 | # cross-shuffle test domains.
67 | has_mixture: bool = True
68 |
69 |
70 | class TestCase(NamedTuple):
71 | """
72 | Defines the interaction across domains and some necessary setups in the test-time.
73 | """
74 |
75 | inter_domain: Union[
76 | HomogeneousNoMixture, HeterogeneousNoMixture, InOutMixture, CrossMixture
77 | ]
78 | batch_size: int = 32
79 | data_wise: str = "sample_wise"
80 | offline_pre_adapt: bool = False
81 | episodic: bool = True
82 | intra_domain_shuffle: bool = False
83 |
84 |
85 | class Scenario(NamedTuple):
86 | """
87 | Defines a distribution shift scenario in practice. More details can be found in Setion 4 of the paper.
88 | """
89 |
90 | task: str
91 | model_name: str
92 | model_adaptation_method: str
93 | model_selection_method: str
94 |
95 | base_data_name: str # test dataset (base type).
96 | src_data_name: str # name of source domain
97 | test_domains: List[TestDomain] # a list of domain (will be evaluated in order)
98 | test_case: TestCase
99 |
--------------------------------------------------------------------------------
/ttab/scenarios/default_scenarios.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | from ttab.loads.datasets.dataset_shifts import SyntheticShiftProperty
4 | from ttab.scenarios import HomogeneousNoMixture, Scenario, TestCase, TestDomain
5 |
6 | default_scenarios = {
7 | "S1": Scenario(
8 | task="classification",
9 | model_name="resnet26",
10 | model_adaptation_method="tent",
11 | model_selection_method="last_iterate",
12 | base_data_name="cifar10",
13 | src_data_name="cifar10",
14 | test_domains=[
15 | TestDomain(
16 | base_data_name="cifar10",
17 | data_name="cifar10_c_deterministic-gaussian_noise-5",
18 | shift_type="synthetic",
19 | shift_property=SyntheticShiftProperty(
20 | shift_degree=5,
21 | shift_name="gaussian_noise",
22 | version="deterministic",
23 | has_shift=True,
24 | ),
25 | domain_sampling_name="uniform",
26 | domain_sampling_value=None,
27 | domain_sampling_ratio=1.0,
28 | )
29 | ],
30 | test_case=TestCase(
31 | inter_domain=HomogeneousNoMixture(has_mixture=False),
32 | batch_size=64,
33 | data_wise="batch_wise",
34 | offline_pre_adapt=False,
35 | episodic=False,
36 | intra_domain_shuffle=True,
37 | ),
38 | ),
39 | }
40 |
--------------------------------------------------------------------------------
/ttab/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/utils/__init__.py
--------------------------------------------------------------------------------
/ttab/utils/auxiliary.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | import collections
5 |
6 | import contextlib
7 |
8 | import torch
9 |
10 | import ttab.utils.checkpoint as checkpoint
11 |
12 |
13 | class dict2obj(object):
14 | def __init__(self, d):
15 | for a, b in d.items():
16 | if isinstance(b, (list, tuple)):
17 | setattr(self, a, [dict2obj(x) if isinstance(x, dict) else x for x in b])
18 | else:
19 | setattr(self, a, dict2obj(b) if isinstance(b, dict) else b)
20 |
21 |
22 | def flatten_nested_dicts(d, parent_key="", sep="_"):
23 | """Borrowed from
24 | https://stackoverflow.com/a/6027615
25 | """
26 | items = []
27 | for k, v in d.items():
28 | new_key = parent_key + sep + k if parent_key else k
29 | if isinstance(v, collections.MutableMapping):
30 | items.extend(flatten_nested_dicts(v, new_key, sep=sep).items())
31 | else:
32 | items.append((new_key, v))
33 | return dict(items)
34 |
35 |
36 | @contextlib.contextmanager
37 | def fork_rng_with_seed(seed):
38 | if seed is None:
39 | yield
40 | else:
41 | with torch.random.fork_rng(devices=[]):
42 | torch.manual_seed(seed)
43 | yield
44 |
45 |
46 | @contextlib.contextmanager
47 | def evaluation_monitor(conf):
48 | conf.status = "started"
49 | checkpoint.save_arguments(conf)
50 |
51 | yield
52 |
53 | # update the training status.
54 | job_id = (
55 | conf.job_id if conf.job_id is not None else f"/tmp/tmp_{str(int(time.time()))}"
56 | )
57 | os.system(f"echo {conf.checkpoint_path} >> {job_id}")
58 |
59 | # get updated conf
60 | conf.status = "finished"
61 | checkpoint.save_arguments(conf, force=True)
62 |
--------------------------------------------------------------------------------
/ttab/utils/checkpoint.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import time
4 | import json
5 | from typing import Any
6 |
7 | import torch
8 |
9 | import ttab.utils.file_io as file_io
10 |
11 |
12 | def init_checkpoint(conf: Any):
13 | # init checkpoint dir.
14 | conf.checkpoint_path = os.path.join(
15 | conf.root_path,
16 | conf.model_name,
17 | conf.job_name,
18 | # f"{conf.model_name}_{conf.base_data_name}_{conf.model_adaptation_method}_{conf.model_selection_method}_{int(conf.timestamp if conf.timestamp is not None else time.time())}-seed{conf.seed}",
19 | f"{conf.model_name}_{conf.base_data_name}_{conf.model_adaptation_method}_{conf.model_selection_method}_{str(time.time()).replace('.', '_')}-seed{conf.seed}",
20 | )
21 |
22 | # if the directory does not exists, create them.
23 | file_io.build_dirs(conf.checkpoint_path)
24 | return conf.checkpoint_path
25 |
26 |
27 | def save_arguments(conf: Any, force: bool = False):
28 | # save the configure file to the checkpoint.
29 | path = os.path.join(conf.checkpoint_path, "arguments.json")
30 |
31 | if force or not os.path.exists(path):
32 | with open(path, "w") as fp:
33 | json.dump(
34 | dict(
35 | [
36 | (k, v)
37 | for k, v in conf.__dict__.items()
38 | if file_io.is_jsonable(v) and type(v) is not torch.Tensor
39 | ]
40 | ),
41 | fp,
42 | indent=" ",
43 | )
44 |
--------------------------------------------------------------------------------
/ttab/utils/early_stopping.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | class EarlyStoppingTracker(object):
5 | def __init__(self, patience: int, delta: int = 0, mode: str = "max") -> None:
6 | self.patience = patience
7 | self.delta = delta
8 | self.mode = mode
9 | self.best_value = None
10 | self.counter = 0
11 |
12 | def __call__(self, value: float) -> bool:
13 | if self.patience is None or self.patience <= 0:
14 | return False
15 |
16 | if self.best_value is None:
17 | self.best_value = value
18 | self.counter = 0
19 | return False
20 |
21 | if self.mode == "max":
22 | if value > self.best_value + self.delta:
23 | return self._positive_update(value)
24 | else:
25 | return self._negative_update(value)
26 | elif self.mode == "min":
27 | if value < self.best_value - self.delta:
28 | return self._positive_update(value)
29 | else:
30 | return self._negative_update(value)
31 | else:
32 | raise ValueError(f"Illegal mode for early stopping: {self.mode}")
33 |
34 | def _positive_update(self, value: float) -> bool:
35 | self.counter = 0
36 | self.best_value = value
37 | return False
38 |
39 | def _negative_update(self, value: float) -> bool:
40 | self.counter += 1
41 | if self.counter > self.patience:
42 | return True
43 | else:
44 | return False
45 |
--------------------------------------------------------------------------------
/ttab/utils/file_io.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import shutil
4 | import json
5 |
6 |
7 | """json related."""
8 |
9 |
10 | def read_json(path):
11 | """read json file from path."""
12 | with open(path, "r") as f:
13 | return json.load(f)
14 |
15 |
16 | def is_jsonable(x):
17 | try:
18 | json.dumps(x)
19 | return True
20 | except:
21 | return False
22 |
23 |
24 | """operate dir."""
25 |
26 |
27 | def build_dir(path, force):
28 | """build directory."""
29 | if os.path.exists(path) and force:
30 | shutil.rmtree(path)
31 | os.mkdir(path)
32 | elif not os.path.exists(path):
33 | os.mkdir(path)
34 | return path
35 |
36 |
37 | def build_dirs(path):
38 | try:
39 | os.makedirs(path)
40 | except Exception as e:
41 | print(" encounter error: {}".format(e))
42 |
43 |
44 | def remove_folder(path):
45 | try:
46 | shutil.rmtree(path)
47 | except Exception as e:
48 | print(" encounter error: {}".format(e))
49 |
50 |
51 | def list_files(root_path):
52 | dirs = os.listdir(root_path)
53 | return [os.path.join(root_path, path) for path in dirs]
54 |
--------------------------------------------------------------------------------
/ttab/utils/logging.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import json
4 | import time
5 | import pprint
6 | from typing import Any, Dict
7 |
8 | from io import StringIO
9 | import csv
10 |
11 |
12 | class Logger(object):
13 | """
14 | Very simple prototype logger that will store the values to a JSON file
15 | """
16 |
17 | def __init__(self, folder_path: str) -> None:
18 | """
19 | :param filename: ending with .json
20 | :param auto_save: save the JSON file after every addition
21 | """
22 | self.folder_path = folder_path
23 | self.json_file_path = os.path.join(folder_path, "log-1.json")
24 | self.txt_file_path = os.path.join(folder_path, "log.txt")
25 | self.values = []
26 | self.pp = MyPrettyPrinter(indent=2, depth=3, compact=True)
27 |
28 | def log_metric(
29 | self,
30 | name: str,
31 | values: Dict[str, Any],
32 | tags: Dict[str, Any],
33 | display: bool = False,
34 | ) -> None:
35 | """
36 | Store a scalar metric
37 |
38 | :param name: measurement, like 'accuracy'
39 | :param values: dictionary, like { epoch: 3, value: 0.23 }
40 | :param tags: dictionary, like { split: train }
41 | """
42 | self.values.append({"measurement": name, **values, **tags})
43 |
44 | if display:
45 | print(
46 | "{name}: {values} ({tags})".format(name=name, values=values, tags=tags)
47 | )
48 |
49 | def pretty_print(self, value: Any) -> None:
50 | self.pp.pprint(value)
51 |
52 | def log(self, value: str, display: bool = True) -> None:
53 | content = time.strftime("%Y-%m-%d %H:%M:%S") + "\t" + value
54 | if display:
55 | print(content)
56 | self.save_txt(content)
57 |
58 | def save_json(self) -> None:
59 | """Save the internal memory to a file."""
60 | with open(self.json_file_path, "w") as fp:
61 | json.dump(self.values, fp, indent=" ")
62 |
63 | if len(self.values) > 1e4:
64 | # reset 'values' and redirect the json file to a different path.
65 | self.values = []
66 | self.redirect_new_json()
67 |
68 | def save_txt(self, value: str) -> None:
69 | with open(self.txt_file_path, "a") as f:
70 | f.write(value + "\n")
71 |
72 | def redirect_new_json(self) -> None:
73 | """get the number of existing json files under the current folder."""
74 | existing_json_files = [
75 | file for file in os.listdir(self.folder_path) if "json" in file
76 | ]
77 | self.json_file_path = os.path.join(
78 | self.folder_path, "log-{}.json".format(len(existing_json_files) + 1)
79 | )
80 |
81 |
82 | class MyPrettyPrinter(pprint.PrettyPrinter):
83 | """Borrowed from
84 | https://stackoverflow.com/questions/30062384/pretty-print-namedtuple
85 | """
86 |
87 | def format_namedtuple(self, object, stream, indent, allowance, context, level):
88 | # Code almost equal to _format_dict, see pprint code
89 | write = stream.write
90 | write(object.__class__.__name__ + "(")
91 | object_dict = object._asdict()
92 | length = len(object_dict)
93 | if length:
94 | # We first try to print inline, and if it is too large then we print it on multiple lines
95 | inline_stream = StringIO()
96 | self.format_namedtuple_items(
97 | object_dict.items(),
98 | inline_stream,
99 | indent,
100 | allowance + 1,
101 | context,
102 | level,
103 | inline=True,
104 | )
105 | max_width = self._width - indent - allowance
106 | if len(inline_stream.getvalue()) > max_width:
107 | self.format_namedtuple_items(
108 | object_dict.items(),
109 | stream,
110 | indent,
111 | allowance + 1,
112 | context,
113 | level,
114 | inline=False,
115 | )
116 | else:
117 | stream.write(inline_stream.getvalue())
118 | write(")")
119 |
120 | def format_namedtuple_items(
121 | self, items, stream, indent, allowance, context, level, inline=False
122 | ):
123 | # Code almost equal to _format_dict_items, see pprint code
124 | indent += self._indent_per_level
125 | write = stream.write
126 | last_index = len(items) - 1
127 | if inline:
128 | delimnl = ", "
129 | else:
130 | delimnl = ",\n" + " " * indent
131 | write("\n" + " " * indent)
132 | for i, (key, ent) in enumerate(items):
133 | last = i == last_index
134 | write(key + "=")
135 | self._format(
136 | ent,
137 | stream,
138 | indent + len(key) + 2,
139 | allowance if last else 1,
140 | context,
141 | level,
142 | )
143 | if not last:
144 | write(delimnl)
145 |
146 | def _format(self, object, stream, indent, allowance, context, level):
147 | # We dynamically add the types of our namedtuple and namedtuple like
148 | # classes to the _dispatch object of pprint that maps classes to
149 | # formatting methods
150 | # We use a simple criteria (_asdict method) that allows us to use the
151 | # same formatting on other classes but a more precise one is possible
152 | if hasattr(object, "_asdict") and type(object).__repr__ not in self._dispatch:
153 | self._dispatch[type(object).__repr__] = MyPrettyPrinter.format_namedtuple
154 | super()._format(object, stream, indent, allowance, context, level)
155 |
156 |
157 | class CSVBatchLogger:
158 | """Borrowed from https://github.com/kohpangwei/group_DRO/blob/master/utils.py#L39"""
159 |
160 | def __init__(self, csv_path, n_groups, mode="w"):
161 | columns = ["epoch", "batch"]
162 | for idx in range(n_groups):
163 | columns.append(f"avg_loss_group:{idx}")
164 | columns.append(f"exp_avg_loss_group:{idx}")
165 | columns.append(f"avg_acc_group:{idx}")
166 | columns.append(f"processed_data_count_group:{idx}")
167 | columns.append(f"update_data_count_group:{idx}")
168 | columns.append(f"update_batch_count_group:{idx}")
169 | columns.append("avg_actual_loss")
170 | columns.append("avg_per_sample_loss")
171 | columns.append("avg_acc")
172 | columns.append("model_norm_sq")
173 | columns.append("reg_loss")
174 |
175 | self.path = csv_path
176 | self.file = open(csv_path, mode)
177 | self.columns = columns
178 | self.writer = csv.DictWriter(self.file, fieldnames=columns)
179 | if mode == "w":
180 | self.writer.writeheader()
181 |
182 | def log(self, epoch, batch, stats_dict):
183 | stats_dict["epoch"] = epoch
184 | stats_dict["batch"] = batch
185 | self.writer.writerow(stats_dict)
186 |
187 | def flush(self):
188 | self.file.flush()
189 |
190 | def close(self):
191 | self.file.close()
192 |
--------------------------------------------------------------------------------
/ttab/utils/mathdict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import Any, Dict, List, Union, Callable
3 | from collections.abc import Mapping, ItemsView, ValuesView
4 |
5 | C = Union[float, int]
6 |
7 |
8 | class MathDict:
9 | def __init__(self, dictionary: Dict[Any, Union[int, float]]) -> None:
10 | self.dictionary = dictionary
11 | self.keys = set(dictionary.keys())
12 |
13 | def __str__(self) -> str:
14 | return f"MathDict({self.dictionary})"
15 |
16 | def __repr__(self) -> str:
17 | return f"MathDict({repr(self.dictionary)})"
18 |
19 | def map(self: "MathDict", mapfun: Mapping[Any, Union[int, float]]) -> "MathDict":
20 | new_dict = {}
21 | for key in self.keys:
22 | new_dict[key] = mapfun(self.dictionary[key])
23 | return MathDict(new_dict)
24 |
25 | def filter(self: "MathDict", condfun: Mapping[Any, bool]) -> "MathDict":
26 | new_dict = {}
27 | for key in self.keys:
28 | if condfun(key):
29 | new_dict[key] = self.dictionary[key]
30 | return MathDict(new_dict)
31 |
32 | def detach(self) -> None:
33 | for key in self.keys:
34 | self.dictionary[key] = self.dictionary[key].detach()
35 |
36 | def values(self) -> ValuesView:
37 | return self.dictionary.values()
38 |
39 | def items(self) -> ItemsView:
40 | return self.dictionary.items()
41 |
42 |
43 | def _mathdict_binary_op(operation: Callable[[C, C], C]) -> MathDict:
44 | def op(self: MathDict, other: Union[MathDict, Dict]) -> MathDict:
45 | new_dict = {}
46 | if isinstance(other, MathDict):
47 | assert other.keys == self.keys
48 | for key in self.keys:
49 | new_dict[key] = operation(self.dictionary[key], other.dictionary[key])
50 | else:
51 | for key in self.keys:
52 | new_dict[key] = operation(self.dictionary[key], other)
53 | return MathDict(new_dict)
54 |
55 | return op
56 |
57 |
58 | def _mathdict_map_op(
59 | operation: Callable[[ValuesView, List[Any], List[Any]], Any]
60 | ) -> MathDict:
61 | def op(self: MathDict, *args, **kwargs) -> MathDict:
62 | new_dict = {}
63 | for key in self.keys:
64 | new_dict[key] = operation(self.dictionary[key], args, kwargs)
65 | return MathDict(new_dict)
66 |
67 | return op
68 |
69 |
70 | def _mathdict_binary_in_place_op(operation: Callable[[Dict, Any, C], None]) -> MathDict:
71 | def op(self: MathDict, other: Union[MathDict, Dict]) -> MathDict:
72 | if isinstance(other, MathDict):
73 | assert other.keys == self.keys
74 | for key in self.keys:
75 | operation(self.dictionary, key, other.dictionary[key])
76 | else:
77 | for key in self.keys:
78 | operation(self.dictionary, key, other)
79 | return self
80 |
81 | return op
82 |
83 |
84 | def _iadd(dict: Dict, key: Any, b: C) -> None:
85 | dict[key] += b
86 |
87 |
88 | def _isub(dict: Dict, key: Any, b: C) -> None:
89 | dict[key] -= b
90 |
91 |
92 | def _imul(dict: Dict, key: Any, b: C) -> None:
93 | dict[key] *= b
94 |
95 |
96 | def _itruediv(dict: Dict, key: Any, b: C) -> None:
97 | dict[key] /= b
98 |
99 |
100 | def _ifloordiv(dict: Dict, key: Any, b: C) -> None:
101 | dict[key] //= b
102 |
103 |
104 | MathDict.__add__ = _mathdict_binary_op(lambda a, b: a + b)
105 | MathDict.__sub__ = _mathdict_binary_op(lambda a, b: a - b)
106 | MathDict.__rsub__ = _mathdict_binary_op(lambda a, b: b - a)
107 | MathDict.__mul__ = _mathdict_binary_op(lambda a, b: a * b)
108 | MathDict.__rmul__ = _mathdict_binary_op(lambda a, b: a * b)
109 | MathDict.__truediv__ = _mathdict_binary_op(lambda a, b: a / b)
110 | MathDict.__floordiv__ = _mathdict_binary_op(lambda a, b: a // b)
111 | MathDict.__getitem__ = _mathdict_map_op(
112 | lambda x, args, kwargs: x.__getitem__(*args, **kwargs)
113 | )
114 | MathDict.__iadd__ = _mathdict_binary_in_place_op(_iadd)
115 | MathDict.__isub__ = _mathdict_binary_in_place_op(_isub)
116 | MathDict.__imul__ = _mathdict_binary_in_place_op(_imul)
117 | MathDict.__itruediv__ = _mathdict_binary_in_place_op(_itruediv)
118 | MathDict.__ifloordiv__ = _mathdict_binary_in_place_op(_ifloordiv)
119 |
--------------------------------------------------------------------------------
/ttab/utils/stat_tracker.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import copy
3 |
4 |
5 | class MaxMeter(object):
6 | """
7 | Keeps track of the max of all the values that are 'add'ed
8 | """
9 |
10 | def __init__(self):
11 | self.max = None
12 |
13 | def update(self, value):
14 | """
15 | Add a value to the accumulator.
16 | :return: `true` if the provided value became the new max
17 | """
18 | if self.max is None or value > self.max:
19 | self.max = copy.deepcopy(value)
20 | return True
21 | else:
22 | return False
23 |
24 | def value(self):
25 | """Access the current running average"""
26 | return self.max
27 |
28 |
29 | class MinMeter(object):
30 | """
31 | Keeps track of the max of all the values that are 'add'ed
32 | """
33 |
34 | def __init__(self):
35 | self.min = None
36 |
37 | def update(self, value):
38 | """
39 | Add a value to the accumulator.
40 | :return: `true` if the provided value became the new max
41 | """
42 | if self.min is None or value < self.min:
43 | self.min = copy.deepcopy(value)
44 | return True
45 | else:
46 | return False
47 |
48 | def value(self):
49 | """Access the current running average"""
50 | return self.min
51 |
52 |
53 | class AverageMeter(object):
54 | """Computes and stores the average and current value"""
55 |
56 | def __init__(self):
57 | self.reset()
58 |
59 | def reset(self):
60 | self.val = 0
61 | self.avg = 0
62 | self.sum = 0
63 | self.max = -float("inf")
64 | self.min = float("inf")
65 | self.count = 0
66 |
67 | def update(self, val, n=1):
68 | self.val = val
69 | self.sum += val * n
70 | self.count += n
71 | self.avg = self.sum / self.count
72 | self.max = val if val > self.max else self.max
73 | self.min = val if val < self.min else self.min
74 |
75 |
76 | class RuntimeTracker(object):
77 | """Tracking the runtime stat for local training."""
78 |
79 | def __init__(self, metrics_to_track):
80 | self.metrics_to_track = metrics_to_track
81 | self.reset()
82 |
83 | def reset(self):
84 | self.stat = dict((name, AverageMeter()) for name in self.metrics_to_track)
85 |
86 | def add_stat(self, metric_name: str):
87 | self.stat[metric_name] = AverageMeter()
88 |
89 | def get_metrics_performance(self):
90 | return [self.stat[metric].avg for metric in self.metrics_to_track]
91 |
92 | def update_metrics(self, metric_stat, n_samples):
93 | for name, value in metric_stat.items():
94 | self.stat[name].update(value, n_samples)
95 |
96 | def __call__(self):
97 | return dict((name, val.avg) for name, val in self.stat.items())
98 |
99 | def get_current_val(self):
100 | return dict((name, val.val) for name, val in self.stat.items())
101 |
102 | def get_val_by_name(self, metric_name):
103 | return dict([(metric_name, self.stat[metric_name].val)])
104 |
105 |
106 | class BestPerf(object):
107 | def __init__(self, best_perf=None, larger_is_better=True):
108 | self.best_perf = best_perf
109 | self.cur_perf = None
110 | self.best_perf_locs = []
111 | self.larger_is_better = larger_is_better
112 |
113 | # define meter
114 | self._define_meter()
115 |
116 | def _define_meter(self):
117 | self.meter = MaxMeter() if self.larger_is_better else MinMeter()
118 |
119 | def update(self, perf, perf_location):
120 | self.is_best = self.meter.update(perf)
121 | self.cur_perf = perf
122 |
123 | if self.is_best:
124 | self.best_perf = perf
125 | self.best_perf_locs += [perf_location]
126 |
127 | def get_best_perf_loc(self):
128 | return self.best_perf_locs[-1] if len(self.best_perf_locs) != 0 else None
129 |
--------------------------------------------------------------------------------
/ttab/utils/tensor_buffer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | from typing import List, Tuple
3 |
4 | import torch
5 |
6 |
7 | class TensorBuffer(object):
8 | """Packs multiple tensors into one flat buffer."""
9 |
10 | def __init__(self, tensors: List[torch.Tensor], use_cuda: bool = True) -> None:
11 | indices: List[int] = [0]
12 | for tensor in tensors:
13 | new_end = indices[-1] + tensor.nelement()
14 | indices.append(new_end)
15 |
16 | self._start_idx = indices[:-1]
17 | self._end_idx = indices[1:]
18 | self._tensors_len = len(tensors)
19 | self._tensors_sizes = [x.size() for x in tensors]
20 |
21 | self.buffer = flatten(tensors, use_cuda=use_cuda) # copies
22 |
23 | def __getitem__(self, index: int) -> torch.Tensor:
24 | return self.buffer[self._start_idx[index] : self._end_idx[index]].view(
25 | self._tensors_sizes[index]
26 | )
27 |
28 | def __len__(self) -> int:
29 | return self._tensors_len
30 |
31 | def is_cuda(self) -> bool:
32 | return self.buffer.is_cuda
33 |
34 | def nelement(self) -> int:
35 | return self.buffer.nelement()
36 |
37 | def unpack(self, tensors: List[torch.Tensor]) -> None:
38 | for tensor, entry in zip(tensors, self):
39 | tensor.data[:] = entry
40 |
41 |
42 | def flatten(
43 | tensors: List[torch.Tensor],
44 | shapes: List[Tuple[int, int]] = None,
45 | use_cuda: bool = True,
46 | ):
47 | # init and recover the shapes vec.
48 | pointers: List[int] = [0]
49 | if shapes is not None:
50 | for shape in shapes:
51 | pointers.append(pointers[-1] + shape[1])
52 | else:
53 | for tensor in tensors:
54 | pointers.append(pointers[-1] + tensor.nelement())
55 |
56 | # flattening.
57 | current_device = tensors[0].device
58 | target_device = tensors[0].device if tensors[0].is_cuda and use_cuda else "cpu"
59 | vec = torch.empty(pointers[-1], device=target_device)
60 |
61 | for tensor, start_idx, end_idx in zip(tensors, pointers[:-1], pointers[1:]):
62 | vec[start_idx:end_idx] = (
63 | tensor.data.view(-1).to(device=target_device)
64 | if current_device != target_device
65 | else tensor.data.view(-1)
66 | )
67 | return vec
68 |
69 |
70 | def unflatten(
71 | self_tensors: List[torch.Tensor],
72 | out_tensors: List[torch.Tensor],
73 | shapes: Tuple[int, int],
74 | ):
75 | pointer: int = 0
76 |
77 | for self_tensor, shape in zip(self_tensors, shapes):
78 | param_size, nelement = shape
79 | self_tensor.data[:] = out_tensors[pointer : pointer + nelement].view(param_size)
80 | pointer += nelement
81 |
--------------------------------------------------------------------------------
/ttab/utils/timer.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 | from typing import Any, Dict
4 | from io import StringIO
5 | from contextlib import contextmanager
6 |
7 | import numpy as np
8 | import torch
9 |
10 |
11 | class Timer(object):
12 | """
13 | Timer for PyTorch code
14 | Comes in the form of a contextmanager:
15 |
16 | Example:
17 | >>> timer = Timer()
18 | ... for i in range(10):
19 | ... with timer("expensive operation"):
20 | ... x = torch.randn(100)
21 | ... print(timer.summary())
22 | """
23 |
24 | def __init__(
25 | self,
26 | device: str,
27 | verbosity_level: int = 1,
28 | log_fn=None,
29 | skip_first: bool = True,
30 | on_cuda: bool = True,
31 | ) -> None:
32 | self.device = device
33 | self.verbosity_level = verbosity_level
34 | self.log_fn = log_fn if log_fn is not None else self._default_log_fn
35 | self.skip_first = skip_first
36 | self.cuda_available = torch.cuda.is_available() and on_cuda
37 |
38 | self.reset()
39 |
40 | def reset(self) -> None:
41 | """Reset the timer"""
42 | self.totals = {} # Total time per label
43 | self.first_time = {} # First occurrence of a label (start time)
44 | self.last_time = {} # Last occurence of a label (end time)
45 | self.call_counts = {} # Number of times a label occurred
46 |
47 | @contextmanager
48 | def __call__(
49 | self, label: str, step: int = -1, epoch: int = -1.0, verbosity: int = 1
50 | ) -> None:
51 | # Don't measure this if the verbosity level is too high
52 | if verbosity > self.verbosity_level:
53 | yield
54 | return
55 |
56 | # Measure the time
57 | self._cuda_sync()
58 | start = time.time()
59 | yield
60 | self._cuda_sync()
61 | end = time.time()
62 |
63 | # Update first and last occurrence of this label
64 | if label not in self.first_time:
65 | self.first_time[label] = start
66 | self.last_time[label] = end
67 |
68 | # Update the totals and call counts
69 | if label not in self.totals and self.skip_first:
70 | self.totals[label] = 0.0
71 | del self.first_time[label]
72 | self.call_counts[label] = 0
73 | elif label not in self.totals and not self.skip_first:
74 | self.totals[label] = end - start
75 | self.call_counts[label] = 1
76 | else:
77 | self.totals[label] += end - start
78 | self.call_counts[label] += 1
79 |
80 | if self.call_counts[label] > 0:
81 | # We will reduce the probability of logging a timing
82 | # linearly with the number of time we have seen it.
83 | # It will always be recorded in the totals, though.
84 | if np.random.rand() < 1 / self.call_counts[label]:
85 | self.log_fn(
86 | "timer",
87 | {"step": step, "epoch": epoch, "value": end - start},
88 | {"event": label},
89 | )
90 |
91 | def summary(self) -> None:
92 | """
93 | Return a summary in string-form of all the timings recorded so far
94 | """
95 | if len(self.totals) > 0:
96 | with StringIO() as buffer:
97 | total_avg_time = 0
98 | print("--- Timer summary ------------------------", file=buffer)
99 | print(" Event | Count | Average time | Frac.", file=buffer)
100 | for event_label in sorted(self.totals):
101 | total = self.totals[event_label]
102 | count = self.call_counts[event_label]
103 | if count == 0:
104 | continue
105 | avg_duration = total / count
106 | total_runtime = (
107 | self.last_time[event_label] - self.first_time[event_label]
108 | )
109 | runtime_percentage = 100 * total / total_runtime
110 | total_avg_time += avg_duration if "." not in event_label else 0
111 | print(
112 | f"- {event_label:30s} | {count:6d} | {avg_duration:11.5f}s | {runtime_percentage:5.1f}%",
113 | file=buffer,
114 | )
115 | print("-------------------------------------------", file=buffer)
116 | event_label = "total_averaged_time"
117 | print(
118 | f"- {event_label:30s}| {count:6d} | {total_avg_time:11.5f}s |",
119 | file=buffer,
120 | )
121 | print("-------------------------------------------", file=buffer)
122 | return buffer.getvalue()
123 |
124 | def _cuda_sync(self) -> None:
125 | """Finish all asynchronous GPU computations to get correct timings"""
126 | if self.cuda_available:
127 | torch.cuda.synchronize(device=self.device)
128 |
129 | def _default_log_fn(self, _: Any, values: Dict, tags: Dict) -> None:
130 | label = tags["label"]
131 | epoch = values["epoch"]
132 | duration = values["value"]
133 | print(f"Timer: {label:30s} @ {epoch:4.1f} - {duration:8.5f}s")
134 |
--------------------------------------------------------------------------------