├── .gitignore ├── LICENSE ├── README.md ├── exps ├── cifar100c.py ├── cifar10_1.py ├── cifar10c.py ├── imagenet-c.py ├── officehome.py ├── pacs.py └── yearbook.py ├── figs └── overview of datasets.jpg ├── monitor ├── __init__.py ├── tmux_cluster │ ├── __init__.py │ ├── tmux.py │ └── utils.py └── tools │ ├── __init__.py │ ├── file_io.py │ ├── plot.py │ ├── plot_utils.py │ ├── show_results.py │ └── utils.py ├── notebooks └── example.ipynb ├── parameters.py ├── pretrain ├── ssl_pretrain.py └── third_party │ ├── __init__.py │ ├── augmentations.py │ ├── datasets.py │ └── utils.py ├── run_exp.py ├── run_exps.py ├── run_extract.py └── ttab ├── __init__.py ├── api.py ├── benchmark.py ├── configs ├── __init__.py ├── algorithms.py ├── datasets.py └── utils.py ├── loads ├── .DS_Store ├── __init__.py ├── datasets │ ├── __init__.py │ ├── cifar │ │ ├── __init__.py │ │ ├── data_aug_cifar.py │ │ └── synthetic.py │ ├── dataset_sampling.py │ ├── dataset_shifts.py │ ├── datasets.py │ ├── imagenet │ │ ├── __init__.py │ │ ├── data_aug_imagenet.py │ │ ├── synthetic_224.py │ │ └── synthetic_64.py │ ├── loaders.py │ ├── mnist │ │ ├── __init__.py │ │ └── data_aug_mnist.py │ ├── utils │ │ ├── __init__.py │ │ ├── c_resource │ │ │ ├── frost1.png │ │ │ ├── frost2.png │ │ │ ├── frost3.png │ │ │ ├── frost4.jpg │ │ │ ├── frost5.jpg │ │ │ └── frost6.jpg │ │ ├── lmdb.py │ │ ├── preprocess_toolkit.py │ │ └── serialize.py │ └── yearbook │ │ └── data_aug_yearbook.py ├── define_dataset.py ├── define_model.py └── models │ ├── __init__.py │ ├── cct.py │ ├── resnet.py │ ├── utils │ ├── __init__.py │ ├── embedder.py │ ├── helpers.py │ ├── stochastic_depth.py │ ├── tokenizer.py │ └── transformers.py │ └── wideresnet.py ├── model_adaptation ├── __init__.py ├── base_adaptation.py ├── bn_adapt.py ├── conjugate_pl.py ├── cotta.py ├── eata.py ├── memo.py ├── no_adaptation.py ├── note.py ├── rotta.py ├── sar.py ├── shot.py ├── t3a.py ├── tent.py ├── ttt.py ├── ttt_plus_plus.py └── utils.py ├── model_selection ├── __init__.py ├── base_selection.py ├── group_metrics.py ├── last_iterate.py ├── metrics.py └── oracle_model_selection.py ├── scenarios ├── __init__.py ├── default_scenarios.py └── define_scenario.py └── utils ├── __init__.py ├── auxiliary.py ├── checkpoint.py ├── early_stopping.py ├── file_io.py ├── logging.py ├── mathdict.py ├── stat_tracker.py ├── tensor_buffer.py └── timer.py /.gitignore: -------------------------------------------------------------------------------- 1 | ### vscode ### 2 | .vscode/ 3 | !.vscode/settings.json 4 | *.DS_Store 5 | 6 | ### Python ### 7 | __pycache__/ 8 | 9 | ### generated files ### 10 | ckpt/ 11 | data/ 12 | logs/ 13 | # pretrain/ 14 | runs/ 15 | todo/ 16 | *.out 17 | *.pth 18 | *.sh 19 | *.ipynb 20 | *.txt 21 | 22 | datasets/ 23 | -------------------------------------------------------------------------------- /exps/cifar100c.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | "cifar100c_episodic_oracle_model_selection", 11 | # "cifar100c_online_last_iterate", 12 | ], 13 | base_data_name=[ 14 | "cifar100", 15 | ], 16 | data_names=[ 17 | "cifar100_c_deterministic-gaussian_noise-5", 18 | "cifar100_c_deterministic-shot_noise-5", 19 | "cifar100_c_deterministic-impulse_noise-5", 20 | "cifar100_c_deterministic-defocus_blur-5", 21 | "cifar100_c_deterministic-glass_blur-5", 22 | "cifar100_c_deterministic-motion_blur-5", 23 | "cifar100_c_deterministic-zoom_blur-5", 24 | "cifar100_c_deterministic-snow-5", 25 | "cifar100_c_deterministic-frost-5", 26 | "cifar100_c_deterministic-fog-5", 27 | "cifar100_c_deterministic-brightness-5", 28 | "cifar100_c_deterministic-contrast-5", 29 | "cifar100_c_deterministic-elastic_transform-5", 30 | "cifar100_c_deterministic-pixelate-5", 31 | "cifar100_c_deterministic-jpeg_compression-5", 32 | ], 33 | model_name=[ 34 | "resnet26", 35 | ], 36 | model_adaptation_method=[ 37 | # "no_adaptation", 38 | "tent", 39 | # "bn_adapt", 40 | # "t3a", 41 | # "memo", 42 | # "shot", 43 | # "ttt", 44 | # "note", 45 | # "sar", 46 | # "conjugate_pl", 47 | # "cotta", 48 | # "eata", 49 | ], 50 | model_selection_method=[ 51 | "oracle_model_selection", 52 | # "last_iterate", 53 | ], 54 | offline_pre_adapt=[ 55 | "false", 56 | ], 57 | data_wise=["batch_wise"], 58 | batch_size=[64], 59 | episodic=[ 60 | # "false", 61 | "true", 62 | ], 63 | inter_domain=["HomogeneousNoMixture"], 64 | non_iid_ness=[0.1], 65 | non_iid_pattern=["class_wise_over_domain"], 66 | python_path=["/opt/conda/bin/python"], 67 | data_path=["/run/determined/workdir/data/"], 68 | ckpt_path=[ 69 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar100/rn26_bn.pth", 70 | ], 71 | # oracle_model_selection 72 | lr_grid=[ 73 | [1e-3], 74 | [5e-4], 75 | [1e-4], 76 | ], 77 | n_train_steps=[50], 78 | # last_iterate 79 | # lr=[ 80 | # 5e-3, 81 | # 1e-3, 82 | # 5e-4, 83 | # ], 84 | # n_train_steps=[ 85 | # 1, 86 | # 2, 87 | # 3, 88 | # ], 89 | intra_domain_shuffle=["true"], 90 | record_preadapted_perf=["true"], 91 | device=[ 92 | "cuda:0", 93 | ], 94 | gradient_checkpoint=["false"], 95 | ) 96 | -------------------------------------------------------------------------------- /exps/cifar10_1.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | "cifar10_1_episodic_oracle_model_selection", 11 | # "cifar10_1_online_last_iterate", 12 | ], 13 | base_data_name=[ 14 | "cifar10", 15 | ], 16 | data_names=[ 17 | "cifar10_1", 18 | ], 19 | model_name=[ 20 | "resnet26", 21 | ], 22 | model_adaptation_method=[ 23 | # "no_adaptation", 24 | "tent", 25 | # "bn_adapt", 26 | # "t3a", 27 | # "memo", 28 | # "shot", 29 | # "ttt", 30 | # "note", 31 | # "sar", 32 | # "conjugate_pl", 33 | # "cotta", 34 | # "eata", 35 | ], 36 | model_selection_method=[ 37 | "oracle_model_selection", 38 | # "last_iterate", 39 | ], 40 | offline_pre_adapt=[ 41 | "false", 42 | ], 43 | data_wise=["batch_wise"], 44 | batch_size=[64], 45 | episodic=[ 46 | # "false", 47 | "true", 48 | ], 49 | inter_domain=["HomogeneousNoMixture"], 50 | non_iid_ness=[0.1], 51 | non_iid_pattern=["class_wise_over_domain"], 52 | python_path=["/opt/conda/bin/python"], 53 | data_path=["/run/determined/workdir/data/"], 54 | ckpt_path=[ 55 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth", 56 | ], 57 | # oracle_model_selection 58 | lr_grid=[ 59 | [1e-3], 60 | [5e-4], 61 | [1e-4], 62 | ], 63 | n_train_steps=[50], 64 | # last_iterate 65 | # lr=[ 66 | # 5e-3, 67 | # 1e-3, 68 | # 5e-4, 69 | # ], 70 | # n_train_steps=[ 71 | # 1, 72 | # 2, 73 | # 3, 74 | # ], 75 | intra_domain_shuffle=["true"], 76 | record_preadapted_perf=["true"], 77 | device=[ 78 | "cuda:0", 79 | ], 80 | gradient_checkpoint=["false"], 81 | ) 82 | -------------------------------------------------------------------------------- /exps/cifar10c.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | "cifar10c_episodic_oracle_model_selection", 11 | # "cifar10c_online_last_iterate", 12 | ], 13 | base_data_name=[ 14 | "cifar10", 15 | ], 16 | data_names=[ 17 | "cifar10_c_deterministic-gaussian_noise-5", 18 | "cifar10_c_deterministic-shot_noise-5", 19 | "cifar10_c_deterministic-impulse_noise-5", 20 | "cifar10_c_deterministic-defocus_blur-5", 21 | "cifar10_c_deterministic-glass_blur-5", 22 | "cifar10_c_deterministic-motion_blur-5", 23 | "cifar10_c_deterministic-zoom_blur-5", 24 | "cifar10_c_deterministic-snow-5", 25 | "cifar10_c_deterministic-frost-5", 26 | "cifar10_c_deterministic-fog-5", 27 | "cifar10_c_deterministic-brightness-5", 28 | "cifar10_c_deterministic-contrast-5", 29 | "cifar10_c_deterministic-elastic_transform-5", 30 | "cifar10_c_deterministic-pixelate-5", 31 | "cifar10_c_deterministic-jpeg_compression-5", 32 | ], 33 | model_name=[ 34 | "resnet26", 35 | ], 36 | model_adaptation_method=[ 37 | # "no_adaptation", 38 | "tent", 39 | # "bn_adapt", 40 | # "t3a", 41 | # "memo", 42 | # "shot", 43 | # "ttt", 44 | # "note", 45 | # "sar", 46 | # "conjugate_pl", 47 | # "cotta", 48 | # "eata", 49 | ], 50 | model_selection_method=[ 51 | "oracle_model_selection", 52 | # "last_iterate", 53 | ], 54 | offline_pre_adapt=[ 55 | "false", 56 | ], 57 | data_wise=["batch_wise"], 58 | batch_size=[64], 59 | episodic=[ 60 | # "false", 61 | "true", 62 | ], 63 | inter_domain=["HomogeneousNoMixture"], 64 | non_iid_ness=[0.1], 65 | non_iid_pattern=["class_wise_over_domain"], 66 | python_path=["/home/ttab/anaconda3/envs/test_algo/bin/python"], 67 | data_path=["./datasets/"], 68 | ckpt_path=[ 69 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth", 70 | ], 71 | # oracle_model_selection 72 | lr_grid=[ 73 | [1e-3], 74 | [5e-4], 75 | [1e-4], 76 | ], 77 | n_train_steps=[10], 78 | # last_iterate 79 | # lr=[ 80 | # 5e-3, 81 | # 1e-3, 82 | # 5e-4, 83 | # ], 84 | # n_train_steps=[ 85 | # 1, 86 | # 2, 87 | # 3, 88 | # ], 89 | intra_domain_shuffle=["true"], 90 | record_preadapted_perf=["true"], 91 | device=[ 92 | "cuda:0", 93 | ], 94 | gradient_checkpoint=["false"], 95 | ) 96 | -------------------------------------------------------------------------------- /exps/imagenet-c.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | "imagenet_c_episodic_oracle_model_selection", 11 | # "imagenet_c_online_last_iterate", 12 | ], 13 | base_data_name=[ 14 | "imagenet", 15 | ], 16 | data_names=[ 17 | "imagenet_c_deterministic-gaussian_noise-5", 18 | "imagenet_c_deterministic-shot_noise-5", 19 | "imagenet_c_deterministic-impulse_noise-5", 20 | "imagenet_c_deterministic-defocus_blur-5", 21 | "imagenet_c_deterministic-glass_blur-5", 22 | "imagenet_c_deterministic-motion_blur-5", 23 | "imagenet_c_deterministic-zoom_blur-5", 24 | "imagenet_c_deterministic-snow-5", 25 | "imagenet_c_deterministic-frost-5", 26 | "imagenet_c_deterministic-fog-5", 27 | "imagenet_c_deterministic-brightness-5", 28 | "imagenet_c_deterministic-contrast-5", 29 | "imagenet_c_deterministic-elastic_transform-5", 30 | "imagenet_c_deterministic-pixelate-5", 31 | "imagenet_c_deterministic-jpeg_compression-5", 32 | ], 33 | model_name=[ 34 | "resnet50", 35 | ], 36 | model_adaptation_method=[ 37 | # "no_adaptation", 38 | "tent", 39 | # "bn_adapt", 40 | # "t3a", 41 | # "memo", 42 | # "shot", 43 | # "ttt", 44 | # "note", 45 | # "sar", 46 | # "conjugate_pl", 47 | # "cotta", 48 | # "eata", 49 | ], 50 | model_selection_method=[ 51 | "oracle_model_selection", 52 | # "last_iterate", 53 | ], 54 | offline_pre_adapt=[ 55 | "false", 56 | ], 57 | data_wise=["batch_wise"], 58 | batch_size=[64], 59 | episodic=[ 60 | # "false", 61 | "true", 62 | ], 63 | inter_domain=["HomogeneousNoMixture"], 64 | non_iid_ness=[0.1], 65 | non_iid_pattern=["class_wise_over_domain"], 66 | python_path=["/opt/conda/bin/python"], 67 | data_path=["/run/determined/workdir/data/"], 68 | ckpt_path=[ 69 | "./data/pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth", # Since ttab will automatically download the pretrained model from torchvision or huggingface, what ckpt_path is here does not matter. 70 | ], 71 | # oracle_model_selection 72 | lr_grid=[ 73 | [1e-3], 74 | [5e-4], 75 | [1e-4], 76 | ], 77 | n_train_steps=[10], 78 | # last_iterate 79 | # lr=[ 80 | # 5e-3, 81 | # 1e-3, 82 | # 5e-4, 83 | # ], 84 | # n_train_steps=[ 85 | # 1, 86 | # 2, 87 | # 3, 88 | # ], 89 | intra_domain_shuffle=["true"], 90 | record_preadapted_perf=["true"], 91 | device=[ 92 | "cuda:0", 93 | "cuda:1", 94 | "cuda:2", 95 | "cuda:3", 96 | "cuda:4", 97 | "cuda:5", 98 | "cuda:6", 99 | "cuda:7", 100 | ], 101 | gradient_checkpoint=["false"], 102 | ) 103 | -------------------------------------------------------------------------------- /exps/officehome.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | "officehome_episodic_oracle_model_selection", 11 | # "officehome_online_last_iterate", 12 | ], 13 | base_data_name=[ 14 | "officehome", 15 | ], 16 | data_names=[ 17 | "officehome_clipart", 18 | "officehome_product", 19 | "officehome_realworld", 20 | "officehome_art", 21 | "officehome_product", 22 | "officehome_realworld", 23 | "officehome_art", 24 | "officehome_clipart", 25 | "officehome_realworld", 26 | "officehome_art", 27 | "officehome_clipart", 28 | "officehome_product", 29 | ], 30 | model_name=[ 31 | "resnet50", 32 | ], 33 | model_adaptation_method=[ 34 | # "no_adaptation", 35 | "tent", 36 | # "bn_adapt", 37 | # "t3a", 38 | # "memo", 39 | # "shot", 40 | # "ttt", 41 | # "sar", 42 | # "conjugate_pl", 43 | # "note", 44 | # "cotta", 45 | # "eata", 46 | ], 47 | model_selection_method=[ 48 | "oracle_model_selection", 49 | # "last_iterate", 50 | ], 51 | offline_pre_adapt=[ 52 | "false", 53 | ], 54 | data_wise=["batch_wise"], 55 | batch_size=[64], 56 | episodic=[ 57 | # "false", 58 | "true", 59 | ], 60 | inter_domain=["HomogeneousNoMixture"], 61 | python_path=["/opt/conda/bin/python"], 62 | data_path=["/run/determined/workdir/data/"], 63 | ckpt_path=[ 64 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_art.pth", 65 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_art.pth", 66 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_art.pth", 67 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_clipart.pth", 68 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_clipart.pth", 69 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_clipart.pth", 70 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_product.pth", 71 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_product.pth", 72 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_product.pth", 73 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_realworld.pth", 74 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_realworld.pth", 75 | "./data/pretrained_ckpts/classification/resnet50_with_head/officehome/rn50_bn_realworld.pth", 76 | ], 77 | # oracle_model_selection 78 | lr_grid=[ 79 | [1e-3], 80 | [5e-4], 81 | [1e-4], 82 | ], 83 | n_train_steps=[25], 84 | # last_iterate 85 | # lr=[ 86 | # 5e-3, 87 | # 1e-3, 88 | # 5e-4, 89 | # ], 90 | # n_train_steps=[ 91 | # 1, 92 | # 2, 93 | # 3, 94 | # ], 95 | entry_of_shared_layers=["layer3"], 96 | intra_domain_shuffle=["true"], 97 | record_preadapted_perf=["true"], 98 | device=[ 99 | "cuda:0", 100 | ], 101 | grad_checkpoint=["false"], 102 | coupled=[ 103 | "data_names", 104 | "ckpt_path", 105 | ], 106 | ) 107 | -------------------------------------------------------------------------------- /exps/pacs.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | # "pacs_online_last_iterate", 11 | "pacs_episodic_oracle_model_selection" 12 | ], 13 | base_data_name=[ 14 | "pacs", 15 | ], 16 | data_names=[ 17 | "pacs_cartoon", 18 | "pacs_photo", 19 | "pacs_sketch", 20 | "pacs_art", 21 | "pacs_photo", 22 | "pacs_sketch", 23 | "pacs_art", 24 | "pacs_cartoon", 25 | "pacs_sketch", 26 | "pacs_art", 27 | "pacs_cartoon", 28 | "pacs_photo", 29 | ], 30 | model_name=[ 31 | "resnet50", 32 | ], 33 | model_adaptation_method=[ 34 | # "no_adaptation", 35 | "tent", 36 | # "bn_adapt", 37 | # "t3a", 38 | # "memo", 39 | # "shot", 40 | # "ttt", 41 | # "sar", 42 | # "conjugate_pl", 43 | # "note", 44 | # "cotta", 45 | # "eata", 46 | ], 47 | model_selection_method=[ 48 | "oracle_model_selection", 49 | # "last_iterate", 50 | ], 51 | offline_pre_adapt=[ 52 | "false", 53 | ], 54 | data_wise=["batch_wise"], 55 | batch_size=[64], 56 | episodic=[ 57 | # "false", 58 | "true", 59 | ], 60 | inter_domain=["HomogeneousNoMixture"], 61 | python_path=["/opt/conda/bin/python"], 62 | data_path=["/run/determined/workdir/data/"], 63 | ckpt_path=[ 64 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_art.pth", 65 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_art.pth", 66 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_art.pth", 67 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_cartoon.pth", 68 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_cartoon.pth", 69 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_cartoon.pth", 70 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_photo.pth", 71 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_photo.pth", 72 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_photo.pth", 73 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_sketch.pth", 74 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_sketch.pth", 75 | "./data/pretrained_ckpts/classification/resnet50_with_head/pacs/rn50_bn_sketch.pth", 76 | ], 77 | # oracle_model_selection 78 | lr_grid=[ 79 | [1e-3], 80 | [5e-4], 81 | [1e-4], 82 | ], 83 | n_train_steps=[25], 84 | # last_iterate 85 | # lr=[ 86 | # 5e-3, 87 | # 1e-3, 88 | # 5e-4, 89 | # ], 90 | # n_train_steps=[ 91 | # 1, 92 | # 2, 93 | # 3, 94 | # ], 95 | entry_of_shared_layers=["layer3"], 96 | intra_domain_shuffle=["true"], 97 | record_preadapted_perf=["true"], 98 | device=[ 99 | "cuda:0", 100 | ], 101 | grad_checkpoint=["false"], 102 | coupled=[ 103 | "data_names", 104 | "ckpt_path", 105 | ], 106 | ) 107 | -------------------------------------------------------------------------------- /exps/yearbook.py: -------------------------------------------------------------------------------- 1 | class NewConf(object): 2 | # create the list of hyper-parameters to be replaced. 3 | to_be_replaced = dict( 4 | # general for world. 5 | seed=[2022, 2023, 2024], 6 | main_file=[ 7 | "run_exp.py", 8 | ], 9 | job_name=[ 10 | "yearbook_episodic_oracle_model_selection", 11 | # "yearbook_online_last_iterate", 12 | ], 13 | base_data_name=[ 14 | "yearbook", 15 | ], 16 | data_names=[ 17 | "yearbook", 18 | ], 19 | model_name=[ 20 | "resnet18", 21 | ], 22 | model_adaptation_method=[ 23 | # "no_adaptation", 24 | "tent", 25 | # "bn_adapt", 26 | # "t3a", 27 | # "memo", 28 | # "shot", 29 | # "ttt", 30 | # "sar", 31 | # "conjugate_pl", 32 | # "note", 33 | # "cotta", 34 | # "eata", 35 | ], 36 | model_selection_method=[ 37 | "oracle_model_selection", 38 | # "last_iterate", 39 | ], 40 | offline_pre_adapt=[ 41 | "false", 42 | ], 43 | data_wise=["batch_wise"], 44 | batch_size=[64], 45 | episodic=[ 46 | # "false", 47 | "true", 48 | ], 49 | inter_domain=["HomogeneousNoMixture"], 50 | python_path=["/home/ttab/anaconda3/envs/test_algo/bin/python"], 51 | data_path=["./datasets"], 52 | ckpt_path=[ 53 | "./pretrain/checkpoint/resnet18_with_head/yearbook/resnet18_bn.pth", 54 | ], 55 | # oracle_model_selection 56 | lr_grid=[ 57 | [1e-3], 58 | [5e-4], 59 | [1e-4], 60 | ], 61 | n_train_steps=[50], 62 | # last_iterate 63 | # lr=[ 64 | # 5e-3, 65 | # 1e-3, 66 | # 5e-4, 67 | # ], 68 | # n_train_steps=[ 69 | # 1, 70 | # 2, 71 | # 3, 72 | # ], 73 | entry_of_shared_layers=["layer3"], 74 | intra_domain_shuffle=["false"], 75 | record_preadapted_perf=["true"], 76 | device=[ 77 | "cuda:0", 78 | ], 79 | grad_checkpoint=["false"], 80 | ) 81 | -------------------------------------------------------------------------------- /figs/overview of datasets.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/figs/overview of datasets.jpg -------------------------------------------------------------------------------- /monitor/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /monitor/tmux_cluster/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/monitor/tmux_cluster/__init__.py -------------------------------------------------------------------------------- /monitor/tmux_cluster/tmux.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from monitor.tmux_cluster.utils import ossystem 3 | import shlex 4 | 5 | TASKDIR_PREFIX = "/tmp/tasklogs" 6 | 7 | 8 | def exec_on_node(cmds, host="localhost"): 9 | def _decide_node(cmd): 10 | return cmd if host == "localhost" else f"ssh {host} -t {shlex.quote(cmd)}" 11 | 12 | cmds = ( 13 | [_decide_node(cmd) for cmd in cmds] 14 | if isinstance(cmds, list) 15 | else _decide_node(cmds) 16 | ) 17 | ossystem(cmds) 18 | 19 | 20 | class Run(object): 21 | def __init__(self, name, job_node="localhost"): 22 | self.name = name 23 | self.jobs = [] 24 | self.job_node = job_node 25 | 26 | def make_job(self, job_name, task_scripts, run=True, **kwargs): 27 | num_tasks = len(task_scripts) 28 | assert num_tasks > 0 29 | 30 | if kwargs: 31 | print("Warning: unused kwargs", kwargs) 32 | 33 | # Creating cmds 34 | cmds = [] 35 | session_name = self.name + "-" + job_name # tmux can't use . in name 36 | cmds.append(f"tmux kill-session -t {session_name}") 37 | 38 | windows = [] 39 | for task_id in range(num_tasks): 40 | if task_id == 0: 41 | cmds.append(f"tmux new-session -s {session_name} -n {task_id} -d") 42 | else: 43 | cmds.append(f"tmux new-window -t {session_name} -n {task_id}") 44 | windows.append(f"{session_name}:{task_id}") 45 | 46 | job = Job(self, job_name, windows, task_scripts, self.job_node) 47 | job.make_tasks() 48 | self.jobs.append(job) 49 | if run: 50 | for job in self.jobs: 51 | cmds += job.cmds 52 | exec_on_node(cmds, self.job_node) 53 | return job 54 | 55 | def attach_job(self): 56 | raise NotImplementedError 57 | 58 | def kill_jobs(self): 59 | cmds = [] 60 | for job in self.jobs: 61 | session_name = self.name + "-" + job.name 62 | cmds.append(f"tmux kill-session -t {session_name}") 63 | exec_on_node(cmds, self.job_node) 64 | 65 | 66 | class Job(object): 67 | def __init__(self, run, name, windows, task_scripts, job_node): 68 | self._run = run 69 | self.name = name 70 | self.job_node = job_node 71 | self.windows = windows 72 | self.task_scripts = task_scripts 73 | self.tasks = [] 74 | 75 | def make_tasks(self): 76 | for task_id, (window, script) in enumerate( 77 | zip(self.windows, self.task_scripts) 78 | ): 79 | self.tasks.append( 80 | Task( 81 | window, 82 | self, 83 | task_id, 84 | install_script=script, 85 | task_node=self.job_node, 86 | ) 87 | ) 88 | 89 | def attach_tasks(self): 90 | raise NotImplementedError 91 | 92 | @property 93 | def cmds(self): 94 | output = [] 95 | for task in self.tasks: 96 | output += task.cmds 97 | return output 98 | 99 | 100 | class Task(object): 101 | """Local tasks interact with tmux session. 102 | 103 | * session name is derived from job name, and window names are task ids. 104 | * no pane is used. 105 | 106 | """ 107 | 108 | def __init__(self, window, job, task_id, install_script, task_node): 109 | self.window = window 110 | self.job = job 111 | self.id = task_id 112 | self.install_script = install_script 113 | self.task_node = task_node 114 | 115 | # Path 116 | self.cmds = [] 117 | self._run_counter = 0 118 | 119 | for line in install_script.split("\n"): 120 | self.run(line) 121 | 122 | def run(self, cmd): 123 | self._run_counter += 1 124 | 125 | cmd = cmd.strip() 126 | if not cmd or cmd.startswith("#"): 127 | # ignore empty command lines 128 | # ignore commented out lines 129 | return 130 | 131 | modified_cmd = cmd 132 | self.cmds.append( 133 | f"tmux send-keys -t {self.window} {shlex.quote(modified_cmd)} Enter" 134 | ) 135 | 136 | def upload(self, source_fn, target_fn="."): 137 | raise NotImplementedError() 138 | 139 | def download(self, source_fn, target_fn="."): 140 | raise NotImplementedError() 141 | -------------------------------------------------------------------------------- /monitor/tmux_cluster/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import yaml 3 | import os 4 | import time 5 | from tqdm import tqdm 6 | 7 | 8 | def ossystem(cmds): 9 | if isinstance(cmds, str): 10 | print(f"\n=> {cmds}") 11 | os.system(cmds) 12 | elif isinstance(cmds, list): 13 | for cmd in tqdm(cmds): 14 | ossystem(cmd) 15 | else: 16 | raise NotImplementedError( 17 | "Cmds should be string or list of str. Got {}.".format(cmds) 18 | ) 19 | 20 | 21 | def environ(env): 22 | return os.getenv(env) 23 | 24 | 25 | def load_yaml(file): 26 | with open(file) as f: 27 | return yaml.safe_load(f) 28 | 29 | 30 | def wait_for_file(fn, max_wait_sec=600, check_interval=0.02): 31 | start_time = time.time() 32 | while True: 33 | if time.time() - start_time > max_wait_sec: 34 | assert False, "Timeout %s exceeded" % (max_wait_sec) 35 | if not os.path.exists(fn): 36 | time.sleep(check_interval) 37 | continue 38 | else: 39 | break 40 | 41 | 42 | if __name__ == "__main__": 43 | ossystem("ls") 44 | -------------------------------------------------------------------------------- /monitor/tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/monitor/tools/__init__.py -------------------------------------------------------------------------------- /monitor/tools/file_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import shutil 4 | import json 5 | import pickle 6 | 7 | 8 | def read_text_withoutsplit(path): 9 | """read text file from path.""" 10 | with open(path, "r") as f: 11 | return f.read() 12 | 13 | 14 | def read_txt(path): 15 | """read text file from path.""" 16 | with open(path, "r") as f: 17 | return f.read().splitlines() 18 | 19 | 20 | def write_txt(data, out_path, type="w"): 21 | """write the data to the txt file.""" 22 | with open(out_path, type) as f: 23 | f.write(data) 24 | 25 | 26 | def load_pickle(path): 27 | """load data by pickle.""" 28 | with open(path, "rb") as handle: 29 | return pickle.load(handle) 30 | 31 | 32 | def write_pickle(data, path): 33 | """dump file to dir.""" 34 | print("write --> data to path: {}\n".format(path)) 35 | with open(path, "wb") as handle: 36 | pickle.dump(data, handle) 37 | 38 | 39 | """json related.""" 40 | 41 | 42 | def read_json(path): 43 | """read json file from path.""" 44 | with open(path, "r") as f: 45 | return json.load(f) 46 | 47 | 48 | def is_jsonable(x): 49 | try: 50 | json.dumps(x) 51 | return True 52 | except: 53 | return False 54 | 55 | 56 | """operate dir.""" 57 | 58 | 59 | def build_dir(path, force): 60 | """build directory.""" 61 | if os.path.exists(path) and force: 62 | shutil.rmtree(path) 63 | os.mkdir(path) 64 | elif not os.path.exists(path): 65 | os.mkdir(path) 66 | return path 67 | 68 | 69 | def build_dirs(path): 70 | try: 71 | os.makedirs(path) 72 | except Exception as e: 73 | print(" encounter error: {}".format(e)) 74 | 75 | 76 | def remove_folder(path): 77 | try: 78 | shutil.rmtree(path) 79 | except Exception as e: 80 | print(" encounter error: {}".format(e)) 81 | 82 | 83 | def list_files(root_path): 84 | dirs = os.listdir(root_path) 85 | return [os.path.join(root_path, path) for path in dirs] 86 | -------------------------------------------------------------------------------- /monitor/tools/plot.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | 4 | from monitor.tools.show_results import reorder_records 5 | from monitor.tools.plot_utils import ( 6 | determine_color_and_lines, 7 | plot_one_case, 8 | smoothing_func, 9 | configure_figure, 10 | build_legend, 11 | ) 12 | 13 | 14 | """plot the curve in terms of time.""" 15 | 16 | 17 | def plot_curve_wrt_time( 18 | ax, 19 | records, 20 | x_wrt_sth, 21 | y_wrt_sth, 22 | xlabel, 23 | ylabel, 24 | title=None, 25 | markevery_list=None, 26 | is_smooth=True, 27 | smooth_space=100, 28 | l_subset=0.0, 29 | r_subset=1.0, 30 | reorder_record_item=None, 31 | remove_duplicate=True, 32 | n_by_line=1, 33 | n_by_color=1, 34 | font_shift=0, 35 | has_legend=True, 36 | legend=None, 37 | legend_loc="lower right", 38 | legend_ncol=2, 39 | bbox_to_anchor=[0, 0], 40 | ylimit_bottom=None, 41 | ylimit_top=None, 42 | use_log=False, 43 | ): 44 | """Each info consists of 45 | ['tr_loss', 'tr_top1', 'tr_time', 'te_top1', 'te_step', 'te_time']. 46 | """ 47 | # parse a list of records. 48 | distinct_conf_set = set() 49 | 50 | # re-order the records. 51 | if reorder_record_item is not None: 52 | records = reorder_records(records, reorder_on=reorder_record_item) 53 | 54 | count = 0 55 | for ind, (args, info) in enumerate(records): 56 | y_wrt_sth_list = y_wrt_sth.split(",") 57 | # check. 58 | if len(y_wrt_sth_list) > 1: 59 | assert y_wrt_sth_list[-1].endswith("preadapt_accuracy"), "the second item of y_wrt_sth must end with preadapt_accuracy." 60 | assert len(y_wrt_sth_list) == 2, "only support two arguments in y_wrt_sth." 61 | assert len(records) <= 4, "only support # records <= 4 when having two arguments in y_wrt_sth" # otherwise, the color and line_style will be a mess. 62 | 63 | for i in range(len(y_wrt_sth_list)): 64 | y_wrt_sth_i = y_wrt_sth_list[i] 65 | # build legend. 66 | _legend = build_legend(args, legend) 67 | if len(y_wrt_sth_list)>1 and i==0: 68 | _legend = ", ".join([_legend, "after_adapt"]) 69 | elif i==1: 70 | _legend = ", ".join([_legend, "before_adapt"]) 71 | 72 | if _legend in distinct_conf_set and remove_duplicate: 73 | continue 74 | else: 75 | distinct_conf_set.add(_legend) 76 | 77 | # determine the style of line, color and marker. 78 | line_style, color_style, mark_style = determine_color_and_lines( 79 | n_by_line, n_by_color, index=count 80 | ) 81 | 82 | if len(y_wrt_sth_list)>1 and i==0: 83 | line_style = "-" 84 | elif i==1: 85 | line_style = "--" 86 | 87 | if markevery_list is not None: 88 | mark_every = markevery_list[ind] 89 | else: 90 | mark_style = None 91 | mark_every = None 92 | 93 | # determine if we want to smooth the curve. 94 | if "train-step" in x_wrt_sth or "train-epoch" in x_wrt_sth: 95 | info["train-step"] = list(range(1, 1 + len(info["train-loss"]))) 96 | if "train-epoch" == x_wrt_sth: 97 | x = info["train-step"] 98 | x = [1.0 * _x / args["num_batches_train_per_device_per_epoch"] for _x in x] 99 | else: 100 | x = info[x_wrt_sth] 101 | if "time" in x_wrt_sth: 102 | x = [(time - x[0]).seconds + 1 for time in x] 103 | y = info[y_wrt_sth_i] 104 | 105 | if is_smooth: 106 | x, y = smoothing_func(x, y, smooth_space) 107 | 108 | # only plot subtset. 109 | _l_subset, _r_subset = int(len(x) * l_subset), int(len(x) * r_subset) 110 | _x = x[_l_subset:_r_subset] 111 | _y = y[_l_subset:_r_subset] 112 | 113 | # use log scale for y 114 | if use_log: 115 | _y = np.log10(_y) 116 | 117 | # plot 118 | ax = plot_one_case( 119 | ax, 120 | x=_x, 121 | y=_y, 122 | label=_legend, 123 | line_style=line_style, 124 | color_style=color_style, 125 | mark_style=mark_style, 126 | mark_every=mark_every, 127 | remove_duplicate=remove_duplicate, 128 | ) 129 | count += 1 130 | 131 | ax.set_ylim(bottom=ylimit_bottom, top=ylimit_top) 132 | ax = configure_figure( 133 | ax, 134 | xlabel=xlabel, 135 | ylabel=ylabel, 136 | title=title, 137 | has_legend=has_legend, 138 | legend_loc=legend_loc, 139 | legend_ncol=legend_ncol, 140 | bbox_to_anchor=bbox_to_anchor, 141 | font_shift=font_shift, 142 | ) 143 | return ax 144 | -------------------------------------------------------------------------------- /monitor/tools/plot_utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from itertools import groupby 3 | 4 | import numpy as np 5 | from matplotlib.lines import Line2D 6 | import seaborn as sns 7 | 8 | """operate x and y.""" 9 | 10 | 11 | def smoothing_func(x, y, smooth_length=10): 12 | def smoothing(end_index): 13 | # print(end_index) 14 | if end_index - smooth_length < 0: 15 | start_index = 0 16 | else: 17 | start_index = end_index - smooth_length 18 | 19 | data = y[start_index:end_index] 20 | if len(data) == 0: 21 | return y[start_index] 22 | else: 23 | return 1.0 * sum(data) / len(data) 24 | 25 | # smooth curve 26 | x_, y_ = [], [] 27 | 28 | for end_ind in range(0, len(x)): 29 | x_.append(x[end_ind]) 30 | y_.append(smoothing(end_ind)) 31 | return x_, y_ 32 | 33 | 34 | def reject_outliers(data, threshold=3): 35 | return data[abs(data - np.mean(data)) < threshold * np.std(data)] 36 | 37 | 38 | def groupby_indices(results, grouper): 39 | """group by indices and select the subset parameters""" 40 | out = [] 41 | for key, group in groupby(sorted(results, key=grouper), grouper): 42 | group_item = list(group) 43 | out += [(key, group_item)] 44 | return out 45 | 46 | 47 | def find_same_num_sync(num_update_steps_and_local_step): 48 | list_of_num_sync = [ 49 | num_update_steps // local_step 50 | for num_update_steps, local_step in num_update_steps_and_local_step 51 | ] 52 | return min(list_of_num_sync) 53 | 54 | 55 | def sample_from_records(x, y, local_step, max_same_num_sync): 56 | # cut the records. 57 | if max_same_num_sync is not None: 58 | x = x[: local_step * max_same_num_sync] 59 | y = y[: local_step * max_same_num_sync] 60 | return x[::local_step], y[::local_step] 61 | 62 | 63 | def drop_first_few(x, y, num_drop): 64 | return x[num_drop:], y[num_drop:] 65 | 66 | 67 | def rebuild_runtime_record(times): 68 | times = [(time - times[0]).seconds + 1 for time in times] 69 | return times 70 | 71 | 72 | def add_communication_delay(times, local_step, delay_factor): 73 | """add communication delay to original time.""" 74 | return [ 75 | time + delay_factor * ((ind + 1) // local_step) 76 | for ind, time in enumerate(times) 77 | ] 78 | 79 | 80 | """plot style related.""" 81 | 82 | 83 | def determine_color_and_lines(n_by_line, n_by_color, index): 84 | line_styles = ["-", "--", "-.", ":"] 85 | color_styles = ["b", "g", "r", "k"] 86 | 87 | # safety check 88 | total_num_combs = len(line_styles) * len(color_styles) 89 | assert index + 1 <= total_num_combs 90 | 91 | if n_by_line >= n_by_color: 92 | line = index // n_by_line 93 | color = (index % n_by_line) // n_by_color 94 | else: 95 | color = index // n_by_color 96 | line = (index % n_by_color) // n_by_line 97 | return line_styles[line], color_styles[color], Line2D.filled_markers[index] 98 | 99 | 100 | def configure_figure( 101 | ax, 102 | xlabel, 103 | ylabel, 104 | title=None, 105 | has_legend=True, 106 | legend_loc="lower right", 107 | legend_ncol=2, 108 | bbox_to_anchor=[0, 0], 109 | font_shift=0, 110 | ): 111 | if has_legend: 112 | ax.legend( 113 | loc=legend_loc, 114 | bbox_to_anchor=bbox_to_anchor, 115 | ncol=legend_ncol, 116 | shadow=True, 117 | fancybox=True, 118 | fontsize=20 + font_shift, 119 | ) 120 | 121 | ax.set_xlabel(xlabel, fontsize=24 + font_shift, labelpad=18 + font_shift) 122 | ax.set_ylabel(ylabel, fontsize=24 + font_shift, labelpad=18 + font_shift) 123 | 124 | if title is not None: 125 | ax.set_title(title, fontsize=24 + font_shift) 126 | ax.xaxis.set_tick_params(labelsize=22 + font_shift) 127 | ax.yaxis.set_tick_params(labelsize=22 + font_shift) 128 | return ax 129 | 130 | 131 | def plot_one_case( 132 | ax, 133 | label, 134 | line_style, 135 | color_style, 136 | mark_style, 137 | line_width=2.0, 138 | mark_every=5000, 139 | x=None, 140 | y=None, 141 | sns_plot=None, 142 | remove_duplicate=False, 143 | ): 144 | if sns_plot is not None and not remove_duplicate: 145 | ax = sns.lineplot( 146 | x="x", 147 | y="y", 148 | data=sns_plot, 149 | label=label, 150 | linewidth=line_width, 151 | linestyle=line_style, 152 | color=color_style, 153 | marker=mark_style, 154 | markevery=mark_every, 155 | markersize=16, 156 | ax=ax, 157 | ) 158 | elif sns_plot is not None and remove_duplicate: 159 | ax = sns.lineplot( 160 | x="x", 161 | y="y", 162 | data=sns_plot, 163 | label=label, 164 | linewidth=line_width, 165 | linestyle=line_style, 166 | color=color_style, 167 | marker=mark_style, 168 | markevery=mark_every, 169 | markersize=16, 170 | ax=ax, 171 | estimator=None, 172 | ) 173 | else: 174 | ax.plot( 175 | x, 176 | y, 177 | label=label, 178 | linewidth=line_width, 179 | linestyle=line_style, 180 | color=color_style, 181 | marker=mark_style, 182 | markevery=mark_every, 183 | markersize=16, 184 | ) 185 | return ax 186 | 187 | 188 | def build_legend(args, legend): 189 | legend = legend.split(",") 190 | 191 | my_legend = [] 192 | for _legend in legend: 193 | _legend_content = args[_legend] if _legend in args else -1 194 | my_legend += [ 195 | "{}={}".format( 196 | _legend, 197 | list(_legend_content)[0] 198 | if "pandas" in str(type(_legend_content)) 199 | else _legend_content, 200 | ) 201 | ] 202 | return ", ".join(my_legend) 203 | -------------------------------------------------------------------------------- /monitor/tools/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from datetime import datetime 3 | 4 | 5 | def str2time(string, pattern): 6 | """convert the string to the datetime.""" 7 | return datetime.strptime(string, pattern) 8 | 9 | 10 | def str2bool(v): 11 | if v.lower() in ("yes", "true", "t", "y", "1"): 12 | return True 13 | elif v.lower() in ("no", "false", "f", "n", "0"): 14 | return False 15 | else: 16 | raise ValueError("Boolean value expected.") 17 | 18 | 19 | def is_float(value): 20 | try: 21 | float(value) 22 | return True 23 | except: 24 | return False 25 | 26 | 27 | def dict_parser(values): 28 | local_dict = {} 29 | if values is None: 30 | return local_dict 31 | for kv in values.split(",,"): 32 | k, v = kv.split("=") 33 | try: 34 | local_dict[k] = float(v) 35 | except ValueError: 36 | try: 37 | local_dict[k] = str2bool(v) 38 | except ValueError: 39 | local_dict[k] = v 40 | return local_dict 41 | -------------------------------------------------------------------------------- /notebooks/example.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "metadata": {}, 7 | "outputs": [], 8 | "source": [ 9 | "# Useful starting lines \n", 10 | "%matplotlib inline\n", 11 | "import numpy as np\n", 12 | "import matplotlib.pyplot as plt\n", 13 | "from matplotlib.lines import Line2D\n", 14 | "%load_ext autoreload\n", 15 | "%autoreload 2" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": null, 21 | "metadata": {}, 22 | "outputs": [], 23 | "source": [ 24 | "import os \n", 25 | "import sys\n", 26 | "import time\n", 27 | "import copy\n", 28 | "from copy import deepcopy\n", 29 | "import pickle\n", 30 | "import math\n", 31 | "import functools \n", 32 | "from IPython.display import display, HTML\n", 33 | "import operator\n", 34 | "from operator import itemgetter\n", 35 | "\n", 36 | "import pandas as pd\n", 37 | "import seaborn as sns\n", 38 | "from matplotlib.lines import Line2D\n", 39 | "\n", 40 | "sns.set(style=\"darkgrid\")\n", 41 | "sns.set_context(\"paper\")" 42 | ] 43 | }, 44 | { 45 | "cell_type": "code", 46 | "execution_count": null, 47 | "metadata": {}, 48 | "outputs": [], 49 | "source": [ 50 | "root_path = '/home/ttab/sp-ttabed/codes/code'\n", 51 | "sys.path.append(root_path)" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "metadata": {}, 58 | "outputs": [], 59 | "source": [ 60 | "from monitor.tools.show_results import extract_list_of_records, reorder_records, get_pickle_info, summarize_info\n", 61 | "from monitor.tools.plot import plot_curve_wrt_time\n", 62 | "import monitor.tools.plot_utils as plot_utils\n", 63 | "\n", 64 | "from monitor.tools.utils import dict_parser\n", 65 | "from monitor.tools.file_io import load_pickle" 66 | ] 67 | }, 68 | { 69 | "cell_type": "code", 70 | "execution_count": null, 71 | "metadata": {}, 72 | "outputs": [], 73 | "source": [ 74 | "root_data_path = os.path.join(root_path, 'data', 'logs', 'resnet26')\n", 75 | "experiments = ['cifar10_label_shift_episodic_oracle_model_selection']\n", 76 | "raw_records = get_pickle_info(root_data_path, experiments)" 77 | ] 78 | }, 79 | { 80 | "cell_type": "code", 81 | "execution_count": null, 82 | "metadata": {}, 83 | "outputs": [], 84 | "source": [ 85 | "# Have a glimpse of experimental results.\n", 86 | "conditions = {\n", 87 | " \"model_adaptation_method\": [\"tent\"],\n", 88 | " \"seed\": [2022],\n", 89 | " \"batch_size\": [64],\n", 90 | " \"episodic\": [False],\n", 91 | " \"n_train_steps\": [1],\n", 92 | " # \"lr\": [0.005],\n", 93 | " # \"data_names\": [\"cifar10_c_deterministic-gaussian_noise-5\"],\n", 94 | "}\n", 95 | "attributes = ['model_adaptation_method', 'step_ratio', 'label_shift_param', 'ckpt_path', 'episodic', 'model_selection_method', 'seed', 'data_names', 'status']\n", 96 | "records = extract_list_of_records(list_of_records=raw_records, conditions=conditions)\n", 97 | "aggregated_results, averaged_records_overall = summarize_info(records, attributes, reorder_on='model_adaptation_method', groupby_on='test-overall-accuracy', larger_is_better=True)\n", 98 | "display(HTML(averaged_records_overall.to_html()))\n", 99 | "\n", 100 | "# display test accuracy per test step.\n", 101 | "aggregated_results, averaged_records_step_nonepisodic_optimal = summarize_info(records, attributes, reorder_on='model_adaptation_method', groupby_on='test-step-accuracy', larger_is_better=True)\n", 102 | "\n", 103 | "fig = plt.figure(num=1, figsize=(18, 9))\n", 104 | "ax1 = fig.add_subplot(111)\n", 105 | "plot_curve_wrt_time(\n", 106 | " ax1, records,\n", 107 | " x_wrt_sth='test-step-step', y_wrt_sth='test-step-accuracy', is_smooth=True,\n", 108 | " xlabel='batch index', ylabel='Test accuracy', l_subset=0.0, r_subset=1, markevery_list=None,\n", 109 | " n_by_line=4, has_legend=True, legend='model_selection_method,step_ratio', legend_loc='lower right', legend_ncol=1, bbox_to_anchor=[1, 0],\n", 110 | " ylimit_bottom=0, ylimit_top=100, use_log=False)\n", 111 | "fig.tight_layout()\n", 112 | "plt.show()" 113 | ] 114 | } 115 | ], 116 | "metadata": { 117 | "kernelspec": { 118 | "display_name": "test_algo", 119 | "language": "python", 120 | "name": "python3" 121 | }, 122 | "language_info": { 123 | "codemirror_mode": { 124 | "name": "ipython", 125 | "version": 3 126 | }, 127 | "file_extension": ".py", 128 | "mimetype": "text/x-python", 129 | "name": "python", 130 | "nbconvert_exporter": "python", 131 | "pygments_lexer": "ipython3", 132 | "version": "3.7.11" 133 | }, 134 | "orig_nbformat": 4, 135 | "vscode": { 136 | "interpreter": { 137 | "hash": "dc7a203c487a4c1b41bd3d170020b3757b8af76b16b2c4bd8127396815ac049f" 138 | } 139 | } 140 | }, 141 | "nbformat": 4, 142 | "nbformat_minor": 2 143 | } 144 | -------------------------------------------------------------------------------- /parameters.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import argparse 3 | 4 | 5 | def get_args(): 6 | parser = argparse.ArgumentParser() 7 | # define meta info (this part can be ignored if `monitor` package is unused). 8 | parser.add_argument("--job_name", default="tta", type=str) 9 | parser.add_argument("--job_id", default=None, type=str) 10 | parser.add_argument("--timestamp", default=None, type=str) 11 | parser.add_argument("--python_path", default="python", type=str) 12 | parser.add_argument("--main_file", default="run_exp.py", type=str) 13 | 14 | parser.add_argument("--script_path", default=None, type=str) 15 | parser.add_argument("--script_class_name", default=None, type=str) 16 | parser.add_argument("--num_jobs_per_node", default=2, type=int) 17 | parser.add_argument("--num_jobs_per_script", default=1, type=int) 18 | parser.add_argument("--wait_in_seconds_per_job", default=15, type=float) 19 | 20 | # define test evaluation info. 21 | parser.add_argument("--root_path", default="./logs", type=str) 22 | parser.add_argument("--data_path", default="./datasets", type=str) 23 | parser.add_argument( 24 | "--ckpt_path", 25 | default="./pretrained_ckpts/classification/resnet26_with_head/cifar10/rn26_bn.pth", 26 | type=str, 27 | ) 28 | parser.add_argument("--seed", default=2022, type=int) 29 | parser.add_argument("--device", default="cuda:0", type=str) 30 | parser.add_argument("--num_cpus", default=2, type=int) 31 | 32 | # define the task & model & adaptation & selection method. 33 | parser.add_argument("--model_name", default="resnet26", type=str) 34 | parser.add_argument("--group_norm_num_groups", default=None, type=int) 35 | parser.add_argument( 36 | "--model_adaptation_method", 37 | default="tent", 38 | choices=[ 39 | "no_adaptation", 40 | "tent", 41 | "bn_adapt", 42 | "memo", 43 | "shot", 44 | "t3a", 45 | "ttt", 46 | "note", 47 | "sar", 48 | "conjugate_pl", 49 | "cotta", 50 | "eata", 51 | ], 52 | type=str, 53 | ) 54 | parser.add_argument( 55 | "--model_selection_method", 56 | default="last_iterate", 57 | choices=["last_iterate", "oracle_model_selection"], 58 | type=str, 59 | ) 60 | parser.add_argument("--task", default="classification", type=str) 61 | 62 | # define the test scenario. 63 | parser.add_argument("--test_scenario", default=None, type=str) 64 | parser.add_argument( 65 | "--base_data_name", 66 | default="cifar10", 67 | choices=[ 68 | "cifar10", 69 | "cifar100", 70 | "imagenet", 71 | "officehome", 72 | "pacs", 73 | "coloredmnist", 74 | "waterbirds", 75 | "yearbook", 76 | ], 77 | type=str, 78 | ) 79 | parser.add_argument("--src_data_name", default="cifar10", type=str) 80 | parser.add_argument( 81 | "--data_names", default="cifar10_c_deterministic-gaussian_noise-5", type=str 82 | ) 83 | parser.add_argument( 84 | "--data_wise", 85 | default="batch_wise", 86 | choices=["batch_wise", "sample_wise"], 87 | type=str, 88 | ) 89 | parser.add_argument("--batch_size", default=64, type=int) 90 | parser.add_argument("--lr", default=1e-3, type=float) 91 | parser.add_argument("--n_train_steps", default=1, type=int) 92 | parser.add_argument("--offline_pre_adapt", default=False, type=str2bool) 93 | parser.add_argument("--episodic", default=False, type=str2bool) 94 | parser.add_argument("--intra_domain_shuffle", default=True, type=str2bool) 95 | parser.add_argument( 96 | "--inter_domain", 97 | default="HomogeneousNoMixture", 98 | choices=[ 99 | "HomogeneousNoMixture", 100 | "HeterogeneousNoMixture", 101 | "InOutMixture", 102 | "CrossMixture", 103 | ], 104 | type=str, 105 | ) 106 | # Test domain 107 | parser.add_argument("--domain_sampling_name", default="uniform", type=str) 108 | parser.add_argument("--domain_sampling_ratio", default=1.0, type=float) 109 | # HeterogeneousNoMixture 110 | parser.add_argument("--non_iid_pattern", default="class_wise_over_domain", type=str) 111 | parser.add_argument("--non_iid_ness", default=0.1, type=float) 112 | # for evaluation. 113 | # label shift 114 | parser.add_argument( 115 | "--label_shift_param", 116 | help="parameter to control the severity of label shift", 117 | default=None, 118 | type=float, 119 | ) 120 | parser.add_argument( 121 | "--data_size", 122 | help="parameter to control the size of dataset", 123 | default=None, 124 | type=int, 125 | ) 126 | # optimal model selection 127 | parser.add_argument( 128 | "--step_ratios", 129 | nargs="+", 130 | default=[0.1, 0.3, 0.5, 0.75], 131 | help="ratios used to control adaptation step length", 132 | type=float, 133 | ) 134 | parser.add_argument("--step_ratio", default=None, type=float) 135 | # time-varying 136 | parser.add_argument("--stochastic_restore_model", default=False, type=str2bool) 137 | parser.add_argument("--restore_prob", default=0.01, type=float) 138 | parser.add_argument("--fishers", default=False, type=str2bool) 139 | parser.add_argument( 140 | "--fisher_size", 141 | default=5000, 142 | type=int, 143 | help="number of samples to compute fisher information matrix.", 144 | ) 145 | parser.add_argument( 146 | "--fisher_alpha", 147 | type=float, 148 | default=1.5, 149 | help="the trade-off between entropy and regularization loss", 150 | ) 151 | # method-wise hparams 152 | parser.add_argument( 153 | "--aug_size", 154 | default=32, 155 | help="number of per-image augmentation operations in memo and ttt", 156 | type=int, 157 | ) 158 | parser.add_argument( 159 | "--entry_of_shared_layers", 160 | default=None, 161 | help="the split position of auxiliary head. Only used in TTT.", 162 | ) 163 | # metrics 164 | parser.add_argument( 165 | "--record_preadapted_perf", 166 | default=False, 167 | help="record performance on the local batch prior to implementing test-time adaptation.", 168 | type=str2bool, 169 | ) 170 | # misc 171 | parser.add_argument( 172 | "--grad_checkpoint", 173 | default=False, 174 | help="Trade computation for gpu space.", 175 | type=str2bool, 176 | ) 177 | parser.add_argument("--debug", default=False, help="Display logs.", type=str2bool) 178 | 179 | # parse conf. 180 | conf = parser.parse_args() 181 | return conf 182 | 183 | 184 | def str2bool(v): 185 | if v.lower() in ("yes", "true", "t", "y", "1"): 186 | return True 187 | elif v.lower() in ("no", "false", "f", "n", "0"): 188 | return False 189 | else: 190 | raise ValueError("Boolean value expected.") 191 | 192 | 193 | if __name__ == "__main__": 194 | args = get_args() 195 | -------------------------------------------------------------------------------- /pretrain/third_party/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/pretrain/third_party/__init__.py -------------------------------------------------------------------------------- /pretrain/third_party/augmentations.py: -------------------------------------------------------------------------------- 1 | # https://github.com/google-research/augmix/blob/master/augmentations.py 2 | """Base augmentations operators.""" 3 | 4 | import numpy as np 5 | from PIL import Image, ImageOps, ImageEnhance 6 | 7 | # ImageNet code should change this value 8 | IMAGE_SIZE = 32 9 | 10 | 11 | def int_parameter(level, maxval): 12 | """Helper function to scale `val` between 0 and maxval . 13 | 14 | Args: 15 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 16 | maxval: Maximum value that the operation can have. This will be scaled to 17 | level/PARAMETER_MAX. 18 | 19 | Returns: 20 | An int that results from scaling `maxval` according to `level`. 21 | """ 22 | return int(level * maxval / 10) 23 | 24 | 25 | def float_parameter(level, maxval): 26 | """Helper function to scale `val` between 0 and maxval. 27 | 28 | Args: 29 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 30 | maxval: Maximum value that the operation can have. This will be scaled to 31 | level/PARAMETER_MAX. 32 | 33 | Returns: 34 | A float that results from scaling `maxval` according to `level`. 35 | """ 36 | return float(level) * maxval / 10.0 37 | 38 | 39 | def sample_level(n): 40 | return np.random.uniform(low=0.1, high=n) 41 | 42 | 43 | def autocontrast(pil_img, _): 44 | return ImageOps.autocontrast(pil_img) 45 | 46 | 47 | def equalize(pil_img, _): 48 | return ImageOps.equalize(pil_img) 49 | 50 | 51 | def posterize(pil_img, level): 52 | level = int_parameter(sample_level(level), 4) 53 | return ImageOps.posterize(pil_img, 4 - level) 54 | 55 | 56 | def rotate(pil_img, level): 57 | degrees = int_parameter(sample_level(level), 30) 58 | if np.random.uniform() > 0.5: 59 | degrees = -degrees 60 | return pil_img.rotate(degrees, resample=Image.BILINEAR) 61 | 62 | 63 | def solarize(pil_img, level): 64 | level = int_parameter(sample_level(level), 256) 65 | return ImageOps.solarize(pil_img, 256 - level) 66 | 67 | 68 | def shear_x(pil_img, level): 69 | level = float_parameter(sample_level(level), 0.3) 70 | if np.random.uniform() > 0.5: 71 | level = -level 72 | return pil_img.transform( 73 | (IMAGE_SIZE, IMAGE_SIZE), 74 | Image.AFFINE, 75 | (1, level, 0, 0, 1, 0), 76 | resample=Image.BILINEAR, 77 | ) 78 | 79 | 80 | def shear_y(pil_img, level): 81 | level = float_parameter(sample_level(level), 0.3) 82 | if np.random.uniform() > 0.5: 83 | level = -level 84 | return pil_img.transform( 85 | (IMAGE_SIZE, IMAGE_SIZE), 86 | Image.AFFINE, 87 | (1, 0, 0, level, 1, 0), 88 | resample=Image.BILINEAR, 89 | ) 90 | 91 | 92 | def translate_x(pil_img, level): 93 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 94 | if np.random.random() > 0.5: 95 | level = -level 96 | return pil_img.transform( 97 | (IMAGE_SIZE, IMAGE_SIZE), 98 | Image.AFFINE, 99 | (1, 0, level, 0, 1, 0), 100 | resample=Image.BILINEAR, 101 | ) 102 | 103 | 104 | def translate_y(pil_img, level): 105 | level = int_parameter(sample_level(level), IMAGE_SIZE / 3) 106 | if np.random.random() > 0.5: 107 | level = -level 108 | return pil_img.transform( 109 | (IMAGE_SIZE, IMAGE_SIZE), 110 | Image.AFFINE, 111 | (1, 0, 0, 0, 1, level), 112 | resample=Image.BILINEAR, 113 | ) 114 | 115 | 116 | # operation that overlaps with ImageNet-C's test set 117 | def color(pil_img, level): 118 | level = float_parameter(sample_level(level), 1.8) + 0.1 119 | return ImageEnhance.Color(pil_img).enhance(level) 120 | 121 | 122 | # operation that overlaps with ImageNet-C's test set 123 | def contrast(pil_img, level): 124 | level = float_parameter(sample_level(level), 1.8) + 0.1 125 | return ImageEnhance.Contrast(pil_img).enhance(level) 126 | 127 | 128 | # operation that overlaps with ImageNet-C's test set 129 | def brightness(pil_img, level): 130 | level = float_parameter(sample_level(level), 1.8) + 0.1 131 | return ImageEnhance.Brightness(pil_img).enhance(level) 132 | 133 | 134 | # operation that overlaps with ImageNet-C's test set 135 | def sharpness(pil_img, level): 136 | level = float_parameter(sample_level(level), 1.8) + 0.1 137 | return ImageEnhance.Sharpness(pil_img).enhance(level) 138 | 139 | 140 | augmentations = [ 141 | autocontrast, 142 | equalize, 143 | posterize, 144 | rotate, 145 | solarize, 146 | shear_x, 147 | shear_y, 148 | translate_x, 149 | translate_y, 150 | ] 151 | 152 | augmentations_all = [ 153 | autocontrast, 154 | equalize, 155 | posterize, 156 | rotate, 157 | solarize, 158 | shear_x, 159 | shear_y, 160 | translate_x, 161 | translate_y, 162 | color, 163 | contrast, 164 | brightness, 165 | sharpness, 166 | ] 167 | -------------------------------------------------------------------------------- /pretrain/third_party/datasets.py: -------------------------------------------------------------------------------- 1 | # A collection of datasets class used in pretraining models. 2 | 3 | # imports. 4 | import os 5 | import functools 6 | import torch 7 | import sys 8 | sys.path.append("..") 9 | 10 | from ttab.loads.datasets.dataset_shifts import NoShiftedData, SyntheticShiftedData 11 | from ttab.loads.datasets.mnist import ColoredSyntheticShift 12 | from ttab.loads.datasets.loaders import BaseLoader 13 | from ttab.loads.datasets.datasets import OfficeHomeDataset, PACSDataset, CIFARDataset, WBirdsDataset, ColoredMNIST 14 | 15 | def get_train_dataset(config) -> BaseLoader: 16 | """Get the training dataset from `config`.""" 17 | data_shift_class = functools.partial(NoShiftedData, data_name=config.data_name) 18 | if "cifar" in config.data_name: 19 | train_dataset = CIFARDataset( 20 | root=os.path.join(config.data_path, config.data_name), 21 | data_name=config.data_name, 22 | split="train", 23 | device=config.device, 24 | data_augment=True, 25 | data_shift_class=data_shift_class, 26 | ) 27 | val_dataset = CIFARDataset( 28 | root=os.path.join(config.data_path, config.data_name), 29 | data_name=config.data_name, 30 | split="test", 31 | device=config.device, 32 | data_augment=False, 33 | data_shift_class=data_shift_class, 34 | ) 35 | elif "officehome" in config.data_name: 36 | _data_names = config.data_name.split("_") 37 | dataset = OfficeHomeDataset( 38 | root=os.path.join(config.data_path, _data_names[0], _data_names[1]), 39 | device=config.device, 40 | data_augment=True, 41 | data_shift_class=data_shift_class, 42 | ).split_data(fractions=[0.9, 0.1], augment=[True, False], seed=config.seed) 43 | train_dataset, val_dataset = dataset[0], dataset[1] 44 | elif "pacs" in config.data_name: 45 | _data_names = config.data_name.split("_") 46 | dataset = PACSDataset( 47 | root=os.path.join(config.data_path, _data_names[0], _data_names[1]), 48 | device=config.device, 49 | data_augment=True, 50 | data_shift_class=data_shift_class, 51 | ).split_data(fractions=[0.9, 0.1], augment=[True, False], seed=config.seed) 52 | train_dataset, val_dataset = dataset[0], dataset[1] 53 | elif config.data_name == "waterbirds": 54 | train_dataset = WBirdsDataset( 55 | root=os.path.join(config.data_path, config.data_name), 56 | split="train", 57 | device=config.device, 58 | data_augment=True, 59 | data_shift_class=data_shift_class, 60 | ) 61 | val_dataset = WBirdsDataset( 62 | root=os.path.join(config.data_path, config.data_name), 63 | split="val", 64 | device=config.device, 65 | data_augment=False, 66 | ) 67 | elif config.data_name == "coloredmnist": 68 | data_shift_class = functools.partial( 69 | SyntheticShiftedData, 70 | data_name=config.data_name, 71 | seed=config.seed, 72 | synthetic_class=ColoredSyntheticShift( 73 | data_name=config.data_name, seed=config.seed 74 | ), 75 | version="stochastic", 76 | ) 77 | train_dataset = ColoredMNIST( 78 | root=os.path.join(config.data_path, "mnist"), 79 | data_name=config.data_name, 80 | split="train", 81 | device=config.device, 82 | data_shift_class=data_shift_class, 83 | ) 84 | val_dataset = ColoredMNIST( 85 | root=os.path.join(config.data_path, "mnist"), 86 | data_name=config.data_name, 87 | split="val", 88 | device=config.device, 89 | data_shift_class=data_shift_class, 90 | ) 91 | else: 92 | raise RuntimeError(f"Unknown dataset: {config.data_name}") 93 | 94 | return BaseLoader(train_dataset), BaseLoader(val_dataset) 95 | 96 | 97 | class AugMixDataset(torch.utils.data.Dataset): 98 | """Dataset wrapper to perform AugMix augmentation.""" 99 | 100 | def __init__(self, dataset, preprocess, aug, no_jsd=False): 101 | self.dataset = dataset 102 | self.preprocess = preprocess 103 | self.no_jsd = no_jsd 104 | self.aug = aug 105 | 106 | def __getitem__(self, i): 107 | x, y = self.dataset[i] 108 | if self.no_jsd: 109 | return self.aug(x, self.preprocess), y 110 | else: 111 | im_tuple = ( 112 | self.preprocess(x), 113 | self.aug(x, self.preprocess), 114 | self.aug(x, self.preprocess), 115 | ) 116 | return im_tuple, y 117 | 118 | def __len__(self): 119 | return len(self.dataset) 120 | -------------------------------------------------------------------------------- /pretrain/third_party/utils.py: -------------------------------------------------------------------------------- 1 | # utils functions. 2 | import copy 3 | import numpy as np 4 | import sys 5 | sys.path.append("..") 6 | 7 | import torch 8 | import torch.nn as nn 9 | import torch.nn.functional as F 10 | import torchvision.models as models 11 | 12 | import ttab.model_adaptation.utils as adaptation_utils 13 | from ttab.loads.models import resnet, WideResNet 14 | from ttab.configs.datasets import dataset_defaults 15 | 16 | 17 | def convert_iabn(module: nn.Module, config, **kwargs) -> nn.Module: 18 | module_output = module 19 | if isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm1d): 20 | IABN = ( 21 | adaptation_utils.InstanceAwareBatchNorm2d 22 | if isinstance(module, nn.BatchNorm2d) 23 | else adaptation_utils.InstanceAwareBatchNorm1d 24 | ) 25 | module_output = IABN( 26 | num_channels=module.num_features, 27 | k=config.iabn_k, 28 | eps=module.eps, 29 | momentum=module.momentum, 30 | threshold=config.threshold_note, 31 | affine=module.affine, 32 | ) 33 | 34 | module_output._bn = copy.deepcopy(module) 35 | 36 | for name, child in module.named_children(): 37 | module_output.add_module(name, convert_iabn(child, **kwargs)) 38 | del module 39 | return module_output 40 | 41 | def update_pytorch_model_key(state_dict: dict) -> dict: 42 | """This function is used to modify the state dict key of pretrained model from pytorch.""" 43 | new_state_dict = {} 44 | for key, value in state_dict.items(): 45 | if "downsample" in key: 46 | name_split = key.split(".") 47 | if name_split[-2] == "0": 48 | name_split[-2] = "conv" 49 | new_key = ".".join(name_split) 50 | new_state_dict[new_key] = value 51 | elif name_split[-2] == "1": 52 | name_split[-2] = "bn" 53 | new_key = ".".join(name_split) 54 | new_state_dict[new_key] = value 55 | elif "fc" in key: 56 | name_split = key.split(".") 57 | if name_split[0] == "fc": 58 | name_split[0] = "classifier" 59 | new_key = ".".join(name_split) 60 | new_state_dict[new_key] = value 61 | else: 62 | new_state_dict[key] = value 63 | 64 | return new_state_dict 65 | 66 | def build_model(config) -> nn.Module: 67 | """Build model from `config`""" 68 | num_classes = dataset_defaults[config.base_data_name]["statistics"]["n_classes"] 69 | if config.base_data_name in ["officehome", "pacs", "waterbirds"]: 70 | pretrained_model = models.resnet50(pretrained=True) 71 | pretrained_model.fc = nn.Linear(pretrained_model.fc.in_features, out_features=num_classes, bias=False) 72 | pretrained_model.fc.weight.data.normal_(mean=0, std=0.01) 73 | 74 | model = resnet(config.base_data_name, depth=50, split_point=config.entry_of_shared_layers, grad_checkpoint=True).to(config.device) 75 | model_dict = model.state_dict() 76 | pretrained_dict = pretrained_model.state_dict() 77 | new_pretrained_dict = update_pytorch_model_key(pretrained_dict) 78 | new_pretrained_dict = {k: v for k, v in new_pretrained_dict.items() if k in model_dict} 79 | model.load_state_dict(new_pretrained_dict) 80 | model.to(config.device) 81 | del pretrained_dict, pretrained_model, new_pretrained_dict 82 | if config.use_iabn: 83 | model = convert_iabn(model, config) 84 | else: 85 | if "wideresnet" in config.model_name: 86 | components = config.model_name.split("_") 87 | depth = int(components[0].replace("wideresnet", "")) 88 | widen_factor = int(components[1]) 89 | 90 | model = WideResNet( 91 | depth, 92 | widen_factor, 93 | num_classes, 94 | split_point=config.entry_of_shared_layers, 95 | dropout_rate=0.3, 96 | ) 97 | elif "resnet" in config.model_name: 98 | depth = int(config.model_name.replace("resnet", "")) 99 | model = resnet( 100 | config.base_data_name, 101 | depth, 102 | split_point=config.entry_of_shared_layers, 103 | group_norm_num_groups=config.group_norm, 104 | ).to(config.device) 105 | if config.use_iabn: 106 | assert config.group_norm is None, "IABN cannot be used with group norm." 107 | model = convert_iabn(model, config) 108 | return model 109 | 110 | def get_train_params(model: nn.Module, config) -> list: 111 | """Define the trainable parameters for a model using `config`""" 112 | if config.base_data_name in ["officehome", "pacs"]: 113 | params = [] 114 | learning_rate = config.lr 115 | 116 | for name_module, module in model.main_model.named_children(): 117 | if name_module != "classifier": 118 | for _, param in module.named_parameters(): 119 | params += [{"params": param, "lr": learning_rate*0.1}] 120 | else: 121 | for _, param in module.named_parameters(): 122 | params += [{"params": param, "lr": learning_rate}] 123 | 124 | for name_module, module in model.ssh.head.named_children(): 125 | if isinstance(module, nn.Linear): 126 | for _, param in module.named_parameters(): 127 | params += [{"params": param, "lr": learning_rate}] 128 | else: 129 | for _, param in module.named_parameters(): 130 | params += [{"params": param, "lr": learning_rate*0.1}] 131 | else: 132 | params = list(model.main_model.parameters()) + list(model.ssh.head.parameters()) 133 | return params 134 | 135 | def get_lr(step, total_steps, lr_max, lr_min): 136 | """Compute learning rate according to cosine annealing schedule.""" 137 | return lr_min + (lr_max - lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi)) 138 | 139 | 140 | def mixup_data(x, y, alpha=1.0, device=None): 141 | """Returns mixed inputs, pairs of targets, and lambda""" 142 | if alpha > 0: 143 | lam = np.random.beta(alpha, alpha) 144 | else: 145 | lam = 1 146 | 147 | batch_size = x.size()[0] 148 | if device is None: 149 | index = torch.randperm(batch_size) 150 | else: 151 | index = torch.randperm(batch_size).to(device) 152 | 153 | mixed_x = lam * x + (1 - lam) * x[index, :] 154 | y_a, y_b = y, y[index] 155 | return mixed_x, y_a, y_b, lam 156 | 157 | 158 | def mixup_criterion(criterion, pred, y_a, y_b, lam): 159 | return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b) 160 | -------------------------------------------------------------------------------- /run_exp.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import parameters 3 | import ttab.configs.utils as configs_utils 4 | import ttab.loads.define_dataset as define_dataset 5 | from ttab.benchmark import Benchmark 6 | from ttab.loads.define_model import define_model, load_pretrained_model 7 | from ttab.model_adaptation import get_model_adaptation_method 8 | from ttab.model_selection import get_model_selection_method 9 | 10 | 11 | def main(init_config): 12 | # Required auguments. 13 | config, scenario = configs_utils.config_hparams(config=init_config) 14 | 15 | test_data_cls = define_dataset.ConstructTestDataset(config=config) 16 | test_loader = test_data_cls.construct_test_loader(scenario=scenario) 17 | 18 | # Model. 19 | model = define_model(config=config) 20 | load_pretrained_model(config=config, model=model) 21 | 22 | # Algorithms. 23 | model_adaptation_cls = get_model_adaptation_method( 24 | adaptation_name=scenario.model_adaptation_method 25 | )(meta_conf=config, model=model) 26 | model_selection_cls = get_model_selection_method(selection_name=scenario.model_selection_method)( 27 | meta_conf=config, model_adaptation_method=model_adaptation_cls 28 | ) 29 | 30 | # Evaluate. 31 | benchmark = Benchmark( 32 | scenario=scenario, 33 | model_adaptation_cls=model_adaptation_cls, 34 | model_selection_cls=model_selection_cls, 35 | test_loader=test_loader, 36 | meta_conf=config, 37 | ) 38 | benchmark.eval() 39 | 40 | 41 | if __name__ == "__main__": 42 | conf = parameters.get_args() 43 | main(init_config=conf) 44 | -------------------------------------------------------------------------------- /run_extract.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import argparse 4 | 5 | import monitor.tools.file_io as file_io 6 | from monitor.tools.show_results import load_raw_info_from_experiments 7 | from parameters import str2bool 8 | 9 | """parse and define arguments for different tasks.""" 10 | 11 | 12 | def get_args(): 13 | # feed them to the parser. 14 | parser = argparse.ArgumentParser(description="Extract results.") 15 | 16 | # add arguments. 17 | parser.add_argument("--in_dir", type=str) 18 | parser.add_argument("--out_name", type=str, default="summary.pickle") 19 | 20 | # parse aˇˇrgs. 21 | args = parser.parse_args() 22 | 23 | # an argument safety check. 24 | check_args(args) 25 | return args 26 | 27 | 28 | def check_args(args): 29 | assert args.in_dir is not None 30 | 31 | # define out path. 32 | args.out_path = os.path.join(args.in_dir, args.out_name) 33 | 34 | 35 | """write the results to path.""" 36 | 37 | 38 | def main(args): 39 | # save the parsed results to path. 40 | file_io.write_pickle( 41 | load_raw_info_from_experiments(args.in_dir), 42 | args.out_path, 43 | ) 44 | 45 | 46 | if __name__ == "__main__": 47 | args = get_args() 48 | 49 | main(args) 50 | -------------------------------------------------------------------------------- /ttab/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/__init__.py -------------------------------------------------------------------------------- /ttab/api.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import itertools 3 | from typing import Any, Callable, Iterable, List, Mapping, Optional, Tuple, Union 4 | 5 | import numpy as np 6 | import torch 7 | from torch.utils.data.dataset import random_split 8 | 9 | State = List[torch.Tensor] 10 | Gradient = List[torch.Tensor] 11 | Parameters = List[torch.Tensor] 12 | Loss = float 13 | Quality = Mapping[str, float] 14 | 15 | 16 | class Batch(object): 17 | def __init__(self, x, y): 18 | self._x = x 19 | self._y = y 20 | 21 | def __len__(self) -> int: 22 | return len(self._x) 23 | 24 | def to(self, device) -> "Batch": 25 | return Batch(self._x.to(device), self._y.to(device)) 26 | 27 | def __getitem__(self, index): 28 | return self._x[index], self._y[index] 29 | 30 | 31 | class GroupBatch(object): 32 | def __init__(self, x, y, g): 33 | self._x = x 34 | self._y = y 35 | self._g = g 36 | 37 | def __len__(self) -> int: 38 | return len(self._x) 39 | 40 | def to(self, device) -> "Batch": 41 | return GroupBatch(self._x.to(device), self._y.to(device), self._g.to(device)) 42 | 43 | def __getitem__(self, index): 44 | return self._x[index], self._y[index], self._g[index] 45 | 46 | 47 | class Dataset: 48 | def random_split(self, fractions: List[float]) -> List["Dataset"]: 49 | pass 50 | 51 | def iterator( 52 | self, batch_size: int, shuffle: bool, repeat=True 53 | ) -> Iterable[Tuple[float, Batch]]: 54 | pass 55 | 56 | def __len__(self) -> int: 57 | pass 58 | 59 | 60 | class PyTorchDataset(object): 61 | def __init__( 62 | self, 63 | dataset: torch.utils.data.Dataset, 64 | device: str, 65 | prepare_batch: Callable, 66 | num_classes: int, 67 | ): 68 | self._set = dataset 69 | self._device = device 70 | self._prepare_batch = prepare_batch 71 | self._num_classes = num_classes 72 | 73 | def __len__(self): 74 | return len(self._set) 75 | 76 | def replace_indices( 77 | self, 78 | indices_pattern: str = "original", 79 | new_indices: List[int] = None, 80 | random_seed: int = None, 81 | ) -> None: 82 | """Change the order of dataset indices in a particular pattern.""" 83 | if indices_pattern == "original": 84 | pass 85 | elif indices_pattern == "random_shuffle": 86 | rng = np.random.default_rng(random_seed) 87 | rng.shuffle(self.dataset.indices) 88 | elif indices_pattern == "new": 89 | if new_indices is None: 90 | raise ValueError("new_indices should be specified.") 91 | self.dataset.update_indices(new_indices=new_indices) 92 | else: 93 | raise NotImplementedError 94 | 95 | def query_dataset_attr(self, attr_name: str) -> Any: 96 | return getattr(self._set, attr_name, None) 97 | 98 | @property 99 | def dataset(self): 100 | return self._set 101 | 102 | @property 103 | def num_classes(self): 104 | return self._num_classes 105 | 106 | def no_split(self) -> List[Dataset]: 107 | return [ 108 | PyTorchDataset( 109 | dataset=self._set, 110 | device=self._device, 111 | prepare_batch=self._prepare_batch, 112 | num_classes=self._num_classes, 113 | ) 114 | ] 115 | 116 | def random_split(self, fractions: List[float], seed: int = 0) -> List[Dataset]: 117 | lengths = [int(f * len(self._set)) for f in fractions] 118 | lengths[0] += len(self._set) - sum(lengths) 119 | return [ 120 | PyTorchDataset( 121 | dataset=split, 122 | device=self._device, 123 | prepare_batch=self._prepare_batch, 124 | num_classes=self._num_classes, 125 | ) 126 | for split in random_split( 127 | self._set, lengths, torch.Generator().manual_seed(seed) 128 | ) 129 | ] 130 | 131 | def iterator( 132 | self, 133 | batch_size: int, 134 | shuffle: bool = True, 135 | repeat: bool = False, 136 | ref_num_data: Optional[int] = None, 137 | num_workers: int = 1, 138 | sampler: Optional[torch.utils.data.Sampler] = None, 139 | generator: Optional[torch.Generator] = None, 140 | pin_memory: bool = True, 141 | drop_last: bool = True, 142 | ) -> Iterable[Tuple[int, float, Batch]]: 143 | _num_batch = 1 if not drop_last else 0 144 | if ref_num_data is None: 145 | num_batches = int(len(self) / batch_size + _num_batch) 146 | else: 147 | num_batches = int(ref_num_data / batch_size + _num_batch) 148 | if sampler is not None: 149 | shuffle = False 150 | 151 | loader = torch.utils.data.DataLoader( 152 | self._set, 153 | batch_size=batch_size, 154 | shuffle=shuffle, 155 | pin_memory=pin_memory, 156 | drop_last=drop_last, 157 | num_workers=num_workers, 158 | sampler=sampler, 159 | generator=generator, 160 | ) 161 | 162 | step = 0 163 | for _ in itertools.count() if repeat else [0]: 164 | for i, batch in enumerate(loader): 165 | step += 1 166 | epoch_fractional = float(step) / num_batches 167 | yield step, epoch_fractional, self._prepare_batch(batch, self._device) 168 | 169 | def record_class_distribution( 170 | self, 171 | targets: Union[List, np.ndarray], 172 | indices: Union[List, np.ndarray], 173 | print_fn: Callable = print, 174 | is_train: bool = True, 175 | display: bool = True, 176 | ): 177 | targets_np = np.array(targets) 178 | unique_elements, counts_elements = np.unique( 179 | targets_np[indices] if indices is not None else targets_np, 180 | return_counts=True, 181 | ) 182 | element_counts = list(zip(unique_elements, counts_elements)) 183 | 184 | if display: 185 | print_fn( 186 | f"\tThe histogram of the targets in {'train' if is_train else 'test'}: {element_counts}" 187 | ) 188 | return element_counts 189 | -------------------------------------------------------------------------------- /ttab/configs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/configs/__init__.py -------------------------------------------------------------------------------- /ttab/configs/algorithms.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # 1. This file collects significant hyperparameters for the configuration of TTA methods. 4 | # 2. We are only concerned about method-related hyperparameters here. 5 | # 3. We provide default hyperparameters from the paper or official repo if users have no idea how to set up reasonable values. 6 | import math 7 | 8 | algorithm_defaults = { 9 | "no_adaptation": {"model_selection_method": "last_iterate"}, 10 | # 11 | "bn_adapt": { 12 | "adapt_prior": 0, # the ratio of training set statistics. 13 | }, 14 | "shot": { 15 | "optimizer": "SGD", # Adam for officehome 16 | "auxiliary_batch_size": 32, 17 | "threshold_shot": 0.9, # confidence threshold for online shot. 18 | "ent_par": 1.0, 19 | "cls_par": 0.3, # 0.1 for officehome. 20 | "offline_nepoch": 10, 21 | }, 22 | "ttt": { 23 | "optimizer": "SGD", 24 | "entry_of_shared_layers": "layer2", 25 | "aug_size": 32, 26 | "threshold_ttt": 1.0, 27 | "dim_out": 4, # For rotation prediction self-supervision task. 28 | "rotation_type": "rand", 29 | }, 30 | "tent": { 31 | "optimizer": "SGD", 32 | }, 33 | "t3a": {"top_M": 100}, 34 | "cotta": { 35 | "optimizer": "SGD", 36 | "alpha_teacher": 0.999, # weight of moving average for updating the teacher model. 37 | "restore_prob": 0.01, # the probability of restoring model parameters. 38 | "threshold_cotta": 0.92, # Threshold choice discussed in supplementary 39 | "aug_size": 32, 40 | }, 41 | "eata": { 42 | "optimizer": "SGD", 43 | "eata_margin_e0": math.log(1000) 44 | * 0.40, # The threshold for reliable minimization in EATA. 45 | "eata_margin_d0": 0.05, # for filtering redundant samples. 46 | "fishers": True, # whether to use fisher regularizer. 47 | "fisher_size": 2000, # number of samples to compute fisher information matrix. 48 | "fisher_alpha": 50, # the trade-off between entropy and regularization loss. 49 | }, 50 | "memo": { 51 | "optimizer": "SGD", 52 | "aug_size": 32, 53 | "bn_prior_strength": 16, 54 | }, 55 | # "ttt_plus_plus": { 56 | # "optimizer": "SGD", 57 | # "entry_of_shared_layers": None, 58 | # "batch_size_align": 256, 59 | # "queue_size": 256, 60 | # "offline_nepoch": 500, 61 | # "bnepoch": 2, # first few epochs to update bn stat. 62 | # "delayepoch": 0, # In first few epochs after bnepoch, we dont do both ssl and align (only ssl actually). 63 | # "stopepoch": 25, 64 | # "scale_ext": 0.5, 65 | # "scale_ssh": 0.2, 66 | # "align_ext": True, 67 | # "align_ssh": True, 68 | # "fix_ssh": False, 69 | # "method": "align", # choices = ['ssl', 'align', 'both'] 70 | # "divergence": "all", # choices = ['all', 'coral', 'mmd'] 71 | # }, 72 | "note": { 73 | "optimizer": "SGD", # use Adam in the paper 74 | "memory_size": 64, 75 | "update_every_x": 64, # This param may change in our codebase. 76 | "memory_type": "PBRS", 77 | "bn_momentum": 0.01, 78 | "temperature": 1.0, 79 | "iabn": False, # replace bn with iabn layer 80 | "iabn_k": 4, 81 | "threshold_note": 1, # skip threshold to discard adjustment. 82 | "use_learned_stats": True, 83 | }, 84 | "conjugate_pl": { 85 | "optimizer": "SGD", 86 | "temperature_scaling": 1.0, 87 | "model_eps": 0.0, # this should be added for Polyloss model. 88 | }, 89 | "sar": { 90 | "optimizer": "SGD", 91 | "sar_margin_e0": math.log(1000) 92 | * 0.40, # The threshold for reliable minimization in SAR. 93 | "reset_constant_em": 0.2, # threshold e_m for model recovery scheme 94 | }, 95 | "rotta":{ 96 | "optimizer": "Adam", 97 | "nu": 0.001, 98 | "memory_size": 64, 99 | "update_frequency": 64, 100 | "lambda_t": 1.0, 101 | "lambda_u": 1.0, 102 | "alpha": 0.05, 103 | } 104 | } 105 | -------------------------------------------------------------------------------- /ttab/configs/datasets.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # 1. This file collects hyperparameters significant for pretraining and test-time adaptation. 4 | # 2. We are only concerned about dataset-related hyperparameters here, e.g., lr, dataset statistics, and type of corruptions. 5 | # 3. We provide default hyperparameters if users have no idea how to set up reasonable values. 6 | 7 | dataset_defaults = { 8 | "cifar10": { 9 | "statistics": { 10 | "mean": (0.4914, 0.4822, 0.4465), 11 | "std": (0.2023, 0.1994, 0.2010), 12 | "n_classes": 10, 13 | }, 14 | "version": "deterministic", 15 | "img_shape": (32, 32, 3), 16 | }, 17 | "cifar100": { 18 | "statistics": { 19 | "mean": (0.5070751592371323, 0.48654887331495095, 0.4409178433670343), 20 | "std": (0.2673342858792401, 0.2564384629170883, 0.27615047132568404), 21 | "n_classes": 100, 22 | }, 23 | "version": "deterministic", 24 | "img_shape": (32, 32, 3), 25 | }, 26 | "officehome": { 27 | "statistics": { 28 | "mean": (0.485, 0.456, 0.406), 29 | "std": (0.229, 0.224, 0.225), 30 | "n_classes": 65, 31 | }, 32 | "img_shape": (224, 224, 3), 33 | }, 34 | "pacs": { 35 | "statistics": { 36 | "mean": (0.485, 0.456, 0.406), 37 | "std": (0.229, 0.224, 0.225), 38 | "n_classes": 7, 39 | }, 40 | "img_shape": (224, 224, 3), 41 | }, 42 | "coloredmnist": { 43 | "statistics": { 44 | "mean": (0.1307, 0.1307, 0.0), 45 | "std": (0.3081, 0.3081, 0.3081), 46 | "n_classes": 2, 47 | }, 48 | "img_shape": (28, 28, 3), 49 | }, 50 | "waterbirds": { 51 | "statistics": { 52 | "mean": (0.485, 0.456, 0.406), 53 | "std": (0.229, 0.224, 0.225), 54 | "n_classes": 2, 55 | }, 56 | "group_counts": [3498, 184, 56, 1057], # used to compute group ratio. 57 | "img_shape": (224, 224, 3), 58 | }, 59 | "imagenet": { 60 | "statistics": { 61 | "mean": (0.485, 0.456, 0.406), 62 | "std": (0.229, 0.224, 0.225), 63 | "n_classes": 1000, 64 | }, 65 | "img_shape": (224, 224, 3), 66 | }, 67 | "yearbook": { 68 | "statistics": {"n_classes": 2,}, 69 | "img_shape": (32, 32, 3), 70 | } 71 | } 72 | -------------------------------------------------------------------------------- /ttab/configs/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | # Functions in this file are used to deal with any work related with the configuration of datasets, models and algorithms. 4 | import ttab.scenarios.define_scenario as define_scenario 5 | from ttab.configs.algorithms import algorithm_defaults 6 | from ttab.configs.datasets import dataset_defaults 7 | 8 | 9 | def config_hparams(config): 10 | """ 11 | Populates hyperparameters with defaults implied by choices of other hyperparameters. 12 | 13 | Args: 14 | - config: namespace 15 | Returns: 16 | - config: namespace 17 | - scenario: NamedTuple 18 | """ 19 | # prior safety check 20 | assert ( 21 | config.model_adaptation_method is not None 22 | ), "model adaptation method must be specified" 23 | 24 | assert ( 25 | config.base_data_name is not None 26 | ), "base_data_name must be specified, either from default scenario, or from user-provided inputs." 27 | 28 | # register default arguments from default scenarios if specified by the user. 29 | scenario = define_scenario.get_scenario(config) 30 | config = define_scenario.scenario_registry(config, scenario) 31 | 32 | # register default arguments based on data_name. 33 | config = defaults_registry(config, template=dataset_defaults[config.base_data_name]) 34 | 35 | # register default arguments based on model_adaptation_method 36 | config = defaults_registry( 37 | config, template=algorithm_defaults[config.model_adaptation_method] 38 | ) 39 | 40 | # TODO: register default arguments for different kinds of base models 41 | # add default path for provided ckpts, otherwise provide a link to other sources. 42 | 43 | return config, scenario 44 | 45 | 46 | def defaults_registry(config, template: dict, display_compatibility=False): 47 | """ 48 | Populates missing (key, val) pairs in config with (key, val) in template. 49 | 50 | Args: 51 | - config: namespace 52 | - template: dict 53 | - display_compatibility: option to raise errors if config.key != template[key] 54 | """ 55 | if template is None: 56 | return config 57 | 58 | dict_config = vars(config) 59 | for key, val in template.items(): 60 | if not isinstance(val, dict): # template[key] is non-index-able 61 | if key not in dict_config or dict_config[key] is None: 62 | dict_config[key] = val 63 | elif dict_config[key] != val and display_compatibility: 64 | raise ValueError(f"Argument {key} must be set to {val}") 65 | 66 | else: 67 | if key not in dict_config.keys(): 68 | dict_config[key] = {} 69 | for kwargs_key, kwargs_val in val.items(): 70 | if ( 71 | kwargs_key not in dict_config[key] 72 | or dict_config[key][kwargs_key] is None 73 | ): 74 | dict_config[key][kwargs_key] = kwargs_val 75 | elif ( 76 | dict_config[key][kwargs_key] != kwargs_val and display_compatibility 77 | ): 78 | raise ValueError( 79 | f"Argument {key}[{kwargs_key}] must be set to {kwargs_val}" 80 | ) 81 | return config 82 | 83 | 84 | def build_dict_from_config(arg_names, config): 85 | """ 86 | Build a dictionary from config based on arg_names. 87 | 88 | Args: 89 | - arg_names: list of strings 90 | - config: namespace 91 | Returns: 92 | - dict: dictionary 93 | """ 94 | dict_config = vars(config) 95 | return dict( 96 | (arg_name, dict_config[arg_name]) 97 | for arg_name in arg_names 98 | if (arg_name in dict_config) and (dict_config[arg_name] is not None) 99 | ) 100 | -------------------------------------------------------------------------------- /ttab/loads/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/.DS_Store -------------------------------------------------------------------------------- /ttab/loads/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/__init__.py -------------------------------------------------------------------------------- /ttab/loads/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | -------------------------------------------------------------------------------- /ttab/loads/datasets/cifar/data_aug_cifar.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Callable, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image, ImageOps 7 | from torchvision import transforms 8 | from ttab.configs.datasets import dataset_defaults 9 | 10 | """Base augmentations operators.""" 11 | 12 | ## https://github.com/google-research/augmix 13 | 14 | 15 | def _augmix_aug_cifar(x_orig: torch.Tensor, data_name: str) -> torch.Tensor: 16 | input_size = x_orig.shape[-1] 17 | scale_size = 36 if input_size == 32 else 256 # input size is either 32 or 224 18 | padding = int((scale_size - input_size) / 2) 19 | tensor_to_image, preprocess = get_ops(data_name) 20 | 21 | x_orig = tensor_to_image(x_orig.squeeze(0)) 22 | preaugment = transforms.Compose( 23 | [ 24 | transforms.RandomCrop(input_size, padding=padding), 25 | transforms.RandomHorizontalFlip(), 26 | ] 27 | ) 28 | x_orig = preaugment(x_orig) 29 | x_processed = preprocess(x_orig) 30 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 31 | m = np.float32(np.random.beta(1.0, 1.0)) 32 | 33 | mix = torch.zeros_like(x_processed) 34 | for i in range(3): 35 | x_aug = x_orig.copy() 36 | for _ in range(np.random.randint(1, 4)): 37 | x_aug = np.random.choice(augmentations)(x_aug, input_size) 38 | mix += w[i] * preprocess(x_aug) 39 | mix = m * x_processed + (1 - m) * mix 40 | return mix 41 | 42 | 43 | aug_cifar = _augmix_aug_cifar 44 | 45 | 46 | def autocontrast(pil_img, input_size, level=None): 47 | return ImageOps.autocontrast(pil_img) 48 | 49 | 50 | def equalize(pil_img, input_size, level=None): 51 | return ImageOps.equalize(pil_img) 52 | 53 | 54 | def rotate(pil_img, input_size, level): 55 | degrees = int_parameter(rand_lvl(level), 30) 56 | if np.random.uniform() > 0.5: 57 | degrees = -degrees 58 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128) 59 | 60 | 61 | def solarize(pil_img, input_size, level): 62 | level = int_parameter(rand_lvl(level), 256) 63 | return ImageOps.solarize(pil_img, 256 - level) 64 | 65 | 66 | def shear_x(pil_img, input_size, level): 67 | level = float_parameter(rand_lvl(level), 0.3) 68 | if np.random.uniform() > 0.5: 69 | level = -level 70 | return pil_img.transform( 71 | (input_size, input_size), 72 | Image.AFFINE, 73 | (1, level, 0, 0, 1, 0), 74 | resample=Image.BILINEAR, 75 | fillcolor=128, 76 | ) 77 | 78 | 79 | def shear_y(pil_img, input_size, level): 80 | level = float_parameter(rand_lvl(level), 0.3) 81 | if np.random.uniform() > 0.5: 82 | level = -level 83 | return pil_img.transform( 84 | (input_size, input_size), 85 | Image.AFFINE, 86 | (1, 0, 0, level, 1, 0), 87 | resample=Image.BILINEAR, 88 | fillcolor=128, 89 | ) 90 | 91 | 92 | def translate_x(pil_img, input_size, level): 93 | level = int_parameter(rand_lvl(level), input_size / 3) 94 | if np.random.random() > 0.5: 95 | level = -level 96 | return pil_img.transform( 97 | (input_size, input_size), 98 | Image.AFFINE, 99 | (1, 0, level, 0, 1, 0), 100 | resample=Image.BILINEAR, 101 | fillcolor=128, 102 | ) 103 | 104 | 105 | def translate_y(pil_img, input_size, level): 106 | level = int_parameter(rand_lvl(level), input_size / 3) 107 | if np.random.random() > 0.5: 108 | level = -level 109 | return pil_img.transform( 110 | (input_size, input_size), 111 | Image.AFFINE, 112 | (1, 0, 0, 0, 1, level), 113 | resample=Image.BILINEAR, 114 | fillcolor=128, 115 | ) 116 | 117 | 118 | def posterize(pil_img, input_size, level): 119 | level = int_parameter(rand_lvl(level), 4) 120 | return ImageOps.posterize(pil_img, 4 - level) 121 | 122 | 123 | def int_parameter(level, maxval): 124 | """Helper function to scale `val` between 0 and maxval . 125 | Args: 126 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 127 | maxval: Maximum value that the operation can have. This will be scaled 128 | to level/PARAMETER_MAX. 129 | Returns: 130 | An int that results from scaling `maxval` according to `level`. 131 | """ 132 | return int(level * maxval / 10) 133 | 134 | 135 | def float_parameter(level, maxval): 136 | """Helper function to scale `val` between 0 and maxval . 137 | Args: 138 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 139 | maxval: Maximum value that the operation can have. This will be scaled 140 | to level/PARAMETER_MAX. 141 | Returns: 142 | A float that results from scaling `maxval` according to `level`. 143 | """ 144 | return float(level) * maxval / 10.0 145 | 146 | 147 | def rand_lvl(n): 148 | return np.random.uniform(low=0.1, high=n) 149 | 150 | 151 | augmentations = [ 152 | autocontrast, 153 | equalize, 154 | lambda x, y: rotate(x, y, 1), 155 | lambda x, y: solarize(x, y, 1), 156 | lambda x, y: shear_x(x, y, 1), 157 | lambda x, y: shear_y(x, y, 1), 158 | lambda x, y: translate_x(x, y, 1), 159 | lambda x, y: translate_y(x, y, 1), 160 | lambda x, y: posterize(x, y, 1), 161 | ] 162 | 163 | def get_ops(data_name: str) -> Tuple[Callable, Callable]: 164 | """Get the operations to be applied when defining transforms.""" 165 | unnormalize = transforms.Compose( 166 | [ 167 | transforms.Normalize( 168 | mean=[0.0, 0.0, 0.0], 169 | std=[1.0 / v for v in dataset_defaults[data_name]["statistics"]["std"]], 170 | ), 171 | transforms.Normalize( 172 | mean=[-v for v in dataset_defaults[data_name]["statistics"]["mean"]], 173 | std=[1.0, 1.0, 1.0], 174 | ), 175 | ] 176 | ) 177 | 178 | tensor_to_image = transforms.Compose([unnormalize, transforms.ToPILImage()]) 179 | preprocess = transforms.Compose( 180 | [ 181 | transforms.ToTensor(), 182 | transforms.Normalize( 183 | dataset_defaults[data_name]["statistics"]["mean"], dataset_defaults[data_name]["statistics"]["std"] 184 | ), 185 | ] 186 | ) 187 | return tensor_to_image, preprocess 188 | 189 | 190 | def tr_transforms_cifar(image: torch.Tensor, data_name: str) -> torch.Tensor: 191 | """ 192 | Data augmentation for input images. 193 | args: 194 | inputs: 195 | image: tensor [n_channel, H, W] 196 | outputs: 197 | augment_image: tensor [1, n_channel, H, W] 198 | """ 199 | input_size = image.shape[-1] 200 | scale_size = 36 if input_size == 32 else 256 # input size is either 32 or 224 201 | padding = int(scale_size - input_size) 202 | tensor_to_image, preprocess = get_ops(data_name) 203 | 204 | image = tensor_to_image(image) 205 | preaugment = transforms.Compose( 206 | [ 207 | transforms.RandomCrop(input_size, padding=padding), 208 | transforms.RandomHorizontalFlip(), 209 | ] 210 | ) 211 | augment_image = preaugment(image) 212 | augment_image = preprocess(augment_image) 213 | 214 | return augment_image 215 | -------------------------------------------------------------------------------- /ttab/loads/datasets/dataset_sampling.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import random 3 | 4 | from ttab.api import PyTorchDataset 5 | from ttab.scenarios import TestDomain 6 | 7 | 8 | class DatasetSampling(object): 9 | def __init__(self, test_domain: TestDomain): 10 | self.domain_sampling_name = test_domain.domain_sampling_name 11 | self.domain_sampling_value = test_domain.domain_sampling_value 12 | self.domain_sampling_ratio = test_domain.domain_sampling_ratio 13 | 14 | def sample( 15 | self, dataset: PyTorchDataset, random_seed: int = None 16 | ) -> PyTorchDataset: 17 | if self.domain_sampling_name == "uniform": 18 | return self.uniform_sample( 19 | dataset=dataset, 20 | ratio=self.domain_sampling_ratio, 21 | random_seed=random_seed, 22 | ) 23 | else: 24 | raise NotImplementedError 25 | 26 | @staticmethod 27 | def uniform_sample( 28 | dataset: PyTorchDataset, ratio: float, random_seed: int = None 29 | ) -> PyTorchDataset: 30 | """This function uniformly samples data from the original dataset without replacement.""" 31 | random.seed(random_seed) 32 | indices = dataset.query_dataset_attr("indices") 33 | sampled_list = random.sample( 34 | indices, 35 | int(ratio * len(indices)), 36 | ) 37 | sampled_list.sort() 38 | dataset.replace_indices( 39 | indices_pattern="new", new_indices=sampled_list, random_seed=random_seed 40 | ) 41 | return dataset 42 | -------------------------------------------------------------------------------- /ttab/loads/datasets/dataset_shifts.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Callable, List, NamedTuple, Optional 3 | 4 | import numpy as np 5 | import torch 6 | from PIL import Image 7 | 8 | # map data_name to a distribution shift. 9 | # Let's say the data_name could be in the form of 1) , 2), _, and 3) __ 10 | 11 | data2shift = dict( 12 | cifar10="no_shift", 13 | cifar100="no_shift", 14 | cifar10_c="synthetic", 15 | cifar100_c="synthetic", 16 | cifar10_1="natural", 17 | cifar10_shiftedlabel="natural", 18 | cifar100_shiftedlabel="natural", 19 | cifar10_temporal="temporal", 20 | imagenet="no_shift", 21 | imagenet_c="synthetic", 22 | imagenet_a="natural", 23 | imagenet_r="natural", 24 | # imagenet_v2_ 25 | imagenet_v2="natural", 26 | officehome_art="natural", 27 | officehome_clipart="natural", 28 | officehome_product="natural", 29 | officehome_realworld="natural", 30 | pacs_art="natural", 31 | pacs_cartoon="natural", 32 | pacs_photo="natural", 33 | pacs_sketch="natural", 34 | mnist="no_shift", 35 | coloredmnist="synthetic", 36 | waterbirds="natural", 37 | yearbook="natural", 38 | ) 39 | 40 | 41 | class SyntheticShiftProperty(NamedTuple): 42 | shift_degree: int 43 | shift_name: str 44 | 45 | version: str = "stochastic" 46 | has_shift: bool = True 47 | 48 | 49 | class NaturalShiftProperty(NamedTuple): 50 | version: str = None 51 | has_shift: bool = True 52 | 53 | 54 | class NoShiftProperty(NamedTuple): 55 | has_shift: bool = False 56 | 57 | 58 | class ShiftedData(object): 59 | def __init__(self, dataset: torch.utils.data.Dataset): 60 | self.dataset = dataset 61 | 62 | def __getitem__(self, index): 63 | return self.dataset[index] 64 | 65 | def __len__(self): 66 | return len(self.dataset) 67 | 68 | def update_indices(self, new_indices: List[int]) -> None: 69 | """Update the indices of the dataset after applying shift.""" 70 | self.dataset.indices = new_indices 71 | self.dataset.data_size = len(self.indices) 72 | 73 | @property 74 | def data(self): 75 | return self.dataset.data 76 | 77 | @property 78 | def targets(self): 79 | return self.dataset.targets 80 | 81 | @property 82 | def indices(self): 83 | return self.dataset.indices 84 | 85 | @property 86 | def data_size(self): 87 | return self.dataset.data_size 88 | 89 | @property 90 | def classes(self): 91 | return self.dataset.classes 92 | 93 | @property 94 | def class_to_index(self): 95 | return self.dataset.class_to_index 96 | 97 | @property 98 | def group_array(self): 99 | return getattr(self.dataset, "group_array", None) 100 | 101 | 102 | class NoShiftedData(ShiftedData): 103 | """ 104 | Dataset-like object, but only access a subset of it. 105 | And it applies NO shift to the data when __getitem__. 106 | """ 107 | 108 | def __init__(self, data_name, dataset: torch.utils.data.Dataset): 109 | super().__init__(dataset) 110 | # initialize corruption class 111 | self.data_name = data_name 112 | 113 | 114 | class NaturalShiftedData(ShiftedData): 115 | """ 116 | Dataset-like object, but only access a subset of it. 117 | And it will reload the data with natural shift. It will apply NO shift to the data when __getitem__. 118 | """ 119 | 120 | def __init__( 121 | self, 122 | data_name, 123 | dataset: torch.utils.data.Dataset, 124 | new_data: torch.utils.data.Dataset, 125 | ): 126 | super().__init__(dataset) 127 | # replace original data/targets with new data/targets 128 | self.dataset.data = new_data.data 129 | self.dataset.targets = new_data.targets 130 | self.dataset.data_size = len(self.dataset.data) 131 | self.dataset.indices = list([x for x in range(0, self.dataset.data_size)]) 132 | self.data_name = data_name 133 | 134 | 135 | class SyntheticShiftedData(ShiftedData): 136 | """ 137 | Dataset-like object, but only access a subset of it. 138 | And it applies corruptions to the data when __getitem__. 139 | """ 140 | 141 | def __init__( 142 | self, 143 | data_name: str, 144 | dataset: torch.utils.data.Dataset, 145 | seed: int, 146 | synthetic_class: Callable, 147 | version: str, 148 | **kwargs, 149 | ): 150 | super().__init__(dataset) 151 | self.data_name = data_name 152 | self.version = version # either stochastic or determinstic 153 | 154 | # initialize corruption class 155 | if any([name in data_name for name in ["cifar", "imagenet"]]): 156 | self.synthetic_ops = synthetic_class(data_name, seed, kwargs["severity"]) 157 | elif "mnist" in data_name: 158 | self.synthetic_ops = synthetic_class 159 | else: 160 | NotImplementedError( 161 | f"synthetic shift for {data_name} is not supported in TTAB." 162 | ) 163 | 164 | def apply_corruption(self): 165 | """Apply corruption to the clean dataset.""" 166 | corrupted_imgs = [] 167 | 168 | for index in range(self.dataset.data_size): 169 | img_array = self.dataset.data[index] 170 | img = Image.fromarray(img_array) 171 | img = self.synthetic_ops.apply(img) 172 | corrupted_imgs.append(img) 173 | 174 | # replace data. 175 | self.dataset.data = np.stack(corrupted_imgs, axis=0) 176 | 177 | def prepare_colored_mnist( 178 | self, 179 | transform: Optional[Callable] = None, 180 | target_transform: Optional[Callable] = None, 181 | ): 182 | return self.synthetic_ops.apply(self.dataset, transform, target_transform) 183 | -------------------------------------------------------------------------------- /ttab/loads/datasets/imagenet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import numpy as np 4 | import tarfile 5 | 6 | from torchvision.datasets.utils import download_url 7 | from torchvision.datasets import ImageFolder 8 | 9 | 10 | class ImageNetValNaturalShift(object): 11 | """Borrowed from 12 | (1) https://github.com/hendrycks/imagenet-r/, 13 | (2) https://github.com/hendrycks/natural-adv-examples, 14 | (3) https://github.com/modestyachts/ImageNetV2. 15 | """ 16 | 17 | stats = { 18 | "imagenet_r": { 19 | "data_and_labels": "https://people.eecs.berkeley.edu/~hendrycks/imagenet-r.tar", 20 | "folder_name": "imagenet-r", 21 | }, 22 | "imagenet_a": { 23 | "data_and_labels": "https://people.eecs.berkeley.edu/~hendrycks/imagenet-a.tar", 24 | "folder_name": "imagenet-a", 25 | }, 26 | "imagenet_v2_matched-frequency": { 27 | "data_and_labels": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-matched-frequency.tar.gz", 28 | "folder_name": "imagenetv2-matched-frequency-format-val", 29 | }, 30 | "imagenet_v2_threshold0.7": { 31 | "data_and_labels": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-threshold0.7.tar.gz", 32 | "folder_name": "imagenetv2-threshold0.7-format-val", 33 | }, 34 | "imagenet_v2_topimages": { 35 | "data_and_labels": "https://s3-us-west-2.amazonaws.com/imagenetv2public/imagenetv2-top-images.tar.gz", 36 | "folder_name": "imagenetv2-topimages-format-val", 37 | }, 38 | } 39 | 40 | def __init__(self, root, data_name, version=None): 41 | self.data_name = data_name 42 | self.path_data_and_labels_tar = os.path.join( 43 | root, self.stats[data_name]["data_and_labels"].split("/")[-1] 44 | ) 45 | self.path_data_and_labels = os.path.join( 46 | root, self.stats[data_name]["folder_name"] 47 | ) 48 | 49 | self._download(root) 50 | 51 | self.image_folder = ImageFolder(self.path_data_and_labels) 52 | self.data = self.image_folder.samples 53 | self.targets = self.image_folder.targets 54 | 55 | def _download(self, root): 56 | download_url(url=self.stats[self.data_name]["data_and_labels"], root=root) 57 | 58 | if self._check_integrity(): 59 | print("Files already downloaded, verified, and uncompressed.") 60 | return 61 | self._uncompress(root) 62 | 63 | def _uncompress(self, root): 64 | with tarfile.open(self.path_data_and_labels_tar) as file: 65 | file.extractall(root) 66 | 67 | def _check_integrity(self) -> bool: 68 | if os.path.exists(self.path_data_and_labels): 69 | return True 70 | else: 71 | return False 72 | 73 | def __getitem__(self, index): 74 | path, target = self.data[index] 75 | img = self.image_folder.loader(path) 76 | return img, target 77 | 78 | def __len__(self): 79 | return len(self.data) 80 | 81 | 82 | """Some corruptions are referred to https://github.com/hendrycks/robustness/blob/master/ImageNet-C/create_c/make_imagenet_c.py""" 83 | 84 | 85 | class ImageNetSyntheticShift(object): 86 | """The class of synthetic corruptions/shifts introduced in ImageNet_C.""" 87 | 88 | def __init__( 89 | self, data_name, seed, severity=5, corruption_type=None, img_resolution=224 90 | ): 91 | assert "imagenet" in data_name 92 | 93 | if img_resolution == 224: 94 | from ttab.loads.datasets.imagenet.synthetic_224 import ( 95 | gaussian_noise, 96 | shot_noise, 97 | impulse_noise, 98 | defocus_blur, 99 | glass_blur, 100 | motion_blur, 101 | zoom_blur, 102 | snow, 103 | frost, 104 | fog, 105 | brightness, 106 | contrast, 107 | elastic_transform, 108 | pixelate, 109 | jpeg_compression, 110 | # for validation. 111 | speckle_noise, 112 | gaussian_blur, 113 | spatter, 114 | saturate, 115 | ) 116 | elif img_resolution == 64: 117 | from ttab.loads.datasets.imagenet.synthetic_64 import ( 118 | gaussian_noise, 119 | shot_noise, 120 | impulse_noise, 121 | defocus_blur, 122 | glass_blur, 123 | motion_blur, 124 | zoom_blur, 125 | snow, 126 | frost, 127 | fog, 128 | brightness, 129 | contrast, 130 | elastic_transform, 131 | pixelate, 132 | jpeg_compression, 133 | # for validation. 134 | speckle_noise, 135 | gaussian_blur, 136 | spatter, 137 | saturate, 138 | ) 139 | else: 140 | raise NotImplementedError( 141 | f"Invalid img_resolution for ImageNet: {img_resolution}" 142 | ) 143 | 144 | self.data_name = data_name 145 | self.base_data_name = data_name.split("_")[0] 146 | self.seed = seed 147 | self.severity = severity 148 | self.corruption_type = corruption_type 149 | self.dict_corruption = { 150 | "gaussian_noise": gaussian_noise, 151 | "shot_noise": shot_noise, 152 | "impulse_noise": impulse_noise, 153 | "defocus_blur": defocus_blur, 154 | "glass_blur": glass_blur, 155 | "motion_blur": motion_blur, 156 | "zoom_blur": zoom_blur, 157 | "snow": snow, 158 | "frost": frost, 159 | "fog": fog, 160 | "brightness": brightness, 161 | "contrast": contrast, 162 | "elastic_transform": elastic_transform, 163 | "pixelate": pixelate, 164 | "jpeg_compression": jpeg_compression, 165 | "speckle_noise": speckle_noise, 166 | "gaussian_blur": gaussian_blur, 167 | "spatter": spatter, 168 | "saturate": saturate, 169 | } 170 | if corruption_type is not None: 171 | assert ( 172 | corruption_type in self.dict_corruption.keys() 173 | ), f"{corruption_type} is out of range" 174 | self.random_state = np.random.RandomState(self.seed) 175 | 176 | def _apply_corruption(self, pil_img): 177 | if self.corruption_index is None or self.corruption_type == "all": 178 | corruption = self.random_state.choice(self.dict_corruption.values()) 179 | else: 180 | corruption = self.dict_corruption[self.corruption_type] 181 | 182 | return np.uint8( 183 | corruption(pil_img, random_state=self.random_state, severity=self.severity) 184 | ) 185 | 186 | def apply(self, pil_img): 187 | return self._apply_corruption(pil_img) 188 | -------------------------------------------------------------------------------- /ttab/loads/datasets/imagenet/data_aug_imagenet.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image, ImageOps 6 | from torchvision import transforms 7 | from ttab.configs.datasets import dataset_defaults 8 | 9 | ## https://github.com/google-research/augmix 10 | 11 | 12 | def _augmix_aug(x_orig: torch.Tensor, data_name: str) -> torch.Tensor: 13 | tensor_to_image, preprocess = get_ops(data_name) 14 | x_orig = tensor_to_image(x_orig.squeeze(0)) 15 | preaugment = transforms.Compose( 16 | [ 17 | transforms.RandomResizedCrop(224), 18 | transforms.RandomHorizontalFlip(), 19 | ] 20 | ) 21 | x_orig = preaugment(x_orig) 22 | x_processed = preprocess(x_orig) 23 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 24 | m = np.float32(np.random.beta(1.0, 1.0)) 25 | 26 | mix = torch.zeros_like(x_processed) 27 | for i in range(3): 28 | x_aug = x_orig.copy() 29 | for _ in range(np.random.randint(1, 4)): 30 | x_aug = np.random.choice(augmentations)(x_aug) 31 | mix += w[i] * preprocess(x_aug) 32 | mix = m * x_processed + (1 - m) * mix 33 | return mix 34 | 35 | 36 | aug_imagenet = _augmix_aug 37 | 38 | 39 | def autocontrast(pil_img, level=None): 40 | return ImageOps.autocontrast(pil_img) 41 | 42 | 43 | def equalize(pil_img, level=None): 44 | return ImageOps.equalize(pil_img) 45 | 46 | 47 | def rotate(pil_img, level): 48 | degrees = int_parameter(rand_lvl(level), 30) 49 | if np.random.uniform() > 0.5: 50 | degrees = -degrees 51 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128) 52 | 53 | 54 | def solarize(pil_img, level): 55 | level = int_parameter(rand_lvl(level), 256) 56 | return ImageOps.solarize(pil_img, 256 - level) 57 | 58 | 59 | def shear_x(pil_img, level): 60 | level = float_parameter(rand_lvl(level), 0.3) 61 | if np.random.uniform() > 0.5: 62 | level = -level 63 | return pil_img.transform( 64 | (224, 224), 65 | Image.AFFINE, 66 | (1, level, 0, 0, 1, 0), 67 | resample=Image.BILINEAR, 68 | fillcolor=128, 69 | ) 70 | 71 | 72 | def shear_y(pil_img, level): 73 | level = float_parameter(rand_lvl(level), 0.3) 74 | if np.random.uniform() > 0.5: 75 | level = -level 76 | return pil_img.transform( 77 | (224, 224), 78 | Image.AFFINE, 79 | (1, 0, 0, level, 1, 0), 80 | resample=Image.BILINEAR, 81 | fillcolor=128, 82 | ) 83 | 84 | 85 | def translate_x(pil_img, level): 86 | level = int_parameter(rand_lvl(level), 224 / 3) 87 | if np.random.random() > 0.5: 88 | level = -level 89 | return pil_img.transform( 90 | (224, 224), 91 | Image.AFFINE, 92 | (1, 0, level, 0, 1, 0), 93 | resample=Image.BILINEAR, 94 | fillcolor=128, 95 | ) 96 | 97 | 98 | def translate_y(pil_img, level): 99 | level = int_parameter(rand_lvl(level), 224 / 3) 100 | if np.random.random() > 0.5: 101 | level = -level 102 | return pil_img.transform( 103 | (224, 224), 104 | Image.AFFINE, 105 | (1, 0, 0, 0, 1, level), 106 | resample=Image.BILINEAR, 107 | fillcolor=128, 108 | ) 109 | 110 | 111 | def posterize(pil_img, level): 112 | level = int_parameter(rand_lvl(level), 4) 113 | return ImageOps.posterize(pil_img, 4 - level) 114 | 115 | 116 | def int_parameter(level, maxval): 117 | """Helper function to scale `val` between 0 and maxval . 118 | Args: 119 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 120 | maxval: Maximum value that the operation can have. This will be scaled 121 | to level/PARAMETER_MAX. 122 | Returns: 123 | An int that results from scaling `maxval` according to `level`. 124 | """ 125 | return int(level * maxval / 10) 126 | 127 | 128 | def float_parameter(level, maxval): 129 | """Helper function to scale `val` between 0 and maxval . 130 | Args: 131 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 132 | maxval: Maximum value that the operation can have. This will be scaled 133 | to level/PARAMETER_MAX. 134 | Returns: 135 | A float that results from scaling `maxval` according to `level`. 136 | """ 137 | return float(level) * maxval / 10.0 138 | 139 | 140 | def rand_lvl(n): 141 | return np.random.uniform(low=0.1, high=n) 142 | 143 | 144 | augmentations = [ 145 | autocontrast, 146 | equalize, 147 | lambda x: rotate(x, 1), 148 | lambda x: solarize(x, 1), 149 | lambda x: shear_x(x, 1), 150 | lambda x: shear_y(x, 1), 151 | lambda x: translate_x(x, 1), 152 | lambda x: translate_y(x, 1), 153 | lambda x: posterize(x, 1), 154 | ] 155 | 156 | def get_ops(data_name: str) -> Tuple[Callable, Callable]: 157 | """Get the operations to be applied when defining transforms.""" 158 | unnormalize = transforms.Compose( 159 | [ 160 | transforms.Normalize( 161 | mean=[0.0, 0.0, 0.0], 162 | std=[1.0 / v for v in dataset_defaults[data_name]["statistics"]["std"]], 163 | ), 164 | transforms.Normalize( 165 | mean=[-v for v in dataset_defaults[data_name]["statistics"]["mean"]], 166 | std=[1.0, 1.0, 1.0], 167 | ), 168 | ] 169 | ) 170 | 171 | tensor_to_image = transforms.Compose([unnormalize, transforms.ToPILImage()]) 172 | preprocess = transforms.Compose( 173 | [ 174 | transforms.ToTensor(), 175 | transforms.Normalize( 176 | dataset_defaults[data_name]["statistics"]["mean"], dataset_defaults[data_name]["statistics"]["std"] 177 | ), 178 | ] 179 | ) 180 | return tensor_to_image, preprocess 181 | 182 | 183 | def tr_transforms_imagenet(image: torch.Tensor, data_name: str) -> torch.Tensor: 184 | """ 185 | Data augmentation for input images. 186 | args: 187 | inputs: 188 | image: tensor [n_channel, H, W] 189 | outputs: 190 | augment_image: tensor [1, n_channel, H, W] 191 | """ 192 | tensor_to_image, preprocess = get_ops(data_name) 193 | image = tensor_to_image(image) 194 | 195 | preaugment = transforms.Compose( 196 | [ 197 | transforms.RandomResizedCrop(224), 198 | transforms.RandomHorizontalFlip(), 199 | ] 200 | ) 201 | augment_image = preaugment(image) 202 | augment_image = preprocess(augment_image) 203 | 204 | return augment_image 205 | -------------------------------------------------------------------------------- /ttab/loads/datasets/loaders.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Iterable, Optional, Tuple, Type, Union 3 | 4 | import torch 5 | from ttab.api import Batch, PyTorchDataset 6 | from ttab.loads.datasets.datasets import WrapperDataset 7 | from ttab.scenarios import Scenario 8 | 9 | D = Union[torch.utils.data.Dataset, PyTorchDataset] 10 | 11 | 12 | class BaseLoader(object): 13 | def __init__(self, dataset: PyTorchDataset): 14 | self.dataset = dataset 15 | 16 | def iterator( 17 | self, 18 | batch_size: int, 19 | shuffle: bool = True, 20 | repeat: bool = False, 21 | ref_num_data: Optional[int] = None, 22 | num_workers: int = 1, 23 | sampler: Optional[torch.utils.data.Sampler] = None, 24 | generator: Optional[torch.Generator] = None, 25 | pin_memory: bool = True, 26 | drop_last: bool = True, 27 | ) -> Iterable[Tuple[int, float, Batch]]: 28 | yield from self.dataset.iterator( 29 | batch_size, 30 | shuffle, 31 | repeat, 32 | ref_num_data, 33 | num_workers, 34 | sampler, 35 | generator, 36 | pin_memory, 37 | drop_last, 38 | ) 39 | 40 | 41 | def _init_dataset(dataset: D, device: str) -> PyTorchDataset: 42 | if isinstance(dataset, torch.utils.data.Dataset): 43 | return WrapperDataset(dataset, device) 44 | else: 45 | return dataset 46 | 47 | 48 | def get_test_loader(dataset: D, device: str) -> Type[BaseLoader]: 49 | dataset: PyTorchDataset = _init_dataset(dataset, device) 50 | return BaseLoader(dataset) 51 | 52 | 53 | def get_auxiliary_loader(dataset: D, device: str) -> Type[BaseLoader]: 54 | dataset: PyTorchDataset = _init_dataset(dataset, device) 55 | return BaseLoader(dataset) 56 | -------------------------------------------------------------------------------- /ttab/loads/datasets/mnist/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import numpy as np 3 | from ttab.loads.datasets.datasets import ImageArrayDataset 4 | 5 | """ColoredMNIST is borrowed from https://github.com/facebookresearch/DomainBed/blob/main/domainbed/datasets.py""" 6 | 7 | 8 | class ColoredSyntheticShift(object): 9 | """The class of synthetic colored shifts introduced in ColoredMNIST.""" 10 | 11 | def __init__(self, data_name, seed, color_flip_prob: float = 0.25) -> None: 12 | self.data_name = data_name 13 | self.base_data_name = data_name.split("_")[0] 14 | self.seed = seed 15 | self.color_flip_prob = color_flip_prob 16 | assert ( 17 | 0 <= self.color_flip_prob <= 1 18 | ), f"{self.color_flip_prob} is out of range." 19 | 20 | def _color_grayscale_arr(self, arr, red=True): 21 | """Converts grayscale image to either red or green""" 22 | assert arr.ndim == 2 23 | dtype = arr.dtype 24 | h, w = arr.shape 25 | arr = np.reshape(arr, [h, w, 1]) 26 | if red: 27 | arr = np.concatenate([arr, np.zeros((h, w, 2), dtype=dtype)], axis=2) 28 | else: 29 | arr = np.concatenate( 30 | [ 31 | np.zeros((h, w, 1), dtype=dtype), 32 | arr, 33 | np.zeros((h, w, 1), dtype=dtype), 34 | ], 35 | axis=2, 36 | ) 37 | return arr 38 | 39 | def apply(self, dataset, transform, target_transform): 40 | return self._apply_color( 41 | dataset, self.color_flip_prob, transform, target_transform 42 | ) 43 | 44 | def _apply_color(self, dataset, color_flip_prob, transform, target_transform): 45 | 46 | train_data = [] 47 | train_targets = [] 48 | val_data = [] 49 | val_targets = [] 50 | test_data = [] 51 | test_targets = [] 52 | for idx, (im, label) in enumerate(dataset): 53 | im_array = np.array(im) 54 | 55 | # Assign a binary label y to the image based on the digit 56 | binary_label = 0 if label < 5 else 1 57 | 58 | # Flip label with probability of `color_flip_prob` 59 | if np.random.uniform() < color_flip_prob: 60 | binary_label = binary_label ^ 1 61 | 62 | # Color the image either red or green according to its possibly flipped label 63 | color_red = binary_label == 0 64 | 65 | # Flip the color with a probability e that depends on the domain 66 | if idx < 30000: 67 | # 10% in the training environment 68 | if np.random.uniform() < 0.1: 69 | color_red = not color_red 70 | elif idx < 40000: 71 | # 10% in the in-distribution eval environment 72 | # val set should have the same distribution as the source domain 73 | if np.random.uniform() < 0.1: 74 | color_red = not color_red 75 | else: 76 | # 90% in the ood test environment 77 | if np.random.uniform() < 0.9: 78 | color_red = not color_red 79 | 80 | colored_arr = self._color_grayscale_arr(im_array, red=color_red) 81 | 82 | if idx < 30000: 83 | train_data.append(colored_arr) 84 | train_targets.append(binary_label) 85 | elif idx < 40000: 86 | val_data.append(colored_arr) 87 | val_targets.append(binary_label) 88 | else: 89 | test_data.append(colored_arr) 90 | test_targets.append(binary_label) 91 | 92 | classes = ["0-4", "5-9"] 93 | class_to_index = {"0-4": 0, "5-9": 1} 94 | train_dataset = ImageArrayDataset( 95 | data=train_data, 96 | targets=train_targets, 97 | classes=classes, 98 | class_to_index=class_to_index, 99 | transform=transform, 100 | target_transform=target_transform, 101 | ) 102 | val_dataset = ImageArrayDataset( 103 | data=val_data, 104 | targets=val_targets, 105 | classes=classes, 106 | class_to_index=class_to_index, 107 | transform=transform, 108 | target_transform=target_transform, 109 | ) 110 | test_dataset = ImageArrayDataset( 111 | data=test_data, 112 | targets=test_targets, 113 | classes=classes, 114 | class_to_index=class_to_index, 115 | transform=transform, 116 | target_transform=target_transform, 117 | ) 118 | 119 | return { 120 | "train": train_dataset, 121 | "val": val_dataset, 122 | "test": test_dataset, 123 | } 124 | -------------------------------------------------------------------------------- /ttab/loads/datasets/mnist/data_aug_mnist.py: -------------------------------------------------------------------------------- 1 | from typing import Callable, Tuple 2 | 3 | import numpy as np 4 | import torch 5 | from PIL import Image, ImageOps 6 | from torchvision import transforms 7 | from ttab.configs.datasets import dataset_defaults 8 | 9 | ## https://github.com/google-research/augmix 10 | 11 | 12 | def _augmix_aug(x_orig: torch.Tensor, data_name: str) -> torch.Tensor: 13 | input_size = x_orig.shape[-1] 14 | scale_size = 32 if input_size == 28 else 256 # input size is either 28 or 224 15 | padding = int((scale_size - input_size) / 2) 16 | tensor_to_image, preprocess = get_ops(data_name) 17 | 18 | x_orig = tensor_to_image(x_orig.squeeze(0)) 19 | preaugment = transforms.Compose( 20 | [ 21 | transforms.RandomCrop(input_size, padding=padding), 22 | transforms.RandomHorizontalFlip(), 23 | ] 24 | ) 25 | x_orig = preaugment(x_orig) 26 | x_processed = preprocess(x_orig) 27 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 28 | m = np.float32(np.random.beta(1.0, 1.0)) 29 | 30 | mix = torch.zeros_like(x_processed) 31 | for i in range(3): 32 | x_aug = x_orig.copy() 33 | for _ in range(np.random.randint(1, 4)): 34 | x_aug = np.random.choice(augmentations)(x_aug, input_size) 35 | mix += w[i] * preprocess(x_aug) 36 | mix = m * x_processed + (1 - m) * mix 37 | return mix 38 | 39 | 40 | aug_mnist = _augmix_aug 41 | 42 | 43 | def autocontrast(pil_img, input_size, level=None): 44 | return ImageOps.autocontrast(pil_img) 45 | 46 | 47 | def equalize(pil_img, input_size, level=None): 48 | return ImageOps.equalize(pil_img) 49 | 50 | 51 | def rotate(pil_img, input_size, level): 52 | degrees = int_parameter(rand_lvl(level), 30) 53 | if np.random.uniform() > 0.5: 54 | degrees = -degrees 55 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128) 56 | 57 | 58 | def solarize(pil_img, input_size, level): 59 | level = int_parameter(rand_lvl(level), 256) 60 | return ImageOps.solarize(pil_img, 256 - level) 61 | 62 | 63 | def shear_x(pil_img, input_size, level): 64 | level = float_parameter(rand_lvl(level), 0.3) 65 | if np.random.uniform() > 0.5: 66 | level = -level 67 | return pil_img.transform( 68 | (input_size, input_size), 69 | Image.AFFINE, 70 | (1, level, 0, 0, 1, 0), 71 | resample=Image.BILINEAR, 72 | fillcolor=128, 73 | ) 74 | 75 | 76 | def shear_y(pil_img, input_size, level): 77 | level = float_parameter(rand_lvl(level), 0.3) 78 | if np.random.uniform() > 0.5: 79 | level = -level 80 | return pil_img.transform( 81 | (input_size, input_size), 82 | Image.AFFINE, 83 | (1, 0, 0, level, 1, 0), 84 | resample=Image.BILINEAR, 85 | fillcolor=128, 86 | ) 87 | 88 | 89 | def translate_x(pil_img, input_size, level): 90 | level = int_parameter(rand_lvl(level), input_size / 3) 91 | if np.random.random() > 0.5: 92 | level = -level 93 | return pil_img.transform( 94 | (input_size, input_size), 95 | Image.AFFINE, 96 | (1, 0, level, 0, 1, 0), 97 | resample=Image.BILINEAR, 98 | fillcolor=128, 99 | ) 100 | 101 | 102 | def translate_y(pil_img, input_size, level): 103 | level = int_parameter(rand_lvl(level), input_size / 3) 104 | if np.random.random() > 0.5: 105 | level = -level 106 | return pil_img.transform( 107 | (input_size, input_size), 108 | Image.AFFINE, 109 | (1, 0, 0, 0, 1, level), 110 | resample=Image.BILINEAR, 111 | fillcolor=128, 112 | ) 113 | 114 | 115 | def posterize(pil_img, input_size, level): 116 | level = int_parameter(rand_lvl(level), 4) 117 | return ImageOps.posterize(pil_img, 4 - level) 118 | 119 | 120 | def int_parameter(level, maxval): 121 | """Helper function to scale `val` between 0 and maxval . 122 | Args: 123 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 124 | maxval: Maximum value that the operation can have. This will be scaled 125 | to level/PARAMETER_MAX. 126 | Returns: 127 | An int that results from scaling `maxval` according to `level`. 128 | """ 129 | return int(level * maxval / 10) 130 | 131 | 132 | def float_parameter(level, maxval): 133 | """Helper function to scale `val` between 0 and maxval . 134 | Args: 135 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 136 | maxval: Maximum value that the operation can have. This will be scaled 137 | to level/PARAMETER_MAX. 138 | Returns: 139 | A float that results from scaling `maxval` according to `level`. 140 | """ 141 | return float(level) * maxval / 10.0 142 | 143 | 144 | def rand_lvl(n): 145 | return np.random.uniform(low=0.1, high=n) 146 | 147 | 148 | augmentations = [ 149 | autocontrast, 150 | equalize, 151 | lambda x, y: rotate(x, y, 1), 152 | lambda x, y: solarize(x, y, 1), 153 | lambda x, y: shear_x(x, y, 1), 154 | lambda x, y: shear_y(x, y, 1), 155 | lambda x, y: translate_x(x, y, 1), 156 | lambda x, y: translate_y(x, y, 1), 157 | lambda x, y: posterize(x, y, 1), 158 | ] 159 | 160 | def get_ops(data_name: str) -> Tuple[Callable, Callable]: 161 | """Get the operations to be applied when defining transforms.""" 162 | unnormalize = transforms.Compose( 163 | [ 164 | transforms.Normalize( 165 | mean=[0.0, 0.0, 0.0], 166 | std=[1.0 / v for v in dataset_defaults[data_name]["statistics"]["std"]], 167 | ), 168 | transforms.Normalize( 169 | mean=[-v for v in dataset_defaults[data_name]["statistics"]["mean"]], 170 | std=[1.0, 1.0, 1.0], 171 | ), 172 | ] 173 | ) 174 | 175 | tensor_to_image = transforms.Compose([unnormalize, transforms.ToPILImage()]) 176 | preprocess = transforms.Compose( 177 | [ 178 | transforms.ToTensor(), 179 | transforms.Normalize( 180 | dataset_defaults[data_name]["statistics"]["mean"], dataset_defaults[data_name]["statistics"]["std"] 181 | ), 182 | ] 183 | ) 184 | return tensor_to_image, preprocess 185 | 186 | 187 | def tr_transforms_mnist(image: torch.Tensor, data_name: str) -> torch.Tensor: 188 | """ 189 | Data augmentation for input images. 190 | args: 191 | inputs: 192 | image: tensor [n_channel, H, W] 193 | outputs: 194 | augment_image: tensor [1, n_channel, H, W] 195 | """ 196 | input_size = image.shape[-1] 197 | scale_size = 32 if input_size == 28 else 256 # input size is either 28 or 224 198 | padding = int((scale_size - input_size) / 2) 199 | tensor_to_image, preprocess = get_ops(data_name) 200 | 201 | image = tensor_to_image(image) 202 | preaugment = transforms.Compose( 203 | [ 204 | transforms.RandomCrop(input_size, padding=padding), 205 | transforms.RandomHorizontalFlip(), 206 | ] 207 | ) 208 | augment_image = preaugment(image) 209 | augment_image = preprocess(augment_image) 210 | 211 | return augment_image 212 | -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/__init__.py -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/c_resource/frost1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost1.png -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/c_resource/frost2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost2.png -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/c_resource/frost3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost3.png -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/c_resource/frost4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost4.jpg -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/c_resource/frost5.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost5.jpg -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/c_resource/frost6.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/datasets/utils/c_resource/frost6.jpg -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/lmdb.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import sys 4 | 5 | import cv2 6 | import lmdb 7 | import numpy as np 8 | import torch.utils.data as data 9 | from PIL import Image 10 | 11 | from .serialize import loads 12 | 13 | if sys.version_info[0] == 2: 14 | import cPickle as pickle 15 | else: 16 | import pickle 17 | 18 | 19 | def be_ncwh_pt(x): 20 | return x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h) 21 | 22 | 23 | def uint8_to_float(x): 24 | x = x.permute(0, 3, 1, 2) # pytorch is (n,c,w,h) 25 | return x.float() / 128.0 - 1.0 26 | 27 | 28 | class LMDBPT(data.Dataset): 29 | """A class to load the LMDB file for extreme large datasets. 30 | Args: 31 | root (string): Either root directory for the database files, 32 | or a absolute path pointing to the file. 33 | classes (string or list): One of {'train', 'val', 'test'} or a list of 34 | categories to load. e,g. ['bedroom_train', 'church_train']. 35 | transform (callable, optional): A function/transform that 36 | takes in an PIL image and returns a transformed version. 37 | E.g, ``transforms.RandomCrop`` 38 | target_transform (callable, optional): 39 | A function/transform that takes in the target and transforms it. 40 | """ 41 | 42 | def __init__(self, root, transform=None, target_transform=None, is_image=True): 43 | self.root = os.path.expanduser(root) 44 | self.transform = transform 45 | self.target_transform = target_transform 46 | self.lmdb_files = self._get_valid_lmdb_files() 47 | 48 | # for each class, create an LSUNClassDataset 49 | self.dbs = [] 50 | for lmdb_file in self.lmdb_files: 51 | self.dbs.append( 52 | LMDBPTClass( 53 | root=lmdb_file, 54 | transform=transform, 55 | target_transform=target_transform, 56 | is_image=is_image, 57 | ) 58 | ) 59 | 60 | # build up indices. 61 | self.indices = np.cumsum([len(db) for db in self.dbs]) 62 | self.length = self.indices[-1] 63 | self._build_indices() 64 | self._prepare_target() 65 | 66 | def _get_valid_lmdb_files(self): 67 | """get valid lmdb based on given root.""" 68 | if not self.root.endswith(".lmdb"): 69 | files = [] 70 | for l in os.listdir(self.root): 71 | if "_" in l and "-lock" not in l: 72 | files.append(os.path.join(self.root, l)) 73 | return files 74 | else: 75 | return [self.root] 76 | 77 | def _build_indices(self): 78 | self.from_to_indices = enumerate(zip(self.indices[:-1], self.indices[1:])) 79 | 80 | def _get_matched_index(self, index): 81 | if len(list(self.from_to_indices)) == 0: 82 | return 0, index 83 | 84 | for ind, (from_index, to_index) in self.from_to_indices: 85 | if from_index <= index and index < to_index: 86 | return ind, index - from_index 87 | 88 | def __getitem__(self, index, apply_transform=True): 89 | block_index, item_index = self._get_matched_index(index) 90 | image, target = self.dbs[block_index].__getitem__(item_index, apply_transform) 91 | return image, target 92 | 93 | def __len__(self): 94 | return self.length 95 | 96 | def __repr__(self): 97 | fmt_str = "Dataset " + self.__class__.__name__ + "\n" 98 | fmt_str += " Number of datapoints: {}\n".format(self.__len__()) 99 | fmt_str += " Root Location: {}\n".format(self.root) 100 | tmp = " Transforms (if any): " 101 | fmt_str += "{0}{1}\n".format( 102 | tmp, self.transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 103 | ) 104 | tmp = " Target Transforms (if any): " 105 | fmt_str += "{0}{1}".format( 106 | tmp, self.target_transform.__repr__().replace("\n", "\n" + " " * len(tmp)) 107 | ) 108 | return fmt_str 109 | 110 | def _prepare_target(self): 111 | cache_file = self.root + "_targets_cache_" 112 | if os.path.isfile(cache_file): 113 | self.targets = pickle.load(open(cache_file, "rb")) 114 | else: 115 | self.targets = [ 116 | self.__getitem__(idx, apply_transform=False)[1] 117 | for idx in range(self.length) 118 | ] 119 | pickle.dump(self.targets, open(cache_file, "wb")) 120 | 121 | 122 | class LMDBPTClass(data.Dataset): 123 | def __init__(self, root, transform=None, target_transform=None, is_image=True): 124 | self.root = os.path.expanduser(root) 125 | self.transform = transform 126 | self.target_transform = target_transform 127 | self.is_image = is_image 128 | 129 | # init the placeholder for env and length. 130 | self.env = None 131 | self.length = self._get_tmp_length() 132 | 133 | def _open_lmdb(self): 134 | return lmdb.open( 135 | self.root, 136 | subdir=os.path.isdir(self.root), 137 | readonly=True, 138 | lock=False, 139 | readahead=False, 140 | # map_size=1099511627776 * 2, 141 | max_readers=1, 142 | meminit=False, 143 | ) 144 | 145 | def _get_tmp_length(self): 146 | env = lmdb.open( 147 | self.root, 148 | subdir=os.path.isdir(self.root), 149 | readonly=True, 150 | lock=False, 151 | readahead=False, 152 | # map_size=1099511627776 * 2, 153 | max_readers=1, 154 | meminit=False, 155 | ) 156 | with env.begin(write=False) as txn: 157 | length = txn.stat()["entries"] 158 | 159 | if txn.get(b"__keys__") is not None: 160 | length -= 1 161 | # clean everything. 162 | del env 163 | return length 164 | 165 | def _get_length(self): 166 | with self.env.begin(write=False) as txn: 167 | self.length = txn.stat()["entries"] 168 | 169 | if txn.get(b"__keys__") is not None: 170 | self.length -= 1 171 | 172 | def _prepare_cache(self): 173 | cache_file = self.root + "_cache_" 174 | if os.path.isfile(cache_file): 175 | self.keys = pickle.load(open(cache_file, "rb")) 176 | else: 177 | with self.env.begin(write=False) as txn: 178 | self.keys = [key for key, _ in txn.cursor() if key != b"__keys__"] 179 | pickle.dump(self.keys, open(cache_file, "wb")) 180 | 181 | def _decode_from_image(self, x): 182 | image = cv2.imdecode(x, cv2.IMREAD_COLOR).astype("uint8") 183 | return Image.fromarray(image, "RGB") 184 | 185 | def _decode_from_array(self, x): 186 | return Image.fromarray(x.reshape(3, 32, 32).transpose((1, 2, 0)), "RGB") 187 | 188 | def __getitem__(self, index, apply_transform=True): 189 | if self.env is None: 190 | # # open lmdb env. 191 | self.env = self._open_lmdb() 192 | # # get file stats. 193 | # self._get_length() 194 | # # prepare cache_file 195 | self._prepare_cache() 196 | 197 | # setup. 198 | env = self.env 199 | with env.begin(write=False) as txn: 200 | bin_file = txn.get(self.keys[index]) 201 | 202 | image, target = loads(bin_file) 203 | 204 | if apply_transform: 205 | if self.is_image: 206 | image = self._decode_from_image(image) 207 | else: 208 | image = self._decode_from_array(image) 209 | 210 | if self.transform is not None: 211 | image = self.transform(image) 212 | if self.target_transform is not None: 213 | target = self.target_transform(target) 214 | return image, target 215 | 216 | def __len__(self): 217 | return self.length 218 | 219 | def __repr__(self): 220 | return self.__class__.__name__ + " (" + self.root + ")" 221 | -------------------------------------------------------------------------------- /ttab/loads/datasets/utils/serialize.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import os 4 | 5 | __all__ = ["loads", "dumps"] 6 | 7 | 8 | def create_dummy_func(func, dependency): 9 | """ 10 | When a dependency of a function is not available, 11 | create a dummy function which throws ImportError when used. 12 | Args: 13 | func (str): name of the function. 14 | dependency (str or list[str]): name(s) of the dependency. 15 | Returns: 16 | function: a function object 17 | """ 18 | if isinstance(dependency, (list, tuple)): 19 | dependency = ",".join(dependency) 20 | 21 | def _dummy(*args, **kwargs): 22 | raise ImportError( 23 | "Cannot import '{}', therefore '{}' is not available".format( 24 | dependency, func 25 | ) 26 | ) 27 | 28 | return _dummy 29 | 30 | 31 | def dumps_msgpack(obj): 32 | """ 33 | Serialize an object. 34 | Returns: 35 | Implementation-dependent bytes-like object 36 | """ 37 | return msgpack.dumps(obj, use_bin_type=True) 38 | 39 | 40 | def loads_msgpack(buf): 41 | """ 42 | Args: 43 | buf: the output of `dumps`. 44 | """ 45 | return msgpack.loads(buf, raw=False) 46 | 47 | 48 | def dumps_pyarrow(obj): 49 | """ 50 | Serialize an object. 51 | 52 | Returns: 53 | Implementation-dependent bytes-like object 54 | """ 55 | return pa.serialize(obj).to_buffer() 56 | 57 | 58 | def loads_pyarrow(buf): 59 | """ 60 | Args: 61 | buf: the output of `dumps`. 62 | """ 63 | return pa.deserialize(buf) 64 | 65 | 66 | try: 67 | # fixed in pyarrow 0.9: https://github.com/apache/arrow/pull/1223#issuecomment-359895666 68 | import pyarrow as pa 69 | except ImportError: 70 | pa = None 71 | dumps_pyarrow = create_dummy_func("dumps_pyarrow", ["pyarrow"]) # noqa 72 | loads_pyarrow = create_dummy_func("loads_pyarrow", ["pyarrow"]) # noqa 73 | 74 | try: 75 | import msgpack 76 | import msgpack_numpy 77 | 78 | msgpack_numpy.patch() 79 | except ImportError: 80 | assert pa is not None, "pyarrow is a dependency of tensorpack!" 81 | loads_msgpack = create_dummy_func( # noqa 82 | "loads_msgpack", ["msgpack", "msgpack_numpy"] 83 | ) 84 | dumps_msgpack = create_dummy_func( # noqa 85 | "dumps_msgpack", ["msgpack", "msgpack_numpy"] 86 | ) 87 | 88 | if os.environ.get("TENSORPACK_SERIALIZE", "msgpack") == "msgpack": 89 | loads = loads_msgpack 90 | dumps = dumps_msgpack 91 | else: 92 | loads = loads_pyarrow 93 | dumps = dumps_pyarrow 94 | -------------------------------------------------------------------------------- /ttab/loads/datasets/yearbook/data_aug_yearbook.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from PIL import Image, ImageOps 4 | from torchvision import transforms 5 | 6 | from ttab.configs.datasets import dataset_defaults 7 | 8 | ## https://github.com/google-research/augmix 9 | 10 | 11 | def _augmix_aug(x_orig: torch.Tensor, data_name: str) -> torch.Tensor: 12 | input_size = x_orig.shape[-1] 13 | scale_size = 36 if input_size == 32 else 256 # input size is either 32 or 224 14 | padding = int((scale_size - input_size) / 2) 15 | tensor_to_image = transforms.Compose([transforms.ToPILImage()]) 16 | preprocess = transforms.Compose([transforms.ToTensor()]) 17 | 18 | x_orig = tensor_to_image(x_orig.squeeze(0)) 19 | preaugment = transforms.Compose( 20 | [ 21 | transforms.RandomCrop(input_size, padding=padding), 22 | transforms.RandomHorizontalFlip(), 23 | ] 24 | ) 25 | x_orig = preaugment(x_orig) 26 | x_processed = preprocess(x_orig) 27 | w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0])) 28 | m = np.float32(np.random.beta(1.0, 1.0)) 29 | 30 | mix = torch.zeros_like(x_processed) 31 | for i in range(3): 32 | x_aug = x_orig.copy() 33 | for _ in range(np.random.randint(1, 4)): 34 | x_aug = np.random.choice(augmentations)(x_aug, input_size) 35 | mix += w[i] * preprocess(x_aug) 36 | mix = m * x_processed + (1 - m) * mix 37 | return mix 38 | 39 | 40 | aug_yearbook = _augmix_aug 41 | 42 | 43 | def autocontrast(pil_img, input_size, level=None): 44 | return ImageOps.autocontrast(pil_img) 45 | 46 | 47 | def equalize(pil_img, input_size, level=None): 48 | return ImageOps.equalize(pil_img) 49 | 50 | 51 | def rotate(pil_img, input_size, level): 52 | degrees = int_parameter(rand_lvl(level), 30) 53 | if np.random.uniform() > 0.5: 54 | degrees = -degrees 55 | return pil_img.rotate(degrees, resample=Image.BILINEAR, fillcolor=128) 56 | 57 | 58 | def solarize(pil_img, input_size, level): 59 | level = int_parameter(rand_lvl(level), 256) 60 | return ImageOps.solarize(pil_img, 256 - level) 61 | 62 | 63 | def shear_x(pil_img, input_size, level): 64 | level = float_parameter(rand_lvl(level), 0.3) 65 | if np.random.uniform() > 0.5: 66 | level = -level 67 | return pil_img.transform( 68 | (input_size, input_size), 69 | Image.AFFINE, 70 | (1, level, 0, 0, 1, 0), 71 | resample=Image.BILINEAR, 72 | fillcolor=128, 73 | ) 74 | 75 | 76 | def shear_y(pil_img, input_size, level): 77 | level = float_parameter(rand_lvl(level), 0.3) 78 | if np.random.uniform() > 0.5: 79 | level = -level 80 | return pil_img.transform( 81 | (input_size, input_size), 82 | Image.AFFINE, 83 | (1, 0, 0, level, 1, 0), 84 | resample=Image.BILINEAR, 85 | fillcolor=128, 86 | ) 87 | 88 | 89 | def translate_x(pil_img, input_size, level): 90 | level = int_parameter(rand_lvl(level), input_size / 3) 91 | if np.random.random() > 0.5: 92 | level = -level 93 | return pil_img.transform( 94 | (input_size, input_size), 95 | Image.AFFINE, 96 | (1, 0, level, 0, 1, 0), 97 | resample=Image.BILINEAR, 98 | fillcolor=128, 99 | ) 100 | 101 | 102 | def translate_y(pil_img, input_size, level): 103 | level = int_parameter(rand_lvl(level), input_size / 3) 104 | if np.random.random() > 0.5: 105 | level = -level 106 | return pil_img.transform( 107 | (input_size, input_size), 108 | Image.AFFINE, 109 | (1, 0, 0, 0, 1, level), 110 | resample=Image.BILINEAR, 111 | fillcolor=128, 112 | ) 113 | 114 | 115 | def posterize(pil_img, input_size, level): 116 | level = int_parameter(rand_lvl(level), 4) 117 | return ImageOps.posterize(pil_img, 4 - level) 118 | 119 | 120 | def int_parameter(level, maxval): 121 | """Helper function to scale `val` between 0 and maxval . 122 | Args: 123 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 124 | maxval: Maximum value that the operation can have. This will be scaled 125 | to level/PARAMETER_MAX. 126 | Returns: 127 | An int that results from scaling `maxval` according to `level`. 128 | """ 129 | return int(level * maxval / 10) 130 | 131 | 132 | def float_parameter(level, maxval): 133 | """Helper function to scale `val` between 0 and maxval . 134 | Args: 135 | level: Level of the operation that will be between [0, `PARAMETER_MAX`]. 136 | maxval: Maximum value that the operation can have. This will be scaled 137 | to level/PARAMETER_MAX. 138 | Returns: 139 | A float that results from scaling `maxval` according to `level`. 140 | """ 141 | return float(level) * maxval / 10.0 142 | 143 | 144 | def rand_lvl(n): 145 | return np.random.uniform(low=0.1, high=n) 146 | 147 | 148 | augmentations = [ 149 | autocontrast, 150 | equalize, 151 | lambda x, y: rotate(x, y, 1), 152 | lambda x, y: solarize(x, y, 1), 153 | lambda x, y: shear_x(x, y, 1), 154 | lambda x, y: shear_y(x, y, 1), 155 | lambda x, y: translate_x(x, y, 1), 156 | lambda x, y: translate_y(x, y, 1), 157 | lambda x, y: posterize(x, y, 1), 158 | ] 159 | 160 | 161 | def tr_transforms_yearbook(image: torch.Tensor, data_name: str) -> torch.Tensor: 162 | """ 163 | Data augmentation for input images. 164 | args: 165 | inputs: 166 | image: tensor [n_channel, H, W] 167 | outputs: 168 | augment_image: tensor [1, n_channel, H, W] 169 | """ 170 | input_size = image.shape[-1] 171 | scale_size = 36 if input_size == 32 else 256 # input size is either 28 or 224 172 | padding = int((scale_size - input_size) / 2) 173 | tensor_to_image = transforms.Compose([transforms.ToPILImage()]) 174 | preprocess = transforms.Compose([transforms.ToTensor()]) 175 | 176 | image = tensor_to_image(image) 177 | preaugment = transforms.Compose( 178 | [ 179 | transforms.RandomCrop(input_size, padding=padding), 180 | transforms.RandomHorizontalFlip(), 181 | ] 182 | ) 183 | augment_image = preaugment(image) 184 | augment_image = preprocess(augment_image) 185 | 186 | return augment_image 187 | -------------------------------------------------------------------------------- /ttab/loads/models/__init__.py: -------------------------------------------------------------------------------- 1 | from .cct import * 2 | from .resnet import * 3 | from .wideresnet import * 4 | -------------------------------------------------------------------------------- /ttab/loads/models/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/loads/models/utils/__init__.py -------------------------------------------------------------------------------- /ttab/loads/models/utils/embedder.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class Embedder(nn.Module): 5 | def __init__( 6 | self, 7 | word_embedding_dim=300, 8 | vocab_size=100000, 9 | padding_idx=1, 10 | pretrained_weight=None, 11 | embed_freeze=False, 12 | *args, 13 | **kwargs 14 | ): 15 | super(Embedder, self).__init__() 16 | self.embeddings = ( 17 | nn.Embedding.from_pretrained(pretrained_weight, freeze=embed_freeze) 18 | if pretrained_weight is not None 19 | else nn.Embedding(vocab_size, word_embedding_dim, padding_idx=padding_idx) 20 | ) 21 | self.embeddings.weight.requires_grad = not embed_freeze 22 | 23 | def forward_mask(self, mask): 24 | bsz, seq_len = mask.shape 25 | new_mask = mask.view(bsz, seq_len, 1) 26 | new_mask = new_mask.sum(-1) 27 | new_mask = new_mask > 0 28 | return new_mask 29 | 30 | def forward(self, x, mask=None): 31 | embed = self.embeddings(x) 32 | embed = ( 33 | embed 34 | if mask is None 35 | else embed * self.forward_mask(mask).unsqueeze(-1).float() 36 | ) 37 | return embed, mask 38 | 39 | @staticmethod 40 | def init_weight(m): 41 | if isinstance(m, nn.Linear): 42 | nn.init.trunc_normal_(m.weight, std=0.02) 43 | if isinstance(m, nn.Linear) and m.bias is not None: 44 | nn.init.constant_(m.bias, 0) 45 | else: 46 | nn.init.normal_(m.weight) 47 | -------------------------------------------------------------------------------- /ttab/loads/models/utils/helpers.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.nn.functional as F 4 | import logging 5 | 6 | _logger = logging.getLogger("train") 7 | 8 | 9 | def resize_pos_embed(posemb, posemb_new, num_tokens=1): 10 | # Copied from `timm` by Ross Wightman: 11 | # github.com/rwightman/pytorch-image-models 12 | # Rescale the grid of position embeddings when loading from state_dict. Adapted from 13 | # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224 14 | ntok_new = posemb_new.shape[1] 15 | if num_tokens: 16 | posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:] 17 | ntok_new -= num_tokens 18 | else: 19 | posemb_tok, posemb_grid = posemb[:, :0], posemb[0] 20 | gs_old = int(math.sqrt(len(posemb_grid))) 21 | gs_new = int(math.sqrt(ntok_new)) 22 | posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2) 23 | posemb_grid = F.interpolate(posemb_grid, size=(gs_new, gs_new), mode="bilinear") 24 | posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new * gs_new, -1) 25 | posemb = torch.cat([posemb_tok, posemb_grid], dim=1) 26 | return posemb 27 | 28 | 29 | def pe_check(model, state_dict, pe_key="classifier.positional_emb"): 30 | if ( 31 | pe_key is not None 32 | and pe_key in state_dict.keys() 33 | and pe_key in model.state_dict().keys() 34 | ): 35 | if model.state_dict()[pe_key].shape != state_dict[pe_key].shape: 36 | state_dict[pe_key] = resize_pos_embed( 37 | state_dict[pe_key], 38 | model.state_dict()[pe_key], 39 | num_tokens=model.classifier.num_tokens, 40 | ) 41 | return state_dict 42 | 43 | 44 | def fc_check(model, state_dict, fc_key="classifier.fc"): 45 | for key in [f"{fc_key}.weight", f"{fc_key}.bias"]: 46 | if ( 47 | key is not None 48 | and key in state_dict.keys() 49 | and key in model.state_dict().keys() 50 | ): 51 | if model.state_dict()[key].shape != state_dict[key].shape: 52 | _logger.warning(f"Removing {key}, number of classes has changed.") 53 | state_dict[key] = model.state_dict()[key] 54 | return state_dict 55 | -------------------------------------------------------------------------------- /ttab/loads/models/utils/stochastic_depth.py: -------------------------------------------------------------------------------- 1 | # Thanks to rwightman's timm package 2 | # github.com:rwightman/pytorch-image-models 3 | 4 | import torch 5 | import torch.nn as nn 6 | 7 | 8 | def drop_path(x, drop_prob: float = 0.0, training: bool = False): 9 | """ 10 | Obtained from: github.com:rwightman/pytorch-image-models 11 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 12 | This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, 13 | the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... 14 | See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for 15 | changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 16 | 'survival rate' as the argument. 17 | """ 18 | if drop_prob == 0.0 or not training: 19 | return x 20 | keep_prob = 1 - drop_prob 21 | shape = (x.shape[0],) + (1,) * ( 22 | x.ndim - 1 23 | ) # work with diff dim tensors, not just 2D ConvNets 24 | random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) 25 | random_tensor.floor_() # binarize 26 | output = x.div(keep_prob) * random_tensor 27 | return output 28 | 29 | 30 | class DropPath(nn.Module): 31 | """ 32 | Obtained from: github.com:rwightman/pytorch-image-models 33 | Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). 34 | """ 35 | 36 | def __init__(self, drop_prob=None): 37 | super(DropPath, self).__init__() 38 | self.drop_prob = drop_prob 39 | 40 | def forward(self, x): 41 | return drop_path(x, self.drop_prob, self.training) 42 | -------------------------------------------------------------------------------- /ttab/loads/models/utils/tokenizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class Tokenizer(nn.Module): 7 | def __init__( 8 | self, 9 | kernel_size, 10 | stride, 11 | padding, 12 | pooling_kernel_size=3, 13 | pooling_stride=2, 14 | pooling_padding=1, 15 | n_conv_layers=1, 16 | n_input_channels=3, 17 | n_output_channels=64, 18 | in_planes=64, 19 | activation=None, 20 | max_pool=True, 21 | conv_bias=False, 22 | ): 23 | super(Tokenizer, self).__init__() 24 | 25 | n_filter_list = ( 26 | [n_input_channels] 27 | + [in_planes for _ in range(n_conv_layers - 1)] 28 | + [n_output_channels] 29 | ) 30 | 31 | self.conv_layers = nn.Sequential( 32 | *[ 33 | nn.Sequential( 34 | nn.Conv2d( 35 | n_filter_list[i], 36 | n_filter_list[i + 1], 37 | kernel_size=(kernel_size, kernel_size), 38 | stride=(stride, stride), 39 | padding=(padding, padding), 40 | bias=conv_bias, 41 | ), 42 | nn.Identity() if activation is None else activation(), 43 | nn.MaxPool2d( 44 | kernel_size=pooling_kernel_size, 45 | stride=pooling_stride, 46 | padding=pooling_padding, 47 | ) 48 | if max_pool 49 | else nn.Identity(), 50 | ) 51 | for i in range(n_conv_layers) 52 | ] 53 | ) 54 | 55 | self.flattener = nn.Flatten(2, 3) 56 | self.apply(self.init_weight) 57 | 58 | def sequence_length(self, n_channels=3, height=224, width=224): 59 | return self.forward(torch.zeros((1, n_channels, height, width))).shape[1] 60 | 61 | def forward(self, x): 62 | return self.flattener(self.conv_layers(x)).transpose(-2, -1) 63 | 64 | @staticmethod 65 | def init_weight(m): 66 | if isinstance(m, nn.Conv2d): 67 | nn.init.kaiming_normal_(m.weight) 68 | 69 | 70 | class TextTokenizer(nn.Module): 71 | def __init__( 72 | self, 73 | kernel_size, 74 | stride, 75 | padding, 76 | pooling_kernel_size=3, 77 | pooling_stride=2, 78 | pooling_padding=1, 79 | embedding_dim=300, 80 | n_output_channels=128, 81 | activation=None, 82 | max_pool=True, 83 | *args, 84 | **kwargs 85 | ): 86 | super(TextTokenizer, self).__init__() 87 | 88 | self.max_pool = max_pool 89 | self.conv_layers = nn.Sequential( 90 | nn.Conv2d( 91 | 1, 92 | n_output_channels, 93 | kernel_size=(kernel_size, embedding_dim), 94 | stride=(stride, 1), 95 | padding=(padding, 0), 96 | bias=False, 97 | ), 98 | nn.Identity() if activation is None else activation(), 99 | nn.MaxPool2d( 100 | kernel_size=(pooling_kernel_size, 1), 101 | stride=(pooling_stride, 1), 102 | padding=(pooling_padding, 0), 103 | ) 104 | if max_pool 105 | else nn.Identity(), 106 | ) 107 | 108 | self.apply(self.init_weight) 109 | 110 | def seq_len(self, seq_len=32, embed_dim=300): 111 | return self.forward(torch.zeros((1, seq_len, embed_dim)))[0].shape[1] 112 | 113 | def forward_mask(self, mask): 114 | new_mask = mask.unsqueeze(1).float() 115 | cnn_weight = torch.ones( 116 | (1, 1, self.conv_layers[0].kernel_size[0]), 117 | device=mask.device, 118 | dtype=torch.float, 119 | ) 120 | new_mask = F.conv1d( 121 | new_mask, 122 | cnn_weight, 123 | None, 124 | self.conv_layers[0].stride[0], 125 | self.conv_layers[0].padding[0], 126 | 1, 127 | 1, 128 | ) 129 | if self.max_pool: 130 | new_mask = F.max_pool1d( 131 | new_mask, 132 | self.conv_layers[2].kernel_size[0], 133 | self.conv_layers[2].stride[0], 134 | self.conv_layers[2].padding[0], 135 | 1, 136 | False, 137 | False, 138 | ) 139 | new_mask = new_mask.squeeze(1) 140 | new_mask = new_mask > 0 141 | return new_mask 142 | 143 | def forward(self, x, mask=None): 144 | x = x.unsqueeze(1) 145 | x = self.conv_layers(x) 146 | x = x.transpose(1, 3).squeeze(1) 147 | if mask is not None: 148 | mask = self.forward_mask(mask).unsqueeze(-1).float() 149 | x = x * mask 150 | return x, mask 151 | 152 | @staticmethod 153 | def init_weight(m): 154 | if isinstance(m, nn.Conv2d): 155 | nn.init.kaiming_normal_(m.weight) 156 | -------------------------------------------------------------------------------- /ttab/loads/models/wideresnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import torch.nn.init as init 5 | 6 | 7 | def conv3x3(in_planes, out_planes, stride=1): 8 | return nn.Conv2d( 9 | in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=True 10 | ) 11 | 12 | 13 | def conv_init(m): 14 | classname = m.__class__.__name__ 15 | if classname.find("Conv") != -1: 16 | init.xavier_uniform_(m.weight, gain=np.sqrt(2)) 17 | init.constant_(m.bias, 0) 18 | elif classname.find("BatchNorm") != -1: 19 | init.constant_(m.weight, 1) 20 | init.constant_(m.bias, 0) 21 | 22 | 23 | class wide_basic(nn.Module): 24 | def __init__(self, in_planes, planes, dropout_rate, stride=1): 25 | super(wide_basic, self).__init__() 26 | self.bn1 = nn.BatchNorm2d(in_planes) 27 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, bias=True) 28 | self.dropout = nn.Dropout(p=dropout_rate) 29 | self.bn2 = nn.BatchNorm2d(planes) 30 | self.conv2 = nn.Conv2d( 31 | planes, planes, kernel_size=3, stride=stride, padding=1, bias=True 32 | ) 33 | 34 | self.shortcut = nn.Sequential() 35 | if stride != 1 or in_planes != planes: 36 | self.shortcut = nn.Sequential( 37 | nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=True), 38 | ) 39 | 40 | def forward(self, x): 41 | out = self.dropout(self.conv1(F.relu(self.bn1(x)))) 42 | out = self.conv2(F.relu(self.bn2(out))) 43 | out += self.shortcut(x) 44 | 45 | return out 46 | 47 | 48 | class WideResNet(nn.Module): 49 | def __init__( 50 | self, depth, widen_factor, num_classes, split_point="layer3", dropout_rate=0.3 51 | ): 52 | super(WideResNet, self).__init__() 53 | self.in_planes = 16 54 | assert split_point in ["layer2", "layer3", None], "invalid split position." 55 | self.split_point = split_point 56 | 57 | assert (depth - 4) % 6 == 0, "Wide-resnet depth should be 6n+4" 58 | n = (depth - 4) / 6 59 | k = widen_factor 60 | 61 | print("| Wide-Resnet %dx%d" % (depth, k)) 62 | nStages = [16, 16 * k, 32 * k, 64 * k] 63 | 64 | self.conv1 = conv3x3(3, nStages[0]) 65 | self.relu = nn.ReLU(inplace=True) 66 | self.layer1 = self._wide_layer( 67 | wide_basic, nStages[1], n, dropout_rate, stride=1 68 | ) 69 | self.layer2 = self._wide_layer( 70 | wide_basic, nStages[2], n, dropout_rate, stride=2 71 | ) 72 | self.layer3 = self._wide_layer( 73 | wide_basic, nStages[3], n, dropout_rate, stride=2 74 | ) 75 | self.bn1 = nn.BatchNorm2d(nStages[3], momentum=0.9) 76 | self.avgpool = nn.AvgPool2d(kernel_size=8) 77 | self.classifier = nn.Linear(nStages[3], num_classes, bias=False) 78 | 79 | def _wide_layer(self, block, planes, num_blocks, dropout_rate, stride): 80 | strides = [stride] + [1] * (int(num_blocks) - 1) 81 | layers = [] 82 | 83 | for stride in strides: 84 | layers.append(block(self.in_planes, planes, dropout_rate, stride)) 85 | self.in_planes = planes 86 | 87 | return nn.Sequential(*layers) 88 | 89 | def forward_features(self, x): 90 | """Forward function without classifier. Use gradient checkpointing to save memory.""" 91 | x = self.conv1(x) 92 | x = self.layer1(x) 93 | x = self.layer2(x) 94 | if self.split_point in ["layer3", None]: 95 | x = self.layer3(x) 96 | x = self.bn1(x) 97 | x = self.relu(x) 98 | x = self.avgpool(x) 99 | x = x.view(x.size(0), -1) 100 | 101 | return x 102 | 103 | def forward_head(self, x, pre_logits: bool = False): 104 | """Forward function for classifier. Use gridient checkpointing to save memory.""" 105 | if self.split_point == "layer2": 106 | x = self.layer3(x) 107 | x = self.bn1(x) 108 | x = self.relu(x) 109 | x = self.avgpool(x) 110 | x = x.view(x.size(0), -1) 111 | 112 | return x if pre_logits else self.classifier(x) 113 | 114 | def forward(self, x): 115 | x = self.forward_features(x) 116 | x = self.forward_head(x) 117 | return x 118 | -------------------------------------------------------------------------------- /ttab/model_adaptation/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from .bn_adapt import BNAdapt 3 | from .conjugate_pl import ConjugatePL 4 | from .cotta import CoTTA 5 | from .eata import EATA 6 | from .memo import MEMO 7 | from .no_adaptation import NoAdaptation 8 | from .note import NOTE 9 | from .sar import SAR 10 | from .shot import SHOT 11 | from .t3a import T3A 12 | from .tent import TENT 13 | from .ttt import TTT 14 | from .ttt_plus_plus import TTTPlusPlus 15 | from .rotta import Rotta 16 | 17 | 18 | def get_model_adaptation_method(adaptation_name): 19 | return { 20 | "no_adaptation": NoAdaptation, 21 | "tent": TENT, 22 | "bn_adapt": BNAdapt, 23 | "memo": MEMO, 24 | "shot": SHOT, 25 | "t3a": T3A, 26 | "ttt": TTT, 27 | "ttt_plus_plus": TTTPlusPlus, 28 | "note": NOTE, 29 | "sar": SAR, 30 | "conjugate_pl": ConjugatePL, 31 | "cotta": CoTTA, 32 | "eata": EATA, 33 | "rotta": Rotta, 34 | }[adaptation_name] 35 | -------------------------------------------------------------------------------- /ttab/model_adaptation/bn_adapt.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | import functools 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | import ttab.loads.define_dataset as define_dataset 9 | from torch.nn import functional as F 10 | from ttab.api import Batch 11 | from ttab.model_adaptation.base_adaptation import BaseAdaptation 12 | from ttab.model_selection.base_selection import BaseSelection 13 | from ttab.model_selection.metrics import Metrics 14 | from ttab.utils.auxiliary import fork_rng_with_seed 15 | from ttab.utils.logging import Logger 16 | from ttab.utils.timer import Timer 17 | 18 | 19 | class BNAdapt(BaseAdaptation): 20 | """ 21 | Improving robustness against common corruptions by covariate shift adaptation, 22 | https://arxiv.org/abs/2006.16971, 23 | https://github.com/bethgelab/robustness 24 | """ 25 | 26 | def __init__(self, meta_conf, model: nn.Module): 27 | super().__init__(meta_conf, model) 28 | 29 | def _prior_safety_check(self): 30 | 31 | assert ( 32 | self._meta_conf.adapt_prior is not None 33 | ), "the ratio of training set statistics is required" 34 | assert ( 35 | self._meta_conf.debug is not None 36 | ), "The state of debug should be specified" 37 | assert self._meta_conf.n_train_steps > 0, "adaptation steps requires >= 1." 38 | 39 | def _initialize_model(self, model: nn.Module): 40 | """Configure model for use with adaptation method.""" 41 | # disable grad. 42 | model.eval() 43 | model.requires_grad_(False) 44 | 45 | return model.to(self._meta_conf.device) 46 | 47 | def _initialize_trainable_parameters(self): 48 | """select target params for adaptation methods.""" 49 | self._adapt_module_names = [] 50 | 51 | for name_module, module in self._base_model.named_children(): 52 | if isinstance(module, nn.BatchNorm2d): 53 | self._adapt_module_names.append(name_module) 54 | 55 | def _initialize_optimizer(self, params) -> torch.optim.Optimizer: 56 | """No optimizer is used in BNAdapt.""" 57 | pass 58 | 59 | def _post_safety_check(self): 60 | 61 | param_grads = [p.requires_grad for p in (self._model.parameters())] 62 | has_any_params = any(param_grads) 63 | assert not has_any_params, "BNAdapt doesn't need adaptable params." 64 | 65 | has_bn = any( 66 | [ 67 | (name_module in self._adapt_module_names) 68 | for name_module, _ in (self._model.named_modules()) 69 | ] 70 | ) 71 | assert has_bn, "BNAdapt needs batch normalization layers." 72 | 73 | def initialize(self, seed: int): 74 | """Initialize the algorithm.""" 75 | self._initialize_trainable_parameters() 76 | self._model = self._initialize_model(model=copy.deepcopy(self._base_model)) 77 | self.adapt_setup() 78 | self._auxiliary_data_cls = define_dataset.ConstructAuxiliaryDataset( 79 | config=self._meta_conf 80 | ) 81 | self.criterion = nn.CrossEntropyLoss() 82 | if ( 83 | self._meta_conf.n_train_steps > 1 84 | ): # no need to do multiple-step adaptation in bn_adapt. 85 | self._meta_conf.n_train_steps = 1 86 | 87 | def _bn_swap(self, model: nn.Module, prior: float): 88 | """ 89 | replace the original BN layers in the model with new defined BN layer (AdaptiveBatchNorm). 90 | modifying BN forward pass. 91 | """ 92 | return AdaptiveBatchNorm.adapt_model( 93 | model, prior=prior, device=self._meta_conf.device 94 | ) 95 | 96 | def one_adapt_step( 97 | self, 98 | model: torch.nn.Module, 99 | timer: Timer, 100 | batch: Batch, 101 | random_seed: int = None, 102 | ): 103 | """adapt the model in one step.""" 104 | with timer("forward"): 105 | with fork_rng_with_seed(random_seed): 106 | y_hat = model(batch._x) 107 | loss = self.criterion(y_hat, batch._y) 108 | 109 | return {"loss": loss.item(), "yhat": y_hat} 110 | 111 | def adapt_setup(self): 112 | """adjust batch normalization layers.""" 113 | self._bn_swap(self._model, self._meta_conf.adapt_prior) 114 | 115 | def adapt_and_eval( 116 | self, 117 | episodic: bool, 118 | metrics: Metrics, 119 | model_selection_method: BaseSelection, 120 | current_batch, 121 | previous_batches: List[Batch], 122 | logger: Logger, 123 | timer: Timer, 124 | ): 125 | """The key entry of test-time adaptation.""" 126 | # some simple initialization. 127 | log = functools.partial(logger.log, display=self._meta_conf.debug) 128 | nbsteps = self._meta_conf.n_train_steps 129 | with timer("test_time_adaptation"): 130 | log(f"\tadapt the model for {nbsteps} steps.") 131 | for _ in range(nbsteps): 132 | adaptation_result = self.one_adapt_step( 133 | self._model, timer, current_batch, random_seed=self._meta_conf.seed 134 | ) 135 | 136 | with timer("evaluate_adaptation_result"): 137 | metrics.eval(current_batch._y, adaptation_result["yhat"]) 138 | if self._meta_conf.base_data_name in ["waterbirds"]: 139 | self.tta_loss_computer.loss( 140 | adaptation_result["yhat"], 141 | current_batch._y, 142 | current_batch._g, 143 | is_training=False, 144 | ) 145 | 146 | @property 147 | def name(self): 148 | return "bn_adapt" 149 | 150 | 151 | class AdaptiveBatchNorm(nn.Module): 152 | """Use the source statistics as a prior on the target statistics""" 153 | 154 | @staticmethod 155 | def find_bns(parent, prior, device): 156 | replace_mods = [] 157 | if parent is None: 158 | return [] 159 | for name, child in parent.named_children(): 160 | child.requires_grad_(False) 161 | if isinstance(child, nn.BatchNorm2d): 162 | module = AdaptiveBatchNorm(child, prior, device) 163 | replace_mods.append((parent, name, module)) 164 | else: 165 | replace_mods.extend(AdaptiveBatchNorm.find_bns(child, prior, device)) 166 | 167 | return replace_mods 168 | 169 | @staticmethod 170 | def adapt_model(model, prior, device): 171 | replace_mods = AdaptiveBatchNorm.find_bns(model, prior, device) 172 | print(f"| Found {len(replace_mods)} modules to be replaced.") 173 | for (parent, name, child) in replace_mods: 174 | setattr(parent, name, child) 175 | return model 176 | 177 | def __init__(self, layer, prior, device): 178 | assert prior >= 0 and prior <= 1 179 | 180 | super().__init__() 181 | self.layer = layer 182 | self.layer.eval() 183 | 184 | self.norm = nn.BatchNorm2d( 185 | self.layer.num_features, affine=False, momentum=1.0 186 | ).to(device) 187 | 188 | self.prior = prior 189 | 190 | def forward(self, input): 191 | self.norm(input) 192 | 193 | running_mean = ( 194 | self.prior * self.layer.running_mean 195 | + (1 - self.prior) * self.norm.running_mean 196 | ) 197 | running_var = ( 198 | self.prior * self.layer.running_var 199 | + (1 - self.prior) * self.norm.running_var 200 | ) 201 | 202 | return F.batch_norm( 203 | input, 204 | running_mean, 205 | running_var, 206 | self.layer.weight, 207 | self.layer.bias, 208 | False, 209 | 0, 210 | self.layer.eps, 211 | ) 212 | -------------------------------------------------------------------------------- /ttab/model_adaptation/no_adaptation.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | import warnings 4 | from typing import List 5 | 6 | import torch 7 | import torch.nn as nn 8 | import ttab.model_adaptation.utils as adaptation_utils 9 | from ttab.api import Batch 10 | from ttab.loads.define_model import load_pretrained_model 11 | from ttab.model_adaptation.base_adaptation import BaseAdaptation 12 | from ttab.model_selection.base_selection import BaseSelection 13 | from ttab.model_selection.metrics import Metrics 14 | from ttab.utils.logging import Logger 15 | from ttab.utils.timer import Timer 16 | 17 | 18 | class NoAdaptation(BaseAdaptation): 19 | """Standard test-time evaluation (no adaptation).""" 20 | 21 | def __init__(self, meta_conf, model: nn.Module): 22 | super().__init__(meta_conf, model) 23 | 24 | def convert_iabn(self, module: nn.Module, **kwargs): 25 | """ 26 | Recursively convert all BatchNorm to InstanceAwareBatchNorm. 27 | """ 28 | module_output = module 29 | if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)): 30 | IABN = ( 31 | adaptation_utils.InstanceAwareBatchNorm2d 32 | if isinstance(module, nn.BatchNorm2d) 33 | else adaptation_utils.InstanceAwareBatchNorm1d 34 | ) 35 | module_output = IABN( 36 | num_channels=module.num_features, 37 | k=self._meta_conf.iabn_k, 38 | eps=module.eps, 39 | momentum=module.momentum, 40 | threshold=self._meta_conf.threshold_note, 41 | affine=module.affine, 42 | ) 43 | 44 | module_output._bn = copy.deepcopy(module) 45 | 46 | for name, child in module.named_children(): 47 | module_output.add_module(name, self.convert_iabn(child, **kwargs)) 48 | del module 49 | return module_output 50 | 51 | def _initialize_model(self, model: nn.Module): 52 | """Configure model for adaptation.""" 53 | if hasattr(self._meta_conf, "iabn") and self._meta_conf.iabn: 54 | # check BN layers 55 | bn_flag = False 56 | for name_module, module in model.named_modules(): 57 | if isinstance(module, (nn.BatchNorm2d, nn.BatchNorm1d)): 58 | bn_flag = True 59 | if not bn_flag: 60 | warnings.warn( 61 | "IABN needs bn layers, while there is no bn in the base model." 62 | ) 63 | self.convert_iabn(model) 64 | load_pretrained_model(self._meta_conf, model) 65 | model.eval() 66 | return model.to(self._meta_conf.device) 67 | 68 | def _post_safety_check(self): 69 | pass 70 | 71 | def initialize(self, seed: int): 72 | """Initialize the algorithm.""" 73 | self._model = self._initialize_model(model=copy.deepcopy(self._base_model)) 74 | 75 | def adapt_and_eval( 76 | self, 77 | episodic: bool, 78 | metrics: Metrics, 79 | model_selection_method: BaseSelection, 80 | current_batch: Batch, 81 | previous_batches: List[Batch], 82 | logger: Logger, 83 | timer: Timer, 84 | ): 85 | """The key entry of test-time adaptation.""" 86 | # some simple initialization. 87 | with timer("test_time_adaptation"): 88 | with torch.no_grad(): 89 | y_hat = self._model(current_batch._x) 90 | 91 | with timer("evaluate_adaptation_result"): 92 | metrics.eval(current_batch._y, y_hat) 93 | if self._meta_conf.base_data_name in ["waterbirds"]: 94 | self.tta_loss_computer.loss( 95 | y_hat, current_batch._y, current_batch._g, is_training=False 96 | ) 97 | 98 | @property 99 | def name(self): 100 | return "no_adaptation" 101 | -------------------------------------------------------------------------------- /ttab/model_selection/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from .last_iterate import LastIterate 4 | from .oracle_model_selection import OracleModelSelection 5 | 6 | 7 | def get_model_selection_method(selection_name): 8 | return { 9 | "last_iterate": LastIterate, 10 | "oracle_model_selection": OracleModelSelection, 11 | }[selection_name] 12 | -------------------------------------------------------------------------------- /ttab/model_selection/base_selection.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, List 3 | 4 | from ttab.api import Batch 5 | 6 | 7 | class BaseSelection(object): 8 | def __init__(self, meta_conf, model_adaptation_method): 9 | self.meta_conf = meta_conf 10 | self.model = model_adaptation_method.copy_model() 11 | self.model.to(self.meta_conf.device) 12 | 13 | self.initialize() 14 | 15 | def initialize(self): 16 | pass 17 | 18 | def clean_up(self): 19 | pass 20 | 21 | def save_state(self): 22 | pass 23 | 24 | def select_state( 25 | self, 26 | current_batch: Batch, 27 | previous_batches: List[Batch], 28 | ) -> Dict[str, Any]: 29 | pass 30 | -------------------------------------------------------------------------------- /ttab/model_selection/last_iterate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | from typing import Any, Dict 4 | 5 | from ttab.model_selection.base_selection import BaseSelection 6 | 7 | 8 | class LastIterate(BaseSelection): 9 | """Naively return the model generated from the last iterate of adaptation.""" 10 | 11 | def __init__(self, meta_conf, model_adaptation_method): 12 | super().__init__(meta_conf, model_adaptation_method) 13 | 14 | def initialize(self): 15 | if hasattr(self.model, "ssh"): 16 | self.model.ssh.eval() 17 | self.model.main_model.eval() 18 | else: 19 | self.model.eval() 20 | 21 | self.optimal_state = None 22 | 23 | def clean_up(self): 24 | self.optimal_state = None 25 | 26 | def save_state(self, state, current_batch): 27 | self.optimal_state = state 28 | 29 | def select_state(self) -> Dict[str, Any]: 30 | """return the optimal state and sync the model defined in the model selection method.""" 31 | return self.optimal_state 32 | 33 | @property 34 | def name(self): 35 | return "last_iterate" 36 | -------------------------------------------------------------------------------- /ttab/model_selection/metrics.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import torch 3 | from ttab.utils.stat_tracker import RuntimeTracker 4 | 5 | task2metrics = {"classification": ["cross_entropy", "accuracy_top1"]} 6 | auxiliary_metrics_dict = { 7 | "preadapted_cross_entropy": "cross_entropy", 8 | "preadapted_accuracy_top1": "accuracy_top1", 9 | } 10 | 11 | 12 | class Metrics(object): 13 | def __init__(self, scenario) -> None: 14 | self._conf = scenario 15 | self._init_metrics() 16 | 17 | def _init_metrics(self) -> None: 18 | self._metrics = task2metrics[self._conf.task] 19 | self.tracker = RuntimeTracker(metrics_to_track=self._metrics) 20 | self._primary_metrics = self._metrics[0] 21 | 22 | def init_auxiliary_metric(self, metric_name: str): 23 | self._metrics.append(metric_name) 24 | self.tracker.add_stat(metric_name) 25 | 26 | @torch.no_grad() 27 | def eval(self, y: torch.Tensor, y_hat: torch.Tensor) -> None: 28 | results = dict() 29 | for metric_name in self._metrics: 30 | if not metric_name in auxiliary_metrics_dict.keys(): 31 | results[metric_name] = eval(metric_name)(y, y_hat) 32 | else: 33 | continue 34 | self.tracker.update_metrics(results, n_samples=y.size(0)) 35 | return results 36 | 37 | @torch.no_grad() 38 | def eval_auxiliary_metric( 39 | self, y: torch.Tensor, y_hat: torch.Tensor, metric_name: str 40 | ): 41 | assert ( 42 | metric_name in self._metrics 43 | ), "The target metric must be in the list of metrics." 44 | results = dict() 45 | results[metric_name] = eval(auxiliary_metrics_dict[metric_name])(y, y_hat) 46 | self.tracker.update_metrics(results, n_samples=y.size(0)) 47 | return results 48 | 49 | 50 | """list some common metrics.""" 51 | 52 | 53 | def _accuracy(target, output, topk): 54 | """Computes the precision@k for the specified values of k""" 55 | batch_size = target.size(0) 56 | 57 | _, pred = output.topk(topk, 1, True, True) 58 | pred = pred.t() 59 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 60 | 61 | correct_k = correct[:topk].reshape(-1).float().sum(0, keepdim=True) 62 | return correct_k.mul_(100.0 / batch_size).item() 63 | 64 | 65 | def accuracy_top1(target, output, topk=1): 66 | """Computes the precision@k for the specified values of k""" 67 | return _accuracy(target, output, topk) 68 | 69 | 70 | def accuracy_top5(target, output, topk=5): 71 | """Computes the precision@k for the specified values of k""" 72 | return _accuracy(target, output, topk) 73 | 74 | 75 | cross_entropy_loss = torch.nn.CrossEntropyLoss() 76 | 77 | 78 | def cross_entropy(target, output): 79 | """Cross entropy loss""" 80 | return cross_entropy_loss(output, target).item() 81 | -------------------------------------------------------------------------------- /ttab/model_selection/oracle_model_selection.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | from typing import Any, Dict 4 | 5 | import torch 6 | from ttab.model_selection.base_selection import BaseSelection 7 | from ttab.model_selection.metrics import accuracy_top1, cross_entropy 8 | 9 | 10 | class OracleModelSelection(BaseSelection): 11 | """grid-search the best adaptation result per batch (given a sufficiently long adaptation 12 | steps and a single learning rate in each step, save the best checkpoint 13 | and its optimizer states after iterating over all adaptation steps)""" 14 | 15 | def __init__(self, meta_conf, model_adaptation_method): 16 | super().__init__(meta_conf, model_adaptation_method) 17 | 18 | def initialize(self): 19 | if hasattr(self.model, "ssh"): 20 | self.model.ssh.eval() 21 | self.model.main_model.eval() 22 | else: 23 | self.model.eval() 24 | 25 | self.optimal_state = None 26 | self.current_batch_best_acc = 0 27 | self.current_batch_coupled_ent = None 28 | 29 | def clean_up(self): 30 | self.optimal_state = None 31 | self.current_batch_best_acc = 0 32 | self.current_batch_coupled_ent = None 33 | 34 | def save_state(self, state, current_batch): 35 | """Selectively save state for current batch of data.""" 36 | batch_best_acc = self.current_batch_best_acc 37 | coupled_ent = self.current_batch_coupled_ent 38 | 39 | if not hasattr(self.model, "ssh"): 40 | self.model.load_state_dict(state["model"]) 41 | with torch.no_grad(): 42 | outputs = self.model(current_batch._x) 43 | else: 44 | self.model.main_model.load_state_dict(state["main_model"]) 45 | with torch.no_grad(): 46 | outputs = self.model.main_model(current_batch._x) 47 | 48 | current_acc = self.cal_acc(current_batch._y, outputs) 49 | if (self.optimal_state is None) or (current_acc > batch_best_acc): 50 | self.current_batch_best_acc = current_acc 51 | self.current_batch_coupled_ent = self.cal_ent(current_batch._y, outputs) 52 | state["yhat"] = outputs 53 | self.optimal_state = state 54 | elif current_acc == batch_best_acc: 55 | # compare cross entropy 56 | assert coupled_ent is not None, "Cross entropy value cannot be none." 57 | current_ent = self.cal_ent(current_batch._y, outputs) 58 | if current_ent < coupled_ent: 59 | self.current_batch_coupled_ent = current_ent 60 | state["yhat"] = outputs 61 | self.optimal_state = state 62 | 63 | def cal_acc(self, targets, outputs): 64 | return accuracy_top1(targets, outputs) 65 | 66 | def cal_ent(self, targets, outputs): 67 | return cross_entropy(targets, outputs) 68 | 69 | def select_state(self) -> Dict[str, Any]: 70 | """return the optimal state and sync the model defined in the model selection method.""" 71 | if not hasattr(self.model, "ssh"): 72 | self.model.load_state_dict(self.optimal_state["model"]) 73 | else: 74 | self.model.main_model.load_state_dict(self.optimal_state["main_model"]) 75 | self.model.ssh.load_state_dict(self.optimal_state["ssh"]) 76 | return self.optimal_state 77 | 78 | @property 79 | def name(self): 80 | return "oracle_model_selection" 81 | -------------------------------------------------------------------------------- /ttab/scenarios/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List, NamedTuple, Union 3 | 4 | from ttab.loads.datasets.dataset_shifts import ( 5 | NaturalShiftProperty, 6 | NoShiftProperty, 7 | SyntheticShiftProperty, 8 | ) 9 | 10 | 11 | class TestDomain(NamedTuple): 12 | """ 13 | The definition of TestDomain: shift property (i.e., P(a^{1:K})) for each domain and its sampling strategy. 14 | 15 | Each data_name follow the pattern of either 1) , or 2), _, or 3) __, 16 | where is the base_data_name, is the shift_name, and is the shift version. 17 | """ 18 | 19 | base_data_name: str 20 | data_name: str 21 | shift_type: str # shift between data_name and base_data_name 22 | shift_property: Union[SyntheticShiftProperty, NaturalShiftProperty, NoShiftProperty] 23 | 24 | domain_sampling_name: str = "uniform" # ['uniform', 'label_skew'] 25 | domain_sampling_value: float = None # hyper-parameter for domain_sampling_name 26 | domain_sampling_ratio: float = 1.0 27 | 28 | 29 | class HomogeneousNoMixture(NamedTuple): 30 | """ 31 | Only consider the shift on P(a^{1:K}) in Figure 6 of the paper, but no label shift. 32 | """ 33 | 34 | # no mixture 35 | has_mixture: bool = False 36 | 37 | 38 | class HeterogeneousNoMixture(NamedTuple): 39 | """ 40 | Only consider the shift on P(a^{1:K}) with label shift. 41 | 42 | We use this setting to evaluate TTA methods under the continual distribution shift setting in Table 4 of the paper. 43 | """ 44 | 45 | # no mixture 46 | has_mixture: bool = False 47 | non_iid_pattern: str = "class_wise_over_domain" 48 | non_iid_ness: float = 100 49 | 50 | 51 | class InOutMixture(NamedTuple): 52 | """ 53 | Mix the source domain (left) with the target domain (right). 54 | """ 55 | 56 | # mix one in-domain (left) with one out out-domain (right) 57 | has_mixture: bool = True 58 | ratio: float = 0.5 # for left domain 59 | 60 | 61 | class CrossMixture(NamedTuple): 62 | """ 63 | Mix multiple target domains (right). Consider shuffle data across domains. 64 | """ 65 | 66 | # cross-shuffle test domains. 67 | has_mixture: bool = True 68 | 69 | 70 | class TestCase(NamedTuple): 71 | """ 72 | Defines the interaction across domains and some necessary setups in the test-time. 73 | """ 74 | 75 | inter_domain: Union[ 76 | HomogeneousNoMixture, HeterogeneousNoMixture, InOutMixture, CrossMixture 77 | ] 78 | batch_size: int = 32 79 | data_wise: str = "sample_wise" 80 | offline_pre_adapt: bool = False 81 | episodic: bool = True 82 | intra_domain_shuffle: bool = False 83 | 84 | 85 | class Scenario(NamedTuple): 86 | """ 87 | Defines a distribution shift scenario in practice. More details can be found in Setion 4 of the paper. 88 | """ 89 | 90 | task: str 91 | model_name: str 92 | model_adaptation_method: str 93 | model_selection_method: str 94 | 95 | base_data_name: str # test dataset (base type). 96 | src_data_name: str # name of source domain 97 | test_domains: List[TestDomain] # a list of domain (will be evaluated in order) 98 | test_case: TestCase 99 | -------------------------------------------------------------------------------- /ttab/scenarios/default_scenarios.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from ttab.loads.datasets.dataset_shifts import SyntheticShiftProperty 4 | from ttab.scenarios import HomogeneousNoMixture, Scenario, TestCase, TestDomain 5 | 6 | default_scenarios = { 7 | "S1": Scenario( 8 | task="classification", 9 | model_name="resnet26", 10 | model_adaptation_method="tent", 11 | model_selection_method="last_iterate", 12 | base_data_name="cifar10", 13 | src_data_name="cifar10", 14 | test_domains=[ 15 | TestDomain( 16 | base_data_name="cifar10", 17 | data_name="cifar10_c_deterministic-gaussian_noise-5", 18 | shift_type="synthetic", 19 | shift_property=SyntheticShiftProperty( 20 | shift_degree=5, 21 | shift_name="gaussian_noise", 22 | version="deterministic", 23 | has_shift=True, 24 | ), 25 | domain_sampling_name="uniform", 26 | domain_sampling_value=None, 27 | domain_sampling_ratio=1.0, 28 | ) 29 | ], 30 | test_case=TestCase( 31 | inter_domain=HomogeneousNoMixture(has_mixture=False), 32 | batch_size=64, 33 | data_wise="batch_wise", 34 | offline_pre_adapt=False, 35 | episodic=False, 36 | intra_domain_shuffle=True, 37 | ), 38 | ), 39 | } 40 | -------------------------------------------------------------------------------- /ttab/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LINs-lab/ttab/19cb3ce4e1a64a42f3b79bbce654bd5f8fb12421/ttab/utils/__init__.py -------------------------------------------------------------------------------- /ttab/utils/auxiliary.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import collections 5 | 6 | import contextlib 7 | 8 | import torch 9 | 10 | import ttab.utils.checkpoint as checkpoint 11 | 12 | 13 | class dict2obj(object): 14 | def __init__(self, d): 15 | for a, b in d.items(): 16 | if isinstance(b, (list, tuple)): 17 | setattr(self, a, [dict2obj(x) if isinstance(x, dict) else x for x in b]) 18 | else: 19 | setattr(self, a, dict2obj(b) if isinstance(b, dict) else b) 20 | 21 | 22 | def flatten_nested_dicts(d, parent_key="", sep="_"): 23 | """Borrowed from 24 | https://stackoverflow.com/a/6027615 25 | """ 26 | items = [] 27 | for k, v in d.items(): 28 | new_key = parent_key + sep + k if parent_key else k 29 | if isinstance(v, collections.MutableMapping): 30 | items.extend(flatten_nested_dicts(v, new_key, sep=sep).items()) 31 | else: 32 | items.append((new_key, v)) 33 | return dict(items) 34 | 35 | 36 | @contextlib.contextmanager 37 | def fork_rng_with_seed(seed): 38 | if seed is None: 39 | yield 40 | else: 41 | with torch.random.fork_rng(devices=[]): 42 | torch.manual_seed(seed) 43 | yield 44 | 45 | 46 | @contextlib.contextmanager 47 | def evaluation_monitor(conf): 48 | conf.status = "started" 49 | checkpoint.save_arguments(conf) 50 | 51 | yield 52 | 53 | # update the training status. 54 | job_id = ( 55 | conf.job_id if conf.job_id is not None else f"/tmp/tmp_{str(int(time.time()))}" 56 | ) 57 | os.system(f"echo {conf.checkpoint_path} >> {job_id}") 58 | 59 | # get updated conf 60 | conf.status = "finished" 61 | checkpoint.save_arguments(conf, force=True) 62 | -------------------------------------------------------------------------------- /ttab/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import time 4 | import json 5 | from typing import Any 6 | 7 | import torch 8 | 9 | import ttab.utils.file_io as file_io 10 | 11 | 12 | def init_checkpoint(conf: Any): 13 | # init checkpoint dir. 14 | conf.checkpoint_path = os.path.join( 15 | conf.root_path, 16 | conf.model_name, 17 | conf.job_name, 18 | # f"{conf.model_name}_{conf.base_data_name}_{conf.model_adaptation_method}_{conf.model_selection_method}_{int(conf.timestamp if conf.timestamp is not None else time.time())}-seed{conf.seed}", 19 | f"{conf.model_name}_{conf.base_data_name}_{conf.model_adaptation_method}_{conf.model_selection_method}_{str(time.time()).replace('.', '_')}-seed{conf.seed}", 20 | ) 21 | 22 | # if the directory does not exists, create them. 23 | file_io.build_dirs(conf.checkpoint_path) 24 | return conf.checkpoint_path 25 | 26 | 27 | def save_arguments(conf: Any, force: bool = False): 28 | # save the configure file to the checkpoint. 29 | path = os.path.join(conf.checkpoint_path, "arguments.json") 30 | 31 | if force or not os.path.exists(path): 32 | with open(path, "w") as fp: 33 | json.dump( 34 | dict( 35 | [ 36 | (k, v) 37 | for k, v in conf.__dict__.items() 38 | if file_io.is_jsonable(v) and type(v) is not torch.Tensor 39 | ] 40 | ), 41 | fp, 42 | indent=" ", 43 | ) 44 | -------------------------------------------------------------------------------- /ttab/utils/early_stopping.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | 4 | class EarlyStoppingTracker(object): 5 | def __init__(self, patience: int, delta: int = 0, mode: str = "max") -> None: 6 | self.patience = patience 7 | self.delta = delta 8 | self.mode = mode 9 | self.best_value = None 10 | self.counter = 0 11 | 12 | def __call__(self, value: float) -> bool: 13 | if self.patience is None or self.patience <= 0: 14 | return False 15 | 16 | if self.best_value is None: 17 | self.best_value = value 18 | self.counter = 0 19 | return False 20 | 21 | if self.mode == "max": 22 | if value > self.best_value + self.delta: 23 | return self._positive_update(value) 24 | else: 25 | return self._negative_update(value) 26 | elif self.mode == "min": 27 | if value < self.best_value - self.delta: 28 | return self._positive_update(value) 29 | else: 30 | return self._negative_update(value) 31 | else: 32 | raise ValueError(f"Illegal mode for early stopping: {self.mode}") 33 | 34 | def _positive_update(self, value: float) -> bool: 35 | self.counter = 0 36 | self.best_value = value 37 | return False 38 | 39 | def _negative_update(self, value: float) -> bool: 40 | self.counter += 1 41 | if self.counter > self.patience: 42 | return True 43 | else: 44 | return False 45 | -------------------------------------------------------------------------------- /ttab/utils/file_io.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import shutil 4 | import json 5 | 6 | 7 | """json related.""" 8 | 9 | 10 | def read_json(path): 11 | """read json file from path.""" 12 | with open(path, "r") as f: 13 | return json.load(f) 14 | 15 | 16 | def is_jsonable(x): 17 | try: 18 | json.dumps(x) 19 | return True 20 | except: 21 | return False 22 | 23 | 24 | """operate dir.""" 25 | 26 | 27 | def build_dir(path, force): 28 | """build directory.""" 29 | if os.path.exists(path) and force: 30 | shutil.rmtree(path) 31 | os.mkdir(path) 32 | elif not os.path.exists(path): 33 | os.mkdir(path) 34 | return path 35 | 36 | 37 | def build_dirs(path): 38 | try: 39 | os.makedirs(path) 40 | except Exception as e: 41 | print(" encounter error: {}".format(e)) 42 | 43 | 44 | def remove_folder(path): 45 | try: 46 | shutil.rmtree(path) 47 | except Exception as e: 48 | print(" encounter error: {}".format(e)) 49 | 50 | 51 | def list_files(root_path): 52 | dirs = os.listdir(root_path) 53 | return [os.path.join(root_path, path) for path in dirs] 54 | -------------------------------------------------------------------------------- /ttab/utils/logging.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import os 3 | import json 4 | import time 5 | import pprint 6 | from typing import Any, Dict 7 | 8 | from io import StringIO 9 | import csv 10 | 11 | 12 | class Logger(object): 13 | """ 14 | Very simple prototype logger that will store the values to a JSON file 15 | """ 16 | 17 | def __init__(self, folder_path: str) -> None: 18 | """ 19 | :param filename: ending with .json 20 | :param auto_save: save the JSON file after every addition 21 | """ 22 | self.folder_path = folder_path 23 | self.json_file_path = os.path.join(folder_path, "log-1.json") 24 | self.txt_file_path = os.path.join(folder_path, "log.txt") 25 | self.values = [] 26 | self.pp = MyPrettyPrinter(indent=2, depth=3, compact=True) 27 | 28 | def log_metric( 29 | self, 30 | name: str, 31 | values: Dict[str, Any], 32 | tags: Dict[str, Any], 33 | display: bool = False, 34 | ) -> None: 35 | """ 36 | Store a scalar metric 37 | 38 | :param name: measurement, like 'accuracy' 39 | :param values: dictionary, like { epoch: 3, value: 0.23 } 40 | :param tags: dictionary, like { split: train } 41 | """ 42 | self.values.append({"measurement": name, **values, **tags}) 43 | 44 | if display: 45 | print( 46 | "{name}: {values} ({tags})".format(name=name, values=values, tags=tags) 47 | ) 48 | 49 | def pretty_print(self, value: Any) -> None: 50 | self.pp.pprint(value) 51 | 52 | def log(self, value: str, display: bool = True) -> None: 53 | content = time.strftime("%Y-%m-%d %H:%M:%S") + "\t" + value 54 | if display: 55 | print(content) 56 | self.save_txt(content) 57 | 58 | def save_json(self) -> None: 59 | """Save the internal memory to a file.""" 60 | with open(self.json_file_path, "w") as fp: 61 | json.dump(self.values, fp, indent=" ") 62 | 63 | if len(self.values) > 1e4: 64 | # reset 'values' and redirect the json file to a different path. 65 | self.values = [] 66 | self.redirect_new_json() 67 | 68 | def save_txt(self, value: str) -> None: 69 | with open(self.txt_file_path, "a") as f: 70 | f.write(value + "\n") 71 | 72 | def redirect_new_json(self) -> None: 73 | """get the number of existing json files under the current folder.""" 74 | existing_json_files = [ 75 | file for file in os.listdir(self.folder_path) if "json" in file 76 | ] 77 | self.json_file_path = os.path.join( 78 | self.folder_path, "log-{}.json".format(len(existing_json_files) + 1) 79 | ) 80 | 81 | 82 | class MyPrettyPrinter(pprint.PrettyPrinter): 83 | """Borrowed from 84 | https://stackoverflow.com/questions/30062384/pretty-print-namedtuple 85 | """ 86 | 87 | def format_namedtuple(self, object, stream, indent, allowance, context, level): 88 | # Code almost equal to _format_dict, see pprint code 89 | write = stream.write 90 | write(object.__class__.__name__ + "(") 91 | object_dict = object._asdict() 92 | length = len(object_dict) 93 | if length: 94 | # We first try to print inline, and if it is too large then we print it on multiple lines 95 | inline_stream = StringIO() 96 | self.format_namedtuple_items( 97 | object_dict.items(), 98 | inline_stream, 99 | indent, 100 | allowance + 1, 101 | context, 102 | level, 103 | inline=True, 104 | ) 105 | max_width = self._width - indent - allowance 106 | if len(inline_stream.getvalue()) > max_width: 107 | self.format_namedtuple_items( 108 | object_dict.items(), 109 | stream, 110 | indent, 111 | allowance + 1, 112 | context, 113 | level, 114 | inline=False, 115 | ) 116 | else: 117 | stream.write(inline_stream.getvalue()) 118 | write(")") 119 | 120 | def format_namedtuple_items( 121 | self, items, stream, indent, allowance, context, level, inline=False 122 | ): 123 | # Code almost equal to _format_dict_items, see pprint code 124 | indent += self._indent_per_level 125 | write = stream.write 126 | last_index = len(items) - 1 127 | if inline: 128 | delimnl = ", " 129 | else: 130 | delimnl = ",\n" + " " * indent 131 | write("\n" + " " * indent) 132 | for i, (key, ent) in enumerate(items): 133 | last = i == last_index 134 | write(key + "=") 135 | self._format( 136 | ent, 137 | stream, 138 | indent + len(key) + 2, 139 | allowance if last else 1, 140 | context, 141 | level, 142 | ) 143 | if not last: 144 | write(delimnl) 145 | 146 | def _format(self, object, stream, indent, allowance, context, level): 147 | # We dynamically add the types of our namedtuple and namedtuple like 148 | # classes to the _dispatch object of pprint that maps classes to 149 | # formatting methods 150 | # We use a simple criteria (_asdict method) that allows us to use the 151 | # same formatting on other classes but a more precise one is possible 152 | if hasattr(object, "_asdict") and type(object).__repr__ not in self._dispatch: 153 | self._dispatch[type(object).__repr__] = MyPrettyPrinter.format_namedtuple 154 | super()._format(object, stream, indent, allowance, context, level) 155 | 156 | 157 | class CSVBatchLogger: 158 | """Borrowed from https://github.com/kohpangwei/group_DRO/blob/master/utils.py#L39""" 159 | 160 | def __init__(self, csv_path, n_groups, mode="w"): 161 | columns = ["epoch", "batch"] 162 | for idx in range(n_groups): 163 | columns.append(f"avg_loss_group:{idx}") 164 | columns.append(f"exp_avg_loss_group:{idx}") 165 | columns.append(f"avg_acc_group:{idx}") 166 | columns.append(f"processed_data_count_group:{idx}") 167 | columns.append(f"update_data_count_group:{idx}") 168 | columns.append(f"update_batch_count_group:{idx}") 169 | columns.append("avg_actual_loss") 170 | columns.append("avg_per_sample_loss") 171 | columns.append("avg_acc") 172 | columns.append("model_norm_sq") 173 | columns.append("reg_loss") 174 | 175 | self.path = csv_path 176 | self.file = open(csv_path, mode) 177 | self.columns = columns 178 | self.writer = csv.DictWriter(self.file, fieldnames=columns) 179 | if mode == "w": 180 | self.writer.writeheader() 181 | 182 | def log(self, epoch, batch, stats_dict): 183 | stats_dict["epoch"] = epoch 184 | stats_dict["batch"] = batch 185 | self.writer.writerow(stats_dict) 186 | 187 | def flush(self): 188 | self.file.flush() 189 | 190 | def close(self): 191 | self.file.close() 192 | -------------------------------------------------------------------------------- /ttab/utils/mathdict.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import Any, Dict, List, Union, Callable 3 | from collections.abc import Mapping, ItemsView, ValuesView 4 | 5 | C = Union[float, int] 6 | 7 | 8 | class MathDict: 9 | def __init__(self, dictionary: Dict[Any, Union[int, float]]) -> None: 10 | self.dictionary = dictionary 11 | self.keys = set(dictionary.keys()) 12 | 13 | def __str__(self) -> str: 14 | return f"MathDict({self.dictionary})" 15 | 16 | def __repr__(self) -> str: 17 | return f"MathDict({repr(self.dictionary)})" 18 | 19 | def map(self: "MathDict", mapfun: Mapping[Any, Union[int, float]]) -> "MathDict": 20 | new_dict = {} 21 | for key in self.keys: 22 | new_dict[key] = mapfun(self.dictionary[key]) 23 | return MathDict(new_dict) 24 | 25 | def filter(self: "MathDict", condfun: Mapping[Any, bool]) -> "MathDict": 26 | new_dict = {} 27 | for key in self.keys: 28 | if condfun(key): 29 | new_dict[key] = self.dictionary[key] 30 | return MathDict(new_dict) 31 | 32 | def detach(self) -> None: 33 | for key in self.keys: 34 | self.dictionary[key] = self.dictionary[key].detach() 35 | 36 | def values(self) -> ValuesView: 37 | return self.dictionary.values() 38 | 39 | def items(self) -> ItemsView: 40 | return self.dictionary.items() 41 | 42 | 43 | def _mathdict_binary_op(operation: Callable[[C, C], C]) -> MathDict: 44 | def op(self: MathDict, other: Union[MathDict, Dict]) -> MathDict: 45 | new_dict = {} 46 | if isinstance(other, MathDict): 47 | assert other.keys == self.keys 48 | for key in self.keys: 49 | new_dict[key] = operation(self.dictionary[key], other.dictionary[key]) 50 | else: 51 | for key in self.keys: 52 | new_dict[key] = operation(self.dictionary[key], other) 53 | return MathDict(new_dict) 54 | 55 | return op 56 | 57 | 58 | def _mathdict_map_op( 59 | operation: Callable[[ValuesView, List[Any], List[Any]], Any] 60 | ) -> MathDict: 61 | def op(self: MathDict, *args, **kwargs) -> MathDict: 62 | new_dict = {} 63 | for key in self.keys: 64 | new_dict[key] = operation(self.dictionary[key], args, kwargs) 65 | return MathDict(new_dict) 66 | 67 | return op 68 | 69 | 70 | def _mathdict_binary_in_place_op(operation: Callable[[Dict, Any, C], None]) -> MathDict: 71 | def op(self: MathDict, other: Union[MathDict, Dict]) -> MathDict: 72 | if isinstance(other, MathDict): 73 | assert other.keys == self.keys 74 | for key in self.keys: 75 | operation(self.dictionary, key, other.dictionary[key]) 76 | else: 77 | for key in self.keys: 78 | operation(self.dictionary, key, other) 79 | return self 80 | 81 | return op 82 | 83 | 84 | def _iadd(dict: Dict, key: Any, b: C) -> None: 85 | dict[key] += b 86 | 87 | 88 | def _isub(dict: Dict, key: Any, b: C) -> None: 89 | dict[key] -= b 90 | 91 | 92 | def _imul(dict: Dict, key: Any, b: C) -> None: 93 | dict[key] *= b 94 | 95 | 96 | def _itruediv(dict: Dict, key: Any, b: C) -> None: 97 | dict[key] /= b 98 | 99 | 100 | def _ifloordiv(dict: Dict, key: Any, b: C) -> None: 101 | dict[key] //= b 102 | 103 | 104 | MathDict.__add__ = _mathdict_binary_op(lambda a, b: a + b) 105 | MathDict.__sub__ = _mathdict_binary_op(lambda a, b: a - b) 106 | MathDict.__rsub__ = _mathdict_binary_op(lambda a, b: b - a) 107 | MathDict.__mul__ = _mathdict_binary_op(lambda a, b: a * b) 108 | MathDict.__rmul__ = _mathdict_binary_op(lambda a, b: a * b) 109 | MathDict.__truediv__ = _mathdict_binary_op(lambda a, b: a / b) 110 | MathDict.__floordiv__ = _mathdict_binary_op(lambda a, b: a // b) 111 | MathDict.__getitem__ = _mathdict_map_op( 112 | lambda x, args, kwargs: x.__getitem__(*args, **kwargs) 113 | ) 114 | MathDict.__iadd__ = _mathdict_binary_in_place_op(_iadd) 115 | MathDict.__isub__ = _mathdict_binary_in_place_op(_isub) 116 | MathDict.__imul__ = _mathdict_binary_in_place_op(_imul) 117 | MathDict.__itruediv__ = _mathdict_binary_in_place_op(_itruediv) 118 | MathDict.__ifloordiv__ = _mathdict_binary_in_place_op(_ifloordiv) 119 | -------------------------------------------------------------------------------- /ttab/utils/stat_tracker.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import copy 3 | 4 | 5 | class MaxMeter(object): 6 | """ 7 | Keeps track of the max of all the values that are 'add'ed 8 | """ 9 | 10 | def __init__(self): 11 | self.max = None 12 | 13 | def update(self, value): 14 | """ 15 | Add a value to the accumulator. 16 | :return: `true` if the provided value became the new max 17 | """ 18 | if self.max is None or value > self.max: 19 | self.max = copy.deepcopy(value) 20 | return True 21 | else: 22 | return False 23 | 24 | def value(self): 25 | """Access the current running average""" 26 | return self.max 27 | 28 | 29 | class MinMeter(object): 30 | """ 31 | Keeps track of the max of all the values that are 'add'ed 32 | """ 33 | 34 | def __init__(self): 35 | self.min = None 36 | 37 | def update(self, value): 38 | """ 39 | Add a value to the accumulator. 40 | :return: `true` if the provided value became the new max 41 | """ 42 | if self.min is None or value < self.min: 43 | self.min = copy.deepcopy(value) 44 | return True 45 | else: 46 | return False 47 | 48 | def value(self): 49 | """Access the current running average""" 50 | return self.min 51 | 52 | 53 | class AverageMeter(object): 54 | """Computes and stores the average and current value""" 55 | 56 | def __init__(self): 57 | self.reset() 58 | 59 | def reset(self): 60 | self.val = 0 61 | self.avg = 0 62 | self.sum = 0 63 | self.max = -float("inf") 64 | self.min = float("inf") 65 | self.count = 0 66 | 67 | def update(self, val, n=1): 68 | self.val = val 69 | self.sum += val * n 70 | self.count += n 71 | self.avg = self.sum / self.count 72 | self.max = val if val > self.max else self.max 73 | self.min = val if val < self.min else self.min 74 | 75 | 76 | class RuntimeTracker(object): 77 | """Tracking the runtime stat for local training.""" 78 | 79 | def __init__(self, metrics_to_track): 80 | self.metrics_to_track = metrics_to_track 81 | self.reset() 82 | 83 | def reset(self): 84 | self.stat = dict((name, AverageMeter()) for name in self.metrics_to_track) 85 | 86 | def add_stat(self, metric_name: str): 87 | self.stat[metric_name] = AverageMeter() 88 | 89 | def get_metrics_performance(self): 90 | return [self.stat[metric].avg for metric in self.metrics_to_track] 91 | 92 | def update_metrics(self, metric_stat, n_samples): 93 | for name, value in metric_stat.items(): 94 | self.stat[name].update(value, n_samples) 95 | 96 | def __call__(self): 97 | return dict((name, val.avg) for name, val in self.stat.items()) 98 | 99 | def get_current_val(self): 100 | return dict((name, val.val) for name, val in self.stat.items()) 101 | 102 | def get_val_by_name(self, metric_name): 103 | return dict([(metric_name, self.stat[metric_name].val)]) 104 | 105 | 106 | class BestPerf(object): 107 | def __init__(self, best_perf=None, larger_is_better=True): 108 | self.best_perf = best_perf 109 | self.cur_perf = None 110 | self.best_perf_locs = [] 111 | self.larger_is_better = larger_is_better 112 | 113 | # define meter 114 | self._define_meter() 115 | 116 | def _define_meter(self): 117 | self.meter = MaxMeter() if self.larger_is_better else MinMeter() 118 | 119 | def update(self, perf, perf_location): 120 | self.is_best = self.meter.update(perf) 121 | self.cur_perf = perf 122 | 123 | if self.is_best: 124 | self.best_perf = perf 125 | self.best_perf_locs += [perf_location] 126 | 127 | def get_best_perf_loc(self): 128 | return self.best_perf_locs[-1] if len(self.best_perf_locs) != 0 else None 129 | -------------------------------------------------------------------------------- /ttab/utils/tensor_buffer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | from typing import List, Tuple 3 | 4 | import torch 5 | 6 | 7 | class TensorBuffer(object): 8 | """Packs multiple tensors into one flat buffer.""" 9 | 10 | def __init__(self, tensors: List[torch.Tensor], use_cuda: bool = True) -> None: 11 | indices: List[int] = [0] 12 | for tensor in tensors: 13 | new_end = indices[-1] + tensor.nelement() 14 | indices.append(new_end) 15 | 16 | self._start_idx = indices[:-1] 17 | self._end_idx = indices[1:] 18 | self._tensors_len = len(tensors) 19 | self._tensors_sizes = [x.size() for x in tensors] 20 | 21 | self.buffer = flatten(tensors, use_cuda=use_cuda) # copies 22 | 23 | def __getitem__(self, index: int) -> torch.Tensor: 24 | return self.buffer[self._start_idx[index] : self._end_idx[index]].view( 25 | self._tensors_sizes[index] 26 | ) 27 | 28 | def __len__(self) -> int: 29 | return self._tensors_len 30 | 31 | def is_cuda(self) -> bool: 32 | return self.buffer.is_cuda 33 | 34 | def nelement(self) -> int: 35 | return self.buffer.nelement() 36 | 37 | def unpack(self, tensors: List[torch.Tensor]) -> None: 38 | for tensor, entry in zip(tensors, self): 39 | tensor.data[:] = entry 40 | 41 | 42 | def flatten( 43 | tensors: List[torch.Tensor], 44 | shapes: List[Tuple[int, int]] = None, 45 | use_cuda: bool = True, 46 | ): 47 | # init and recover the shapes vec. 48 | pointers: List[int] = [0] 49 | if shapes is not None: 50 | for shape in shapes: 51 | pointers.append(pointers[-1] + shape[1]) 52 | else: 53 | for tensor in tensors: 54 | pointers.append(pointers[-1] + tensor.nelement()) 55 | 56 | # flattening. 57 | current_device = tensors[0].device 58 | target_device = tensors[0].device if tensors[0].is_cuda and use_cuda else "cpu" 59 | vec = torch.empty(pointers[-1], device=target_device) 60 | 61 | for tensor, start_idx, end_idx in zip(tensors, pointers[:-1], pointers[1:]): 62 | vec[start_idx:end_idx] = ( 63 | tensor.data.view(-1).to(device=target_device) 64 | if current_device != target_device 65 | else tensor.data.view(-1) 66 | ) 67 | return vec 68 | 69 | 70 | def unflatten( 71 | self_tensors: List[torch.Tensor], 72 | out_tensors: List[torch.Tensor], 73 | shapes: Tuple[int, int], 74 | ): 75 | pointer: int = 0 76 | 77 | for self_tensor, shape in zip(self_tensors, shapes): 78 | param_size, nelement = shape 79 | self_tensor.data[:] = out_tensors[pointer : pointer + nelement].view(param_size) 80 | pointer += nelement 81 | -------------------------------------------------------------------------------- /ttab/utils/timer.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import time 3 | from typing import Any, Dict 4 | from io import StringIO 5 | from contextlib import contextmanager 6 | 7 | import numpy as np 8 | import torch 9 | 10 | 11 | class Timer(object): 12 | """ 13 | Timer for PyTorch code 14 | Comes in the form of a contextmanager: 15 | 16 | Example: 17 | >>> timer = Timer() 18 | ... for i in range(10): 19 | ... with timer("expensive operation"): 20 | ... x = torch.randn(100) 21 | ... print(timer.summary()) 22 | """ 23 | 24 | def __init__( 25 | self, 26 | device: str, 27 | verbosity_level: int = 1, 28 | log_fn=None, 29 | skip_first: bool = True, 30 | on_cuda: bool = True, 31 | ) -> None: 32 | self.device = device 33 | self.verbosity_level = verbosity_level 34 | self.log_fn = log_fn if log_fn is not None else self._default_log_fn 35 | self.skip_first = skip_first 36 | self.cuda_available = torch.cuda.is_available() and on_cuda 37 | 38 | self.reset() 39 | 40 | def reset(self) -> None: 41 | """Reset the timer""" 42 | self.totals = {} # Total time per label 43 | self.first_time = {} # First occurrence of a label (start time) 44 | self.last_time = {} # Last occurence of a label (end time) 45 | self.call_counts = {} # Number of times a label occurred 46 | 47 | @contextmanager 48 | def __call__( 49 | self, label: str, step: int = -1, epoch: int = -1.0, verbosity: int = 1 50 | ) -> None: 51 | # Don't measure this if the verbosity level is too high 52 | if verbosity > self.verbosity_level: 53 | yield 54 | return 55 | 56 | # Measure the time 57 | self._cuda_sync() 58 | start = time.time() 59 | yield 60 | self._cuda_sync() 61 | end = time.time() 62 | 63 | # Update first and last occurrence of this label 64 | if label not in self.first_time: 65 | self.first_time[label] = start 66 | self.last_time[label] = end 67 | 68 | # Update the totals and call counts 69 | if label not in self.totals and self.skip_first: 70 | self.totals[label] = 0.0 71 | del self.first_time[label] 72 | self.call_counts[label] = 0 73 | elif label not in self.totals and not self.skip_first: 74 | self.totals[label] = end - start 75 | self.call_counts[label] = 1 76 | else: 77 | self.totals[label] += end - start 78 | self.call_counts[label] += 1 79 | 80 | if self.call_counts[label] > 0: 81 | # We will reduce the probability of logging a timing 82 | # linearly with the number of time we have seen it. 83 | # It will always be recorded in the totals, though. 84 | if np.random.rand() < 1 / self.call_counts[label]: 85 | self.log_fn( 86 | "timer", 87 | {"step": step, "epoch": epoch, "value": end - start}, 88 | {"event": label}, 89 | ) 90 | 91 | def summary(self) -> None: 92 | """ 93 | Return a summary in string-form of all the timings recorded so far 94 | """ 95 | if len(self.totals) > 0: 96 | with StringIO() as buffer: 97 | total_avg_time = 0 98 | print("--- Timer summary ------------------------", file=buffer) 99 | print(" Event | Count | Average time | Frac.", file=buffer) 100 | for event_label in sorted(self.totals): 101 | total = self.totals[event_label] 102 | count = self.call_counts[event_label] 103 | if count == 0: 104 | continue 105 | avg_duration = total / count 106 | total_runtime = ( 107 | self.last_time[event_label] - self.first_time[event_label] 108 | ) 109 | runtime_percentage = 100 * total / total_runtime 110 | total_avg_time += avg_duration if "." not in event_label else 0 111 | print( 112 | f"- {event_label:30s} | {count:6d} | {avg_duration:11.5f}s | {runtime_percentage:5.1f}%", 113 | file=buffer, 114 | ) 115 | print("-------------------------------------------", file=buffer) 116 | event_label = "total_averaged_time" 117 | print( 118 | f"- {event_label:30s}| {count:6d} | {total_avg_time:11.5f}s |", 119 | file=buffer, 120 | ) 121 | print("-------------------------------------------", file=buffer) 122 | return buffer.getvalue() 123 | 124 | def _cuda_sync(self) -> None: 125 | """Finish all asynchronous GPU computations to get correct timings""" 126 | if self.cuda_available: 127 | torch.cuda.synchronize(device=self.device) 128 | 129 | def _default_log_fn(self, _: Any, values: Dict, tags: Dict) -> None: 130 | label = tags["label"] 131 | epoch = values["epoch"] 132 | duration = values["value"] 133 | print(f"Timer: {label:30s} @ {epoch:4.1f} - {duration:8.5f}s") 134 | --------------------------------------------------------------------------------