├── vision ├── requirements.txt ├── figure │ └── csl_imagenet.png ├── README.md ├── script │ ├── baseline_imagenet.sh │ ├── csl_imagenet.sh │ ├── oracle_imagenet.sh │ └── prepare_imagenet.sh ├── SETUP.md ├── download.py ├── helper.py ├── domainbed.py ├── models.py ├── data.py ├── multi_weak_strong_oracle.py ├── run_analysis.py ├── single_weak_strong.py ├── multi_weak_strong_csl.py └── prep_imagenet.py ├── assets └── co-supervised-learning.jpg ├── README.md └── LICENSE.md /vision/requirements.txt: -------------------------------------------------------------------------------- 1 | torchvision 2 | scipy 3 | gdown 4 | seaborn 5 | pandas -------------------------------------------------------------------------------- /vision/figure/csl_imagenet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuejiangLIU/csl/HEAD/vision/figure/csl_imagenet.png -------------------------------------------------------------------------------- /assets/co-supervised-learning.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuejiangLIU/csl/HEAD/assets/co-supervised-learning.jpg -------------------------------------------------------------------------------- /vision/README.md: -------------------------------------------------------------------------------- 1 | # Co-Supervised Learning in Visual Recognition 2 | 3 | ### Configuration 4 | 5 | Please refer to the [SETUP](SETUP.md) 6 | 7 | ### Run Baseline 8 | 9 | ```bash 10 | bash script/baseline_imagenet.sh 11 | ``` 12 | 13 | ### Collective Capability 14 | 15 | ```bash 16 | bash script/oracle_imagenet.sh 17 | ``` 18 | 19 | ### Co-Supervised Learning 20 | 21 | ```bash 22 | bash script/csl_imagenet.sh 23 | ``` 24 | 25 | ### Result Comparision 26 | 27 | ``` 28 | python run_analysis.py 29 | ``` 30 | 31 | ### Multi-Domain Dataset 32 | 33 | Coming soon. -------------------------------------------------------------------------------- /vision/script/baseline_imagenet.sh: -------------------------------------------------------------------------------- 1 | # directory config 2 | DATADIR=/storage/datasets/imagenet 3 | CKPTDIR=/storage/weak2strong/vision/ckpt/imagenet/alexnet1 4 | EMBEDDIR=/storage/weak2strong/vision/embedding/imagenet 5 | 6 | # supervisor config 7 | EPOCH=0 8 | ITER=1000 9 | SOFT=False 10 | 11 | python single_weak_strong.py \ 12 | --weak_path ${CKPTDIR}/0_1000/head-${EPOCH}-${ITER}.pth.tar \ 13 | --data_path ${DATADIR} \ 14 | --embed_path ${EMBEDDIR} \ 15 | --result_path result/imagenet \ 16 | --soft_teacher ${SOFT} \ 17 | --weak_model_name alexnet1 \ 18 | --strong_model_name vits14_dino \ 19 | --lr 1e-4 --n_epochs 20 20 | -------------------------------------------------------------------------------- /vision/SETUP.md: -------------------------------------------------------------------------------- 1 | ### Environment 2 | 3 | ```bash 4 | pip install -r requirement.txt 5 | ``` 6 | 7 | ### Download Dataset 8 | 9 | Download ImageNet data, see torchvision for instructions; should contain files `ILSVRC2012_devkit_t12.tar.gz` and `ILSVRC2012_img_val.tar` 10 | 11 | ### Prepare Supervisor 12 | 13 | ```bash 14 | bash script/prepare_imagenet.sh 15 | ``` 16 | 17 | ### Download Pre-trained 18 | 19 | Instead of prepare the weak supervisors through the script above, one can also download our pre-trained ones and the corresponding embeddings from [google drive](https://drive.google.com/drive/folders/1EA_TCZavnuJK3_NPvmE-23gUK8xzey8Z?usp=drive_link). 20 | 21 | ### Configure directories 22 | 23 | Set directories in the scripts under the [script folder](script) 24 | 25 | ```bash 26 | DATADIR=/datasets/imagenet 27 | CKPTDIR=/ckpt/imagenet/alexnet 28 | EMBEDDIR=/embedding/imagenet 29 | ``` 30 | -------------------------------------------------------------------------------- /vision/script/csl_imagenet.sh: -------------------------------------------------------------------------------- 1 | # directory config 2 | DATADIR=/storage/datasets/imagenet 3 | CKPTDIR=/storage/weak2strong/vision/ckpt/imagenet/alexnet1 4 | EMBEDDIR=/storage/weak2strong/vision/embedding/imagenet 5 | 6 | # supervisor config 7 | EPOCH=0 8 | ITER=1000 9 | SOFT=False 10 | 11 | # experiment config 12 | DENOISE=top3 13 | SEED=0 14 | RESULTDIR=result 15 | 16 | python multi_weak_strong_csl.py \ 17 | ${CKPTDIR}/0_1000/head-${EPOCH}-${ITER}.pth.tar \ 18 | ${CKPTDIR}/0_500/head-${EPOCH}-${ITER}.pth.tar \ 19 | ${CKPTDIR}/500_1000/head-${EPOCH}-${ITER}.pth.tar \ 20 | ${CKPTDIR}/0_250/head-${EPOCH}-${ITER}.pth.tar \ 21 | ${CKPTDIR}/250_500/head-${EPOCH}-${ITER}.pth.tar \ 22 | ${CKPTDIR}/500_750/head-${EPOCH}-${ITER}.pth.tar \ 23 | ${CKPTDIR}/750_1000/head-${EPOCH}-${ITER}.pth.tar \ 24 | ${CKPTDIR}/0_125/head-${EPOCH}-${ITER}.pth.tar \ 25 | ${CKPTDIR}/125_250/head-${EPOCH}-${ITER}.pth.tar \ 26 | ${CKPTDIR}/250_375/head-${EPOCH}-${ITER}.pth.tar \ 27 | ${CKPTDIR}/375_500/head-${EPOCH}-${ITER}.pth.tar \ 28 | ${CKPTDIR}/500_625/head-${EPOCH}-${ITER}.pth.tar \ 29 | ${CKPTDIR}/625_750/head-${EPOCH}-${ITER}.pth.tar \ 30 | ${CKPTDIR}/750_875/head-${EPOCH}-${ITER}.pth.tar \ 31 | ${CKPTDIR}/875_1000/head-${EPOCH}-${ITER}.pth.tar \ 32 | --data_path ${DATADIR} \ 33 | --embed_path ${EMBEDDIR} \ 34 | --result_path ${RESULTDIR} \ 35 | --soft_teacher ${SOFT} \ 36 | --denoise_criterion ${DENOISE} \ 37 | --weak_model_name alexnet1 \ 38 | --strong_model_name vits14_dino \ 39 | --n_epochs 20 --lr 1e-4 \ 40 | --seed ${SEED} 41 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Co-Supervised Learning 2 | 3 |

4 | 5 |

6 | 7 | This repository is the official implementation of 8 |
9 | **[Co-Supervised Learning: 10 | Improving Weak-to-Strong Generalization with Hierarchical Mixture of Experts](https://arxiv.org/abs/2402.15505)** 11 |
12 | 13 | > The current codebase is a minimalistic version of [co-supervised learning](https://arxiv.org/abs/2402.15505), built upon the 'vision' directory of [weak-to-strong generalization](https://github.com/openai/weak-to-strong/tree/main/vision). It will be continually maintained and updated in the future. If you have any questions or comments, please feel free to raise issues or email yuejiang.liu[at]{epfl.ch,stanford.edu}. 14 | 15 | ### Getting Started 16 | 17 | Please refer to the [vision directory](vision) 18 | 19 | ### Expected Results 20 | 21 | 22 | 23 | ### Citation 24 | 25 | If you find this code useful for your research, please cite the following: 26 | 27 | ```bibtex 28 | @article{liu2024csl, 29 | author = {Yuejiang Liu and Alexandre Alahi}, 30 | title = {Co-Supervised Learning: Improving Weak-to-Strong Generalization with Hierarchical Mixture of Experts}, 31 | journal = {arXiv preprint 2402.15505}, 32 | year = {2024}, 33 | } 34 | ``` 35 | 36 | ### Acknowledgments 37 | 38 | - weak-to-strong repository from OpenAI 39 | - pre-trained DINOv2 models from Meta AI 40 | - DomainBed repository from Meta AI 41 | -------------------------------------------------------------------------------- /vision/script/oracle_imagenet.sh: -------------------------------------------------------------------------------- 1 | # directory config 2 | DATADIR=/storage/datasets/imagenet 3 | CKPTDIR=/storage/weak2strong/vision/ckpt/imagenet/alexnet1 4 | EMBEDDIR=/storage/weak2strong/vision/embedding/imagenet 5 | 6 | # supervisor config 7 | EPOCH=0 8 | ITER=1000 9 | SOFT=False 10 | RESULTDIR=result 11 | 12 | # 1-fold generalist 13 | # python multi_weak_strong_oracle.py \ 14 | # ${CKPTDIR}/0_1000/head-${EPOCH}-${ITER}.pth.tar \ 15 | # --data_path ${DATADIR} \ 16 | # --embed_path ${EMBEDDIR} \ 17 | # --result_path ${RESULTDIR} \ 18 | # --soft_teacher ${SOFT} \ 19 | # --weak_model_name alexnet1 \ 20 | # --strong_model_name vits14_dino \ 21 | # --n_epochs 20 --lr 1e-4 22 | 23 | # 2-fold specialist 24 | python multi_weak_strong_oracle.py \ 25 | ${CKPTDIR}/0_500/head-${EPOCH}-${ITER}.pth.tar \ 26 | ${CKPTDIR}/500_1000/head-${EPOCH}-${ITER}.pth.tar \ 27 | --data_path ${DATADIR} \ 28 | --embed_path ${EMBEDDIR} \ 29 | --result_path ${RESULTDIR} \ 30 | --soft_teacher ${SOFT} \ 31 | --weak_model_name alexnet1 \ 32 | --strong_model_name vits14_dino \ 33 | --n_epochs 20 --lr 1e-4 34 | 35 | # 4-fold specialist 36 | python multi_weak_strong_oracle.py \ 37 | ${CKPTDIR}/0_250/head-${EPOCH}-${ITER}.pth.tar \ 38 | ${CKPTDIR}/250_500/head-${EPOCH}-${ITER}.pth.tar \ 39 | ${CKPTDIR}/500_750/head-${EPOCH}-${ITER}.pth.tar \ 40 | ${CKPTDIR}/750_1000/head-${EPOCH}-${ITER}.pth.tar \ 41 | --data_path ${DATADIR} \ 42 | --embed_path ${EMBEDDIR} \ 43 | --result_path ${RESULTDIR} \ 44 | --soft_teacher ${SOFT} \ 45 | --weak_model_name alexnet1 \ 46 | --strong_model_name vits14_dino \ 47 | --n_epochs 20 --lr 1e-4 48 | 49 | # 8-fold specialist 50 | python multi_weak_strong_oracle.py \ 51 | ${CKPTDIR}/0_125/head-${EPOCH}-${ITER}.pth.tar \ 52 | ${CKPTDIR}/125_250/head-${EPOCH}-${ITER}.pth.tar \ 53 | ${CKPTDIR}/250_375/head-${EPOCH}-${ITER}.pth.tar \ 54 | ${CKPTDIR}/375_500/head-${EPOCH}-${ITER}.pth.tar \ 55 | ${CKPTDIR}/500_625/head-${EPOCH}-${ITER}.pth.tar \ 56 | ${CKPTDIR}/625_750/head-${EPOCH}-${ITER}.pth.tar \ 57 | ${CKPTDIR}/750_875/head-${EPOCH}-${ITER}.pth.tar \ 58 | ${CKPTDIR}/875_1000/head-${EPOCH}-${ITER}.pth.tar \ 59 | --data_path ${DATADIR} \ 60 | --embed_path ${EMBEDDIR} \ 61 | --result_path ${RESULTDIR} \ 62 | --soft_teacher ${SOFT} \ 63 | --weak_model_name alexnet1 \ 64 | --strong_model_name vits14_dino \ 65 | --n_epochs 20 --lr 1e-4 66 | -------------------------------------------------------------------------------- /vision/script/prepare_imagenet.sh: -------------------------------------------------------------------------------- 1 | # directory config 2 | DATADIR=/storage/datasets/imagenet 3 | CKPTDIR=/storage/weak2strong/vision/ckpt/imagenet/alexnet1 4 | 5 | # 1-fold 6 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 0 --end-imbal 1000 --savedir $CKPTDIR $DATADIR 7 | 8 | # 2-fold 9 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 0 --end-imbal 500 --savedir $CKPTDIR $DATADIR 10 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 500 --end-imbal 1000 --savedir $CKPTDIR $DATADIR 11 | 12 | # 4-fold 13 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 0 --end-imbal 250 --savedir $CKPTDIR $DATADIR 14 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 250 --end-imbal 500 --savedir $CKPTDIR $DATADIR 15 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 500 --end-imbal 750 --savedir $CKPTDIR $DATADIR 16 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 750 --end-imbal 1000 --savedir $CKPTDIR $DATADIR 17 | 18 | # 8-fold 19 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 0 --end-imbal 125 --savedir $CKPTDIR $DATADIR 20 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 125 --end-imbal 250 --savedir $CKPTDIR $DATADIR 21 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 250 --end-imbal 375 --savedir $CKPTDIR $DATADIR 22 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 375 --end-imbal 500 --savedir $CKPTDIR $DATADIR 23 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 500 --end-imbal 625 --savedir $CKPTDIR $DATADIR 24 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 625 --end-imbal 750 --savedir $CKPTDIR $DATADIR 25 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 750 --end-imbal 875 --savedir $CKPTDIR $DATADIR 26 | python prepare_imagenet.py -a alexnet --lr 1e-4 --batch-size 256 --workers 12 --epochs 20 --start-imbal 875 --end-imbal 1000 --savedir $CKPTDIR $DATADIR 27 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | # License for "Co-Supervised Learning" 2 | 3 | This repository includes software originally developed by OpenAI and Meta, and modifications made by Yuejiang Liu at EPFL. 4 | 5 | The MIT License below applies to both the original software and the made modifications. 6 | 7 | ------------------------------------------ 8 | 9 | Copyright 2023 OpenAI 10 | 11 | Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. 16 | 17 | ------------------------------------------ 18 | 19 | Copyright (c) Meta Platforms, Inc. and affiliates. 20 | 21 | Permission is hereby granted, free of charge, to any person obtaining a copy 22 | of this software and associated documentation files (the "Software"), to deal 23 | in the Software without restriction, including without limitation the rights 24 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 25 | copies of the Software, and to permit persons to whom the Software is 26 | furnished to do so, subject to the following conditions: 27 | 28 | The above copyright notice and this permission notice shall be included in all 29 | copies or substantial portions of the Software. 30 | 31 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 32 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 33 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 34 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 35 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 36 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 37 | SOFTWARE. -------------------------------------------------------------------------------- /vision/download.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 2 | 3 | from zipfile import ZipFile 4 | import argparse 5 | import tarfile 6 | import shutil 7 | import gdown 8 | import uuid 9 | import json 10 | import os 11 | import urllib 12 | 13 | 14 | # utils ####################################################################### 15 | 16 | def stage_path(data_dir, name): 17 | full_path = os.path.join(data_dir, name) 18 | 19 | if not os.path.exists(full_path): 20 | os.makedirs(full_path) 21 | 22 | return full_path 23 | 24 | 25 | def download_and_extract(url, dst, remove=True): 26 | gdown.download(url, dst, quiet=False) 27 | 28 | if dst.endswith(".tar.gz"): 29 | tar = tarfile.open(dst, "r:gz") 30 | tar.extractall(os.path.dirname(dst)) 31 | tar.close() 32 | 33 | if dst.endswith(".tar"): 34 | tar = tarfile.open(dst, "r:") 35 | tar.extractall(os.path.dirname(dst)) 36 | tar.close() 37 | 38 | if dst.endswith(".zip"): 39 | zf = ZipFile(dst, "r") 40 | zf.extractall(os.path.dirname(dst)) 41 | zf.close() 42 | 43 | if remove: 44 | os.remove(dst) 45 | 46 | 47 | # Office-Home ################################################################# 48 | 49 | def download_office_home(data_dir): 50 | # Original URL: http://hemanthdv.org/OfficeHome-Dataset/ 51 | full_path = stage_path(data_dir, "office_home") 52 | 53 | download_and_extract("https://drive.google.com/uc?id=1uY0pj7oFsjMxRwaD3Sxy0jgel0fsYXLC", 54 | os.path.join(data_dir, "office_home.zip")) 55 | 56 | os.rename(os.path.join(data_dir, "OfficeHomeDataset_10072016"), 57 | full_path) 58 | 59 | 60 | # DomainNET ################################################################### 61 | 62 | def download_domain_net(data_dir): 63 | # Original URL: http://ai.bu.edu/M3SDA/ 64 | full_path = stage_path(data_dir, "domain_net") 65 | 66 | urls = [ 67 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/clipart.zip", 68 | "http://csr.bu.edu/ftp/visda/2019/multi-source/infograph.zip", 69 | "http://csr.bu.edu/ftp/visda/2019/multi-source/groundtruth/painting.zip", 70 | "http://csr.bu.edu/ftp/visda/2019/multi-source/quickdraw.zip", 71 | "http://csr.bu.edu/ftp/visda/2019/multi-source/real.zip", 72 | "http://csr.bu.edu/ftp/visda/2019/multi-source/sketch.zip" 73 | ] 74 | 75 | for url in urls: 76 | download_and_extract(url, os.path.join(full_path, url.split("/")[-1])) 77 | 78 | with open("misc/domain_net_duplicates.txt", "r") as f: 79 | for line in f.readlines(): 80 | try: 81 | os.remove(os.path.join(full_path, line.strip())) 82 | except OSError: 83 | pass 84 | 85 | 86 | if __name__ == "__main__": 87 | parser = argparse.ArgumentParser(description='Download datasets') 88 | parser.add_argument('--data_dir', type=str, required=True) 89 | args = parser.parse_args() 90 | 91 | download_domain_net(args.data_dir) -------------------------------------------------------------------------------- /vision/helper.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import pandas as pd 4 | import pickle 5 | 6 | import pdb 7 | 8 | def load_embedding(fname): 9 | with open(fname, 'rb') as handle: 10 | gt_labels = pickle.load(handle) 11 | weak_labels = pickle.load(handle) 12 | weak_acc = pickle.load(handle) 13 | embeddings = pickle.load(handle) 14 | strong_gt_labels = pickle.load(handle) 15 | print('Load embeddings from', fname) 16 | return gt_labels, weak_labels, weak_acc, embeddings, strong_gt_labels 17 | 18 | 19 | def load_weak_domain_embedding(fname): 20 | with open(fname, 'rb') as handle: 21 | gt_labels = pickle.load(handle) 22 | weak_labels = pickle.load(handle) 23 | domain_labels = pickle.load(handle) 24 | weak_acc = pickle.load(handle) 25 | print('Load weak domain embeddings from', fname) 26 | return gt_labels, weak_labels, domain_labels, weak_acc 27 | 28 | 29 | def save_weak_domain_embedding(gt_labels, weak_labels, domain_labels, weak_acc, fname): 30 | with open(fname, 'wb') as handle: 31 | pickle.dump(gt_labels, handle, protocol=pickle.HIGHEST_PROTOCOL) 32 | pickle.dump(weak_labels, handle, protocol=pickle.HIGHEST_PROTOCOL) 33 | pickle.dump(domain_labels, handle, protocol=pickle.HIGHEST_PROTOCOL) 34 | pickle.dump(weak_acc, handle, protocol=pickle.HIGHEST_PROTOCOL) 35 | print('Save weak domain embeddings to', fname) 36 | 37 | 38 | def load_weak_embedding(fname): 39 | with open(fname, 'rb') as handle: 40 | gt_labels = pickle.load(handle) 41 | weak_labels = pickle.load(handle) 42 | weak_acc = pickle.load(handle) 43 | print('Load weak embeddings from', fname) 44 | return gt_labels, weak_labels, weak_acc 45 | 46 | 47 | def save_weak_embedding(gt_labels, weak_labels, weak_acc, fname): 48 | with open(fname, 'wb') as handle: 49 | pickle.dump(gt_labels, handle, protocol=pickle.HIGHEST_PROTOCOL) 50 | pickle.dump(weak_labels, handle, protocol=pickle.HIGHEST_PROTOCOL) 51 | pickle.dump(weak_acc, handle, protocol=pickle.HIGHEST_PROTOCOL) 52 | print('Save weak embeddings to', fname) 53 | 54 | 55 | def load_strong_embedding(fname): 56 | with open(fname, 'rb') as handle: 57 | gt_labels = pickle.load(handle) 58 | embeddings = pickle.load(handle) 59 | print('Load strong embeddings from', fname) 60 | return embeddings, gt_labels 61 | 62 | 63 | def save_strong_embedding(embeddings, gt_labels, fname): 64 | with open(fname, 'wb') as handle: 65 | pickle.dump(gt_labels, handle, protocol=pickle.HIGHEST_PROTOCOL) 66 | pickle.dump(embeddings, handle, protocol=pickle.HIGHEST_PROTOCOL) 67 | print('Save strong embeddings to', fname) 68 | 69 | 70 | def load_result(fname): 71 | df = pd.read_csv(fname, index_col=0) 72 | return df 73 | 74 | 75 | def print_param(model, key=None): 76 | for name, param in model.named_parameters(): 77 | if key is None: 78 | print(name, param.data) 79 | else: 80 | print(name, param.data[key]) 81 | -------------------------------------------------------------------------------- /vision/domainbed.py: -------------------------------------------------------------------------------- 1 | # Adapted from DomainBed 2 | 3 | import os 4 | import torch 5 | from PIL import Image, ImageFile 6 | from torchvision import transforms 7 | import torchvision.datasets.folder 8 | from torch.utils.data import TensorDataset, Subset, ConcatDataset, Dataset 9 | from torchvision.datasets import MNIST, ImageFolder 10 | from torchvision.transforms.functional import rotate 11 | 12 | ImageFile.LOAD_TRUNCATED_IMAGES = True 13 | 14 | DATASETS = [ 15 | # Big images 16 | "VLCS", 17 | "PACS", 18 | "OfficeHome", 19 | "DomainNet", 20 | ] 21 | 22 | def get_dataset_class(dataset_name): 23 | """Return the dataset class with the given name.""" 24 | if dataset_name not in globals(): 25 | raise NotImplementedError("Dataset not found: {}".format(dataset_name)) 26 | return globals()[dataset_name] 27 | 28 | def num_environments(dataset_name): 29 | return len(get_dataset_class(dataset_name).ENVIRONMENTS) 30 | 31 | class MultipleDomainDataset: 32 | N_STEPS = 5001 # Default, subclasses may override 33 | CHECKPOINT_FREQ = 100 # Default, subclasses may override 34 | N_WORKERS = 8 # Default, subclasses may override 35 | ENVIRONMENTS = None # Subclasses should override 36 | INPUT_SHAPE = None # Subclasses should override 37 | 38 | def __getitem__(self, index): 39 | return self.dataset[index] 40 | 41 | def __len__(self): 42 | return len(self.dataset) 43 | 44 | class DomainLabelDataset(Dataset): 45 | """A wrapper for ImageFolder to include domain labels.""" 46 | def __init__(self, dataset, domain_label): 47 | self.dataset = dataset 48 | self.domain_label = domain_label 49 | 50 | def __len__(self): 51 | return len(self.dataset) 52 | 53 | def __getitem__(self, idx): 54 | image, label = self.dataset[idx] 55 | return image, label, self.domain_label 56 | 57 | class MultipleEnvironmentImageFolder(MultipleDomainDataset): 58 | def __init__(self, root, envs, split): 59 | super().__init__() 60 | environments = [f.name for f in os.scandir(root) if f.is_dir()] 61 | environments = sorted(environments) 62 | 63 | transform = transforms.Compose([ 64 | transforms.Resize((224,224)), 65 | transforms.ToTensor(), 66 | transforms.Normalize( 67 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 68 | ]) 69 | 70 | self.datasets = [] 71 | for i, environment in enumerate(environments): 72 | if i in envs: 73 | path = os.path.join(root, environment, split) 74 | env_dataset = ImageFolder(path, transform) 75 | domain_dataset = DomainLabelDataset(env_dataset, domain_label=i) 76 | self.datasets.append(domain_dataset) 77 | print("Loaded data from ", path) 78 | 79 | self.dataset = ConcatDataset(self.datasets) 80 | 81 | self.input_shape = (3, 224, 224,) 82 | self.num_classes = len(self.datasets[-1].dataset.classes) 83 | 84 | class DomainNet(MultipleEnvironmentImageFolder): 85 | CHECKPOINT_FREQ = 1000 86 | ENVIRONMENTS = ["clip", "info", "paint", "quick", "real", "sketch"] 87 | def __init__(self, root, envs, split): 88 | self.dir = os.path.join(root, "domainnet/") 89 | super().__init__(self.dir, envs, split) 90 | -------------------------------------------------------------------------------- /vision/models.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from torch import nn 4 | import torchvision 5 | 6 | import pdb 7 | 8 | class HeadAndEmbedding(nn.Module): 9 | def __init__(self, head): 10 | super(HeadAndEmbedding, self).__init__() 11 | self.head = head 12 | 13 | def forward(self, x): 14 | return x, self.head(x) 15 | 16 | 17 | def _alexnet_replace_fc(model): 18 | model.classifier = HeadAndEmbedding(model.classifier) 19 | return model 20 | 21 | 22 | def resnet50_dino(): 23 | model = torch.hub.load("facebookresearch/dino:main", "dino_resnet50") 24 | return model 25 | 26 | 27 | def vitb8_dino(): 28 | model = torch.hub.load("facebookresearch/dino:main", "dino_vitb8") 29 | return model 30 | 31 | 32 | def vits14_dino(): 33 | model = torch.hub.load("facebookresearch/dinov2", "dinov2_vits14_reg") 34 | return model 35 | 36 | 37 | def vitb14_dino(): 38 | model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitb14_reg") 39 | return model 40 | 41 | 42 | def alexnet(ckpt=None, num_outputs=1000): 43 | model = torchvision.models.alexnet(pretrained=True) 44 | if num_outputs != 1000: 45 | model.classifier[6] = torch.nn.Linear(in_features=model.classifier[6].in_features, out_features=num_outputs) 46 | if os.path.isfile(ckpt): 47 | print("=> loading checkpoint from '{}'".format(ckpt)) 48 | checkpoint = torch.load(ckpt, map_location="cpu") 49 | try: 50 | state_dict = checkpoint['state_dict'] 51 | # Remove all instances of 'module.' from the key 52 | basic_state_dict = {} 53 | for k, v in state_dict.items(): 54 | new_key = k.replace('module.', '') 55 | basic_state_dict[new_key] = v 56 | msg = model.load_state_dict(basic_state_dict, strict=True) 57 | except Exception: 58 | msg = model.load_state_dict(checkpoint, strict=False) 59 | print(msg) 60 | else: 61 | print("=> Load the official Pytorch pretrained checkpoint") 62 | return _alexnet_replace_fc(model) 63 | 64 | 65 | def alexnet1(ckpt=None, num_outputs=1000): 66 | model = torchvision.models.alexnet(pretrained=True) 67 | num_features = model.classifier[1].in_features 68 | model.classifier = nn.Sequential( 69 | nn.Linear(num_features, num_outputs) 70 | ) 71 | if os.path.isfile(ckpt): 72 | print("=> Load checkpoint from '{}'".format(ckpt)) 73 | checkpoint = torch.load(ckpt, map_location="cpu") 74 | msg = model.load_state_dict(checkpoint, strict=False) 75 | print(msg) 76 | else: 77 | print(f"=> Cannot find {ckpt}, load the official Pytorch pretrained checkpoint instead") 78 | return _alexnet_replace_fc(model) 79 | 80 | 81 | def probe(d=2048, n_classes=1000, n_hidden=0): 82 | if n_hidden == 0: 83 | model = nn.Linear(d, n_classes).cuda() 84 | elif n_hidden == 1: 85 | h = int(d/2) 86 | model = nn.Sequential( 87 | nn.Linear(d, h), 88 | nn.ReLU(inplace=True), 89 | nn.Linear(h, n_classes) 90 | ).cuda() 91 | elif n_hidden == 2: 92 | h = int(d/2) 93 | model = nn.Sequential( 94 | nn.Linear(d, h), 95 | nn.ReLU(inplace=True), 96 | nn.Linear(h, h), 97 | nn.ReLU(inplace=True), 98 | nn.Linear(h, n_classes) 99 | ).cuda() 100 | else: 101 | raise NotImplementedError 102 | return model 103 | -------------------------------------------------------------------------------- /vision/data.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | from torch.utils.data import Subset, ConcatDataset 4 | from math import ceil 5 | import tqdm 6 | 7 | import warnings 8 | warnings.filterwarnings("ignore") 9 | 10 | from domainbed import DomainNet 11 | 12 | import pdb 13 | 14 | RESIZE, CROP = 256, 224 15 | TRANSFORM = torchvision.transforms.Compose( 16 | [ 17 | torchvision.transforms.Resize(RESIZE), 18 | torchvision.transforms.CenterCrop(CROP), 19 | torchvision.transforms.ToTensor(), 20 | torchvision.transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), 21 | ] 22 | ) 23 | 24 | 25 | def get_imagenet(datapath, split, batch_size, shuffle, transform=TRANSFORM): 26 | ds = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform) 27 | loader = torch.utils.data.DataLoader(ds, shuffle=shuffle, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8) 28 | return ds, loader 29 | 30 | 31 | def split_imbal_dataset(dataset, n_imbal=4, p_imbal=0.0, num_classes=1000): 32 | images_per_class = len(dataset) // num_classes 33 | subset_size = len(dataset) // (n_imbal+1) 34 | 35 | count_per_class_imbal = images_per_class * n_imbal // (n_imbal+1) 36 | count_per_class_last = images_per_class // (n_imbal+1) 37 | subsets = [] 38 | 39 | imbal_indices = [[] for _ in range(n_imbal)] 40 | for i in range(n_imbal): 41 | class_start = i * num_classes // n_imbal 42 | class_end = (i + 1) * num_classes // n_imbal 43 | 44 | for j in range(num_classes): 45 | if class_start <= j < class_end: 46 | start_idx = j * images_per_class 47 | imbal_indices[i].extend(range(start_idx, start_idx + count_per_class_imbal)) 48 | 49 | for i in range(n_imbal): 50 | indices = imbal_indices[i] 51 | # Ensure the subset size is exactly 1/5th of the dataset 52 | assert len(indices) == subset_size 53 | subsets.append(Subset(dataset, indices)) 54 | 55 | # Last subset with equal representation 56 | equal_indices = [] 57 | for j in range(num_classes): 58 | start_idx = j * images_per_class + count_per_class_imbal 59 | equal_indices.extend(range(start_idx, start_idx + count_per_class_last)) 60 | 61 | subsets.append(Subset(dataset, equal_indices)) 62 | 63 | return subsets 64 | 65 | 66 | def get_imbal_loader(datapath, split, n_train, batch_size, n_imbal, transform=TRANSFORM): 67 | dataset = torchvision.datasets.ImageNet(root=datapath, split=split, transform=transform) 68 | subsets = split_imbal_dataset(dataset, n_imbal) 69 | 70 | combined_indices = [] 71 | for i in range(n_imbal): 72 | combined_indices.extend(subsets[i].indices) 73 | original_dataset = subsets[0].dataset 74 | supset = Subset(original_dataset, combined_indices) 75 | loaders_tr = [torch.utils.data.DataLoader(supset, shuffle=True, batch_size=batch_size, num_workers=min(batch_size//16, 8))] 76 | 77 | loader_te = torch.utils.data.DataLoader(subsets[-1], shuffle=False, batch_size=batch_size, num_workers=min(batch_size//16, 8)) 78 | return loaders_tr, loader_te 79 | 80 | 81 | def get_imbal_sampler(dataset, start_class, end_class, ratio=0.0, num_classes=1000): 82 | # Assign weights to each class 83 | class_weights = [1.0 if start_class <= i < end_class else ratio for i in range(num_classes)] 84 | 85 | # Assign weights to each sample 86 | sample_weights = [class_weights[dataset.targets[i]] for i in range(len(dataset))] 87 | 88 | # Create a WeightedRandomSampler 89 | sampler = torch.utils.data.WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) 90 | 91 | return sampler 92 | 93 | 94 | def get_domain_lables(loader): 95 | all_z = [] 96 | 97 | for x, y, z in tqdm.tqdm(loader): 98 | all_z.append(z) 99 | 100 | all_z = torch.cat(all_z, axis=0) 101 | return all_z 102 | 103 | 104 | def get_domainnet(datapath, split, batch_size, shuffle, envs=[0, 1, 2, 3, 4, 5]): 105 | ds = DomainNet(datapath, envs, split) 106 | loader = torch.utils.data.DataLoader(ds, shuffle=shuffle, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8) 107 | return ds, loader -------------------------------------------------------------------------------- /vision/multi_weak_strong_oracle.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | import os 6 | import json 7 | 8 | from data import get_imagenet 9 | from models import alexnet, resnet50_dino, vitb8_dino, vits14_dino, vitb14_dino, probe 10 | from torch import nn 11 | from helper import save_embedding, save_result, load_weak_embedding, save_weak_embedding, load_strong_embedding, save_strong_embedding 12 | from single_weak_strong import get_model, get_embeddings, train_logreg 13 | 14 | torch.set_printoptions(precision=4, sci_mode=False) 15 | 16 | import pdb 17 | 18 | def main( 19 | *weak_path, 20 | soft_teacher: bool = True, 21 | batch_size: int = 128, 22 | weak_model_name: str = "alexnet", 23 | strong_model_name: str = "resnet50_dino", 24 | n_train: int = 40000, 25 | n_hidden: int = 0, 26 | seed: int = 0, 27 | data_path: str = "/root/", 28 | embed_path: str = "embedding/", 29 | result_path: str = "result/", 30 | ckpt_path: str = "ckpt/", 31 | save_every: int = 0, 32 | n_epochs: int = 10, 33 | lr: float = 1e-3, 34 | ): 35 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) 36 | 37 | num_teacher = len(weak_path) 38 | label_teachers = [] 39 | acc_teachers = [] 40 | for i in range(num_teacher): 41 | teacher_path = weak_path[i] 42 | weak_model = get_model(weak_model_name, teacher_path) 43 | 44 | category = weak_path[i].split('/')[-2] 45 | stage = 'epoch_' + '_'.join(weak_path[i].split('/')[-1].split('.')[0].split('-')[1:]) 46 | fname = os.path.join(embed_path, weak_model_name, f'data_{weak_model_name}_{category}_{stage}.pkl') 47 | 48 | if os.path.exists(fname): 49 | try: 50 | gt_labels, weak_labels, weak_acc = load_weak_embedding(fname) 51 | except Exception as e: 52 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) 53 | save_weak_embedding(gt_labels, weak_labels, weak_acc, fname) 54 | else: 55 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) 56 | save_weak_embedding(gt_labels, weak_labels, weak_acc, fname) 57 | 58 | if not soft_teacher: 59 | weak_labels = nn.functional.one_hot(torch.argmax(weak_labels, dim=1), num_classes=1000).float() 60 | print('Convert teacher outputs to hard class labels') 61 | label_teachers.append(weak_labels) 62 | acc_teachers.append(weak_acc.item()) 63 | print(f"Weak teacher accuracy: {[acc for acc in acc_teachers]}") 64 | 65 | strong_model = get_model(strong_model_name) 66 | fname = os.path.join(embed_path, f'data_{strong_model_name}.pkl') 67 | if os.path.exists(fname): 68 | embeddings, strong_gt_labels = load_strong_embedding(fname) 69 | else: 70 | embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader) 71 | save_strong_embedding(embeddings, strong_gt_labels, fname) 72 | 73 | assert torch.all(gt_labels == strong_gt_labels) 74 | del strong_gt_labels 75 | 76 | if not os.path.exists(result_path): 77 | os.makedirs(result_path) 78 | type_teacher = 'soft' if soft_teacher else 'hard' 79 | prefix = os.path.join(result_path, f'result_{weak_model_name}_{strong_model_name}_{type_teacher}_{num_teacher}_{stage}_student_{n_hidden}_{lr:.6f}') 80 | 81 | order = np.arange(len(embeddings)) 82 | rng = np.random.default_rng(seed) 83 | rng.shuffle(order) 84 | x = embeddings[order] 85 | x_train, x_test = x[:n_train], x[n_train:] 86 | y = gt_labels[order] 87 | y_train, y_test = y[:n_train], y[n_train:] 88 | eval_datasets = {"test": (x_test, y_test)} 89 | print("# examples: ", x_train.shape[0], x_test.shape[0]) 90 | 91 | # multi teacher selectively (oracle) 92 | label_oracle = torch.mean(torch.stack(label_teachers), dim=0) 93 | for i in range(num_teacher): 94 | str_start_end = weak_path[i].split('/')[-2].split('_') 95 | teacher_start = int(str_start_end[0]) 96 | teacher_end = int(str_start_end[1]) 97 | idx_select = (teacher_start <= gt_labels) & (gt_labels < teacher_end) 98 | label_oracle[idx_select] = label_teachers[i][idx_select] 99 | acc_ora = (label_oracle.argmax(dim=1) == gt_labels).float().sum() / gt_labels.shape[0] 100 | yw = label_oracle[order] 101 | yw_train, yw_test = yw[:n_train], yw[n_train:] 102 | 103 | # all data 104 | print("Training logreg on oracle labels...") 105 | results_ora = train_logreg(x_train, yw_train, eval_datasets, n_epochs=n_epochs, lr=lr) 106 | print(f"Teacher accuracy: {acc_ora:.3f}") 107 | print(f"Student accuracy: {results_ora['test']:.3f}") 108 | print(f"Accuracy by epoch: {[acc.item() for acc in results_ora['test_all']]}") 109 | summary = { 110 | 'type': type_teacher, 111 | 'number': num_teacher, 112 | 'stage': stage, 113 | 'teacher': acc_ora.item(), 114 | 'student': results_ora['test'].item(), 115 | } 116 | with open(prefix + '_oracle.json', "w") as outfile: 117 | json.dump(summary, outfile) 118 | 119 | if __name__ == "__main__": 120 | fire.Fire(main) 121 | -------------------------------------------------------------------------------- /vision/run_analysis.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import json 4 | import numpy as np 5 | import pandas as pd 6 | import seaborn as sns 7 | import matplotlib 8 | matplotlib.use('Agg') 9 | import matplotlib.pyplot as plt 10 | 11 | import pdb 12 | 13 | sns.set_style('darkgrid') 14 | 15 | def extract_unique_parts(s, pattern): 16 | # Search for the pattern in the string 17 | match = re.search(pattern, s) 18 | 19 | # Return the matched groups if found 20 | if match: 21 | return match.groups() 22 | else: 23 | return None 24 | 25 | # Set the directory path 26 | def load_data(foldername): 27 | flist = os.listdir(foldername) 28 | 29 | # Initialize a list to store the data 30 | ret = dict() 31 | 32 | # Iterate over each file in the directory 33 | for filename in flist: 34 | if filename.endswith('.json'): 35 | fullname = os.path.join(foldername, filename) 36 | with open(fullname, 'r') as file: 37 | # Load the JSON data from the file and append it to the list 38 | data = json.load(file) 39 | ret[fullname] = data 40 | return ret 41 | 42 | def get_df_csl(data, pattern): 43 | summary = [] 44 | for key, val in data.items(): 45 | match = extract_unique_parts(key, pattern) 46 | if match: 47 | seed = int(key.split('_')[-2]) 48 | criterion = key.split('_')[-3] 49 | setting = val['stage'].split('_') 50 | epoch, iter = int(setting[1]), int(setting[2]) 51 | soft = 'Soft' if (val['type'] == 'soft') else 'Hard' 52 | number = val['number'] 53 | student = val['student'] 54 | teacher = val['teacher'] 55 | rate = val['rate'] 56 | precision = val['precision'] 57 | recall = val['recall'] 58 | for i in range(len(number)): 59 | summary.append([epoch, iter, soft, seed, int(number[i]), criterion, 60 | float(student[i]), float(teacher[i]), 61 | float(rate[i]), float(precision[i]), float(recall[i])]) 62 | df = pd.DataFrame(data = np.array(summary), columns= ["Epoch", "Iter", "Type", "Seed", "Number", "Criterion", 63 | "Student", "Teacher", "Rate", "Precision", "Recall"]) 64 | df['Epoch'] = df['Epoch'].astype(int) 65 | df['Iter'] = df['Iter'].astype(int) 66 | df['Seed'] = df['Seed'].astype(int) 67 | df['Number'] = df['Number'].astype(int) 68 | df['Student'] = df['Student'].astype(float) 69 | df['Teacher'] = df['Teacher'].astype(float) 70 | df['Rate'] = df['Rate'].astype(float) 71 | df['Precision'] = df['Precision'].astype(float) 72 | df['Recall'] = df['Recall'].astype(float) 73 | df = df.sort_values(by=['Epoch', 'Iter', 'Type', 'Number', "Seed", "Criterion"], ascending=[True, True, True, True, True, True]) 74 | df.reset_index(drop=True, inplace=True) 75 | return df 76 | 77 | def main(df_epoch, df_iter, df_denoise): 78 | # load result 79 | data_all = load_data('result') 80 | pattern = r'_1_2_4_8_(.*?)_csl.json' 81 | df = get_df_csl(data_all, pattern) 82 | 83 | # compute Recovery 84 | ceil = 0.74 85 | df['Base'] = df['Rate'] * 0 86 | 87 | for epoch in df['Epoch'].unique(): 88 | for iter in df['Iter'].unique(): 89 | for type in df['Type'].unique(): 90 | for seed in df['Seed'].unique(): 91 | cond = (df['Epoch'] == epoch) & (df['Iter'] == iter) & (df['Type'] == type) & (df['Seed'] == seed) 92 | if cond.any(): 93 | df.loc[cond, 'Base'] = df.loc[cond & (df['Number'] == 1), 'Teacher'].iloc[0] 94 | df['Recovery'] = (df['Student'] - df['Base']) / (ceil - df['Base']) * 100 95 | 96 | # Dataframe condition 97 | dft = df[(df['Epoch'] == df_epoch) & (df['Iter'] == df_iter) & (df['Criterion'] == df_denoise)] 98 | 99 | # Calculate the mean and standard deviation for each group 100 | grouped = dft.groupby(['Number', 'Type'])['Recovery'] 101 | means = grouped.mean().reset_index(name='Recovery_mean') 102 | stds = grouped.std().reset_index(name='Recovery_std') 103 | 104 | # Merge the standard deviations with the means 105 | merged_df = pd.merge(means, stds, on=['Number', 'Type']) 106 | 107 | plt.figure(figsize=(5, 4)) 108 | # Use the combined column for hue 109 | bar_plot = sns.barplot(data=merged_df, x='Number', y='Recovery_mean', hue='Type', errorbar=None) 110 | 111 | # Iterate over the bars to add error bars 112 | for i, bar in enumerate(bar_plot.patches): 113 | # Calculate index for merged_df to get the std deviation 114 | # Considering multiple bars for each 'Number', adjust index calculation if necessary 115 | data_index = i % len(merged_df['Type'].unique()) 116 | std = merged_df.iloc[data_index]['Recovery_std'] # Adjust this line if necessary 117 | 118 | # Position of the error bar 119 | x = bar.get_x() + bar.get_width() / 2 120 | y = bar.get_height() 121 | 122 | # Add error bar 123 | plt.errorbar(x, y, yerr=std, fmt='none', color='black', capsize=5) 124 | 125 | # Define new legend labels (ensure the order and number of labels match your plot) 126 | new_labels = ['Hard', 'Soft'] 127 | # Set the legend on the plot to use the new labels 128 | plt.legend(title=None, labels=new_labels, bbox_to_anchor=(0.5, 1.15), loc='upper center', ncol=2) # Adjust ncol as needed 129 | 130 | vmin = dft['Recovery'].min() * 0.95 131 | vmax = dft['Recovery'].max() * 1.05 132 | 133 | plt.xlabel('Number of Weak Supervisors') 134 | plt.ylabel('PGR (%)') 135 | plt.ylim([vmin, vmax]) # Adjust ylim according to your data 136 | plt.tight_layout() 137 | 138 | # Get the current figure using plt.gcf() after plotting 139 | fig = plt.gcf() 140 | fig.set_size_inches(5, 4) 141 | figname = 'figure/csl_imagenet.png' # Updated figname for clarity 142 | 143 | # Save the figure 144 | fig.savefig(figname, dpi=300, bbox_inches='tight', pad_inches=0.0) 145 | 146 | print(f'Save figure to {figname}') 147 | 148 | if __name__ == "__main__": 149 | iteration = 1000 150 | epoch = 0 151 | denoise = 'top3' 152 | main(epoch, iteration, denoise) -------------------------------------------------------------------------------- /vision/single_weak_strong.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | import os 6 | import json 7 | 8 | from data import get_imagenet, get_domainnet 9 | from models import alexnet, resnet50_dino, vitb8_dino, vits14_dino, vitb14_dino, probe, alexnet1 10 | from torch import nn 11 | from helper import save_embedding, load_embedding, save_result, load_weak_embedding, load_weak_domain_embedding, save_weak_embedding, load_strong_embedding, save_strong_embedding 12 | 13 | import pdb 14 | 15 | 16 | def get_model(name, path=None, num_outputs=1000): 17 | if name == "alexnet": 18 | model = alexnet(path, num_outputs) 19 | elif name == "alexnet1": 20 | model = alexnet1(path, num_outputs) 21 | elif name == "resnet50_dino": 22 | model = resnet50_dino() 23 | elif name == "vitb8_dino": 24 | model = vitb8_dino() 25 | elif name == "vits14_dino": 26 | model = vits14_dino() 27 | elif name == "vitb14_dino": 28 | model = vitb14_dino() 29 | else: 30 | raise ValueError(f"Unknown model {name}") 31 | model.cuda() 32 | model.eval() 33 | model = nn.DataParallel(model) 34 | return model 35 | 36 | 37 | def get_embeddings(model, loader): 38 | all_embeddings, all_y, all_probs = [], [], [] 39 | 40 | for x, y in tqdm.tqdm(loader): 41 | output = model(x.cuda()) 42 | if len(output) == 2: 43 | embeddings, logits = output 44 | probs = torch.nn.functional.softmax(logits, dim=-1).detach().cpu() 45 | all_probs.append(probs) 46 | else: 47 | embeddings = output 48 | 49 | all_embeddings.append(embeddings.detach().cpu()) 50 | all_y.append(y) 51 | 52 | all_embeddings = torch.cat(all_embeddings, axis=0) 53 | all_y = torch.cat(all_y, axis=0) 54 | if len(all_probs) > 0: 55 | all_probs = torch.cat(all_probs, axis=0) 56 | acc = (torch.argmax(all_probs, dim=1) == all_y).float().mean() 57 | else: 58 | all_probs = None 59 | acc = None 60 | return all_embeddings, all_y, all_probs, acc 61 | 62 | 63 | def train_logreg( 64 | x_train, 65 | y_train, 66 | eval_datasets, 67 | n_epochs=10, 68 | weight_decay=0.0, 69 | lr=1.0e-3, 70 | batch_size=100, 71 | n_classes=1000, 72 | ckpt_path=None, 73 | save_every=0, 74 | lw=None, 75 | model=None, 76 | verbose=True, 77 | ): 78 | x_train = x_train.float() 79 | train_ds = torch.utils.data.TensorDataset(x_train, y_train) 80 | train_loader = torch.utils.data.DataLoader(train_ds, shuffle=True, batch_size=batch_size, num_workers=min(batch_size//16, 8)) # <-- add num_workers=min(batch_size//16, 8) 81 | 82 | d = x_train.shape[1] 83 | if model is None: 84 | model = probe(d, n_classes) 85 | if verbose: print('Initialize model') 86 | else: 87 | if verbose: print('Warm-start model') 88 | 89 | criterion = torch.nn.CrossEntropyLoss() 90 | optimizer = torch.optim.Adam(model.parameters(), weight_decay=weight_decay, lr=lr) 91 | n_batches = len(train_loader) 92 | n_iter = n_batches * n_epochs 93 | schedule = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter) 94 | 95 | results = {f"{key}_all": [] for key in eval_datasets.keys()} 96 | results["train_all"] = [] 97 | 98 | # if lw is not None: 99 | # nsample = 1000 100 | # ntotal = y_train.shape[0] 101 | # order = torch.argsort(lw) 102 | # x_sorted = x_train[order] 103 | # y_sorted = y_train[order] 104 | # x_train_clean = x_sorted[:nsample] 105 | # y_train_clean = y_sorted[:nsample].argmax(dim=1) 106 | # x_train_noise = x_sorted[-nsample:] 107 | # y_train_noise = y_sorted[-nsample:].argmax(dim=1) 108 | 109 | # half = int(nsample/2) 110 | # low = int(ntotal/10*4) 111 | # x_train_low = x_sorted[low-half:low+half] 112 | # y_train_low = y_sorted[low-half:low+half].argmax(dim=1) 113 | # high = int(ntotal/10*6) 114 | # x_train_high = x_sorted[high-half:high+half] 115 | # y_train_high = y_sorted[high-half:high+half].argmax(dim=1) 116 | 117 | # results["train_clean"] = [] 118 | # results["train_noise"] = [] 119 | # results["train_low"] = [] 120 | # results["train_high"] = [] 121 | 122 | if verbose: 123 | pbar = tqdm.tqdm(range(n_epochs), desc="Epoch 0") 124 | else: 125 | pbar = range(n_epochs) 126 | for epoch in pbar: 127 | correct, total = 0, 0 128 | for x, y in train_loader: 129 | x, y = x.cuda(), y.cuda() 130 | optimizer.zero_grad() 131 | pred = model(x) 132 | loss = criterion(pred, y) 133 | loss.backward() 134 | optimizer.step() 135 | schedule.step() 136 | if len(y.shape) > 1: 137 | y = torch.argmax(y, dim=1) 138 | correct += (torch.argmax(pred, -1) == y).detach().float().sum().item() 139 | total += len(y) 140 | if verbose: 141 | pbar.set_description(f"Epoch {epoch}, Train Acc {correct / total:.3f}") 142 | results["train_all"].append(correct / total) 143 | 144 | for key, (x_test, y_test) in eval_datasets.items(): 145 | x_test = x_test.float().cuda() 146 | pred = torch.argmax(model(x_test), axis=-1).detach().cpu() 147 | acc = (pred == y_test).float().mean() 148 | results[f"{key}_all"].append(acc) 149 | 150 | # if lw is not None: 151 | # pred = torch.argmax(model(x_train_clean.float().cuda()), axis=-1).detach().cpu() 152 | # acc_clean = (pred == y_train_clean).float().mean() 153 | # results["train_clean"].append(acc_clean) 154 | 155 | # pred = torch.argmax(model(x_train_noise.float().cuda()), axis=-1).detach().cpu() 156 | # acc_noise = (pred == y_train_noise).float().mean() 157 | # results["train_noise"].append(acc_noise) 158 | 159 | # pred = torch.argmax(model(x_train_low.float().cuda()), axis=-1).detach().cpu() 160 | # acc_noise = (pred == y_train_low).float().mean() 161 | # results["train_low"].append(acc_noise) 162 | 163 | # pred = torch.argmax(model(x_train_high.float().cuda()), axis=-1).detach().cpu() 164 | # acc_noise = (pred == y_train_high).float().mean() 165 | # results["train_high"].append(acc_noise) 166 | 167 | for key in eval_datasets.keys(): 168 | results[key] = results[f"{key}_all"][-1] 169 | return results 170 | 171 | 172 | def main( 173 | batch_size: int = 128, 174 | soft_teacher: bool = True, 175 | weak_model_name: str = "alexnet", 176 | strong_model_name: str = "resnet50_dino", 177 | n_train: int = 40000, 178 | seed: int = 0, 179 | data_name: str = "imagenet", # [imagenet, domainnet] 180 | data_path: str = "/root/", 181 | embed_path: str = "embedding/", 182 | result_path: str = "result/", 183 | ckpt_path: str = "ckpt/", 184 | weak_path: str = "", 185 | save_every: int = 0, 186 | n_epochs: int = 10, 187 | lr: float = 1e-3, 188 | ): 189 | if data_name == "imagenet": 190 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) 191 | num_classes = 1000 192 | elif data_name == "domainnet": 193 | _, loader = get_domainnet(data_path, split="val", batch_size=batch_size, shuffle=False) 194 | num_classes = 345 195 | else: 196 | raise NotImplementedError 197 | 198 | weak_model = get_model(weak_model_name, weak_path, num_classes) 199 | category = weak_path.split('/')[-2] 200 | stage = 'epoch_' + '_'.join(weak_path.split('/')[-1].split('.')[0].split('-')[1:]) 201 | fname = os.path.join(embed_path, f'data_{weak_model_name}_{category}_{stage}.pkl') 202 | if os.path.exists(fname): 203 | if data_name == "domainnet": 204 | gt_labels, weak_labels, domain_labels, weak_acc = load_weak_domain_embedding(fname) 205 | else: 206 | gt_labels, weak_labels, weak_acc = load_weak_embedding(fname) 207 | else: 208 | print(f"Cannot load weak embeddings from {fname}, generating them from weak model") 209 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) 210 | save_weak_embedding(gt_labels, weak_labels, weak_acc, fname) 211 | if not soft_teacher: 212 | weak_labels = nn.functional.one_hot(torch.argmax(weak_labels, dim=1), num_classes=num_classes).float() 213 | print('Convert teacher outputs to hard class labels') 214 | print(f"Weak label accuracy: {weak_acc:.3f}") 215 | 216 | strong_model = get_model(strong_model_name, num_outputs=num_classes) 217 | fname = os.path.join(embed_path, f'data_{strong_model_name}.pkl') 218 | if os.path.exists(fname): 219 | embeddings, strong_gt_labels = load_strong_embedding(fname) 220 | else: 221 | embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader) 222 | save_strong_embedding(embeddings, strong_gt_labels, fname) 223 | 224 | assert torch.all(gt_labels == strong_gt_labels) 225 | del strong_gt_labels 226 | 227 | order = np.arange(len(embeddings)) 228 | rng = np.random.default_rng(seed) 229 | rng.shuffle(order) 230 | x = embeddings[order] 231 | y = gt_labels[order] 232 | yw = weak_labels[order] 233 | x_train, x_test = x[:n_train], x[n_train:] 234 | y_train, y_test = y[:n_train], y[n_train:] 235 | yw_train, yw_test = yw[:n_train], yw[n_train:] 236 | yw_test = torch.argmax(yw_test, dim=1) 237 | eval_datasets = {"test": (x_test, y_test), "test_weak": (x_test, yw_test)} 238 | print("# examples: ", x_train.shape[0], x_test.shape[0]) 239 | 240 | print("Training logreg on weak labels...") 241 | results_weak = train_logreg(x_train, yw_train, eval_datasets, n_epochs=n_epochs, lr=lr, n_classes=num_classes) 242 | print(f"Final accuracy: {results_weak['test']:.3f}") 243 | print(f"Final supervisor-student agreement: {results_weak['test_weak']:.3f}") 244 | print(f"Accuracy by epoch: {[acc.item() for acc in results_weak['test_all']]}") 245 | print( 246 | f"Supervisor-student agreement by epoch: {[acc.item() for acc in results_weak['test_weak_all']]}" 247 | ) 248 | 249 | print("Training logreg on ground truth labels...") 250 | results_gt = train_logreg(x_train, y_train, eval_datasets, n_epochs=n_epochs, lr=lr, n_classes=num_classes) 251 | print(f"Final accuracy: {results_gt['test']:.3f}") 252 | print(f"Accuracy by epoch: {[acc.item() for acc in results_gt['test_all']]}") 253 | 254 | print("\n\n" + "=" * 100) 255 | print(f"Weak label accuracy: {weak_acc:.3f}") 256 | print(f"Weak→Strong accuracy: {results_weak['test']:.3f}") 257 | print(f"Strong accuracy: {results_gt['test']:.3f}") 258 | print(f"Accuracy recovery: {(results_weak['test'] - weak_acc) / (results_gt['test'] - weak_acc):.3f}") 259 | print("=" * 100) 260 | 261 | type_teacher = 'soft' if soft_teacher else 'hard' 262 | fname = os.path.join(result_path, f'result_{weak_model_name}_{strong_model_name}_{type_teacher}_{stage}_student_{lr:.6f}.json') 263 | summary = { 264 | 'teacher': weak_acc.item(), 265 | 'baseline': results_weak['test'].item(), 266 | 'groundtruth': results_gt['test'].item(), 267 | } 268 | with open(fname, "w") as outfile: 269 | json.dump(summary, outfile) 270 | 271 | if __name__ == "__main__": 272 | fire.Fire(main) 273 | -------------------------------------------------------------------------------- /vision/multi_weak_strong_csl.py: -------------------------------------------------------------------------------- 1 | import fire 2 | import numpy as np 3 | import torch 4 | import tqdm 5 | import os 6 | import json 7 | 8 | from data import get_imagenet 9 | from models import alexnet, resnet50_dino, vitb8_dino, vits14_dino, vitb14_dino, probe 10 | from torch import nn 11 | from helper import save_embedding, save_result, load_weak_embedding, save_weak_embedding, load_strong_embedding, save_strong_embedding 12 | from single_weak_strong import get_model, get_embeddings, train_logreg 13 | 14 | torch.set_printoptions(precision=4, sci_mode=False) 15 | 16 | 17 | import pdb 18 | 19 | 20 | def get_conservative_estimate(y_next, y_prev, maxk=2): 21 | _, pred = y_prev.topk(maxk, 1, True, True) 22 | pred = pred.t() 23 | correct = pred.eq(y_next.argmax(dim=1).view(1, -1).expand_as(pred)) 24 | correct = correct.t() 25 | consensus = correct.any(dim=1) 26 | return consensus 27 | 28 | 29 | def get_oracle_rank(y_next, y_oracle): 30 | lw_train = nn.functional.cross_entropy(y_next.log(), y_oracle, reduction='none') 31 | lw_order = torch.argsort(lw_train) # ascending order 32 | lw_ranks = torch.argsort(lw_order) 33 | return lw_ranks 34 | 35 | 36 | def get_student_rank(logit_student, y_next): 37 | lw_train = nn.functional.cross_entropy(logit_student, y_next, reduction='none') 38 | lw_order = torch.argsort(lw_train) # ascending order 39 | lw_ranks = torch.argsort(lw_order) 40 | return lw_ranks 41 | 42 | 43 | def get_output(model, x_train, num_classes=1000, batch_size=1024): 44 | dataset = torch.utils.data.TensorDataset(x_train) 45 | data_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False) 46 | 47 | num_sample = x_train.shape[0] 48 | output = torch.zeros((num_sample, num_classes)) 49 | 50 | with torch.no_grad(): 51 | for batch_idx, x_batch in enumerate(data_loader): 52 | x = x_batch[0].cuda() 53 | pred = model(x) 54 | output[batch_idx * batch_size: batch_idx * batch_size + x.size(0)] = pred.cpu() 55 | 56 | return output 57 | 58 | 59 | def get_precision_recall(pred, target): 60 | # Calculate TP, FP, FN 61 | TP = torch.sum((pred == 1) & (target == 1)) 62 | FP = torch.sum((pred == 1) & (target == 0)) 63 | FN = torch.sum((pred == 0) & (target == 1)) 64 | 65 | # Calculate Precision and Recall 66 | precision = TP.float() / (TP + FP) if (TP + FP) > 0 else 0 67 | recall = TP.float() / (TP + FN) if (TP + FN) > 0 else 0 68 | 69 | return precision.item(), recall.item() 70 | 71 | 72 | def normalize_topk_confidence(pred_prev, pred_curr, k=5): 73 | topk_values, topk_indices = torch.topk(pred_curr, k, dim=1) 74 | 75 | logit_prev = torch.log(pred_prev) 76 | logit_curr = torch.log(pred_curr) 77 | normalized_logit_prev = torch.gather(logit_prev, 1, topk_indices) 78 | normalized_logit_curr = torch.gather(logit_curr, 1, topk_indices) 79 | 80 | normalized_confidence_prev = torch.nn.functional.softmax(normalized_logit_prev, dim=1) 81 | normalized_confidence_curr = torch.nn.functional.softmax(normalized_logit_curr, dim=1) 82 | 83 | return normalized_confidence_prev.max(dim=1)[0], normalized_confidence_curr.max(dim=1)[0] 84 | 85 | 86 | def get_consensus_rate(teacher_prev, teacher_curr, ground_truth=None): 87 | consensus_tt = (teacher_prev.argmax(dim=1) == teacher_curr.argmax(dim=1)) 88 | rate_top1_consensus = (consensus_tt).sum() / consensus_tt.shape[0] 89 | # print(f'teacher-teacher top1 consensus rate: {rate_top1_consensus:.2f}') 90 | 91 | consensus_top2 = get_conservative_estimate(teacher_curr, teacher_prev, 2) 92 | rate_top2_consensus = (consensus_top2).sum() / consensus_top2.shape[0] 93 | # print(f'teacher-teacher top2 consensus rate: {rate_top2_consensus:.2f}') 94 | 95 | consensus_top3 = get_conservative_estimate(teacher_curr, teacher_prev, 3) 96 | rate_top3_consensus = (consensus_top3).sum() / consensus_top3.shape[0] 97 | # print(f'teacher-teacher top3 consensus rate: {rate_top3_consensus:.2f}') 98 | 99 | if ground_truth is not None: 100 | consensus_prev = (teacher_prev.argmax(dim=1) == ground_truth) 101 | rate_prev = (consensus_prev).sum() / consensus_prev.shape[0] 102 | print(f'previous teacher accuracy: {rate_prev:.2f}') 103 | 104 | consensus_curr = (teacher_curr.argmax(dim=1) == ground_truth) 105 | rate_curr = (consensus_curr).sum() / consensus_curr.shape[0] 106 | print(f'current teacher accuracy: {rate_curr:.2f}') 107 | 108 | precision_prev, recall_prev = get_precision_recall(consensus_tt, consensus_prev) 109 | print(f'consensus for previous teacher: precision = {precision_prev:.2f}, recall = {recall_prev:.2f}') 110 | precision_curr, recall_curr = get_precision_recall(consensus_tt, consensus_curr) 111 | print(f'consensus for current teacher: precision = {precision_curr:.2f}, recall = {recall_curr:.2f}') 112 | 113 | # confidence consistent 114 | p_prev, y_prev = teacher_prev.max(dim=1) 115 | p_curr, y_curr = teacher_curr.max(dim=1) 116 | consistent_tt = consensus_tt & (p_curr >= p_prev) 117 | rate_tt_consistent = (consistent_tt).sum() / consistent_tt.shape[0] 118 | print(f'teacher-teacher consistent rate: {rate_tt_consistent:.2f}') 119 | 120 | precision_prev, recall_prev = get_precision_recall(consistent_tt, consensus_prev) 121 | print(f'consistent for previous teacher: precision = {precision_prev:.2f}, recall = {recall_prev:.2f}') 122 | precision_curr, recall_curr = get_precision_recall(consistent_tt, consensus_curr) 123 | print(f'consistent for current teacher: precision = {precision_curr:.2f}, recall = {recall_curr:.2f}') 124 | 125 | return rate_top1_consensus.item(), rate_top2_consensus.item(), rate_top3_consensus.item() 126 | 127 | 128 | def main( 129 | *weak_path, 130 | soft_teacher: bool = True, 131 | batch_size: int = 128, 132 | weak_model_name: str = "alexnet", 133 | strong_model_name: str = "resnet50_dino", 134 | denoise_criterion: str = "top3", 135 | n_train: int = 40000, 136 | seed: int = 0, 137 | data_path: str = "/root/", 138 | embed_path: str = "embedding/", 139 | result_path: str = "result/", 140 | ckpt_path: str = "ckpt/", 141 | save_every: int = 0, 142 | n_epochs: int = 10, 143 | lr: float = 1e-3, 144 | num_classes: int = 1000, 145 | ): 146 | label_layers = [] 147 | label_teachers = [] 148 | acc_teachers = [] 149 | for teacher_path in weak_path: 150 | category = teacher_path.split('/')[-2] 151 | stage = 'epoch_' + '_'.join(teacher_path.split('/')[-1].split('.')[0].split('-')[1:]) 152 | fname = os.path.join(embed_path, weak_model_name, f'data_{weak_model_name}_{category}_{stage}.pkl') 153 | 154 | str_start_end = category.split('_') 155 | teacher_start, teacher_end = int(str_start_end[0]), int(str_start_end[1]) 156 | 157 | if os.path.exists(fname): 158 | try: 159 | gt_labels, weak_labels, weak_acc = load_weak_embedding(fname) 160 | except Exception as e: 161 | weak_model = get_model(weak_model_name, teacher_path) 162 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) 163 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) 164 | save_weak_embedding(gt_labels, weak_labels, weak_acc, fname) 165 | else: 166 | weak_model = get_model(weak_model_name, teacher_path) 167 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) 168 | _, gt_labels, weak_labels, weak_acc = get_embeddings(weak_model, loader) 169 | save_weak_embedding(gt_labels, weak_labels, weak_acc, fname) 170 | 171 | if not soft_teacher: 172 | weak_labels = nn.functional.one_hot(torch.argmax(weak_labels, dim=1), num_classes=num_classes).float() 173 | print('Convert teacher outputs to hard class labels') 174 | 175 | label_teachers.append(weak_labels) 176 | acc_teachers.append(weak_acc.item()) 177 | 178 | if teacher_end == num_classes: 179 | label_layers.append(label_teachers) 180 | print(f"Weak teacher accuracy: {[acc for acc in acc_teachers]}") 181 | label_teachers = [] 182 | acc_teachers = [] 183 | 184 | num_teacher_all = [len(layer) for layer in label_layers] 185 | 186 | print('Teacher #:', num_teacher_all) 187 | 188 | strong_model = get_model(strong_model_name) 189 | fname = os.path.join(embed_path, f'data_{strong_model_name}.pkl') 190 | if os.path.exists(fname): 191 | embeddings, strong_gt_labels = load_strong_embedding(fname) 192 | else: 193 | _, loader = get_imagenet(data_path, split="val", batch_size=batch_size, shuffle=False) 194 | embeddings, strong_gt_labels, _, _ = get_embeddings(strong_model, loader) 195 | save_strong_embedding(embeddings, strong_gt_labels, fname) 196 | 197 | assert torch.all(gt_labels == strong_gt_labels) 198 | del strong_gt_labels 199 | 200 | if not os.path.exists(result_path): 201 | os.makedirs(result_path) 202 | type_teacher = 'soft' if soft_teacher else 'hard' 203 | num_teacher_str = '_'.join(map(str, num_teacher_all)) 204 | prefix = os.path.join(result_path, f'result_{weak_model_name}_{strong_model_name}_{type_teacher}_{num_teacher_str}_{stage}_student_{lr:.6f}_{denoise_criterion}_{seed}') 205 | 206 | order = np.arange(len(embeddings)) 207 | rng = np.random.default_rng(seed) 208 | rng.shuffle(order) 209 | x = embeddings[order] 210 | x_train, x_test = x[:n_train], x[n_train:] 211 | y = gt_labels[order] 212 | y_train, y_test = y[:n_train], y[n_train:] 213 | eval_datasets = {"test": (x_test, y_test)} 214 | print("# examples: ", x_train.shape[0], x_test.shape[0]) 215 | 216 | # main loop 217 | result_teacher = list() 218 | result_student = list() 219 | result_rate = list() 220 | result_precision = list() 221 | result_recall = list() 222 | yw_prev = None 223 | logit_prev = None 224 | 225 | for label_teachers in label_layers: 226 | 227 | # teacher assignment 228 | num_teacher = len(label_teachers) 229 | if num_teacher > 1: 230 | weak_stack = torch.stack(label_teachers, dim=1) 231 | gt_assign = torch.zeros((weak_stack.shape[0], num_teacher)) 232 | 233 | # class mapping 234 | mapping_class_teacher = torch.zeros((num_classes, num_teacher)) 235 | num_classes_per_teacher = num_classes // num_teacher 236 | for i in range(num_teacher): 237 | teacher_start = i * num_classes_per_teacher 238 | teacher_end = (i+1) * num_classes_per_teacher 239 | mapping_class_teacher[teacher_start:teacher_end, i] = 1.0 240 | idx_select = (teacher_start <= gt_labels) & (gt_labels < teacher_end) 241 | gt_assign[idx_select, i] = 1.0 242 | gt_labels_full = nn.functional.one_hot(gt_labels, num_classes=1000).float() 243 | gt_assign_full = torch.matmul(gt_labels_full, mapping_class_teacher) 244 | 245 | yws = weak_stack[order] 246 | zw = gt_assign_full[order] 247 | 248 | yws_train, zw_train = yws[:n_train], zw[:n_train] 249 | 250 | pred_output = logit_prev # assignment by previous student 251 | # pred_output = yw_prev # assingment by previous teacher 252 | hard_assign = torch.matmul(nn.functional.one_hot(pred_output.argmax(dim=1), num_classes=1000).float(), mapping_class_teacher) 253 | acc_hard = (hard_assign.argmax(dim=1) == zw_train.argmax(dim=1)).sum() / hard_assign.shape[0] 254 | yw_train = (yws_train * hard_assign.unsqueeze(2)).sum(dim=1) 255 | 256 | # oracel assignment 257 | # yw_train = (yws_train * zw_train.unsqueeze(2)).sum(dim=1) 258 | 259 | else: 260 | yw = label_teachers[0][order] 261 | yw_train = yw[:n_train] 262 | 263 | acc_teacher = ((yw_train.argmax(dim=1) == y_train).sum() / n_train).item() 264 | result_teacher.append(acc_teacher) 265 | print(f"teacher x{num_teacher}: collective accuracy = {acc_teacher:.3f}") 266 | 267 | # Sample selection 268 | if num_teacher > 1 and denoise_criterion != 'all': 269 | rate_top1, rate_top2, rate_top3 = get_consensus_rate(torch.nn.functional.softmax(logit_prev, dim=1), yw_train) 270 | if denoise_criterion == 'top1': 271 | rate_keep = rate_top1 272 | elif denoise_criterion == 'top2': 273 | rate_keep = rate_top2 274 | elif denoise_criterion == 'top3': 275 | rate_keep = rate_top3 276 | elif denoise_criterion == 'oracle': 277 | raise NotImplementedError 278 | else: 279 | raise NotImplementedError 280 | else: 281 | rate_keep = 1.0 282 | 283 | if rate_keep < 1.0: 284 | rank_train = get_student_rank(logit_prev, yw_train.argmax(dim=1)) 285 | keep_train = rank_train < (n_train * rate_keep) 286 | else: 287 | keep_train = torch.ones(n_train, dtype=torch.bool) 288 | 289 | correct_train = (yw_train.argmax(dim=1) == y_train) 290 | precision, recall = get_precision_recall(keep_train, correct_train) 291 | print(f"Select {rate_keep*100:.1f}% samples: precision = {precision:.3f}, recall = {recall:.3f}") 292 | result_rate.append(rate_keep) 293 | result_precision.append(precision) 294 | result_recall.append(recall) 295 | 296 | # Train 297 | print(f"Training logreg on the selected weak labels...") 298 | x_conserve, yw_conserve = x_train[keep_train], yw_train[keep_train] 299 | epochs_conserve = int(n_epochs / rate_keep) # keep the same number of iterations 300 | model = probe(x_train.shape[1], num_classes) 301 | results_weak = train_logreg(x_conserve, yw_conserve, eval_datasets, n_epochs=epochs_conserve, lr=lr, ckpt_path=ckpt_path, save_every=save_every, model=model) 302 | print(f"Final accuracy: {results_weak['test']:.3f}") 303 | result_student.append(max(results_weak['test_all']).item()) 304 | 305 | # Update 306 | yw_prev = yw_train 307 | logit_prev = get_output(model, x_train) 308 | 309 | summary = { 310 | 'type': type_teacher, 311 | 'number': num_teacher_all, 312 | 'stage': stage, 313 | 'rate': result_rate, 314 | 'precision': result_precision, 315 | 'recall': result_recall, 316 | 'teacher': result_teacher, 317 | 'student': result_student, 318 | } 319 | with open(prefix + '_csl.json', "w") as outfile: 320 | json.dump(summary, outfile) 321 | 322 | print('Teacher accuracy: \n', result_teacher) 323 | print('Student accuracy: \n', result_student) 324 | 325 | 326 | if __name__ == "__main__": 327 | fire.Fire(main) 328 | -------------------------------------------------------------------------------- /vision/prep_imagenet.py: -------------------------------------------------------------------------------- 1 | # Adapted from https://github.com/facebookresearch/moco/blob/main/main_lincls.py 2 | 3 | import argparse 4 | import builtins 5 | import os 6 | import random 7 | import shutil 8 | import time 9 | import warnings 10 | 11 | import torch 12 | import torch.backends.cudnn as cudnn 13 | import torch.distributed as dist 14 | import torch.multiprocessing as mp 15 | import torch.nn as nn 16 | import torch.nn.parallel 17 | import torch.optim 18 | import torch.utils.data 19 | import torch.utils.data.distributed 20 | import torchvision.datasets as datasets 21 | import torchvision.models as models 22 | import torchvision.transforms as transforms 23 | 24 | import pdb 25 | from helper import print_param 26 | from data import get_imbal_sampler 27 | 28 | model_names = sorted( 29 | name 30 | for name in models.__dict__ 31 | if name.islower() and not name.startswith("__") and callable(models.__dict__[name]) 32 | ) 33 | 34 | parser = argparse.ArgumentParser(description="PyTorch ImageNet Training") 35 | parser.add_argument("data", metavar="DIR", help="path to dataset") 36 | parser.add_argument( 37 | "-a", 38 | "--arch", 39 | metavar="ARCH", 40 | default="resnet50", 41 | choices=model_names, 42 | help="model architecture: " + " | ".join(model_names) + " (default: resnet50)", 43 | ) 44 | parser.add_argument( 45 | "-j", 46 | "--workers", 47 | default=32, 48 | type=int, 49 | metavar="N", 50 | help="number of data loading workers (default: 32)", 51 | ) 52 | parser.add_argument( 53 | "--epochs", default=100, type=int, metavar="N", help="number of total epochs to run" 54 | ) 55 | parser.add_argument( 56 | "--start-epoch", 57 | default=0, 58 | type=int, 59 | metavar="N", 60 | help="manual epoch number (useful on restarts)", 61 | ) 62 | parser.add_argument( 63 | "-b", 64 | "--batch-size", 65 | default=256, 66 | type=int, 67 | metavar="N", 68 | help="mini-batch size (default: 256), this is the total " 69 | "batch size of all GPUs on the current node when " 70 | "using Data Parallel or Distributed Data Parallel", 71 | ) 72 | parser.add_argument( 73 | "--lr", 74 | "--learning-rate", 75 | default=30.0, 76 | type=float, 77 | metavar="LR", 78 | help="initial learning rate", 79 | dest="lr", 80 | ) 81 | parser.add_argument( 82 | "--schedule", 83 | default=[60, 80], 84 | nargs="*", 85 | type=int, 86 | help="learning rate schedule (when to drop lr by a ratio)", 87 | ) 88 | parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum") 89 | parser.add_argument( 90 | "--wd", 91 | "--weight-decay", 92 | default=0.0, 93 | type=float, 94 | metavar="W", 95 | help="weight decay (default: 0.)", 96 | dest="weight_decay", 97 | ) 98 | parser.add_argument( 99 | "-p", 100 | "--print-freq", 101 | default=10, 102 | type=int, 103 | metavar="N", 104 | help="print frequency (default: 10)", 105 | ) 106 | parser.add_argument( 107 | "--resume", 108 | default="", 109 | type=str, 110 | metavar="PATH", 111 | help="path to latest checkpoint (default: none)", 112 | ) 113 | parser.add_argument( 114 | "-e", 115 | "--evaluate", 116 | dest="evaluate", 117 | action="store_true", 118 | help="evaluate model on validation set", 119 | ) 120 | parser.add_argument( 121 | "--world-size", 122 | default=-1, 123 | type=int, 124 | help="number of nodes for distributed training", 125 | ) 126 | parser.add_argument( 127 | "--rank", default=-1, type=int, help="node rank for distributed training" 128 | ) 129 | parser.add_argument( 130 | "--dist-url", 131 | default="tcp://224.66.41.62:23456", 132 | type=str, 133 | help="url used to set up distributed training", 134 | ) 135 | parser.add_argument( 136 | "--dist-backend", default="nccl", type=str, help="distributed backend" 137 | ) 138 | parser.add_argument( 139 | "--seed", default=None, type=int, help="seed for initializing training. " 140 | ) 141 | parser.add_argument("--gpu", default=None, type=int, help="GPU id to use.") 142 | parser.add_argument( 143 | "--multiprocessing-distributed", 144 | action="store_true", 145 | help="Use multi-processing distributed training to launch " 146 | "N processes per node, which has N GPUs. This is the " 147 | "fastest way to use PyTorch for either single node or " 148 | "multi node data parallel training", 149 | ) 150 | parser.add_argument( 151 | "--pretrained", default="", type=str, help="path to pretrained checkpoint" 152 | ) 153 | parser.add_argument( 154 | "--start-imbal", 155 | default=0, 156 | type=int, 157 | metavar="N", 158 | ) 159 | parser.add_argument( 160 | "--end-imbal", 161 | default=1000, 162 | type=int, 163 | metavar="N", 164 | ) 165 | parser.add_argument( 166 | "--savedir", 167 | default='/ckpt/path/', 168 | type=str, 169 | metavar="PATH", 170 | help="path to save checkpoint", 171 | ) 172 | 173 | best_acc1 = 0 174 | 175 | 176 | def main(): 177 | args = parser.parse_args() 178 | 179 | if args.seed is not None: 180 | random.seed(args.seed) 181 | torch.manual_seed(args.seed) 182 | cudnn.deterministic = True 183 | warnings.warn( 184 | "You have chosen to seed training. " 185 | "This will turn on the CUDNN deterministic setting, " 186 | "which can slow down your training considerably! " 187 | "You may see unexpected behavior when restarting " 188 | "from checkpoints." 189 | ) 190 | 191 | if args.gpu is not None: 192 | warnings.warn( 193 | "You have chosen a specific GPU. This will completely " 194 | "disable data parallelism." 195 | ) 196 | 197 | if args.dist_url == "env://" and args.world_size == -1: 198 | args.world_size = int(os.environ["WORLD_SIZE"]) 199 | 200 | args.distributed = args.world_size > 1 or args.multiprocessing_distributed 201 | 202 | ngpus_per_node = torch.cuda.device_count() 203 | if args.multiprocessing_distributed: 204 | # Since we have ngpus_per_node processes per node, the total world_size 205 | # needs to be adjusted accordingly 206 | args.world_size = ngpus_per_node * args.world_size 207 | # Use torch.multiprocessing.spawn to launch distributed processes: the 208 | # main_worker process function 209 | mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args)) 210 | else: 211 | # Simply call main_worker function 212 | main_worker(args.gpu, ngpus_per_node, args) 213 | 214 | 215 | def main_worker(gpu, ngpus_per_node, args): 216 | global best_acc1 217 | args.gpu = gpu 218 | 219 | # suppress printing if not master 220 | if args.multiprocessing_distributed and args.gpu != 0: 221 | 222 | def print_pass(*args): 223 | pass 224 | 225 | builtins.print = print_pass 226 | 227 | if args.gpu is not None: 228 | print("Use GPU: {} for training".format(args.gpu)) 229 | 230 | if args.distributed: 231 | if args.dist_url == "env://" and args.rank == -1: 232 | args.rank = int(os.environ["RANK"]) 233 | if args.multiprocessing_distributed: 234 | # For multiprocessing distributed training, rank needs to be the 235 | # global rank among all the processes 236 | args.rank = args.rank * ngpus_per_node + gpu 237 | dist.init_process_group( 238 | backend=args.dist_backend, 239 | init_method=args.dist_url, 240 | world_size=args.world_size, 241 | rank=args.rank, 242 | ) 243 | # create model 244 | print("=> creating model '{}'".format(args.arch)) 245 | 246 | if args.arch.startswith('vit'): 247 | model = vits.__dict__[args.arch](pretrained=True) 248 | linear_keyword = 'head' 249 | elif args.arch.startswith('alexnet'): 250 | model = models.__dict__['alexnet'](pretrained=True) 251 | linear_keyword = 'classifier' 252 | else: 253 | model = models.__dict__[args.arch](pretrained=True) 254 | linear_keyword = 'fc' 255 | print("=> loading default '{}'".format(args.arch)) 256 | 257 | # Warning: simplify alexnet for weak supervisor 258 | if args.arch.startswith('alexnet'): 259 | num_features = model.classifier[1].in_features 260 | num_outputs = model.classifier[6].out_features 261 | model.classifier = nn.Sequential( 262 | nn.Linear(num_features, num_outputs) 263 | ) 264 | print('Replace the classifier head with a new single layer') 265 | 266 | if args.pretrained: 267 | if os.path.isfile(args.pretrained): 268 | print("=> loading checkpoint from '{}'".format(args.pretrained)) 269 | checkpoint = torch.load(args.pretrained, map_location="cpu") 270 | state_dict = checkpoint['state_dict'] 271 | 272 | # Remove all instances of 'module.' from the key 273 | basic_state_dict = {} 274 | for k, v in state_dict.items(): 275 | new_key = k.replace('module.', '') 276 | basic_state_dict[new_key] = v 277 | 278 | msg = model.load_state_dict(basic_state_dict, strict=True) 279 | print(msg) 280 | else: 281 | print("=> no checkpoint found at '{}'".format(args.pretrained)) 282 | 283 | # freeze all layers but the last fc 284 | for name, param in model.named_parameters(): 285 | # if name not in ['%s.weight' % linear_keyword, '%s.bias' % linear_keyword]: 286 | if not name.startswith(linear_keyword): 287 | param.requires_grad = False 288 | else: 289 | print('Unfrozen', name) 290 | if name.split('.')[-1] == 'weight': 291 | param.data.normal_(mean=0.0, std=0.01) 292 | print(f'Init {name} from the normal distribution') 293 | if name.split('.')[-1] == 'bias': 294 | param.data.zero_() 295 | print(f'Init {name} to zero') 296 | 297 | if args.distributed: 298 | # For multiprocessing distributed, DistributedDataParallel constructor 299 | # should always set the single device scope, otherwise, 300 | # DistributedDataParallel will use all available devices. 301 | if args.gpu is not None: 302 | torch.cuda.set_device(args.gpu) 303 | model.cuda(args.gpu) 304 | # When using a single GPU per process and per 305 | # DistributedDataParallel, we need to divide the batch size 306 | # ourselves based on the total number of GPUs we have 307 | args.batch_size = int(args.batch_size / ngpus_per_node) 308 | args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node) 309 | model = torch.nn.parallel.DistributedDataParallel( 310 | model, device_ids=[args.gpu] 311 | ) 312 | else: 313 | model.cuda() 314 | # DistributedDataParallel will divide and allocate batch_size to all 315 | # available GPUs if device_ids are not set 316 | model = torch.nn.parallel.DistributedDataParallel(model) 317 | elif args.gpu is not None: 318 | torch.cuda.set_device(args.gpu) 319 | model = model.cuda(args.gpu) 320 | else: 321 | # DataParallel will divide and allocate batch_size to all available GPUs 322 | if args.arch.startswith("alexnet") or args.arch.startswith("vgg"): 323 | model.features = torch.nn.DataParallel(model.features) 324 | model.cuda() 325 | else: 326 | model = torch.nn.DataParallel(model).cuda() 327 | 328 | # define loss function (criterion) and optimizer 329 | criterion = nn.CrossEntropyLoss().cuda(args.gpu) 330 | 331 | # optimize only the linear classifier 332 | parameters = list(filter(lambda p: p.requires_grad, model.parameters())) 333 | 334 | # if args.arch.startswith('alexnet'): 335 | # assert len(parameters) == 6 # fc.weight, fc.bias 336 | # else: 337 | # assert len(parameters) == 2 # fc.weight, fc.bias 338 | 339 | # optimizer = torch.optim.SGD( 340 | # parameters, args.lr, momentum=args.momentum, weight_decay=args.weight_decay 341 | # ) 342 | 343 | # optionally resume from a checkpoint 344 | if args.resume: 345 | if os.path.isfile(args.resume): 346 | print("=> loading checkpoint '{}'".format(args.resume)) 347 | if args.gpu is None: 348 | checkpoint = torch.load(args.resume) 349 | else: 350 | # Map model to be loaded to specified single gpu. 351 | loc = "cuda:{}".format(args.gpu) 352 | checkpoint = torch.load(args.resume, map_location=loc) 353 | args.start_epoch = checkpoint["epoch"] 354 | best_acc1 = checkpoint["best_acc1"] 355 | if args.gpu is not None: 356 | # best_acc1 may be from a checkpoint from a different GPU 357 | best_acc1 = best_acc1.to(args.gpu) 358 | model.load_state_dict(checkpoint["state_dict"]) 359 | optimizer.load_state_dict(checkpoint["optimizer"]) 360 | print( 361 | "=> loaded checkpoint '{}' (epoch {})".format( 362 | args.resume, checkpoint["epoch"] 363 | ) 364 | ) 365 | else: 366 | print("=> no checkpoint found at '{}'".format(args.resume)) 367 | 368 | cudnn.benchmark = True 369 | 370 | # Data loading code 371 | traindir = os.path.join(args.data, "train") 372 | valdir = os.path.join(args.data, "val") 373 | normalize = transforms.Normalize( 374 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 375 | ) 376 | 377 | train_dataset = datasets.ImageFolder( 378 | traindir, 379 | # valdir, 380 | transforms.Compose( 381 | [ 382 | transforms.RandomResizedCrop(224), 383 | transforms.RandomHorizontalFlip(), 384 | transforms.ToTensor(), 385 | normalize, 386 | ] 387 | ), 388 | ) 389 | 390 | # if args.distributed: 391 | # train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) 392 | # else: 393 | # train_sampler = None 394 | 395 | train_sampler = get_imbal_sampler(train_dataset, start_class=args.start_imbal, end_class=args.end_imbal) 396 | # train_sampler = None 397 | 398 | train_loader = torch.utils.data.DataLoader( 399 | train_dataset, 400 | batch_size=args.batch_size, 401 | shuffle=(train_sampler is None), 402 | num_workers=args.workers, 403 | pin_memory=True, 404 | sampler=train_sampler, 405 | ) 406 | 407 | val_loader = torch.utils.data.DataLoader( 408 | datasets.ImageFolder( 409 | valdir, 410 | transforms.Compose( 411 | [ 412 | transforms.Resize(256), 413 | transforms.CenterCrop(224), 414 | transforms.ToTensor(), 415 | normalize, 416 | ] 417 | ), 418 | ), 419 | batch_size=args.batch_size, 420 | shuffle=False, 421 | num_workers=args.workers, 422 | pin_memory=True, 423 | ) 424 | 425 | if args.evaluate: 426 | validate(val_loader, model, criterion, args) 427 | return 428 | 429 | 430 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr) 431 | n_iter = len(train_loader) * (args.epochs - args.start_epoch) 432 | scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=n_iter) 433 | 434 | for epoch in range(args.start_epoch, args.epochs): 435 | if args.distributed: 436 | train_sampler.set_epoch(epoch) 437 | adjust_learning_rate(optimizer, epoch, args) 438 | 439 | # train for one epoch 440 | train(train_loader, model, criterion, optimizer, scheduler, epoch, args) 441 | 442 | # evaluate on validation set 443 | acc1 = validate(val_loader, model, criterion, args) 444 | 445 | # remember best acc@1 and save checkpoint 446 | is_best = acc1 > best_acc1 447 | best_acc1 = max(acc1, best_acc1) 448 | 449 | # if not args.multiprocessing_distributed or ( 450 | # args.multiprocessing_distributed and args.rank % ngpus_per_node == 0 451 | # ): 452 | # # foldername = f'ckpt/train/{args.start_imbal}_{args.end_imbal}' 453 | # foldername = f'ckpt/val/{args.start_imbal}_{args.end_imbal}' 454 | # if not os.path.exists(foldername): 455 | # os.makedirs(foldername) 456 | # save_checkpoint( 457 | # { 458 | # "epoch": epoch + 1, 459 | # "arch": args.arch, 460 | # "state_dict": model.state_dict(), 461 | # "best_acc1": best_acc1, 462 | # "optimizer": optimizer.state_dict(), 463 | # }, 464 | # fname=os.path.join(foldername,f'ckpt-{epoch}.pth.tar'), 465 | # ) 466 | # if epoch == args.start_epoch: 467 | # sanity_check(model.state_dict(), args.pretrained) 468 | 469 | 470 | def train(train_loader, model, criterion, optimizer, scheduler, epoch, args): 471 | batch_time = AverageMeter("Time", ":6.3f") 472 | data_time = AverageMeter("Data", ":6.3f") 473 | losses = AverageMeter("Loss", ":.4e") 474 | top1 = AverageMeter("Acc@1", ":6.2f") 475 | top5 = AverageMeter("Acc@5", ":6.2f") 476 | progress = ProgressMeter( 477 | len(train_loader), 478 | [batch_time, data_time, losses, top1, top5], 479 | prefix="Epoch: [{}]".format(epoch), 480 | ) 481 | 482 | trainable_param_names = {name for name, param in model.named_parameters() if param.requires_grad} 483 | 484 | """ 485 | Switch to eval mode: 486 | Under the protocol of linear classification on frozen features/models, 487 | it is not legitimate to change any part of the pre-trained model. 488 | BatchNorm in train mode may revise running mean/std (even if it receives 489 | no gradient), which are part of the model parameters too. 490 | """ 491 | model.eval() 492 | 493 | end = time.time() 494 | for i, (images, target) in enumerate(train_loader): 495 | # measure data loading time 496 | data_time.update(time.time() - end) 497 | 498 | if args.gpu is not None: 499 | images = images.cuda(args.gpu, non_blocking=True) 500 | target = target.cuda(args.gpu, non_blocking=True) 501 | 502 | # compute output 503 | output = model(images) 504 | loss = criterion(output, target) 505 | 506 | # measure accuracy and record loss 507 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 508 | losses.update(loss.item(), images.size(0)) 509 | top1.update(acc1[0], images.size(0)) 510 | top5.update(acc5[0], images.size(0)) 511 | 512 | # compute gradient and do SGD step 513 | optimizer.zero_grad() 514 | loss.backward() 515 | optimizer.step() 516 | scheduler.step() 517 | 518 | # measure elapsed time 519 | batch_time.update(time.time() - end) 520 | end = time.time() 521 | 522 | if i % args.print_freq == 0: 523 | progress.display(i) 524 | 525 | if (i % (args.print_freq * 1000) == 0) or (epoch < 1 and i < 1000 and i % (args.print_freq * 10) == 0) or (epoch < 1 and i % (args.print_freq * 100) == 0): 526 | foldername = args.savedir + f'/{args.start_imbal}_{args.end_imbal}' 527 | if not os.path.exists(foldername): 528 | os.makedirs(foldername) 529 | 530 | # Get names of trainable parameters 531 | trainable_state_dict = {name: param for name, param in model.state_dict().items() if name in trainable_param_names} 532 | # pdb.set_trace() 533 | ckptname = os.path.join(foldername, f'head-{epoch}-{i}.pth.tar') 534 | torch.save(trainable_state_dict, ckptname) 535 | # save_checkpoint( 536 | # { 537 | # "epoch": epoch + 1, 538 | # "arch": args.arch, 539 | # "state_dict": model.state_dict(), 540 | # "best_acc1": best_acc1, 541 | # "optimizer": optimizer.state_dict(), 542 | # }, 543 | # fname = os.path.join(foldername, f'ckpt-{epoch}-{i}.pth.tar'), 544 | # ) 545 | 546 | def validate(val_loader, model, criterion, args): 547 | batch_time = AverageMeter("Time", ":6.3f") 548 | losses = AverageMeter("Loss", ":.4e") 549 | top1 = AverageMeter("Acc@1", ":6.2f") 550 | top5 = AverageMeter("Acc@5", ":6.2f") 551 | progress = ProgressMeter( 552 | len(val_loader), [batch_time, losses, top1, top5], prefix="Test: " 553 | ) 554 | 555 | # switch to evaluate mode 556 | model.eval() 557 | 558 | with torch.no_grad(): 559 | end = time.time() 560 | for i, (images, target) in enumerate(val_loader): 561 | if args.gpu is not None: 562 | images = images.cuda(args.gpu, non_blocking=True) 563 | target = target.cuda(args.gpu, non_blocking=True) 564 | 565 | # compute output 566 | output = model(images) 567 | loss = criterion(output, target) 568 | 569 | # measure accuracy and record loss 570 | acc1, acc5 = accuracy(output, target, topk=(1, 5)) 571 | losses.update(loss.item(), images.size(0)) 572 | top1.update(acc1[0], images.size(0)) 573 | top5.update(acc5[0], images.size(0)) 574 | 575 | # measure elapsed time 576 | batch_time.update(time.time() - end) 577 | end = time.time() 578 | 579 | if i % args.print_freq == 0: 580 | progress.display(i) 581 | 582 | # TODO: this should also be done with the ProgressMeter 583 | print( 584 | " * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}".format(top1=top1, top5=top5) 585 | ) 586 | 587 | return top1.avg 588 | 589 | def save_checkpoint(state, fname='ckpt.pth.tar'): 590 | torch.save(state, fname) 591 | 592 | # def sanity_check(state_dict, pretrained_weights): 593 | # """ 594 | # Linear classifier should not change any weights other than the linear layer. 595 | # This sanity check asserts nothing wrong happens (e.g., BN stats updated). 596 | # """ 597 | # print("=> loading '{}' for sanity check".format(pretrained_weights)) 598 | # checkpoint = torch.load(pretrained_weights, map_location="cpu") 599 | # state_dict_pre = checkpoint["state_dict"] 600 | 601 | # for k in list(state_dict.keys()): 602 | # # only ignore fc layer 603 | # if "fc.weight" in k or "fc.bias" in k: 604 | # continue 605 | 606 | # # name in pretrained model 607 | # k_pre = ( 608 | # "module.encoder_q." + k[len("module.") :] 609 | # if k.startswith("module.") 610 | # else "module.encoder_q." + k 611 | # ) 612 | 613 | # assert ( 614 | # state_dict[k].cpu() == state_dict_pre[k_pre] 615 | # ).all(), "{} is changed in linear classifier training.".format(k) 616 | 617 | # print("=> sanity check passed.") 618 | 619 | 620 | class AverageMeter: 621 | """Computes and stores the average and current value""" 622 | 623 | def __init__(self, name, fmt=":f"): 624 | self.name = name 625 | self.fmt = fmt 626 | self.reset() 627 | 628 | def reset(self): 629 | self.val = 0 630 | self.avg = 0 631 | self.sum = 0 632 | self.count = 0 633 | 634 | def update(self, val, n=1): 635 | self.val = val 636 | self.sum += val * n 637 | self.count += n 638 | self.avg = self.sum / self.count 639 | 640 | def __str__(self): 641 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 642 | return fmtstr.format(**self.__dict__) 643 | 644 | 645 | class ProgressMeter: 646 | def __init__(self, num_batches, meters, prefix=""): 647 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 648 | self.meters = meters 649 | self.prefix = prefix 650 | 651 | def display(self, batch): 652 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 653 | entries += [str(meter) for meter in self.meters] 654 | print("\t".join(entries)) 655 | 656 | def _get_batch_fmtstr(self, num_batches): 657 | num_digits = len(str(num_batches // 1)) 658 | fmt = "{:" + str(num_digits) + "d}" 659 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 660 | 661 | 662 | def adjust_learning_rate(optimizer, epoch, args): 663 | """Decay the learning rate based on schedule""" 664 | lr = args.lr 665 | for milestone in args.schedule: 666 | lr *= 0.1 if epoch >= milestone else 1.0 667 | for param_group in optimizer.param_groups: 668 | param_group["lr"] = lr 669 | 670 | 671 | def accuracy(output, target, topk=(1,)): 672 | """Computes the accuracy over the k top predictions for the specified values of k""" 673 | with torch.no_grad(): 674 | maxk = max(topk) 675 | batch_size = target.size(0) 676 | 677 | _, pred = output.topk(maxk, 1, True, True) 678 | pred = pred.t() 679 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 680 | 681 | res = [] 682 | for k in topk: 683 | correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) 684 | res.append(correct_k.mul_(100.0 / batch_size)) 685 | return res 686 | 687 | 688 | if __name__ == "__main__": 689 | main() 690 | --------------------------------------------------------------------------------