├── GRAPH-ATTENTION ├── README_simclr.md ├── configs │ ├── cifar_eval.yaml │ ├── cifar_train_epochs200_bs512.yaml │ ├── imagenet_eval.yaml │ └── imagenet_train_epochs100_bs512.yaml ├── environment.yml ├── linear.py ├── models │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── encoder.cpython-36.pyc │ │ ├── encoder.cpython-38.pyc │ │ ├── losses.cpython-36.pyc │ │ ├── losses.cpython-38.pyc │ │ ├── resnet.cpython-36.pyc │ │ ├── resnet.cpython-38.pyc │ │ ├── ssl.cpython-36.pyc │ │ └── ssl.cpython-38.pyc │ ├── encoder.py │ ├── losses.py │ ├── resnet.py │ └── ssl.py ├── myexman │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── index.cpython-36.pyc │ │ ├── index.cpython-38.pyc │ │ ├── parser.cpython-36.pyc │ │ └── parser.cpython-38.pyc │ ├── index.py │ └── parser.py ├── train.py └── utils │ ├── __pycache__ │ ├── datautils.cpython-36.pyc │ ├── datautils.cpython-38.pyc │ ├── lars_optimizer.cpython-36.pyc │ ├── lars_optimizer.cpython-38.pyc │ ├── logger.cpython-36.pyc │ ├── logger.cpython-38.pyc │ ├── utils.cpython-36.pyc │ └── utils.cpython-38.pyc │ ├── datautils.py │ ├── lars_optimizer.py │ ├── logger.py │ └── utils.py ├── MULTI-STAGE-GRAPH-AGGREGATION ├── =0.5.0 ├── README_simsiam.md ├── main_pretrain.py ├── requirements.txt ├── scripts │ ├── linear │ │ └── simsiam_linear.sh │ ├── pretrain │ │ ├── cifar │ │ │ └── simsiam.sh │ │ └── imagenet100 │ │ │ └── simsiam.sh │ └── utils │ │ └── convert_imgfolder_to_h5.py ├── setup.py ├── solo │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-38.pyc │ ├── args │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── dataset.cpython-38.pyc │ │ │ ├── setup.cpython-38.pyc │ │ │ └── utils.cpython-38.pyc │ │ ├── dataset.py │ │ ├── setup.py │ │ └── utils.py │ ├── backbones │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ └── __init__.cpython-38.pyc │ │ └── resnet │ │ │ ├── __init__.py │ │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ └── resnet.cpython-38.pyc │ │ │ └── resnet.py │ ├── data │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── classification_dataloader.cpython-38.pyc │ │ │ ├── dali_dataloader.cpython-38.pyc │ │ │ ├── h5_dataset.cpython-38.pyc │ │ │ └── pretrain_dataloader.cpython-38.pyc │ │ ├── classification_dataloader.py │ │ ├── dali_dataloader.py │ │ ├── dataset_subset │ │ │ └── imagenet100_classes.txt │ │ ├── h5_dataset.py │ │ └── pretrain_dataloader.py │ ├── losses │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── barlow.cpython-38.pyc │ │ │ ├── byol.cpython-38.pyc │ │ │ ├── deepclusterv2.cpython-38.pyc │ │ │ ├── dino.cpython-38.pyc │ │ │ ├── mae.cpython-38.pyc │ │ │ ├── mocov2plus.cpython-38.pyc │ │ │ ├── mocov3.cpython-38.pyc │ │ │ ├── nnclr.cpython-38.pyc │ │ │ ├── ressl.cpython-38.pyc │ │ │ ├── simclr.cpython-38.pyc │ │ │ ├── simsiam.cpython-38.pyc │ │ │ ├── swav.cpython-38.pyc │ │ │ ├── vibcreg.cpython-38.pyc │ │ │ ├── vicreg.cpython-38.pyc │ │ │ └── wmse.cpython-38.pyc │ │ └── simsiam.py │ ├── methods │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-38.pyc │ │ │ ├── barlow_twins.cpython-38.pyc │ │ │ ├── base.cpython-38.pyc │ │ │ ├── byol.cpython-38.pyc │ │ │ ├── deepclusterv2.cpython-38.pyc │ │ │ ├── dino.cpython-38.pyc │ │ │ ├── linear.cpython-38.pyc │ │ │ ├── mae.cpython-38.pyc │ │ │ ├── mocov2plus.cpython-38.pyc │ │ │ ├── mocov3.cpython-38.pyc │ │ │ ├── nnbyol.cpython-38.pyc │ │ │ ├── nnclr.cpython-38.pyc │ │ │ ├── nnsiam.cpython-38.pyc │ │ │ ├── ressl.cpython-38.pyc │ │ │ ├── simclr.cpython-38.pyc │ │ │ ├── simsiam.cpython-38.pyc │ │ │ ├── supcon.cpython-38.pyc │ │ │ ├── swav.cpython-38.pyc │ │ │ ├── vibcreg.cpython-38.pyc │ │ │ ├── vicreg.cpython-38.pyc │ │ │ └── wmse.cpython-38.pyc │ │ ├── base.py │ │ ├── linear.py │ │ └── simsiam.py │ └── utils │ │ ├── __init__.py │ │ ├── __pycache__ │ │ ├── __init__.cpython-38.pyc │ │ ├── auto_resumer.cpython-38.pyc │ │ ├── auto_umap.cpython-38.pyc │ │ ├── checkpointer.cpython-38.pyc │ │ ├── kmeans.cpython-38.pyc │ │ ├── knn.cpython-38.pyc │ │ ├── lars.cpython-38.pyc │ │ ├── metrics.cpython-38.pyc │ │ ├── misc.cpython-38.pyc │ │ ├── momentum.cpython-38.pyc │ │ ├── sinkhorn_knopp.cpython-38.pyc │ │ └── whitening.cpython-38.pyc │ │ ├── auto_resumer.py │ │ ├── auto_umap.py │ │ ├── checkpointer.py │ │ ├── kmeans.py │ │ ├── knn.py │ │ ├── lars.py │ │ ├── metrics.py │ │ ├── misc.py │ │ ├── momentum.py │ │ ├── sinkhorn_knopp.py │ │ └── whitening.py └── zoo │ ├── cifar10.sh │ ├── cifar100.sh │ ├── imagenet.sh │ └── imagenet100.sh └── README.md /GRAPH-ATTENTION/README_simclr.md: -------------------------------------------------------------------------------- 1 | # SimCLR 2 | 3 | 4 | 5 | ## Enviroment Setup 6 | 7 | 8 | Create a python enviroment with the provided config file and [miniconda](https://docs.conda.io/en/latest/miniconda.html): 9 | 10 | ```(bash) 11 | conda env create -f environment.yml 12 | conda activate simclr_pytorch 13 | 14 | export IMAGENET_PATH=... # If you have enough RAM using /dev/shm usually accelerates data loading time 15 | export EXMAN_PATH=... # A path to logs 16 | ``` 17 | 18 | ## Training 19 | Model training consists of two steps: (1) self-supervised encoder pretraining and (2) classifier learning with the encoder representations. Both steps are done with the `train.py` script. To see the help for `sim-clr/eval` problem call the following command: `python source/train.py --help --problem sim-clr/eval`. 20 | 21 | ### Self-supervised pretraining 22 | 23 | #### CIFAR-10 24 | The config `cifar_train_epochs200_bs512.yaml` contains the parameters to reproduce results for CIFAR dataset. The pretraining command is: 25 | 26 | ```(bash) 27 | python train.py --config configs/cifar_train_epochs200_bs512.yaml 28 | ``` 29 | 30 | #### ImageNet-100 31 | The configs `imagenet_params_epochs100_bs512.yaml` contain the parameters to reproduce results for ImageNet-100 dataset. The single-node (4 v100 GPUs) pretraining command is: 32 | 33 | ```(bash) 34 | python train.py --config configs/imagenet_train_epochs100_bs512.yaml 35 | ``` 36 | 37 | #### Logs 38 | The logs and the model will be stored at `./logs/exman-train.py/runs//`. You can access all the experiments from python with `exman.Index('./logs/exman-train.py').info()`. 39 | 40 | 41 | ### Linear Evaluation 42 | To train a linear classifier on top of the pretrained encoder, run the following command: 43 | 44 | ```(bash) 45 | python train.py --config configs/cifar_eval.yaml --encoder_ckpt 46 | ``` 47 | 48 | 49 | ### Pretraining with `DistributedDataParallel` 50 | To train a model with larger batch size on several nodes you need to set `--dist ddp` flag and specify the following parameters: 51 | - `--dist_address`: the address and a port of the main node in the `
:` format 52 | - `--node_rank`: 0 for the main node and 1,... for the others. 53 | - `--world_size`: the number of nodes. 54 | 55 | For example, to train with two nodes you need to run the following command on the main node: 56 | ```(bash) 57 | python train.py --config configs/cifar_train_epochs200_bs512.yaml --dist ddp --dist_address
: --node_rank 0 --world_size 2 58 | ``` 59 | and on the second node: 60 | ```(bash) 61 | python train.py --config configs/cifar_train_epochs200_bs512.yaml --dist ddp --dist_address
: --node_rank 1 --world_size 2 62 | ``` 63 | 64 | The ImageNet the pretaining on 4 nodes all with 4 GPUs looks as follows: 65 | ``` 66 | node1: python train.py --config configs/imagenet_train_epochs100_bs512.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 0 67 | node2: python train.py --config configs/imagenet_train_epochs100_bs512.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 1 68 | node3: python train.py --config configs/imagenet_train_epochs100_bs512.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 2 69 | node4: python train.py --config configs/imagenet_train_epochs100_bs512.yaml --dist ddp --world_size 4 --dist_address
: --node_rank 3 70 | ``` 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/configs/cifar_eval.yaml: -------------------------------------------------------------------------------- 1 | arch: linear 2 | aug: true 3 | augmentation: RandomCrop 4 | batch_size: 1024 5 | ckpt: '' 6 | ckpt_iter: -1 7 | config_file: null 8 | data: cifar 9 | dist: dp 10 | dist_address: 127.0.0.1:1234 11 | encode_layer: h 12 | encoder_ckpt: '' 13 | eval_freq: 1000 14 | finetune: false 15 | iters: 8000 16 | log_freq: 100 17 | lr: 0.1 18 | lr_schedule: linear 19 | model_id: -1 20 | n_augs_test: 50 21 | n_augs_train: 10 22 | name: '' 23 | node_rank: 0 24 | opt: sgd 25 | precompute_emb_bs: -1 26 | problem: eval 27 | root: '' 28 | save_freq: 100000000 29 | scale_lower: 0.08 30 | seed: -1 31 | status: fail 32 | test_bs: 1024 33 | tmp: false 34 | warmup: 0.0 35 | weight_decay: 0.0001 36 | workers: 2 37 | world_size: 1 38 | 39 | time: '2020-07-18T00:55:23' 40 | id: 3549 41 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/configs/cifar_train_epochs200_bs512.yaml: -------------------------------------------------------------------------------- 1 | arch: resnet18 2 | aug: true 3 | batch_size: 512 4 | ckpt: '' 5 | color_dist_s: 0.5 6 | config_file: null 7 | data: cifar 8 | dist: dp 9 | dist_address: '127.0.0.1:1234' 10 | eval_freq: 4800 11 | iters: 19200 12 | log_freq: 48 13 | lr: 1.0 14 | lr_schedule: warmup-anneal 15 | multiplier: 2 16 | name: 'reproduce-cifar10' 17 | node_rank: 0 18 | opt: lars 19 | problem: sim-clr 20 | root: 'none' 21 | save_freq: 4800 22 | scale_lower: 0.08 23 | seed: -1 24 | sync_bn: true 25 | tmp: false 26 | verbose: true 27 | warmup: 0.01 28 | weight_decay: 1.0e-06 29 | workers: 2 30 | world_size: 1 31 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/configs/imagenet_eval.yaml: -------------------------------------------------------------------------------- 1 | arch: linear 2 | aug: true 3 | augmentation: RandomResizedCrop 4 | batch_size: 4096 5 | ckpt: '' 6 | ckpt_iter: -1 7 | config_file: null 8 | data: imagenet 9 | dist: dp 10 | dist_address: '' 11 | encode_layer: h 12 | encoder_ckpt: '' 13 | eval_freq: 100 14 | finetune: false 15 | iters: 28080 16 | log_freq: 1000 17 | lr: 1.6 18 | lr_schedule: linear 19 | model_id: -1 20 | name: eval_imagenet_newmodels 21 | node_rank: 0 22 | opt: sgd 23 | precompute_emb_bs: -1 24 | problem: eval 25 | save_freq: 10000000000000000 26 | scale_lower: 0.08 27 | seed: -1 28 | test_bs: 4096 29 | tmp: false 30 | warmup: 0.0 31 | weight_decay: 0.0 32 | workers: 20 33 | world_size: 1 34 | dist_address: '127.0.0.1:1234' 35 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/configs/imagenet_train_epochs100_bs512.yaml: -------------------------------------------------------------------------------- 1 | arch: ResNet50 2 | aug: true 3 | batch_size: 512 4 | ckpt: '' 5 | color_dist_s: 1.0 6 | config_file: '' 7 | data: imagenet 8 | dist: dp 9 | dist_address: '127.0.0.1:1234' 10 | eval_freq: 50040 11 | gpu: 0 12 | iters: 125100 13 | log_freq: 100 14 | lr: 1.2 15 | lr_schedule: warmup-anneal 16 | multiplier: 2 17 | name: imagenet-reproduce 18 | node_rank: 0 19 | opt: lars 20 | problem: sim-clr 21 | root: '' 22 | save_freq: 12510 23 | scale_lower: 0.08 24 | seed: -1 25 | sync_bn: true 26 | tmp: false 27 | verbose: true 28 | temperature: 0.1 29 | warmup: 0.1 30 | weight_decay: 1.0e-06 31 | workers: 8 32 | world_size: 1 33 | d_ratio: 1 -------------------------------------------------------------------------------- /GRAPH-ATTENTION/environment.yml: -------------------------------------------------------------------------------- 1 | name: simclr_pytorch 2 | channels: 3 | - pytorch 4 | - defaults 5 | dependencies: 6 | - _libgcc_mutex=0.1=main 7 | - backcall=0.2.0=py_0 8 | - blas=1.0=mkl 9 | - ca-certificates=2020.10.14=0 10 | - certifi=2020.12.5=py36h06a4308_0 11 | - configargparse=1.2.3=py_0 12 | - cudatoolkit=10.1.243=h6bb024c_0 13 | - dataclasses=0.7=py36_0 14 | - decorator=4.4.2=py_0 15 | - filelock=3.0.12=py_0 16 | - freetype=2.10.4=h5ab3b9f_0 17 | - intel-openmp=2020.2=254 18 | - ipython=7.16.1=py36h5ca1d4c_0 19 | - ipython_genutils=0.2.0=pyhd3eb1b0_1 20 | - jedi=0.17.2=py36h06a4308_1 21 | - joblib=0.17.0=py_0 22 | - jpeg=9b=h024ee3a_2 23 | - lcms2=2.11=h396b838_0 24 | - ld_impl_linux-64=2.33.1=h53a641e_7 25 | - libedit=3.1.20191231=h14c3975_1 26 | - libffi=3.3=he6710b0_2 27 | - libgcc-ng=9.1.0=hdf63c60_0 28 | - libgfortran-ng=7.3.0=hdf63c60_0 29 | - libpng=1.6.37=hbc83047_0 30 | - libstdcxx-ng=9.1.0=hdf63c60_0 31 | - libtiff=4.1.0=h2733197_1 32 | - libuv=1.40.0=h7b6447c_0 33 | - lz4-c=1.9.2=heb0550a_3 34 | - mkl=2020.2=256 35 | - mkl-service=2.3.0=py36he8ac12f_0 36 | - mkl_fft=1.2.0=py36h23d657b_0 37 | - mkl_random=1.1.1=py36h0573a6f_0 38 | - ncurses=6.2=he6710b0_1 39 | - ninja=1.10.2=py36hff7bd54_0 40 | - numpy=1.19.2=py36h54aff64_0 41 | - numpy-base=1.19.2=py36hfa32c7d_0 42 | - olefile=0.46=py36_0 43 | - openssl=1.1.1h=h7b6447c_0 44 | - pandas=1.1.3=py36he6710b0_0 45 | - parso=0.7.0=py_0 46 | - pexpect=4.8.0=pyhd3eb1b0_3 47 | - pickleshare=0.7.5=pyhd3eb1b0_1003 48 | - pillow=8.0.1=py36he98fc37_0 49 | - pip=20.3.1=py36h06a4308_0 50 | - prompt-toolkit=3.0.8=py_0 51 | - ptyprocess=0.6.0=pyhd3eb1b0_2 52 | - pygments=2.7.3=pyhd3eb1b0_0 53 | - python=3.6.12=hcff3b4d_2 54 | - python-dateutil=2.8.1=py_0 55 | - pytorch=1.7.0=py3.6_cuda10.1.243_cudnn7.6.3_0 56 | - pytz=2020.4=pyhd3eb1b0_0 57 | - pyyaml=5.3.1=py36h7b6447c_1 58 | - readline=8.0=h7b6447c_0 59 | - scikit-learn=0.23.2=py36h0573a6f_0 60 | - scipy=1.5.2=py36h0b6359f_0 61 | - setuptools=51.0.0=py36h06a4308_2 62 | - six=1.15.0=py36h06a4308_0 63 | - sqlite=3.33.0=h62c20be_0 64 | - tabulate=0.8.7=py36_0 65 | - threadpoolctl=2.1.0=pyh5ca1d4c_0 66 | - tk=8.6.10=hbc83047_0 67 | - torchaudio=0.7.0=py36 68 | - torchvision=0.8.1=py36_cu101 69 | - tqdm=4.54.1=pyhd3eb1b0_0 70 | - traitlets=4.3.3=py36_0 71 | - typing_extensions=3.7.4.3=py_0 72 | - wcwidth=0.2.5=py_0 73 | - wheel=0.36.1=pyhd3eb1b0_0 74 | - xz=5.2.5=h7b6447c_0 75 | - yaml=0.2.5=h7b6447c_0 76 | - zlib=1.2.11=h7b6447c_3 77 | - zstd=1.4.5=h9ceee32_0 78 | - pip: 79 | - diffdist==0.1 80 | - strconv==0.4.2 81 | prefix: /home/aashukha/miniconda3/envs/simclr_pytorch 82 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__init__.py: -------------------------------------------------------------------------------- 1 | from models import encoder 2 | from models import losses 3 | from models import resnet 4 | from models import ssl 5 | 6 | REGISTERED_MODELS = { 7 | 'sim-clr': ssl.SimCLR, 8 | 'eval': ssl.SSLEval, 9 | 'semi-supervised-eval': ssl.SemiSupervisedEval, 10 | } 11 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/encoder.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/encoder.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/encoder.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/losses.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/losses.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/losses.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/losses.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/resnet.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/resnet.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/ssl.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/ssl.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/__pycache__/ssl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/models/__pycache__/ssl.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/encoder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import models 4 | from collections import OrderedDict 5 | from argparse import Namespace 6 | import yaml 7 | import os 8 | 9 | 10 | class BatchNorm1dNoBias(nn.BatchNorm1d): 11 | def __init__(self, *args, **kwargs): 12 | super().__init__(*args, **kwargs) 13 | self.bias.requires_grad = False 14 | 15 | 16 | class EncodeProject(nn.Module): 17 | def __init__(self, hparams): 18 | super().__init__() 19 | 20 | if hparams.arch == 'ResNet50': 21 | cifar_head = (hparams.data == 'cifar') 22 | self.convnet = models.resnet.ResNet50(cifar_head=cifar_head, hparams=hparams) 23 | self.encoder_dim = 2048 24 | elif hparams.arch == 'resnet18': 25 | self.convnet = models.resnet.ResNet18(cifar_head=(hparams.data == 'cifar')) 26 | self.encoder_dim = 512 27 | else: 28 | raise NotImplementedError 29 | 30 | num_params = sum(p.numel() for p in self.convnet.parameters() if p.requires_grad) 31 | 32 | print(f'======> Encoder: output dim {self.encoder_dim} | {num_params/1e6:.3f}M parameters') 33 | 34 | self.proj_dim = 128 35 | projection_layers = [ 36 | ('fc1', nn.Linear(self.encoder_dim, self.encoder_dim, bias=False)), 37 | ('bn1', nn.BatchNorm1d(self.encoder_dim)), 38 | ('relu1', nn.ReLU()), 39 | ('fc2', nn.Linear(self.encoder_dim, 128, bias=False)), 40 | ('bn2', BatchNorm1dNoBias(128)), 41 | ] 42 | 43 | self.projection = nn.Sequential(OrderedDict(projection_layers)) 44 | 45 | def forward(self, x, out='z'): 46 | h = self.convnet(x) 47 | if out == 'h': 48 | return h 49 | return self.projection(h),h 50 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import numpy as np 4 | import torch.nn.functional as F 5 | import diffdist 6 | import torch.distributed as dist 7 | 8 | 9 | def gather(z): 10 | gather_z = [torch.zeros_like(z) for _ in range(torch.distributed.get_world_size())] 11 | gather_z = diffdist.functional.all_gather(gather_z, z) 12 | gather_z = torch.cat(gather_z) 13 | 14 | return gather_z 15 | 16 | 17 | def accuracy(logits, labels, k): 18 | topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] 19 | labels = torch.sort(labels, 1)[0] 20 | acc = (topk == labels).all(1).float() 21 | return acc 22 | 23 | 24 | def mean_cumulative_gain(logits, labels, k): 25 | topk = torch.sort(logits.topk(k, dim=1)[1], 1)[0] 26 | labels = torch.sort(labels, 1)[0] 27 | mcg = (topk == labels).float().mean(1) 28 | return mcg 29 | 30 | 31 | def mean_average_precision(logits, labels, k): 32 | # TODO: not the fastest solution but looks fine 33 | argsort = torch.argsort(logits, dim=1, descending=True) 34 | labels_to_sorted_idx = torch.sort(torch.gather(torch.argsort(argsort, dim=1), 1, labels), dim=1)[0] + 1 35 | precision = (1 + torch.arange(k, device=logits.device).float()) / labels_to_sorted_idx 36 | return precision.sum(1) / k 37 | 38 | 39 | class NTXent(nn.Module): 40 | """ 41 | Contrastive loss with distributed data parallel support 42 | """ 43 | LARGE_NUMBER = float(1e9) 44 | 45 | def __init__(self, tau=1., gpu=None, multiplier=2, distributed=False): 46 | super().__init__() 47 | self.tau = tau 48 | self.multiplier = multiplier 49 | self.distributed = distributed 50 | self.norm = 1. 51 | 52 | def forward(self, z,z_2,cur_iter,get_map=False): 53 | n = z.shape[0] 54 | assert n % self.multiplier == 0 55 | 56 | z_2 = F.normalize(z_2,p=2,dim=1) 57 | z = F.normalize(z, p=2, dim=1) / np.sqrt(self.tau) 58 | 59 | if self.distributed: 60 | z_list = [torch.zeros_like(z) for _ in range(dist.get_world_size())] 61 | # all_gather fills the list as [, , ...] 62 | # TODO: try to rewrite it with pytorch official tools 63 | z_list = diffdist.functional.all_gather(z_list, z) 64 | # split it into [, , ..., , , ...] 65 | z_list = [chunk for x in z_list for chunk in x.chunk(self.multiplier)] 66 | # sort it to [, , ...] that simply means [, , ...] as expected below 67 | z_sorted = [] 68 | for m in range(self.multiplier): 69 | for i in range(dist.get_world_size()): 70 | z_sorted.append(z_list[i * self.multiplier + m]) 71 | z = torch.cat(z_sorted, dim=0) 72 | n = z.shape[0] 73 | 74 | logits = z @ z.t() 75 | logits[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER 76 | 77 | # choose all positive objects for an example, for i it would be (i + k * n/m), where k=0...(m-1) 78 | m = self.multiplier 79 | labels = (np.repeat(np.arange(n), m) + np.tile(np.arange(m) * n//m, n)) % n 80 | # remove labels pointet to itself, i.e. (i, i) 81 | labels = labels.reshape(n, m)[:, 1:].reshape(-1) 82 | 83 | 84 | 85 | 86 | logprob = F.log_softmax(logits, dim=1) 87 | 88 | sims = z_2 @ z_2.t() 89 | sims = sims+1 90 | sims = torch.exp(sims) 91 | 92 | weight_sum=n 93 | if cur_iter>=0: 94 | logprob = torch.mul(logprob,sims.detach().cuda()) 95 | weight_sum = sims[np.repeat(np.arange(n), m-1), labels].sum().item() 96 | 97 | # TODO: maybe different terms for each process should only be computed here... 98 | loss = -logprob[np.repeat(np.arange(n), m-1), labels].sum() / weight_sum / (m-1) / self.norm 99 | 100 | # zero the probability of identical pairs 101 | pred = logprob.data.clone() 102 | pred[np.arange(n), np.arange(n)] = -self.LARGE_NUMBER 103 | acc = accuracy(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1) 104 | 105 | if get_map: 106 | _map = mean_average_precision(pred, torch.LongTensor(labels.reshape(n, m-1)).to(logprob.device), m-1) 107 | return loss, acc, _map 108 | 109 | return loss, acc 110 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/models/resnet.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # Copyright (c) Facebook, Inc. and its affiliates. 3 | # All rights reserved. 4 | # 5 | # This source code is licensed under the license found in the 6 | # LICENSE file in the root directory of this source tree. 7 | # 8 | 9 | import torch.nn as nn 10 | import torchvision.models as models 11 | import torch 12 | 13 | 14 | class Flatten(nn.Module): 15 | def __init__(self, dim=-1): 16 | super(Flatten, self).__init__() 17 | self.dim = dim 18 | 19 | def forward(self, feat): 20 | return torch.flatten(feat, start_dim=self.dim) 21 | 22 | 23 | class ResNetEncoder(models.resnet.ResNet): 24 | """Wrapper for TorchVison ResNet Model 25 | This was needed to remove the final FC Layer from the ResNet Model""" 26 | def __init__(self, block, layers, cifar_head=False, hparams=None): 27 | super().__init__(block, layers) 28 | self.cifar_head = cifar_head 29 | if cifar_head: 30 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False) 31 | self.bn1 = self._norm_layer(64) 32 | self.relu = nn.ReLU(inplace=True) 33 | self.hparams = hparams 34 | 35 | print('** Using avgpool **') 36 | 37 | def forward(self, x): 38 | x = self.conv1(x) 39 | x = self.bn1(x) 40 | x = self.relu(x) 41 | if not self.cifar_head: 42 | x = self.maxpool(x) 43 | 44 | x = self.layer1(x) 45 | x = self.layer2(x) 46 | x = self.layer3(x) 47 | x = self.layer4(x) 48 | 49 | x = self.avgpool(x) 50 | x = torch.flatten(x, 1) 51 | 52 | return x 53 | 54 | class ResNet18(ResNetEncoder): 55 | def __init__(self, cifar_head=True): 56 | super().__init__(models.resnet.BasicBlock, [2, 2, 2, 2], cifar_head=cifar_head) 57 | 58 | 59 | class ResNet50(ResNetEncoder): 60 | def __init__(self, cifar_head=True, hparams=None): 61 | super().__init__(models.resnet.Bottleneck, [3, 4, 6, 3], cifar_head=cifar_head, hparams=hparams) 62 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__init__.py: -------------------------------------------------------------------------------- 1 | from .parser import ( 2 | ExParser, 3 | simpleroot 4 | ) 5 | from .index import ( 6 | Index 7 | ) 8 | from . import index 9 | from . import parser 10 | __version__ = '0.0.2' 11 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/myexman/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/myexman/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__pycache__/index.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/myexman/__pycache__/index.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__pycache__/index.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/myexman/__pycache__/index.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__pycache__/parser.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/myexman/__pycache__/parser.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/__pycache__/parser.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/myexman/__pycache__/parser.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/index.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import pandas as pd 3 | import pathlib 4 | import strconv 5 | import json 6 | import functools 7 | import datetime 8 | from . import parser 9 | import yaml 10 | from argparse import Namespace 11 | __all__ = [ 12 | 'Index' 13 | ] 14 | 15 | 16 | def only_value_error(conv): 17 | @functools.wraps(conv) 18 | def new_conv(value): 19 | try: 20 | return conv(value) 21 | except Exception as e: 22 | raise ValueError from e 23 | return new_conv 24 | 25 | 26 | def none2none(none): 27 | if none is None: 28 | return None 29 | else: 30 | raise ValueError 31 | 32 | 33 | converter = strconv.Strconv(converters=[ 34 | ('int', strconv.convert_int), 35 | ('float', strconv.convert_float), 36 | ('bool', only_value_error(parser.str2bool)), 37 | ('time', strconv.convert_time), 38 | ('datetime', strconv.convert_datetime), 39 | ('datetime1', lambda time: datetime.datetime.strptime(time, parser.TIME_FORMAT)), 40 | ('date', strconv.convert_date), 41 | ('json', only_value_error(json.loads)), 42 | ]) 43 | 44 | 45 | def get_args(path): 46 | with open(path, 'rb') as f: 47 | return Namespace(**yaml.load(f)) 48 | 49 | 50 | class Index(object): 51 | def __init__(self, root): 52 | self.root = pathlib.Path(root) 53 | 54 | @property 55 | def index(self): 56 | return self.root / 'index' 57 | 58 | @property 59 | def marked(self): 60 | return self.root / 'marked' 61 | 62 | def info(self, source=None, nlast=None): 63 | if source is None: 64 | source = self.index 65 | files = source.iterdir() 66 | if nlast is not None: 67 | files = sorted(list(files))[-nlast:] 68 | else: 69 | source = self.marked / source 70 | files = source.glob('**/*/'+parser.PARAMS_FILE) 71 | 72 | def get_dict(cfg): 73 | return configargparse.YAMLConfigFileParser().parse(cfg.open('r')) 74 | 75 | def convert_column(col): 76 | if any(isinstance(v, str) for v in converter.convert_series(col)): 77 | return col 78 | else: 79 | return pd.Series(converter.convert_series(col), name=col.name, index=col.index) 80 | try: 81 | df = (pd.DataFrame 82 | .from_records((get_dict(c) for c in files)) 83 | .apply(lambda s: convert_column(s)) 84 | .sort_values('id') 85 | .assign(root=lambda _: _.root.apply(self.root.__truediv__)) 86 | .reset_index(drop=True)) 87 | cols = df.columns.tolist() 88 | cols.insert(0, cols.pop(cols.index('id'))) 89 | return df.reindex(columns=cols) 90 | except FileNotFoundError as e: 91 | raise KeyError(source.name) from e 92 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/myexman/parser.py: -------------------------------------------------------------------------------- 1 | import configargparse 2 | import argparse 3 | import pathlib 4 | import datetime 5 | import yaml 6 | import yaml.representer 7 | import os 8 | import functools 9 | import itertools 10 | from filelock import FileLock 11 | __all__ = [ 12 | 'ExParser', 13 | 'simpleroot', 14 | ] 15 | 16 | 17 | TIME_FORMAT_DIR = '%Y-%m-%d-%H-%M-%S' 18 | TIME_FORMAT = '%Y-%m-%dT%H:%M:%S' 19 | DIR_FORMAT = '{num}' 20 | EXT = 'yaml' 21 | PARAMS_FILE = 'params.'+EXT 22 | FOLDER_DEFAULT = 'exman' 23 | RESERVED_DIRECTORIES = { 24 | 'runs', 'index', 25 | 'tmp', 'marked' 26 | } 27 | 28 | 29 | def yaml_file(name): 30 | return name + '.' + EXT 31 | 32 | 33 | def simpleroot(__file__): 34 | return pathlib.Path(os.path.dirname(os.path.abspath(__file__)))/FOLDER_DEFAULT 35 | 36 | 37 | def represent_as_str(self, data, tostr=str): 38 | return yaml.representer.Representer.represent_str(self, tostr(data)) 39 | 40 | 41 | def register_str_converter(*types, tostr=str): 42 | for T in types: 43 | yaml.add_representer(T, functools.partial(represent_as_str, tostr=tostr)) 44 | 45 | 46 | register_str_converter(pathlib.PosixPath, pathlib.WindowsPath) 47 | 48 | 49 | def str2bool(s): 50 | true = ('true', 't', 'yes', 'y', 'on', '1') 51 | false = ('false', 'f', 'no', 'n', 'off', '0') 52 | 53 | if s.lower() in true: 54 | return True 55 | elif s.lower() in false: 56 | return False 57 | else: 58 | raise argparse.ArgumentTypeError(s, 'bool argument should be one of {}'.format(str(true + false))) 59 | 60 | 61 | class ParserWithRoot(configargparse.ArgumentParser): 62 | def __init__(self, *args, root=None, zfill=6, 63 | **kwargs): 64 | super().__init__(*args, **kwargs) 65 | if root is None: 66 | raise ValueError('Root directory is not specified') 67 | root = pathlib.Path(root) 68 | if not root.is_absolute(): 69 | raise ValueError(root, 'Root directory is not absolute path') 70 | if not root.exists(): 71 | raise ValueError(root, 'Root directory does not exist') 72 | self.root = pathlib.Path(root) 73 | self.zfill = zfill 74 | self.register('type', bool, str2bool) 75 | for directory in RESERVED_DIRECTORIES: 76 | getattr(self, directory).mkdir(exist_ok=True) 77 | self.lock = FileLock(str(self.root/'lock')) 78 | 79 | @property 80 | def runs(self): 81 | return self.root / 'runs' 82 | 83 | @property 84 | def marked(self): 85 | return self.root / 'marked' 86 | 87 | @property 88 | def index(self): 89 | return self.root / 'index' 90 | 91 | @property 92 | def tmp(self): 93 | return self.root / 'tmp' 94 | 95 | def max_ex(self): 96 | max_num = 0 97 | for directory in itertools.chain(self.runs.iterdir(), self.tmp.iterdir()): 98 | num = int(directory.name.split('-', 1)[0]) 99 | if num > max_num: 100 | max_num = num 101 | return max_num 102 | 103 | def num_ex(self): 104 | return len(list(self.runs.iterdir())) 105 | 106 | def next_ex(self): 107 | return self.max_ex() + 1 108 | 109 | def next_ex_str(self): 110 | return str(self.next_ex()).zfill(self.zfill) 111 | 112 | 113 | class ExParser(ParserWithRoot): 114 | """ 115 | Parser responsible for creating the following structure of experiments 116 | ``` 117 | root 118 | |-- runs 119 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS 120 | | |-- params.yaml 121 | | `-- ... 122 | |-- index 123 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS.yaml (symlink) 124 | |-- marked 125 | | `-- 126 | | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS (symlink) 127 | | |-- params.yaml 128 | | `-- ... 129 | `-- tmp 130 | `-- xxxxxx-YYYY-mm-dd-HH-MM-SS 131 | |-- params.yaml 132 | `-- ... 133 | ``` 134 | """ 135 | def __init__(self, *args, zfill=6, file=None, 136 | args_for_setting_config_path=('--config', ), 137 | automark=(), 138 | parents=[], 139 | **kwargs): 140 | 141 | root = os.path.join(os.path.abspath(os.environ.get('EXMAN_PATH', './logs')), ('exman-' + str(file))) 142 | if not os.path.exists(root): 143 | os.makedirs(root) 144 | 145 | if len(parents) == 1: 146 | self.yaml_params_path = parents[0].yaml_params_path 147 | root = parents[0].root 148 | 149 | super().__init__(*args, root=root, zfill=zfill, 150 | args_for_setting_config_path=args_for_setting_config_path, 151 | config_file_parser_class=configargparse.YAMLConfigFileParser, 152 | ignore_unknown_config_file_keys=True, 153 | parents=parents, 154 | formatter_class=argparse.ArgumentDefaultsHelpFormatter, 155 | **kwargs) 156 | self.automark = automark 157 | if len(parents) == 0: 158 | self.add_argument('--tmp', action='store_true') 159 | 160 | def _initialize_dir(self, tmp): 161 | try: 162 | # with self.lock: # different processes can make it same time, this is needed to avoid collision 163 | time = datetime.datetime.now() 164 | num = self.next_ex_str() 165 | name = DIR_FORMAT.format(num=num, time=time.strftime(TIME_FORMAT_DIR)) 166 | if tmp: 167 | absroot = self.tmp / name 168 | relroot = pathlib.Path('tmp') / name 169 | else: 170 | absroot = self.runs / name 171 | relroot = pathlib.Path('runs') / name 172 | # this process now safely owns root directory 173 | # raises FileExistsError on fail 174 | absroot.mkdir() 175 | except FileExistsError: # shit still happens 176 | return self._initialize_dir(tmp) 177 | return absroot, relroot, name, time, num 178 | 179 | def parse_known_args(self, *args, log_params=True, **kwargs): 180 | args, argv = super().parse_known_args(*args, **kwargs) 181 | if not log_params: 182 | return args, argv 183 | 184 | if hasattr(self, 'yaml_params_path'): 185 | with self.yaml_params_path.open('w') as f: 186 | self.dumpd = args.__dict__.copy() 187 | yaml.dump(self.dumpd, f, default_flow_style=False) 188 | print("\ntime: '{}'".format(self.time.strftime(TIME_FORMAT)), file=f) 189 | print("id:", int(self.num), file=f) 190 | print(self.yaml_params_path.read_text()) 191 | return args, argv 192 | 193 | absroot, relroot, name, time, num = self._initialize_dir(args.tmp) 194 | self.time = time 195 | self.num = num 196 | args.root = absroot 197 | self.yaml_params_path = args.root / PARAMS_FILE 198 | rel_yaml_params_path = pathlib.Path('..', 'runs', name, PARAMS_FILE) 199 | with self.yaml_params_path.open('a') as f: 200 | self.dumpd = args.__dict__.copy() 201 | # dumpd['root'] = relroot 202 | yaml.dump(self.dumpd, f, default_flow_style=False) 203 | print("\ntime: '{}'".format(time.strftime(TIME_FORMAT)), file=f) 204 | print("id:", int(num), file=f) 205 | print(self.yaml_params_path.read_text()) 206 | symlink = self.index / yaml_file(name) 207 | if not args.tmp: 208 | symlink.symlink_to(rel_yaml_params_path) 209 | print('Created symlink from', symlink, '->', rel_yaml_params_path) 210 | if self.automark and not args.tmp: 211 | automark_path_part = pathlib.Path(*itertools.chain.from_iterable( 212 | (mark, str(getattr(args, mark, ''))) 213 | for mark in self.automark)) 214 | markpath = pathlib.Path(self.marked, automark_path_part) 215 | markpath.mkdir(exist_ok=True, parents=True) 216 | relpathmark = pathlib.Path('..', *(['..']*len(automark_path_part.parts))) / 'runs' / name 217 | (markpath / name).symlink_to(relpathmark, target_is_directory=True) 218 | print('Created symlink from', markpath / name, '->', relpathmark) 219 | return args, argv 220 | 221 | def done(self): 222 | print('Success.') 223 | self.dumpd['status'] = 'done' 224 | with self.yaml_params_path.open('a') as f: 225 | yaml.dump(self.dumpd, f, default_flow_style=False) 226 | 227 | def update_params_file(self, args): 228 | dumpd = args.__dict__.copy() 229 | with self.yaml_params_path.open('w') as f: 230 | yaml.dump(dumpd, f, default_flow_style=False) 231 | print("\ntime: '{}'".format(self.time.strftime(TIME_FORMAT)), file=f) 232 | print("id:", int(self.num), file=f) 233 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import time 4 | 5 | import torch 6 | import torch.backends.cudnn as cudnn 7 | import models 8 | from utils.logger import Logger 9 | import myexman 10 | from utils import utils 11 | import sys 12 | import torch.multiprocessing as mp 13 | import torch.distributed as dist 14 | import socket 15 | 16 | 17 | def add_learner_params(parser): 18 | parser.add_argument('--problem', default='sim-clr', 19 | help='The problem to train', 20 | choices=models.REGISTERED_MODELS, 21 | ) 22 | parser.add_argument('--name', default='', 23 | help='Name for the experiment', 24 | ) 25 | parser.add_argument('--ckpt', default='', 26 | help='Optional checkpoint to init the model.' 27 | ) 28 | parser.add_argument('--verbose', default=False, type=bool) 29 | # optimizer params 30 | parser.add_argument('--lr_schedule', default='warmup-anneal') 31 | parser.add_argument('--opt', default='lars', help='Optimizer to use', choices=['sgd', 'adam', 'lars']) 32 | parser.add_argument('--iters', default=-1, type=int, help='The number of optimizer updates') 33 | parser.add_argument('--warmup', default=0, type=float, help='The number of warmup iterations in proportion to \'iters\'') 34 | parser.add_argument('--lr', default=0.1, type=float, help='Base learning rate') 35 | parser.add_argument('--wd', '--weight_decay', default=1e-4, type=float, dest='weight_decay') 36 | # trainer params 37 | parser.add_argument('--save_freq', default=10000000000000000, type=int, help='Frequency to save the model') 38 | parser.add_argument('--log_freq', default=100, type=int, help='Logging frequency') 39 | parser.add_argument('--eval_freq', default=10000000000000000, type=int, help='Evaluation frequency') 40 | parser.add_argument('-j', '--workers', default=4, type=int, help='The number of data loader workers') 41 | parser.add_argument('--eval_only', default=False, type=bool, help='Skips the training step if True') 42 | parser.add_argument('--seed', default=-1, type=int, help='Random seed') 43 | # parallelizm params: 44 | parser.add_argument('--dist', default='dp', type=str, 45 | help='dp: DataParallel, ddp: DistributedDataParallel', 46 | choices=['dp', 'ddp'], 47 | ) 48 | parser.add_argument('--dist_address', default='127.0.0.1:1234', type=str, 49 | help='the address and a port of the main node in the
: format' 50 | ) 51 | parser.add_argument('--node_rank', default=0, type=int, 52 | help='Rank of the node (script launched): 0 for the main node and 1,... for the others', 53 | ) 54 | parser.add_argument('--world_size', default=1, type=int, 55 | help='the number of nodes (scripts launched)', 56 | ) 57 | parser.add_argument('--resume_path', default='./latest.pt', type=str, 58 | ) 59 | parser.add_argument('--d_ratio', default=-1, type=float, help='use projected features') 60 | 61 | def main(): 62 | parser = myexman.ExParser(file=os.path.basename(__file__)) 63 | add_learner_params(parser) 64 | 65 | is_help = False 66 | if '--help' in sys.argv or '-h' in sys.argv: 67 | sys.argv.pop(sys.argv.index('--help' if '--help' in sys.argv else '-h')) 68 | is_help = True 69 | 70 | args, _ = parser.parse_known_args(log_params=False) 71 | 72 | models.REGISTERED_MODELS[args.problem].add_model_hparams(parser) 73 | 74 | if is_help: 75 | sys.argv.append('--help') 76 | 77 | args = parser.parse_args(namespace=args) 78 | 79 | if args.data == 'imagenet' and args.aug == False: 80 | raise Exception('ImageNet models should be eval with aug=True!') 81 | 82 | if args.seed != -1: 83 | random.seed(args.seed) 84 | torch.manual_seed(args.seed) 85 | cudnn.deterministic = True 86 | 87 | args.gpu = 0 88 | ngpus = torch.cuda.device_count() 89 | args.number_of_processes = 1 90 | if args.dist == 'ddp': 91 | # add additional argument to be able to retrieve # of processes from logs 92 | # and don't change initial arguments to reproduce the experiment 93 | args.number_of_processes = args.world_size * ngpus 94 | parser.update_params_file(args) 95 | 96 | args.world_size *= ngpus 97 | mp.spawn( 98 | main_worker, 99 | nprocs=ngpus, 100 | args=(ngpus, args), 101 | ) 102 | else: 103 | parser.update_params_file(args) 104 | main_worker(args.gpu, -1, args) 105 | 106 | 107 | def main_worker(gpu, ngpus, args): 108 | fmt = { 109 | 'train_time': '.3f', 110 | 'val_time': '.3f', 111 | 'lr': '.1e', 112 | } 113 | logger = Logger('logs', base=args.root, fmt=fmt) 114 | 115 | args.gpu = gpu 116 | torch.cuda.set_device(gpu) 117 | args.rank = args.node_rank * ngpus + gpu 118 | 119 | device = torch.device('cuda:%d' % args.gpu) 120 | 121 | if args.dist == 'ddp': 122 | dist.init_process_group( 123 | backend='nccl', 124 | init_method='tcp://%s' % args.dist_address, 125 | world_size=args.world_size, 126 | rank=args.rank, 127 | ) 128 | 129 | n_gpus_total = dist.get_world_size() 130 | assert args.batch_size % n_gpus_total == 0 131 | args.batch_size //= n_gpus_total 132 | 133 | # create model 134 | model = models.REGISTERED_MODELS[args.problem](args, device=device) 135 | if args.ckpt != '': 136 | ckpt = torch.load(args.ckpt, map_location=device) 137 | model.load_state_dict(ckpt['state_dict']) 138 | 139 | # Data loading code 140 | model.prepare_data() 141 | train_loader, val_loader = model.dataloaders(iters=args.iters) 142 | 143 | # define optimizer 144 | cur_iter = 0 145 | optimizer, scheduler = models.ssl.configure_optimizers(args, model, cur_iter - 1) 146 | # optionally resume from a checkpoint 147 | if args.ckpt and not args.eval_only: 148 | optimizer.load_state_dict(ckpt['opt_state_dict']) 149 | 150 | cudnn.benchmark = True 151 | 152 | continue_training = args.iters != 0 153 | data_time, it_time = 0, 0 154 | print(cur_iter) 155 | while continue_training: 156 | train_logs = [] 157 | model.train() 158 | 159 | start_time = time.time() 160 | for _, batch in enumerate(train_loader): 161 | cur_iter += 1 162 | 163 | batch = [x.to(device) for x in batch] 164 | data_time += time.time() - start_time 165 | 166 | logs = {} 167 | if not args.eval_only: 168 | # forward pass and compute loss 169 | logs = model.train_step(batch, cur_iter,cur_iter,args.d_ratio) 170 | loss = logs['loss'] 171 | 172 | # gradient step 173 | optimizer.zero_grad() 174 | loss.backward() 175 | optimizer.step() 176 | # save logs for the batch 177 | train_logs.append({k: utils.tonp(v) for k, v in logs.items()}) 178 | 179 | if cur_iter % args.save_freq == 0 and args.rank == 0: 180 | save_checkpoint(args.root, model, optimizer, cur_iter) 181 | 182 | if cur_iter % args.eval_freq == 0 or cur_iter >= args.iters: 183 | # TODO: aggregate metrics over all processes 184 | test_logs = [] 185 | model.eval() 186 | with torch.no_grad(): 187 | for batch in val_loader: 188 | batch = [x.to(device) for x in batch] 189 | # forward pass 190 | logs = model.test_step(batch) 191 | # save logs for the batch 192 | test_logs.append(logs) 193 | model.train() 194 | 195 | test_logs = utils.agg_all_metrics(test_logs) 196 | logger.add_logs(cur_iter, test_logs, pref='test_') 197 | 198 | it_time += time.time() - start_time 199 | 200 | if (cur_iter % args.log_freq == 0 or cur_iter >= args.iters) and args.rank == 0: 201 | save_checkpoint(args.root, model, optimizer) 202 | train_logs = utils.agg_all_metrics(train_logs) 203 | 204 | logger.add_logs(cur_iter, train_logs, pref='train_') 205 | logger.add_scalar(cur_iter, 'lr', optimizer.param_groups[0]['lr']) 206 | logger.add_scalar(cur_iter, 'data_time', data_time) 207 | logger.add_scalar(cur_iter, 'it_time', it_time) 208 | logger.iter_info() 209 | logger.save() 210 | 211 | data_time, it_time = 0, 0 212 | train_logs = [] 213 | 214 | if scheduler is not None: 215 | scheduler.step() 216 | 217 | if cur_iter >= args.iters: 218 | continue_training = False 219 | break 220 | 221 | start_time = time.time() 222 | 223 | save_checkpoint(args.root, model, optimizer) 224 | 225 | if args.dist == 'ddp': 226 | dist.destroy_process_group() 227 | 228 | 229 | def save_checkpoint(path, model, optimizer, cur_iter=None): 230 | if cur_iter is None: 231 | fname = os.path.join(path, 'checkpoint.pth.tar') 232 | else: 233 | fname = os.path.join(path, 'checkpoint-%d.pth.tar' % cur_iter) 234 | 235 | ckpt = model.get_ckpt() 236 | ckpt.update( 237 | { 238 | 'opt_state_dict': optimizer.state_dict(), 239 | 'iter': cur_iter, 240 | } 241 | ) 242 | 243 | torch.save(ckpt, fname) 244 | 245 | 246 | if __name__ == '__main__': 247 | main() 248 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/datautils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/datautils.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/datautils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/datautils.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/lars_optimizer.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/lars_optimizer.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/lars_optimizer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/lars_optimizer.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/logger.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/logger.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/utils.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/utils.cpython-36.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/GRAPH-ATTENTION/utils/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/datautils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import os 4 | from torchvision import transforms 5 | import torch.utils.data 6 | import PIL 7 | import torchvision.transforms.functional as FT 8 | from PIL import Image 9 | 10 | 11 | if 'DATA_ROOT' in os.environ: 12 | DATA_ROOT = os.environ['DATA_ROOT'] 13 | else: 14 | DATA_ROOT = './data' 15 | 16 | IMAGENET_PATH = './data/imagenet/raw-data' 17 | 18 | 19 | def pad(img, size, mode): 20 | if isinstance(img, PIL.Image.Image): 21 | img = np.array(img) 22 | return np.pad(img, [(size, size), (size, size), (0, 0)], mode) 23 | 24 | 25 | mean = { 26 | 'mnist': (0.1307,), 27 | 'cifar10': (0.4914, 0.4822, 0.4465) 28 | } 29 | 30 | std = { 31 | 'mnist': (0.3081,), 32 | 'cifar10': (0.2470, 0.2435, 0.2616) 33 | } 34 | 35 | 36 | class GaussianBlur(object): 37 | """ 38 | PyTorch version of 39 | https://github.com/google-research/simclr/blob/244e7128004c5fd3c7805cf3135c79baa6c3bb96/data_util.py#L311 40 | """ 41 | def gaussian_blur(self, image, sigma): 42 | image = image.reshape(1, 3, 224, 224) 43 | radius = np.int(self.kernel_size/2) 44 | kernel_size = radius * 2 + 1 45 | x = np.arange(-radius, radius + 1) 46 | 47 | blur_filter = np.exp( 48 | -np.power(x, 2.0) / (2.0 * np.power(np.float(sigma), 2.0))) 49 | blur_filter /= np.sum(blur_filter) 50 | 51 | conv1 = torch.nn.Conv2d(3, 3, kernel_size=(kernel_size, 1), groups=3, padding=[kernel_size//2, 0], bias=False) 52 | conv1.weight = torch.nn.Parameter( 53 | torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 0, 1]))) 54 | 55 | conv2 = torch.nn.Conv2d(3, 3, kernel_size=(1, kernel_size), groups=3, padding=[0, kernel_size//2], bias=False) 56 | conv2.weight = torch.nn.Parameter( 57 | torch.Tensor(np.tile(blur_filter.reshape(kernel_size, 1, 1, 1), 3).transpose([3, 2, 1, 0]))) 58 | 59 | res = conv2(conv1(image)) 60 | assert res.shape == image.shape 61 | return res[0] 62 | 63 | def __init__(self, kernel_size, p=0.5): 64 | self.kernel_size = kernel_size 65 | self.p = p 66 | 67 | def __call__(self, img): 68 | with torch.no_grad(): 69 | assert isinstance(img, torch.Tensor) 70 | if np.random.uniform() < self.p: 71 | return self.gaussian_blur(img, sigma=np.random.uniform(0.2, 2)) 72 | return img 73 | 74 | def __repr__(self): 75 | return self.__class__.__name__ + '(kernel_size={0}, p={1})'.format(self.kernel_size, self.p) 76 | 77 | class CenterCropAndResize(object): 78 | """Crops the given PIL Image at the center. 79 | 80 | Args: 81 | size (sequence or int): Desired output size of the crop. If size is an 82 | int instead of sequence like (h, w), a square crop (size, size) is 83 | made. 84 | """ 85 | 86 | def __init__(self, proportion, size): 87 | self.proportion = proportion 88 | self.size = size 89 | 90 | def __call__(self, img): 91 | """ 92 | Args: 93 | img (PIL Image): Image to be cropped. 94 | 95 | Returns: 96 | PIL Image: Cropped and image. 97 | """ 98 | w, h = (np.array(img.size) * self.proportion).astype(int) 99 | img = FT.resize( 100 | FT.center_crop(img, (h, w)), 101 | (self.size, self.size), 102 | interpolation=PIL.Image.BICUBIC 103 | ) 104 | return img 105 | 106 | def __repr__(self): 107 | return self.__class__.__name__ + '(proportion={0}, size={1})'.format(self.proportion, self.size) 108 | 109 | 110 | class Clip(object): 111 | def __call__(self, x): 112 | return torch.clamp(x, 0, 1) 113 | 114 | 115 | class MultiplyBatchSampler(torch.utils.data.sampler.BatchSampler): 116 | MULTILPLIER = 2 117 | 118 | def __iter__(self): 119 | for batch in super().__iter__(): 120 | yield batch * self.MULTILPLIER 121 | 122 | 123 | class ContinousSampler(torch.utils.data.sampler.Sampler): 124 | def __init__(self, sampler, n_iterations): 125 | self.base_sampler = sampler 126 | self.n_iterations = n_iterations 127 | 128 | def __iter__(self): 129 | cur_iter = 0 130 | while cur_iter < self.n_iterations: 131 | for batch in self.base_sampler: 132 | yield batch 133 | cur_iter += 1 134 | if cur_iter >= self.n_iterations: return 135 | 136 | def __len__(self): 137 | return self.n_iterations 138 | 139 | def set_epoch(self, epoch): 140 | self.base_sampler.set_epoch(epoch) 141 | 142 | 143 | def get_color_distortion(s=1.0): 144 | # s is the strength of color distortion. 145 | # given from https://arxiv.org/pdf/2002.05709.pdf 146 | color_jitter = transforms.ColorJitter(0.8*s, 0.8*s, 0.8*s, 0.2*s) 147 | rnd_color_jitter = transforms.RandomApply([color_jitter], p=0.8) 148 | rnd_gray = transforms.RandomGrayscale(p=0.2) 149 | color_distort = transforms.Compose([ 150 | rnd_color_jitter, 151 | rnd_gray]) 152 | return color_distort 153 | 154 | 155 | class DummyOutputWrapper(torch.utils.data.dataset.Dataset): 156 | def __init__(self, dataset, dummy): 157 | self.dummy = dummy 158 | self.dataset = dataset 159 | 160 | def __getitem__(self, index): 161 | return (*self.dataset[index], self.dummy) 162 | 163 | def __len__(self): 164 | return len(self.dataset) 165 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/lars_optimizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch.autograd import Variable 4 | from torch.nn.parameter import Parameter 5 | 6 | 7 | class LARS(object): 8 | """ 9 | Slight modification of LARC optimizer from https://github.com/NVIDIA/apex/blob/d74fda260c403f775817470d87f810f816f3d615/apex/parallel/LARC.py 10 | Matches one from SimCLR implementation https://github.com/google-research/simclr/blob/master/lars_optimizer.py 11 | 12 | Args: 13 | optimizer: Pytorch optimizer to wrap and modify learning rate for. 14 | trust_coefficient: Trust coefficient for calculating the adaptive lr. See https://arxiv.org/abs/1708.03888 15 | """ 16 | 17 | def __init__(self, 18 | optimizer, 19 | trust_coefficient=0.001, 20 | ): 21 | self.param_groups = optimizer.param_groups 22 | self.optim = optimizer 23 | self.trust_coefficient = trust_coefficient 24 | 25 | def __getstate__(self): 26 | return self.optim.__getstate__() 27 | 28 | def __setstate__(self, state): 29 | self.optim.__setstate__(state) 30 | 31 | def __repr__(self): 32 | return self.optim.__repr__() 33 | 34 | def state_dict(self): 35 | return self.optim.state_dict() 36 | 37 | def load_state_dict(self, state_dict): 38 | self.optim.load_state_dict(state_dict) 39 | 40 | def zero_grad(self): 41 | self.optim.zero_grad() 42 | 43 | def add_param_group(self, param_group): 44 | self.optim.add_param_group(param_group) 45 | 46 | def step(self): 47 | with torch.no_grad(): 48 | weight_decays = [] 49 | for group in self.optim.param_groups: 50 | # absorb weight decay control from optimizer 51 | weight_decay = group['weight_decay'] if 'weight_decay' in group else 0 52 | weight_decays.append(weight_decay) 53 | group['weight_decay'] = 0 54 | for p in group['params']: 55 | if p.grad is None: 56 | continue 57 | 58 | if weight_decay != 0: 59 | p.grad.data += weight_decay * p.data 60 | 61 | param_norm = torch.norm(p.data) 62 | grad_norm = torch.norm(p.grad.data) 63 | adaptive_lr = 1. 64 | 65 | if param_norm != 0 and grad_norm != 0 and group['layer_adaptation']: 66 | adaptive_lr = self.trust_coefficient * param_norm / grad_norm 67 | 68 | p.grad.data *= adaptive_lr 69 | 70 | self.optim.step() 71 | # return weight decay control to optimizer 72 | for i, group in enumerate(self.optim.param_groups): 73 | group['weight_decay'] = weight_decays[i] 74 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import random 4 | import numpy as np 5 | 6 | from collections import OrderedDict 7 | from tabulate import tabulate 8 | from pandas import DataFrame 9 | from time import gmtime, strftime 10 | import time 11 | 12 | 13 | class Logger: 14 | def __init__(self, name='name', fmt=None, base='./logs'): 15 | self.handler = True 16 | self.scalar_metrics = OrderedDict() 17 | self.fmt = fmt if fmt else dict() 18 | 19 | if not os.path.exists(base): 20 | os.makedirs(base) 21 | 22 | time = gmtime() 23 | hash = ''.join([chr(random.randint(97, 122)) for _ in range(3)]) 24 | fname = '-'.join(sys.argv[0].split('/')[-3:]) 25 | # self.path = '%s/%s-%s-%s-%s' % (base, fname, name, hash, strftime('%m-%d-%H:%M', time)) 26 | # self.path = '%s/%s-%s' % (base, fname, name) 27 | self.path = os.path.join(base, name) 28 | 29 | self.logs = self.path + '.csv' 30 | self.output = self.path + '.out' 31 | self.iters_since_last_header = 0 32 | 33 | def prin(*args): 34 | str_to_write = ' '.join(map(str, args)) 35 | with open(self.output, 'a') as f: 36 | f.write(str_to_write + '\n') 37 | f.flush() 38 | 39 | print(str_to_write) 40 | sys.stdout.flush() 41 | 42 | self.print = prin 43 | 44 | def add_scalar(self, t, key, value): 45 | if key not in self.scalar_metrics: 46 | self.scalar_metrics[key] = [] 47 | self.scalar_metrics[key] += [(t, value)] 48 | 49 | def add_logs(self, t, logs, pref=''): 50 | for k, v in logs.items(): 51 | self.add_scalar(t, pref + k, v) 52 | 53 | def iter_info(self, order=None): 54 | self.iters_since_last_header += 1 55 | if self.iters_since_last_header > 40: 56 | self.handler = True 57 | 58 | names = list(self.scalar_metrics.keys()) 59 | if order: 60 | names = order 61 | values = [self.scalar_metrics[name][-1][1] for name in names] 62 | t = int(np.max([self.scalar_metrics[name][-1][0] for name in names])) 63 | fmt = ['%s'] + [self.fmt[name] if name in self.fmt else '.3f' for name in names] 64 | 65 | if self.handler: 66 | self.handler = False 67 | self.iters_since_last_header = 0 68 | self.print(tabulate([[t] + values], ['t'] + names, floatfmt=fmt)) 69 | else: 70 | self.print(tabulate([[t] + values], ['t'] + names, tablefmt='plain', floatfmt=fmt).split('\n')[1]) 71 | 72 | def save(self): 73 | result = None 74 | for key in self.scalar_metrics.keys(): 75 | if result is None: 76 | result = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') 77 | else: 78 | df = DataFrame(self.scalar_metrics[key], columns=['t', key]).set_index('t') 79 | result = result.join(df, how='outer') 80 | result.to_csv(self.logs) 81 | -------------------------------------------------------------------------------- /GRAPH-ATTENTION/utils/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import warnings 4 | import time 5 | import torch.distributed as dist 6 | 7 | 8 | def timing(f): 9 | def wrap(*args, **kwargs): 10 | time1 = time.time() 11 | ret = f(*args, **kwargs) 12 | time2 = time.time() 13 | print('{:s} function took {:.3f} ms'.format(f.__name__, (time2-time1)*1000.0)) 14 | 15 | return ret 16 | return wrap 17 | 18 | 19 | def agg_all_metrics(outputs): 20 | if len(outputs) == 0: 21 | return outputs 22 | res = {} 23 | keys = [k for k in outputs[0].keys() if not isinstance(outputs[0][k], dict)] 24 | for k in keys: 25 | all_logs = np.concatenate([tonp(x[k]).reshape(-1) for x in outputs]) 26 | if k != 'epoch': 27 | res[k] = np.mean(all_logs) 28 | else: 29 | res[k] = all_logs[-1] 30 | return res 31 | 32 | 33 | def gather_metrics(metrics): 34 | for k, v in metrics.items(): 35 | if v.dim() == 0: 36 | v = v[None] 37 | v_all = [torch.zeros_like(v) for _ in range(dist.get_world_size())] 38 | dist.all_gather(v_all, v) 39 | v_all = torch.cat(v_all) 40 | metrics[k] = v_all 41 | 42 | 43 | def viz_array_grid(array, rows, cols, padding=0, channels_last=False, normalize=False, **kwargs): 44 | # normalization 45 | ''' 46 | Args: 47 | array: (N_images, N_channels, H, W) or (N_images, H, W, N_channels) 48 | rows, cols: rows and columns of the plot. rows * cols == array.shape[0] 49 | padding: padding between cells of plot 50 | channels_last: for Tensorflow = True, for PyTorch = False 51 | normalize: `False`, `mean_std`, or `min_max` 52 | Kwargs: 53 | if normalize == 'mean_std': 54 | mean: mean of the distribution. Default 0.5 55 | std: std of the distribution. Default 0.5 56 | if normalize == 'min_max': 57 | min: min of the distribution. Default array.min() 58 | max: max if the distribution. Default array.max() 59 | ''' 60 | array = tonp(array) 61 | if not channels_last: 62 | array = np.transpose(array, (0, 2, 3, 1)) 63 | 64 | array = array.astype('float32') 65 | 66 | if normalize: 67 | if normalize == 'mean_std': 68 | mean = kwargs.get('mean', 0.5) 69 | mean = np.array(mean).reshape((1, 1, 1, -1)) 70 | std = kwargs.get('std', 0.5) 71 | std = np.array(std).reshape((1, 1, 1, -1)) 72 | array = array * std + mean 73 | elif normalize == 'min_max': 74 | min_ = kwargs.get('min', array.min()) 75 | min_ = np.array(min_).reshape((1, 1, 1, -1)) 76 | max_ = kwargs.get('max', array.max()) 77 | max_ = np.array(max_).reshape((1, 1, 1, -1)) 78 | array -= min_ 79 | array /= max_ + 1e-9 80 | 81 | batch_size, H, W, channels = array.shape 82 | assert rows * cols == batch_size 83 | 84 | if channels == 1: 85 | canvas = np.ones((H * rows + padding * (rows - 1), 86 | W * cols + padding * (cols - 1))) 87 | array = array[:, :, :, 0] 88 | elif channels == 3: 89 | canvas = np.ones((H * rows + padding * (rows - 1), 90 | W * cols + padding * (cols - 1), 91 | 3)) 92 | else: 93 | raise TypeError('number of channels is either 1 of 3') 94 | 95 | for i in range(rows): 96 | for j in range(cols): 97 | img = array[i * cols + j] 98 | start_h = i * padding + i * H 99 | start_w = j * padding + j * W 100 | canvas[start_h: start_h + H, start_w: start_w + W] = img 101 | 102 | canvas = np.clip(canvas, 0, 1) 103 | canvas *= 255.0 104 | canvas = canvas.astype('uint8') 105 | return canvas 106 | 107 | 108 | def tonp(x): 109 | if isinstance(x, (np.ndarray, float, int)): 110 | return np.array(x) 111 | return x.detach().cpu().numpy() 112 | 113 | 114 | class LinearLR(torch.optim.lr_scheduler._LRScheduler): 115 | def __init__(self, optimizer, num_epochs, last_epoch=-1): 116 | self.num_epochs = max(num_epochs, 1) 117 | super().__init__(optimizer, last_epoch) 118 | 119 | def get_lr(self): 120 | res = [] 121 | for lr in self.base_lrs: 122 | res.append(np.maximum(lr * np.minimum(-self.last_epoch * 1. / self.num_epochs + 1., 1.), 0.)) 123 | return res 124 | 125 | 126 | class LinearWarmupAndCosineAnneal(torch.optim.lr_scheduler._LRScheduler): 127 | def __init__(self, optimizer, warm_up, T_max, last_epoch=-1): 128 | self.warm_up = int(warm_up * T_max) 129 | self.T_max = T_max - self.warm_up 130 | super().__init__(optimizer, last_epoch=last_epoch) 131 | 132 | def get_lr(self): 133 | if not self._get_lr_called_within_step: 134 | warnings.warn("To get the last learning rate computed by the scheduler, " 135 | "please use `get_last_lr()`.") 136 | 137 | if self.last_epoch == 0: 138 | return [lr / (self.warm_up + 1) for lr in self.base_lrs] 139 | elif self.last_epoch <= self.warm_up: 140 | c = (self.last_epoch + 1) / self.last_epoch 141 | return [group['lr'] * c for group in self.optimizer.param_groups] 142 | else: 143 | # ref: https://github.com/pytorch/pytorch/blob/2de4f245c6b1e1c294a8b2a9d7f916d43380af4b/torch/optim/lr_scheduler.py#L493 144 | le = self.last_epoch - self.warm_up 145 | return [(1 + np.cos(np.pi * le / self.T_max)) / 146 | (1 + np.cos(np.pi * (le - 1) / self.T_max)) * 147 | group['lr'] 148 | for group in self.optimizer.param_groups] 149 | 150 | 151 | class BaseLR(torch.optim.lr_scheduler._LRScheduler): 152 | def get_lr(self): 153 | return [group['lr'] for group in self.optimizer.param_groups] 154 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/=0.5.0: -------------------------------------------------------------------------------- 1 | Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple 2 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/README_simsiam.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # SimSiam 4 | 5 | 6 | 7 | 8 | 9 | ## Requirements 10 | * torch 11 | * torchvision 12 | * tqdm 13 | * einops 14 | * wandb 15 | * pytorch-lightning 16 | * lightning-bolts 17 | * torchmetrics 18 | * scipy 19 | * timm 20 | 21 | **Optional**: 22 | * nvidia-dali 23 | * matplotlib 24 | * seaborn 25 | * pandas 26 | * umap-learn 27 | 28 | --- 29 | 30 | ## Enviroment Setup 31 | 32 | First clone the repo. 33 | 34 | Then, to install solo-learn with [Dali](https://github.com/NVIDIA/DALI) and/or UMAP support, use: 35 | ``` 36 | pip3 install .[dali,umap,h5] --extra-index-url https://developer.download.nvidia.com/compute/redist 37 | ``` 38 | 39 | If no Dali/UMAP/H5 support is needed, the repository can be installed as: 40 | ``` 41 | pip3 install . 42 | ``` 43 | 44 | For local development: 45 | ``` 46 | pip3 install -e .[umap,h5] 47 | # Make sure you have pre-commit hooks installed 48 | pre-commit install 49 | ``` 50 | 51 | **NOTE:** if you are having trouble with dali, install it following their [guide](https://github.com/NVIDIA/DALI). 52 | 53 | **NOTE 2:** consider installing [Pillow-SIMD](https://github.com/uploadcare/pillow-simd) for better loading times when not using Dali. 54 | 55 | **NOTE 3:** Soon to be on pip. 56 | 57 | --- 58 | 59 | ## Training 60 | 61 | For pretraining the backbone, follow one of the many bash files in `scripts/pretrain/`. 62 | 63 | After that, for offline linear evaluation, follow the examples in `scripts/linear`. 64 | 65 | 66 | **NOTE:** Files try to be up-to-date and follow as closely as possible the recommended parameters of each paper, but check them before running. 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/main_pretrain.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import os 21 | from pprint import pprint 22 | 23 | import torch 24 | from pytorch_lightning import Trainer, seed_everything 25 | from pytorch_lightning.callbacks import LearningRateMonitor 26 | from pytorch_lightning.loggers import WandbLogger 27 | from pytorch_lightning.strategies.ddp import DDPStrategy 28 | 29 | from solo.args.setup import parse_args_pretrain 30 | from solo.data.classification_dataloader import prepare_data as prepare_data_classification 31 | from solo.data.pretrain_dataloader import ( 32 | prepare_dataloader, 33 | prepare_datasets, 34 | prepare_n_crop_transform, 35 | prepare_transform, 36 | ) 37 | from solo.methods import METHODS 38 | from solo.utils.auto_resumer import AutoResumer 39 | from solo.utils.checkpointer import Checkpointer 40 | from solo.utils.misc import make_contiguous 41 | 42 | try: 43 | from solo.data.dali_dataloader import PretrainDALIDataModule 44 | except ImportError: 45 | _dali_avaliable = False 46 | else: 47 | _dali_avaliable = True 48 | 49 | try: 50 | from solo.utils.auto_umap import AutoUMAP 51 | except ImportError: 52 | _umap_available = False 53 | else: 54 | _umap_available = True 55 | 56 | 57 | def main(): 58 | seed_everything(5) 59 | 60 | args = parse_args_pretrain() 61 | 62 | assert args.method in METHODS, f"Choose from {METHODS.keys()}" 63 | 64 | if args.num_large_crops != 2: 65 | assert args.method in ["wmse", "mae"] 66 | 67 | model = METHODS[args.method](**args.__dict__) 68 | make_contiguous(model) 69 | # can provide up to ~20% speed up 70 | if not args.no_channel_last: 71 | model = model.to(memory_format=torch.channels_last) 72 | 73 | # validation dataloader for when it is available 74 | if args.dataset == "custom" and (args.no_labels or args.val_data_path is None): 75 | val_loader = None 76 | elif args.dataset in ["imagenet100", "imagenet"] and (args.val_data_path is None): 77 | val_loader = None 78 | else: 79 | if args.data_format == "dali": 80 | val_data_format = "image_folder" 81 | else: 82 | val_data_format = args.data_format 83 | 84 | _, val_loader = prepare_data_classification( 85 | args.dataset, 86 | train_data_path=args.train_data_path, 87 | val_data_path=args.val_data_path, 88 | data_format=val_data_format, 89 | batch_size=args.batch_size, 90 | num_workers=args.num_workers, 91 | ) 92 | 93 | # pretrain dataloader 94 | if args.data_format == "dali": 95 | assert ( 96 | _dali_avaliable 97 | ), "Dali is not currently avaiable, please install it first with pip3 install .[dali]." 98 | 99 | dali_datamodule = PretrainDALIDataModule( 100 | dataset=args.dataset, 101 | train_data_path=args.train_data_path, 102 | unique_augs=args.unique_augs, 103 | transform_kwargs=args.transform_kwargs, 104 | num_crops_per_aug=args.num_crops_per_aug, 105 | num_large_crops=args.num_large_crops, 106 | num_small_crops=args.num_small_crops, 107 | num_workers=args.num_workers, 108 | batch_size=args.batch_size, 109 | no_labels=args.no_labels, 110 | data_fraction=args.data_fraction, 111 | dali_device=args.dali_device, 112 | encode_indexes_into_labels=args.encode_indexes_into_labels, 113 | ) 114 | dali_datamodule.val_dataloader = lambda: val_loader 115 | else: 116 | transform_kwargs = ( 117 | args.transform_kwargs if args.unique_augs > 1 else [args.transform_kwargs] 118 | ) 119 | transform = prepare_n_crop_transform( 120 | [prepare_transform(args.dataset, **kwargs) for kwargs in transform_kwargs], 121 | num_crops_per_aug=args.num_crops_per_aug, 122 | ) 123 | 124 | if args.debug_augmentations: 125 | print("Transforms:") 126 | pprint(transform) 127 | 128 | train_dataset = prepare_datasets( 129 | args.dataset, 130 | transform, 131 | train_data_path=args.train_data_path, 132 | data_format=args.data_format, 133 | no_labels=args.no_labels, 134 | data_fraction=args.data_fraction, 135 | ) 136 | train_loader = prepare_dataloader( 137 | train_dataset, batch_size=args.batch_size, num_workers=args.num_workers 138 | ) 139 | 140 | # 1.7 will deprecate resume_from_checkpoint, but for the moment 141 | # the argument is the same, but we need to pass it as ckpt_path to trainer.fit 142 | ckpt_path, wandb_run_id = None, None 143 | if False and args.resume_from_checkpoint is None: 144 | auto_resumer = AutoResumer( 145 | checkpoint_dir=os.path.join(args.checkpoint_dir, args.method), 146 | max_hours=args.auto_resumer_max_hours, 147 | ) 148 | resume_from_checkpoint, wandb_run_id = auto_resumer.find_checkpoint(args) 149 | if resume_from_checkpoint is not None: 150 | print( 151 | "Resuming from previous checkpoint that matches specifications:", 152 | f"'{resume_from_checkpoint}'", 153 | ) 154 | ckpt_path = resume_from_checkpoint 155 | elif args.resume_from_checkpoint is not None: 156 | ckpt_path = args.resume_from_checkpoint 157 | del args.resume_from_checkpoint 158 | 159 | callbacks = [] 160 | 161 | if args.save_checkpoint: 162 | # save checkpoint on last epoch only 163 | ckpt = Checkpointer( 164 | args, 165 | logdir=os.path.join(args.checkpoint_dir, args.method), 166 | frequency=args.checkpoint_frequency, 167 | ) 168 | callbacks.append(ckpt) 169 | 170 | if args.auto_umap: 171 | assert ( 172 | _umap_available 173 | ), "UMAP is not currently avaiable, please install it first with [umap]." 174 | auto_umap = AutoUMAP( 175 | args, 176 | logdir=os.path.join(args.auto_umap_dir, args.method), 177 | frequency=args.auto_umap_frequency, 178 | ) 179 | callbacks.append(auto_umap) 180 | 181 | # wandb logging 182 | if args.wandb: 183 | wandb_logger = WandbLogger( 184 | name=args.name, 185 | project=args.project, 186 | entity=args.entity, 187 | offline=False, 188 | resume="allow" if wandb_run_id else None, 189 | id=wandb_run_id, 190 | ) 191 | wandb_logger.watch(model, log="gradients", log_freq=100) 192 | wandb_logger.log_hyperparams(args) 193 | 194 | # lr logging 195 | lr_monitor = LearningRateMonitor(logging_interval="step") 196 | callbacks.append(lr_monitor) 197 | 198 | trainer = Trainer.from_argparse_args( 199 | args, 200 | logger=wandb_logger if args.wandb else None, 201 | callbacks=callbacks, 202 | enable_checkpointing=False, 203 | strategy=DDPStrategy(find_unused_parameters=False) 204 | if args.strategy == "ddp" 205 | else args.strategy, 206 | ) 207 | 208 | # fix for incompatibility with nvidia-dali and pytorch lightning 209 | # with dali 1.15 (this will be fixed on 1.16) 210 | # https://github.com/Lightning-AI/lightning/issues/12956 211 | try: 212 | from pytorch_lightning.loops import FitLoop 213 | 214 | class WorkaroundFitLoop(FitLoop): 215 | @property 216 | def prefetch_batches(self) -> int: 217 | return 1 218 | 219 | trainer.fit_loop = WorkaroundFitLoop( 220 | trainer.fit_loop.min_epochs, trainer.fit_loop.max_epochs 221 | ) 222 | except: 223 | pass 224 | 225 | if args.data_format == "dali": 226 | trainer.fit(model, ckpt_path=ckpt_path, datamodule=dali_datamodule) 227 | else: 228 | trainer.fit(model, train_loader, val_loader, ckpt_path=ckpt_path) 229 | 230 | 231 | if __name__ == "__main__": 232 | main() 233 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.10.0 2 | torchvision>=0.11.1 3 | einops 4 | pytorch-lightning==1.6.4 5 | torchmetrics==0.6.0 6 | lightning-bolts>=0.5.0 7 | tqdm 8 | wandb 9 | scipy 10 | timm 11 | scikit-learn -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/scripts/linear/simsiam_linear.sh: -------------------------------------------------------------------------------- 1 | python3 main_linear.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet18 \ 4 | --train_data_path datasets/imagenet100/train \ 5 | --val_data_path datasets/imagenet100/val \ 6 | --max_epochs 100 \ 7 | --devices 0 \ 8 | --accelerator gpu \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler step \ 12 | --lr 30.0 \ 13 | --lr_decay_steps 60 80 \ 14 | --weight_decay 0 \ 15 | --batch_size 256 \ 16 | --num_workers 10 \ 17 | --data_format dali \ 18 | --name simsiam-imagenet100-linear-eval \ 19 | --pretrained_feature_extractor $1 \ 20 | --project solo-learn \ 21 | --entity zhangq327 \ 22 | --wandb \ 23 | --save_checkpoint \ 24 | --auto_resume 25 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/scripts/pretrain/cifar/simsiam.sh: -------------------------------------------------------------------------------- 1 | python3 main_pretrain.py \ 2 | --dataset $1 \ 3 | --backbone resnet18 \ 4 | --train_data_path ./datasets \ 5 | --val_data_path ./datasets \ 6 | --max_epochs 1000 \ 7 | --devices 6 \ 8 | --accelerator gpu \ 9 | --precision 16 \ 10 | --optimizer sgd \ 11 | --scheduler warmup_cosine \ 12 | --lr 0.5 \ 13 | --classifier_lr 0.1 \ 14 | --weight_decay 1e-5 \ 15 | --batch_size 256 \ 16 | --num_workers 4 \ 17 | --crop_size 32 \ 18 | --brightness 0.4 \ 19 | --contrast 0.4 \ 20 | --saturation 0.4 \ 21 | --hue 0.1 \ 22 | --gaussian_prob 0.0 0.0 \ 23 | --crop_size 32 \ 24 | --num_crops_per_aug 1 1 \ 25 | --zero_init_residual \ 26 | --name simsiam-$1 \ 27 | --project solo-learn \ 28 | --entity zhangq327 \ 29 | --wandb \ 30 | --save_checkpoint \ 31 | --auto_resume \ 32 | --method simsiam \ 33 | --proj_hidden_dim 2048 \ 34 | --pred_hidden_dim 512 \ 35 | --proj_output_dim 2048 36 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/scripts/pretrain/imagenet100/simsiam.sh: -------------------------------------------------------------------------------- 1 | CUDA_VISIBLE_DEVICES=5 python3 main_pretrain.py \ 2 | --dataset imagenet100 \ 3 | --backbone resnet50 \ 4 | --train_data_path datasets/imagenet100/train \ 5 | --val_data_path datasets/imagenet100/val \ 6 | --max_epochs 400 \ 7 | --devices 0\ 8 | --accelerator gpu \ 9 | --strategy ddp \ 10 | --sync_batchnorm \ 11 | --precision 16 \ 12 | --optimizer sgd \ 13 | --scheduler warmup_cosine \ 14 | --lr 0.5 \ 15 | --classifier_lr 0.1 \ 16 | --weight_decay 1e-5 \ 17 | --batch_size 128 \ 18 | --num_workers 4 \ 19 | --brightness 0.4 \ 20 | --contrast 0.4 \ 21 | --saturation 0.4 \ 22 | --hue 0.1 \ 23 | --num_crops_per_aug 2 \ 24 | --zero_init_residual \ 25 | --name simsiam-400ep-imagenet100 \ 26 | --data_format dali \ 27 | --entity zhangq327 \ 28 | --project solo-learn \ 29 | --wandb \ 30 | --save_checkpoint \ 31 | --auto_resume \ 32 | --method simsiam \ 33 | --proj_hidden_dim 2048 \ 34 | --pred_hidden_dim 512 \ 35 | --proj_output_dim 2048\ 36 | --data_format image_folder 37 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/scripts/utils/convert_imgfolder_to_h5.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | import os 22 | 23 | import h5py 24 | import numpy as np 25 | from tqdm import tqdm 26 | 27 | 28 | def convert_imgfolder_to_h5(folder_path: str, h5_path: str): 29 | """Converts image folder to a h5 dataset. 30 | 31 | Args: 32 | folder_path (str): path to the image folder. 33 | h5_path (str): output path of the h5 file. 34 | """ 35 | 36 | with h5py.File(h5_path, "w") as h5: 37 | classes = os.listdir(folder_path) 38 | for class_name in tqdm(classes, desc="Processing classes"): 39 | cur_folder = os.path.join(folder_path, class_name) 40 | class_group = h5.create_group(class_name) 41 | for i, img_name in enumerate(os.listdir(cur_folder)): 42 | with open(os.path.join(cur_folder, img_name), "rb") as fid_img: 43 | binary_data = fid_img.read() 44 | data = np.frombuffer(binary_data, dtype="uint8") 45 | class_group.create_dataset( 46 | img_name, 47 | data=data, 48 | shape=data.shape, 49 | compression="gzip", 50 | compression_opts=9, 51 | ) 52 | 53 | 54 | if __name__ == "__main__": 55 | parser = argparse.ArgumentParser() 56 | parser.add_argument("--folder_path", type=str, required=True) 57 | parser.add_argument("--h5_path", type=str, required=True) 58 | args = parser.parse_args() 59 | convert_imgfolder_to_h5(args.folder_path, args.h5_path) 60 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from setuptools import find_packages, setup 21 | 22 | KW = ["artificial intelligence", "deep learning", "unsupervised learning", "contrastive learning"] 23 | 24 | 25 | EXTRA_REQUIREMENTS = { 26 | "dali": ["nvidia-dali-cuda110"], 27 | "umap": ["matplotlib", "seaborn", "pandas", "umap-learn"], 28 | "h5": ["h5py"], 29 | } 30 | 31 | 32 | def parse_requirements(path): 33 | with open(path) as f: 34 | requirements = [p.strip().split()[-1] for p in f.readlines()] 35 | return requirements 36 | 37 | 38 | setup( 39 | name="solo-learn", 40 | packages=find_packages(exclude=["bash_files", "docs", "downstream", "tests", "zoo"]), 41 | version="1.0.3", 42 | license="MIT", 43 | author="solo-learn development team", 44 | author_email="vturrisi@gmail.com, enrico.fini@gmail.com", 45 | url="https://github.com/vturrisi/solo-learn", 46 | keywords=KW, 47 | install_requires=[ 48 | "torch>=1.10.0", 49 | "torchvision>=0.11.1", 50 | "einops", 51 | "pytorch-lightning==1.6.4", 52 | "torchmetrics==0.6.0", 53 | "lightning-bolts>=0.5.0", 54 | "tqdm", 55 | "wandb", 56 | "scipy", 57 | "timm", 58 | "scikit-learn", 59 | ], 60 | extras_require=EXTRA_REQUIREMENTS, 61 | dependency_links=["https://developer.download.nvidia.com/compute/redist"], 62 | classifiers=[ 63 | "Programming Language :: Python :: 3.8", 64 | "License :: OSI Approved :: MIT License", 65 | "Operating System :: OS Independent", 66 | ], 67 | include_package_data=True, 68 | zip_safe=False, 69 | ) 70 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo import args, backbones, data, losses, methods, utils 22 | 23 | __all__ = ["args", "backbones", "data", "losses", "methods", "utils"] 24 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.args import dataset, setup, utils 21 | 22 | __all__ = ["dataset", "setup", "utils"] 23 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/dataset.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/setup.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/setup.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/args/__pycache__/utils.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from argparse import ArgumentParser 21 | from pathlib import Path 22 | 23 | from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD 24 | 25 | 26 | def dataset_args(parser: ArgumentParser): 27 | """Adds dataset-related arguments to a parser. 28 | 29 | Args: 30 | parser (ArgumentParser): parser to add dataset args to. 31 | """ 32 | 33 | SUPPORTED_DATASETS = [ 34 | "cifar10", 35 | "cifar100", 36 | "stl10", 37 | "imagenet", 38 | "imagenet100", 39 | "custom", 40 | ] 41 | 42 | parser.add_argument("--dataset", choices=SUPPORTED_DATASETS, type=str, required=True) 43 | 44 | # dataset path 45 | parser.add_argument("--train_data_path", type=Path, required=True) 46 | parser.add_argument("--val_data_path", type=Path, default=None) 47 | parser.add_argument( 48 | "--data_format", default="image_folder", choices=["image_folder", "dali", "h5"] 49 | ) 50 | 51 | # percentage of data used from training, leave -1.0 to use all data available 52 | parser.add_argument("--data_fraction", default=-1.0, type=float) 53 | 54 | 55 | def augmentations_args(parser: ArgumentParser): 56 | """Adds augmentation-related arguments to a parser. 57 | 58 | Args: 59 | parser (ArgumentParser): parser to add augmentation args to. 60 | """ 61 | 62 | # cropping 63 | parser.add_argument("--num_crops_per_aug", type=int, default=[2], nargs="+") 64 | 65 | # color jitter 66 | parser.add_argument("--brightness", type=float, required=True, nargs="+") 67 | parser.add_argument("--contrast", type=float, required=True, nargs="+") 68 | parser.add_argument("--saturation", type=float, required=True, nargs="+") 69 | parser.add_argument("--hue", type=float, required=True, nargs="+") 70 | parser.add_argument("--color_jitter_prob", type=float, default=[0.8], nargs="+") 71 | 72 | # other augmentation probabilities 73 | parser.add_argument("--gray_scale_prob", type=float, default=[0.2], nargs="+") 74 | parser.add_argument("--horizontal_flip_prob", type=float, default=[0.5], nargs="+") 75 | parser.add_argument("--gaussian_prob", type=float, default=[0.5], nargs="+") 76 | parser.add_argument("--solarization_prob", type=float, default=[0.0], nargs="+") 77 | parser.add_argument("--equalization_prob", type=float, default=[0.0], nargs="+") 78 | 79 | # cropping 80 | parser.add_argument("--crop_size", type=int, default=[224], nargs="+") 81 | parser.add_argument("--min_scale", type=float, default=[0.08], nargs="+") 82 | parser.add_argument("--max_scale", type=float, default=[1.0], nargs="+") 83 | 84 | # debug 85 | parser.add_argument("--debug_augmentations", action="store_true") 86 | 87 | 88 | def linear_augmentations_args(parser: ArgumentParser): 89 | parser.add_argument("--crop_size", type=int, default=[224], nargs="+") 90 | 91 | 92 | def custom_dataset_args(parser: ArgumentParser): 93 | """Adds custom data-related arguments to a parser. 94 | 95 | Args: 96 | parser (ArgumentParser): parser to add augmentation args to. 97 | """ 98 | 99 | # custom dataset only 100 | parser.add_argument("--no_labels", action="store_true") 101 | 102 | # for custom dataset 103 | parser.add_argument("--mean", type=float, default=IMAGENET_DEFAULT_MEAN, nargs="+") 104 | parser.add_argument("--std", type=float, default=IMAGENET_DEFAULT_STD, nargs="+") 105 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/args/setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | 22 | import pytorch_lightning as pl 23 | from solo.args.dataset import ( 24 | augmentations_args, 25 | custom_dataset_args, 26 | dataset_args, 27 | linear_augmentations_args, 28 | ) 29 | from solo.args.utils import additional_setup_linear, additional_setup_pretrain 30 | from solo.methods import METHODS 31 | from solo.utils.auto_resumer import AutoResumer 32 | from solo.utils.checkpointer import Checkpointer 33 | 34 | try: 35 | from solo.utils.auto_umap import AutoUMAP 36 | except ImportError: 37 | _umap_available = False 38 | else: 39 | _umap_available = True 40 | 41 | try: 42 | from solo.data.dali_dataloader import ClassificationDALIDataModule, PretrainDALIDataModule 43 | except ImportError: 44 | _dali_available = False 45 | else: 46 | _dali_available = True 47 | 48 | 49 | def parse_args_pretrain() -> argparse.Namespace: 50 | """Parses dataset, augmentation, pytorch lightning, model specific and additional args. 51 | 52 | First adds shared args such as dataset, augmentation and pytorch lightning args, then pulls the 53 | model name from the command and proceeds to add model specific args from the desired class. If 54 | wandb is enabled, it adds checkpointer args. Finally, adds additional non-user given parameters. 55 | 56 | Returns: 57 | argparse.Namespace: a namespace containing all args needed for pretraining. 58 | """ 59 | 60 | parser = argparse.ArgumentParser() 61 | 62 | # add shared arguments 63 | dataset_args(parser) 64 | augmentations_args(parser) 65 | custom_dataset_args(parser) 66 | 67 | # add pytorch lightning trainer args 68 | parser = pl.Trainer.add_argparse_args(parser) 69 | 70 | # add method-specific arguments 71 | parser.add_argument("--method", type=str) 72 | 73 | # THIS LINE IS KEY TO PULL THE MODEL NAME 74 | temp_args, _ = parser.parse_known_args() 75 | 76 | # add model specific args 77 | parser = METHODS[temp_args.method].add_model_specific_args(parser) 78 | 79 | # add auto checkpoint/umap args 80 | parser.add_argument("--save_checkpoint", action="store_true") 81 | parser.add_argument("--auto_umap", action="store_true") 82 | parser.add_argument("--auto_resume", action="store_true") 83 | temp_args, _ = parser.parse_known_args() 84 | 85 | # optionally add checkpointer and AutoUMAP args 86 | if temp_args.save_checkpoint: 87 | parser = Checkpointer.add_checkpointer_args(parser) 88 | 89 | if _umap_available and temp_args.auto_umap: 90 | parser = AutoUMAP.add_auto_umap_args(parser) 91 | 92 | if temp_args.auto_resume: 93 | parser = AutoResumer.add_autoresumer_args(parser) 94 | 95 | if _dali_available and temp_args.data_format == "dali": 96 | parser = PretrainDALIDataModule.add_dali_args(parser) 97 | 98 | # parse args 99 | args = parser.parse_args() 100 | 101 | # prepare arguments with additional setup 102 | additional_setup_pretrain(args) 103 | 104 | return args 105 | 106 | 107 | def parse_args_linear() -> argparse.Namespace: 108 | """Parses feature extractor, dataset, pytorch lightning, linear eval specific and additional args. 109 | 110 | First adds an arg for the pretrained feature extractor, then adds dataset, pytorch lightning 111 | and linear eval specific args. If wandb is enabled, it adds checkpointer args. Finally, adds 112 | additional non-user given parameters. 113 | 114 | Returns: 115 | argparse.Namespace: a namespace containing all args needed for pretraining. 116 | """ 117 | 118 | parser = argparse.ArgumentParser() 119 | 120 | parser.add_argument("--pretrained_feature_extractor", type=str, required=True) 121 | parser.add_argument("--pretrain_method", type=str, default=None) 122 | # add support for auto augment (randaug by default) 123 | parser.add_argument("--auto_augment", action="store_true") 124 | # for mixup, cutmix and drop_path (for transformers) 125 | parser.add_argument("--label_smoothing", type=float, default=0.0) 126 | parser.add_argument("--mixup", type=float, default=0.0) 127 | parser.add_argument("--cutmix", type=float, default=0.0) 128 | 129 | # just for transformers 130 | parser.add_argument("--drop_path", type=float, default=None) 131 | # type of pooling to use, either cls token or global average 132 | parser.add_argument("--global_pool", type=str, default="token", choices=["token", "avg"]) 133 | # different weight decay for each layer (using timm.optim.optim_factory) 134 | parser.add_argument("--layer_decay", type=float, default=0.0) 135 | 136 | # add shared arguments 137 | dataset_args(parser) 138 | linear_augmentations_args(parser) 139 | custom_dataset_args(parser) 140 | 141 | # add pytorch lightning trainer args 142 | parser = pl.Trainer.add_argparse_args(parser) 143 | 144 | # linear model 145 | parser = METHODS["linear"].add_model_specific_args(parser) 146 | 147 | # THIS LINE IS KEY TO PULL WANDB AND SAVE_CHECKPOINT 148 | parser.add_argument("--save_checkpoint", action="store_true") 149 | parser.add_argument("--auto_resume", action="store_true") 150 | temp_args, _ = parser.parse_known_args() 151 | 152 | # optionally add checkpointer 153 | if temp_args.save_checkpoint: 154 | parser = Checkpointer.add_checkpointer_args(parser) 155 | 156 | if temp_args.auto_resume: 157 | parser = AutoResumer.add_autoresumer_args(parser) 158 | 159 | if _dali_available and temp_args.data_format == "dali": 160 | parser = ClassificationDALIDataModule.add_dali_args(parser) 161 | 162 | # parse args 163 | args = parser.parse_args() 164 | additional_setup_linear(args) 165 | 166 | return args 167 | 168 | 169 | def parse_args_knn() -> argparse.Namespace: 170 | """Parses arguments for offline K-NN. 171 | 172 | Returns: 173 | argparse.Namespace: a namespace containing all args needed for pretraining. 174 | """ 175 | 176 | parser = argparse.ArgumentParser() 177 | 178 | # add knn args 179 | parser.add_argument("--pretrained_checkpoint_dir", type=str) 180 | parser.add_argument("--batch_size", type=int, default=16) 181 | parser.add_argument("--num_workers", type=int, default=10) 182 | parser.add_argument("--k", type=int, nargs="+") 183 | parser.add_argument("--temperature", type=float, nargs="+") 184 | parser.add_argument("--distance_function", type=str, nargs="+") 185 | parser.add_argument("--feature_type", type=str, nargs="+") 186 | 187 | # add shared arguments 188 | dataset_args(parser) 189 | custom_dataset_args(parser) 190 | 191 | # parse args 192 | args = parser.parse_args() 193 | 194 | return args 195 | 196 | 197 | def parse_args_umap() -> argparse.Namespace: 198 | """Parses arguments for offline UMAP. 199 | 200 | Returns: 201 | argparse.Namespace: a namespace containing all args needed for pretraining. 202 | """ 203 | 204 | parser = argparse.ArgumentParser() 205 | 206 | # add knn args 207 | parser.add_argument("--pretrained_checkpoint_dir", type=str) 208 | parser.add_argument("--batch_size", type=int, default=16) 209 | parser.add_argument("--num_workers", type=int, default=10) 210 | 211 | # add shared arguments 212 | dataset_args(parser) 213 | custom_dataset_args(parser) 214 | 215 | # parse args 216 | args = parser.parse_args() 217 | 218 | return args 219 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from .convnext import convnext_tiny, convnext_small, convnext_base, convnext_large 22 | from .poolformer import ( 23 | poolformer_s12, 24 | poolformer_s24, 25 | poolformer_s36, 26 | poolformer_m36, 27 | poolformer_m48, 28 | ) 29 | from .resnet import resnet18, resnet50 30 | from .swin import swin_tiny, swin_small, swin_base, swin_large 31 | from .vit import vit_tiny, vit_small, vit_base, vit_large 32 | from .wide_resnet import wide_resnet28w2, wide_resnet28w8 33 | 34 | __all__ = [ 35 | "resnet18", 36 | "resnet50", 37 | "vit_tiny", 38 | "vit_small", 39 | "vit_base", 40 | "vit_large", 41 | "swin_tiny", 42 | "swin_small", 43 | "swin_base", 44 | "swin_large", 45 | "poolformer_s12", 46 | "poolformer_s24", 47 | "poolformer_s36", 48 | "poolformer_m36", 49 | "poolformer_m48", 50 | "convnext_tiny", 51 | "convnext_small", 52 | "convnext_base", 53 | "convnext_large", 54 | "wide_resnet28w2", 55 | "wide_resnet28w8", 56 | ] 57 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/resnet/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from .resnet import resnet18 as default_resnet18 21 | from .resnet import resnet50 as default_resnet50 22 | 23 | 24 | def resnet18(method, *args, **kwargs): 25 | return default_resnet18(*args, **kwargs) 26 | 27 | 28 | def resnet50(method, *args, **kwargs): 29 | return default_resnet50(*args, **kwargs) 30 | 31 | 32 | __all__ = ["resnet18", "resnet50"] 33 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/resnet/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/resnet/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/resnet/__pycache__/resnet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/resnet/__pycache__/resnet.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/backbones/resnet/resnet.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from torchvision.models import resnet18 21 | from torchvision.models import resnet50 22 | 23 | __all__ = ["resnet18", "resnet50"] 24 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo.data import classification_dataloader, h5_dataset, pretrain_dataloader 22 | 23 | __all__ = [ 24 | "classification_dataloader", 25 | "h5_dataset", 26 | "pretrain_dataloader", 27 | ] 28 | 29 | 30 | try: 31 | from solo.data import dali_dataloader # noqa: F401 32 | except ImportError: 33 | pass 34 | else: 35 | __all__.append("dali_dataloader") 36 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/classification_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/classification_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/dali_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/dali_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/h5_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/h5_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/pretrain_dataloader.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/data/__pycache__/pretrain_dataloader.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/dataset_subset/imagenet100_classes.txt: -------------------------------------------------------------------------------- 1 | n02869837 n01749939 n02488291 n02107142 n13037406 n02091831 n04517823 n04589890 n03062245 n01773797 n01735189 n07831146 n07753275 n03085013 n04485082 n02105505 n01983481 n02788148 n03530642 n04435653 n02086910 n02859443 n13040303 n03594734 n02085620 n02099849 n01558993 n04493381 n02109047 n04111531 n02877765 n04429376 n02009229 n01978455 n02106550 n01820546 n01692333 n07714571 n02974003 n02114855 n03785016 n03764736 n03775546 n02087046 n07836838 n04099969 n04592741 n03891251 n02701002 n03379051 n02259212 n07715103 n03947888 n04026417 n02326432 n03637318 n01980166 n02113799 n02086240 n03903868 n02483362 n04127249 n02089973 n03017168 n02093428 n02804414 n02396427 n04418357 n02172182 n01729322 n02113978 n03787032 n02089867 n02119022 n03777754 n04238763 n02231487 n03032252 n02138441 n02104029 n03837869 n03494278 n04136333 n03794056 n03492542 n02018207 n04067472 n03930630 n03584829 n02123045 n04229816 n02100583 n03642806 n04336792 n03259280 n02116738 n02108089 n03424325 n01855672 n02090622 -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/data/h5_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | import io 22 | import os 23 | import logging 24 | from pathlib import Path 25 | from typing import Callable, Optional 26 | 27 | import h5py 28 | from PIL import Image 29 | from torch.utils.data import Dataset 30 | from tqdm import tqdm 31 | 32 | 33 | class H5Dataset(Dataset): 34 | def __init__( 35 | self, 36 | dataset: str, 37 | h5_path: str, 38 | transform: Optional[Callable] = None, 39 | ): 40 | """H5 Dataset. 41 | The dataset assumes that data is organized as: 42 | "class_name" 43 | "img_name" 44 | "img_name" 45 | "img_name" 46 | "class_name" 47 | "img_name" 48 | "img_name" 49 | "img_name" 50 | 51 | Args: 52 | dataset (str): dataset name. 53 | h5_path (str): path of the h5 file. 54 | transform (Callable): pipeline of transformations. Defaults to None. 55 | """ 56 | 57 | self.h5_path = h5_path 58 | self.h5_file = None 59 | self.transform = transform 60 | 61 | assert dataset in ["imagenet100", "imagenet"] 62 | 63 | self._load_h5_data_info() 64 | 65 | # filter if needed to avoid having a copy of imagenet100 data 66 | if dataset == "imagenet100": 67 | script_folder = Path(os.path.dirname(__file__)) 68 | classes_file = script_folder / "dataset_subset" / "imagenet100_classes.txt" 69 | with open(classes_file, "r") as f: 70 | self.classes = f.readline().strip().split() 71 | self.classes = sorted(self.classes) 72 | self.class_to_idx = {cls_name: i for i, cls_name in enumerate(self.classes)} 73 | 74 | class_set = set(self.classes) 75 | new_data = [] 76 | for class_name, img_name, _ in self._data: 77 | if class_name in class_set: 78 | new_data.append((class_name, img_name, self.class_to_idx[class_name])) 79 | if not new_data: 80 | logging.warn( 81 | "Skipped filtering. Tried to filter classes for imagenet100, " 82 | "but wasn't able to do so. Either make sure that you do not " 83 | "rely on the filtering, i.e. your h5 file is already filtered " 84 | "or make sure the class names are the default ones." 85 | ) 86 | else: 87 | self._data = new_data 88 | 89 | def _load_h5_data_info(self): 90 | self._data = [] 91 | h5_data_info_file = os.path.join( 92 | os.path.expanduser("~"), os.path.basename(os.path.splitext(self.h5_path)[0]) + ".txt" 93 | ) 94 | if not os.path.isfile(h5_data_info_file): 95 | temp_h5_file = h5py.File(self.h5_path, "r") 96 | 97 | # collect data from the h5 file directly 98 | self.classes, self.class_to_idx = self._find_classes(temp_h5_file) 99 | for class_name in tqdm(self.classes, desc="Collecting information about the h5 file"): 100 | y = self.class_to_idx[class_name] 101 | for img_name in temp_h5_file[class_name].keys(): 102 | self._data.append((class_name, img_name, int(y))) 103 | 104 | # save the info locally to speed up sequential executions 105 | with open(h5_data_info_file, "w") as f: 106 | for class_name, img_name, y in self._data: 107 | f.write(f"{class_name}/{img_name} {y}\n") 108 | else: 109 | # load data info file that was already generated by previous runs 110 | with open(h5_data_info_file, "r") as f: 111 | for line in f: 112 | class_name_img, y = line.strip().split(" ") 113 | class_name, img_name = class_name_img.split("/") 114 | self._data.append((class_name, img_name, int(y))) 115 | 116 | def _find_classes(self, h5_file: h5py.File): 117 | classes = sorted(h5_file.keys()) 118 | class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} 119 | return classes, class_to_idx 120 | 121 | def _load_img(self, class_name: str, img: str): 122 | img = self.h5_file[class_name][img][:] 123 | img = Image.open(io.BytesIO(img)).convert("RGB") 124 | return img 125 | 126 | def __getitem__(self, index: int): 127 | if self.h5_file is None: 128 | self.h5_file = h5py.File(self.h5_path, "r") 129 | 130 | class_name, img, y = self._data[index] 131 | 132 | x = self._load_img(class_name, img) 133 | if self.transform: 134 | x = self.transform(x) 135 | 136 | return x, y 137 | 138 | def __len__(self): 139 | return len(self._data) 140 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from solo.losses.simsiam import simsiam_loss_func 22 | 23 | 24 | __all__ = [ 25 | "barlow_loss_func", 26 | "byol_loss_func", 27 | "deepclusterv2_loss_func", 28 | "DINOLoss", 29 | "mae_loss_func", 30 | "mocov2plus_loss_func", 31 | "mocov3_loss_func", 32 | "nnclr_loss_func", 33 | "ressl_loss_func", 34 | "simclr_loss_func", 35 | "simsiam_loss_func", 36 | "swav_loss_func", 37 | "vibcreg_loss_func", 38 | "vicreg_loss_func", 39 | "wmse_loss_func", 40 | ] 41 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/barlow.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/barlow.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/byol.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/byol.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/deepclusterv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/deepclusterv2.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/dino.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/dino.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/mae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/mae.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/mocov2plus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/mocov2plus.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/mocov3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/mocov3.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/nnclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/nnclr.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/ressl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/ressl.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/simclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/simclr.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/simsiam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/simsiam.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/swav.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/swav.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/vibcreg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/vibcreg.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/vicreg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/vicreg.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/wmse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/__pycache__/wmse.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/losses/simsiam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import torch 21 | import torch.nn.functional as F 22 | 23 | 24 | def simsiam_loss_func(p: torch.Tensor, z: torch.Tensor, simplified: bool = True) -> torch.Tensor: 25 | """Computes SimSiam's loss given batch of predicted features p from view 1 and 26 | a batch of projected features z from view 2. 27 | 28 | Args: 29 | p (torch.Tensor): Tensor containing predicted features from view 1. 30 | z (torch.Tensor): Tensor containing projected features from view 2. 31 | simplified (bool): faster computation, but with same result. 32 | 33 | Returns: 34 | torch.Tensor: SimSiam loss. 35 | """ 36 | 37 | if simplified: 38 | return -F.cosine_similarity(p, z.detach(), dim=-1).mean() 39 | 40 | p = F.normalize(p, dim=-1) 41 | z = F.normalize(z, dim=-1) 42 | 43 | return -(p * z.detach()).sum(dim=1).mean() 44 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.methods.base import BaseMethod 21 | from solo.methods.linear import LinearModel 22 | from solo.methods.simsiam import SimSiam 23 | 24 | METHODS = { 25 | # base classes 26 | "base": BaseMethod, 27 | "linear": LinearModel, 28 | # methods 29 | "simsiam": SimSiam, 30 | } 31 | __all__ = [ 32 | "BarlowTwins", 33 | "BYOL", 34 | "BaseMethod", 35 | "DeepClusterV2", 36 | "DINO", 37 | "MAE", 38 | "LinearModel", 39 | "MoCoV2Plus", 40 | "MoCoV3", 41 | "NNBYOL", 42 | "NNCLR", 43 | "NNSiam", 44 | "ReSSL", 45 | "SimCLR", 46 | "SimSiam", 47 | "SupCon", 48 | "SwAV", 49 | "VIbCReg", 50 | "VICReg", 51 | "WMSE", 52 | ] 53 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/barlow_twins.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/barlow_twins.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/base.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/base.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/byol.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/byol.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/deepclusterv2.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/deepclusterv2.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/dino.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/dino.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/linear.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/linear.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/mae.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/mae.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/mocov2plus.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/mocov2plus.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/mocov3.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/mocov3.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/nnbyol.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/nnbyol.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/nnclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/nnclr.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/nnsiam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/nnsiam.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/ressl.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/ressl.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/simclr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/simclr.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/simsiam.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/simsiam.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/supcon.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/supcon.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/swav.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/swav.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/vibcreg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/vibcreg.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/vicreg.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/vicreg.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/wmse.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/__pycache__/wmse.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/simsiam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import argparse 21 | from typing import Any, Dict, List, Sequence 22 | 23 | import torch 24 | import torch.nn as nn 25 | import torch.nn.functional as F 26 | from solo.losses.simsiam import simsiam_loss_func 27 | from solo.methods.base import BaseMethod 28 | import copy 29 | 30 | class SimSiam(BaseMethod): 31 | def __init__( 32 | self, 33 | proj_output_dim: int, 34 | proj_hidden_dim: int, 35 | pred_hidden_dim: int, 36 | **kwargs, 37 | ): 38 | """Implements SimSiam (https://arxiv.org/abs/2011.10566). 39 | 40 | Args: 41 | proj_output_dim (int): number of dimensions of projected features. 42 | proj_hidden_dim (int): number of neurons of the hidden layers of the projector. 43 | pred_hidden_dim (int): number of neurons of the hidden layers of the predictor. 44 | """ 45 | 46 | super().__init__(**kwargs) 47 | 48 | # projector 49 | self.projector = nn.Sequential( 50 | nn.Linear(self.features_dim, proj_hidden_dim, bias=False), 51 | nn.BatchNorm1d(proj_hidden_dim), 52 | nn.ReLU(), 53 | nn.Linear(proj_hidden_dim, proj_hidden_dim, bias=False), 54 | nn.BatchNorm1d(proj_hidden_dim), 55 | nn.ReLU(), 56 | nn.Linear(proj_hidden_dim, proj_output_dim), 57 | nn.BatchNorm1d(proj_output_dim, affine=False), 58 | ) 59 | self.projector[6].bias.requires_grad = False # hack: not use bias as it is followed by BN 60 | 61 | # predictor 62 | self.predictor = nn.Sequential( 63 | nn.Linear(proj_output_dim, pred_hidden_dim, bias=False), 64 | nn.BatchNorm1d(pred_hidden_dim), 65 | nn.ReLU(), 66 | nn.Linear(pred_hidden_dim, proj_output_dim), 67 | ) 68 | 69 | #queue 70 | self.register_buffer("agg_queue",torch.randn(130000,2048)) 71 | self.agg_queue = nn.functional.normalize(self.agg_queue,dim=1) 72 | self.register_buffer("temp",torch.randn(130000,2048)) 73 | self.agg_queue = nn.functional.normalize(self.temp,dim=1) 74 | 75 | 76 | 77 | self.epoch_counter=0 78 | @staticmethod 79 | def add_model_specific_args(parent_parser: argparse.ArgumentParser) -> argparse.ArgumentParser: 80 | parent_parser = super(SimSiam, SimSiam).add_model_specific_args(parent_parser) 81 | parser = parent_parser.add_argument_group("simsiam") 82 | 83 | # projector 84 | parser.add_argument("--proj_output_dim", type=int, default=128) 85 | parser.add_argument("--proj_hidden_dim", type=int, default=2048) 86 | 87 | # predictor 88 | parser.add_argument("--pred_hidden_dim", type=int, default=512) 89 | return parent_parser 90 | 91 | @property 92 | def learnable_params(self) -> List[dict]: 93 | """Adds projector and predictor parameters to the parent's learnable parameters. 94 | 95 | Returns: 96 | List[dict]: list of learnable parameters. 97 | """ 98 | 99 | extra_learnable_params: List[dict] = [ 100 | {"params": self.projector.parameters()}, 101 | {"params": self.predictor.parameters(), "static_lr": True}, 102 | ] 103 | return super().learnable_params + extra_learnable_params 104 | 105 | def forward(self, X: torch.Tensor) -> Dict[str, Any]: 106 | """Performs the forward pass of the backbone, the projector and the predictor. 107 | 108 | Args: 109 | X (torch.Tensor): a batch of images in the tensor format. 110 | 111 | Returns: 112 | Dict[str, Any]: 113 | a dict containing the outputs of the parent 114 | and the projected and predicted features. 115 | """ 116 | 117 | out = super().forward(X) 118 | z = self.projector(out["feats"]) 119 | p = self.predictor(z) 120 | out.update({"z": z, "p": p}) 121 | return out 122 | 123 | def training_step(self, batch: Sequence[Any], batch_idx: int) -> torch.Tensor: 124 | """Training step for SimSiam reusing BaseMethod training step. 125 | 126 | Args: 127 | batch (Sequence[Any]): a batch of data in the format of [img_indexes, [X], Y], where 128 | [X] is a list of size num_crops containing batches of images. 129 | batch_idx (int): index of the batch. 130 | 131 | Returns: 132 | torch.Tensor: total loss composed of SimSiam loss and classification loss. 133 | """ 134 | 135 | out = super().training_step(batch, batch_idx) 136 | index,_,label=batch 137 | class_loss = out["loss"] 138 | z1, z2 = out["z"] 139 | p1, p2 = out["p"] 140 | 141 | 142 | z3 = z2.clone() 143 | z4 = z1.clone() 144 | for i in range(len(z1)): 145 | self.temp[index[i],:] = z3[i].detach() 146 | z3[i] = z3[i]*0.5 + self.agg_queue[index[i],:]*0.5 147 | z4[i] = z4[i]*0.5 + self.agg_queue[index[i],:]*0.5 148 | 149 | 150 | 151 | 152 | # ------- negative cosine similarity loss ------- 153 | if self.epoch_counter<=50: 154 | neg_cos_sim = simsiam_loss_func(p1, z2) / 2 + simsiam_loss_func(p2, z1) / 2 155 | else: 156 | neg_cos_sim = simsiam_loss_func(p1, z3) / 2 + simsiam_loss_func(p2, z4) / 2 157 | # calculate std of features 158 | z1_std = F.normalize(z1, dim=-1).std(dim=0).mean() 159 | z2_std = F.normalize(z2, dim=-1).std(dim=0).mean() 160 | z_std = (z1_std + z2_std) / 2 161 | 162 | metrics = { 163 | "train_neg_cos_sim": neg_cos_sim, 164 | "train_z_std": z_std, 165 | } 166 | self.log_dict(metrics, on_epoch=True, sync_dist=True) 167 | 168 | return neg_cos_sim + class_loss 169 | def training_epoch_end(self,outputs): 170 | self.agg_queue = copy.deepcopy(self.temp) 171 | self.epoch_counter+=1 172 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from solo.utils import ( 21 | checkpointer, 22 | knn, 23 | lars, 24 | metrics, 25 | misc, 26 | momentum, 27 | sinkhorn_knopp, 28 | ) 29 | 30 | __all__ = [ 31 | "checkpointer", 32 | "knn", 33 | "misc", 34 | "lars", 35 | "metrics", 36 | "momentum", 37 | "sinkhorn_knopp", 38 | ] 39 | 40 | 41 | try: 42 | from solo.utils import auto_umap # noqa: F401 43 | except ImportError: 44 | pass 45 | else: 46 | __all__.append("auto_umap") 47 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/auto_resumer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/auto_resumer.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/auto_umap.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/auto_umap.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/checkpointer.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/checkpointer.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/kmeans.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/kmeans.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/knn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/knn.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/lars.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/lars.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/misc.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/misc.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/momentum.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/momentum.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/sinkhorn_knopp.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/sinkhorn_knopp.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/whitening.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/PKU-ML/Message-Passing-Contrastive-Learning/6e3f1c82889ef668df0a3286b292daaf7b07be65/MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/__pycache__/whitening.cpython-38.pyc -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/auto_resumer.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from argparse import ArgumentParser, Namespace 4 | from collections import namedtuple 5 | from datetime import datetime, timedelta 6 | from pathlib import Path 7 | from typing import Union 8 | 9 | Checkpoint = namedtuple("Checkpoint", ["creation_time", "args", "checkpoint"]) 10 | 11 | 12 | class AutoResumer: 13 | SHOULD_MATCH = [ 14 | "batch_size", 15 | "weight_decay", 16 | "lr", 17 | "dataset", 18 | "backbone", 19 | "max_epochs", 20 | "method", 21 | "name", 22 | "project", 23 | "entity", 24 | "pretrained_feature_extractor", 25 | ] 26 | 27 | def __init__( 28 | self, 29 | checkpoint_dir: Union[str, Path] = Path("trained_models"), 30 | max_hours: int = 36, 31 | ): 32 | """Autoresumer object that automatically tries to find a checkpoint 33 | that is as old as max_time. 34 | 35 | Args: 36 | checkpoint_dir (Union[str, Path], optional): base directory to store checkpoints. 37 | Defaults to "trained_models". 38 | max_hours (int): maximum elapsed hours to consider checkpoint as valid. 39 | """ 40 | 41 | self.checkpoint_dir = checkpoint_dir 42 | self.max_hours = timedelta(hours=max_hours) 43 | 44 | @staticmethod 45 | def add_autoresumer_args(parent_parser: ArgumentParser): 46 | """Adds user-required arguments to a parser. 47 | 48 | Args: 49 | parent_parser (ArgumentParser): parser to add new args to. 50 | """ 51 | 52 | parser = parent_parser.add_argument_group("autoresumer") 53 | parser.add_argument("--auto_resumer_max_hours", default=36, type=int) 54 | return parent_parser 55 | 56 | def find_checkpoint(self, args: Namespace): 57 | """Finds a valid checkpoint that matches the arguments 58 | 59 | Args: 60 | args (Namespace): namespace object containing all settings of the model. 61 | """ 62 | 63 | current_time = datetime.now() 64 | 65 | candidates = [] 66 | for rootdir, _, files in os.walk(self.checkpoint_dir): 67 | rootdir = Path(rootdir) 68 | if files: 69 | # skip checkpoints that are empty 70 | try: 71 | checkpoint_file = [rootdir / f for f in files if f.endswith(".ckpt")][0] 72 | except: 73 | continue 74 | 75 | creation_time = datetime.fromtimestamp(os.path.getctime(checkpoint_file)) 76 | if current_time - creation_time < self.max_hours: 77 | ck = Checkpoint( 78 | creation_time=creation_time, 79 | args=rootdir / "args.json", 80 | checkpoint=checkpoint_file, 81 | ) 82 | candidates.append(ck) 83 | 84 | if candidates: 85 | # sort by most recent 86 | candidates = sorted(candidates, key=lambda ck: ck.creation_time, reverse=True) 87 | 88 | for candidate in candidates: 89 | candidate_args = Namespace(**json.load(open(candidate.args))) 90 | if all( 91 | getattr(candidate_args, param, None) == getattr(args, param, None) 92 | for param in AutoResumer.SHOULD_MATCH 93 | ): 94 | wandb_run_id = getattr(candidate_args, "wandb_run_id", None) 95 | return candidate.checkpoint, wandb_run_id 96 | 97 | return None, None 98 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/checkpointer.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import json 21 | import os 22 | import random 23 | import string 24 | import time 25 | from argparse import ArgumentParser, Namespace 26 | from pathlib import Path 27 | from typing import Optional, Union 28 | 29 | import pytorch_lightning as pl 30 | from pytorch_lightning.callbacks import Callback 31 | 32 | 33 | def random_string(letter_count=4, digit_count=4): 34 | tmp_random = random.Random(time.time()) 35 | rand_str = "".join((tmp_random.choice(string.ascii_lowercase) for x in range(letter_count))) 36 | rand_str += "".join((tmp_random.choice(string.digits) for x in range(digit_count))) 37 | rand_str = list(rand_str) 38 | tmp_random.shuffle(rand_str) 39 | return "".join(rand_str) 40 | 41 | 42 | class Checkpointer(Callback): 43 | def __init__( 44 | self, 45 | args: Namespace, 46 | logdir: Union[str, Path] = Path("trained_models"), 47 | frequency: int = 1, 48 | keep_previous_checkpoints: bool = False, 49 | ): 50 | """Custom checkpointer callback that stores checkpoints in an easier to access way. 51 | 52 | Args: 53 | args (Namespace): namespace object containing at least an attribute name. 54 | logdir (Union[str, Path], optional): base directory to store checkpoints. 55 | Defaults to "trained_models". 56 | frequency (int, optional): number of epochs between each checkpoint. Defaults to 1. 57 | keep_previous_checkpoints (bool, optional): whether to keep previous checkpoints or not. 58 | Defaults to False. 59 | """ 60 | 61 | super().__init__() 62 | 63 | self.args = args 64 | self.logdir = Path(logdir) 65 | self.frequency = frequency 66 | self.keep_previous_checkpoints = keep_previous_checkpoints 67 | 68 | @staticmethod 69 | def add_checkpointer_args(parent_parser: ArgumentParser): 70 | """Adds user-required arguments to a parser. 71 | 72 | Args: 73 | parent_parser (ArgumentParser): parser to add new args to. 74 | """ 75 | 76 | parser = parent_parser.add_argument_group("checkpointer") 77 | parser.add_argument("--checkpoint_dir", default=Path("trained_models"), type=Path) 78 | parser.add_argument("--checkpoint_frequency", default=1, type=int) 79 | return parent_parser 80 | 81 | def initial_setup(self, trainer: pl.Trainer): 82 | """Creates the directories and does the initial setup needed. 83 | 84 | Args: 85 | trainer (pl.Trainer): pytorch lightning trainer object. 86 | """ 87 | 88 | if trainer.logger is None: 89 | if self.logdir.exists(): 90 | existing_versions = set(os.listdir(self.logdir)) 91 | else: 92 | existing_versions = [] 93 | version = "offline-" + random_string() 94 | while version in existing_versions: 95 | version = "offline-" + random_string() 96 | else: 97 | version = str(trainer.logger.version) 98 | self.wandb_run_id = version 99 | if version is not None: 100 | self.path = self.logdir / version 101 | self.ckpt_placeholder = f"{self.args.name}-{version}" + "-ep={}.ckpt" 102 | else: 103 | self.path = self.logdir 104 | self.ckpt_placeholder = f"{self.args.name}" + "-ep={}.ckpt" 105 | self.last_ckpt: Optional[str] = None 106 | 107 | # create logging dirs 108 | if trainer.is_global_zero: 109 | os.makedirs(self.path, exist_ok=True) 110 | 111 | def save_args(self, trainer: pl.Trainer): 112 | """Stores arguments into a json file. 113 | 114 | Args: 115 | trainer (pl.Trainer): pytorch lightning trainer object. 116 | """ 117 | 118 | if trainer.is_global_zero: 119 | args = vars(self.args) 120 | args["wandb_run_id"] = getattr(self, "wandb_run_id", None) 121 | json_path = self.path / "args.json" 122 | json.dump(args, open(json_path, "w"), default=lambda o: "") 123 | 124 | def save(self, trainer: pl.Trainer): 125 | """Saves current checkpoint. 126 | 127 | Args: 128 | trainer (pl.Trainer): pytorch lightning trainer object. 129 | """ 130 | 131 | if trainer.is_global_zero and not trainer.sanity_checking: 132 | epoch = trainer.current_epoch # type: ignore 133 | ckpt = self.path / self.ckpt_placeholder.format(epoch) 134 | trainer.save_checkpoint(ckpt) 135 | 136 | if self.last_ckpt and self.last_ckpt != ckpt and not self.keep_previous_checkpoints: 137 | os.remove(self.last_ckpt) 138 | self.last_ckpt = ckpt 139 | 140 | def on_train_start(self, trainer: pl.Trainer, _): 141 | """Executes initial setup and saves arguments. 142 | 143 | Args: 144 | trainer (pl.Trainer): pytorch lightning trainer object. 145 | """ 146 | 147 | self.initial_setup(trainer) 148 | self.save_args(trainer) 149 | 150 | def on_train_epoch_end(self, trainer: pl.Trainer, _): 151 | """Tries to save current checkpoint at the end of each train epoch. 152 | 153 | Args: 154 | trainer (pl.Trainer): pytorch lightning trainer object. 155 | """ 156 | 157 | epoch = trainer.current_epoch # type: ignore 158 | if epoch % self.frequency == 0: 159 | self.save(trainer) 160 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/kmeans.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Any, Sequence 21 | 22 | import numpy as np 23 | import torch 24 | import torch.distributed as dist 25 | import torch.nn.functional as F 26 | from scipy.sparse import csr_matrix 27 | 28 | 29 | class KMeans: 30 | def __init__( 31 | self, 32 | world_size: int, 33 | rank: int, 34 | num_large_crops: int, 35 | dataset_size: int, 36 | proj_features_dim: int, 37 | num_prototypes: int, 38 | kmeans_iters: int = 10, 39 | ): 40 | """Class that performs K-Means on the hypersphere. 41 | 42 | Args: 43 | world_size (int): world size. 44 | rank (int): rank of the current process. 45 | num_large_crops (int): number of crops. 46 | dataset_size (int): total size of the dataset (number of samples). 47 | proj_features_dim (int): number of dimensions of the projected features. 48 | num_prototypes (int): number of prototypes. 49 | kmeans_iters (int, optional): number of iterations for the k-means clustering. 50 | Defaults to 10. 51 | """ 52 | self.world_size = world_size 53 | self.rank = rank 54 | self.num_large_crops = num_large_crops 55 | self.dataset_size = dataset_size 56 | self.proj_features_dim = proj_features_dim 57 | self.num_prototypes = num_prototypes 58 | self.kmeans_iters = kmeans_iters 59 | 60 | @staticmethod 61 | def get_indices_sparse(data: np.ndarray): 62 | cols = np.arange(data.size) 63 | M = csr_matrix((cols, (data.ravel(), cols)), shape=(int(data.max()) + 1, data.size)) 64 | return [np.unravel_index(row.data, data.shape) for row in M] 65 | 66 | def cluster_memory( 67 | self, 68 | local_memory_index: torch.Tensor, 69 | local_memory_embeddings: torch.Tensor, 70 | ) -> Sequence[Any]: 71 | """Performs K-Means clustering on the hypersphere and returns centroids and 72 | assignments for each sample. 73 | 74 | Args: 75 | local_memory_index (torch.Tensor): memory bank cointaining indices of the 76 | samples. 77 | local_memory_embeddings (torch.Tensor): memory bank cointaining embeddings 78 | of the samples. 79 | 80 | Returns: 81 | Sequence[Any]: assignments and centroids. 82 | """ 83 | j = 0 84 | device = local_memory_embeddings.device 85 | assignments = -torch.ones(len(self.num_prototypes), self.dataset_size).long() 86 | centroids_list = [] 87 | with torch.no_grad(): 88 | for i_K, K in enumerate(self.num_prototypes): 89 | # run distributed k-means 90 | 91 | # init centroids with elements from memory bank of rank 0 92 | centroids = torch.empty(K, self.proj_features_dim).to(device, non_blocking=True) 93 | if self.rank == 0: 94 | random_idx = torch.randperm(len(local_memory_embeddings[j]))[:K] 95 | assert len(random_idx) >= K, "please reduce the number of centroids" 96 | centroids = local_memory_embeddings[j][random_idx] 97 | if dist.is_available() and dist.is_initialized(): 98 | dist.broadcast(centroids, 0) 99 | 100 | for n_iter in range(self.kmeans_iters + 1): 101 | 102 | # E step 103 | dot_products = torch.mm(local_memory_embeddings[j], centroids.t()) 104 | _, local_assignments = dot_products.max(dim=1) 105 | 106 | # finish 107 | if n_iter == self.kmeans_iters: 108 | break 109 | 110 | # M step 111 | where_helper = self.get_indices_sparse(local_assignments.cpu().numpy()) 112 | counts = torch.zeros(K).to(device, non_blocking=True).int() 113 | emb_sums = torch.zeros(K, self.proj_features_dim).to(device, non_blocking=True) 114 | for k in range(len(where_helper)): 115 | if len(where_helper[k][0]) > 0: 116 | emb_sums[k] = torch.sum( 117 | local_memory_embeddings[j][where_helper[k][0]], 118 | dim=0, 119 | ) 120 | counts[k] = len(where_helper[k][0]) 121 | if dist.is_available() and dist.is_initialized(): 122 | dist.all_reduce(counts) 123 | dist.all_reduce(emb_sums) 124 | mask = counts > 0 125 | centroids[mask] = emb_sums[mask] / counts[mask].unsqueeze(1) 126 | 127 | # normalize centroids 128 | centroids = F.normalize(centroids, dim=1, p=2) 129 | 130 | centroids_list.append(centroids) 131 | 132 | if dist.is_available() and dist.is_initialized(): 133 | # gather the assignments 134 | assignments_all = torch.empty( 135 | self.world_size, 136 | local_assignments.size(0), 137 | dtype=local_assignments.dtype, 138 | device=local_assignments.device, 139 | ) 140 | assignments_all = list(assignments_all.unbind(0)) 141 | 142 | dist_process = dist.all_gather( 143 | assignments_all, local_assignments, async_op=True 144 | ) 145 | dist_process.wait() 146 | assignments_all = torch.cat(assignments_all).cpu() 147 | 148 | # gather the indexes 149 | indexes_all = torch.empty( 150 | self.world_size, 151 | local_memory_index.size(0), 152 | dtype=local_memory_index.dtype, 153 | device=local_memory_index.device, 154 | ) 155 | indexes_all = list(indexes_all.unbind(0)) 156 | dist_process = dist.all_gather(indexes_all, local_memory_index, async_op=True) 157 | dist_process.wait() 158 | indexes_all = torch.cat(indexes_all).cpu() 159 | 160 | else: 161 | assignments_all = local_assignments 162 | indexes_all = local_memory_index 163 | 164 | # log assignments 165 | assignments[i_K][indexes_all] = assignments_all 166 | 167 | # next memory bank to use 168 | j = (j + 1) % self.num_large_crops 169 | 170 | return assignments, centroids_list 171 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/knn.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Tuple 21 | 22 | import torch 23 | import torch.nn.functional as F 24 | from torchmetrics.metric import Metric 25 | 26 | 27 | class WeightedKNNClassifier(Metric): 28 | def __init__( 29 | self, 30 | k: int = 20, 31 | T: float = 0.07, 32 | max_distance_matrix_size: int = int(5e6), 33 | distance_fx: str = "cosine", 34 | epsilon: float = 0.00001, 35 | dist_sync_on_step: bool = False, 36 | ): 37 | """Implements the weighted k-NN classifier used for evaluation. 38 | 39 | Args: 40 | k (int, optional): number of neighbors. Defaults to 20. 41 | T (float, optional): temperature for the exponential. Only used with cosine 42 | distance. Defaults to 0.07. 43 | max_distance_matrix_size (int, optional): maximum number of elements in the 44 | distance matrix. Defaults to 5e6. 45 | distance_fx (str, optional): Distance function. Accepted arguments: "cosine" or 46 | "euclidean". Defaults to "cosine". 47 | epsilon (float, optional): Small value for numerical stability. Only used with 48 | euclidean distance. Defaults to 0.00001. 49 | dist_sync_on_step (bool, optional): whether to sync distributed values at every 50 | step. Defaults to False. 51 | """ 52 | 53 | super().__init__(dist_sync_on_step=dist_sync_on_step, compute_on_step=False) 54 | 55 | self.k = k 56 | self.T = T 57 | self.max_distance_matrix_size = max_distance_matrix_size 58 | self.distance_fx = distance_fx 59 | self.epsilon = epsilon 60 | 61 | self.add_state("train_features", default=[], persistent=False) 62 | self.add_state("train_targets", default=[], persistent=False) 63 | self.add_state("test_features", default=[], persistent=False) 64 | self.add_state("test_targets", default=[], persistent=False) 65 | 66 | def update( 67 | self, 68 | train_features: torch.Tensor = None, 69 | train_targets: torch.Tensor = None, 70 | test_features: torch.Tensor = None, 71 | test_targets: torch.Tensor = None, 72 | ): 73 | """Updates the memory banks. If train (test) features are passed as input, the 74 | corresponding train (test) targets must be passed as well. 75 | 76 | Args: 77 | train_features (torch.Tensor, optional): a batch of train features. Defaults to None. 78 | train_targets (torch.Tensor, optional): a batch of train targets. Defaults to None. 79 | test_features (torch.Tensor, optional): a batch of test features. Defaults to None. 80 | test_targets (torch.Tensor, optional): a batch of test targets. Defaults to None. 81 | """ 82 | assert (train_features is None) == (train_targets is None) 83 | assert (test_features is None) == (test_targets is None) 84 | 85 | if train_features is not None: 86 | assert train_features.size(0) == train_targets.size(0) 87 | self.train_features.append(train_features.detach()) 88 | self.train_targets.append(train_targets.detach()) 89 | 90 | if test_features is not None: 91 | assert test_features.size(0) == test_targets.size(0) 92 | self.test_features.append(test_features.detach()) 93 | self.test_targets.append(test_targets.detach()) 94 | 95 | @torch.no_grad() 96 | def compute(self) -> Tuple[float]: 97 | """Computes weighted k-NN accuracy @1 and @5. If cosine distance is selected, 98 | the weight is computed using the exponential of the temperature scaled cosine 99 | distance of the samples. If euclidean distance is selected, the weight corresponds 100 | to the inverse of the euclidean distance. 101 | 102 | Returns: 103 | Tuple[float]: k-NN accuracy @1 and @5. 104 | """ 105 | 106 | train_features = torch.cat(self.train_features) 107 | train_targets = torch.cat(self.train_targets) 108 | test_features = torch.cat(self.test_features) 109 | test_targets = torch.cat(self.test_targets) 110 | 111 | if self.distance_fx == "cosine": 112 | train_features = F.normalize(train_features) 113 | test_features = F.normalize(test_features) 114 | 115 | num_classes = torch.unique(test_targets).numel() 116 | num_train_images = train_targets.size(0) 117 | num_test_images = test_targets.size(0) 118 | num_train_images = train_targets.size(0) 119 | chunk_size = min( 120 | max(1, self.max_distance_matrix_size // num_train_images), 121 | num_test_images, 122 | ) 123 | k = min(self.k, num_train_images) 124 | 125 | top1, top5, total = 0.0, 0.0, 0 126 | retrieval_one_hot = torch.zeros(k, num_classes).to(train_features.device) 127 | for idx in range(0, num_test_images, chunk_size): 128 | # get the features for test images 129 | features = test_features[idx : min((idx + chunk_size), num_test_images), :] 130 | targets = test_targets[idx : min((idx + chunk_size), num_test_images)] 131 | batch_size = targets.size(0) 132 | 133 | # calculate the dot product and compute top-k neighbors 134 | if self.distance_fx == "cosine": 135 | similarities = torch.mm(features, train_features.t()) 136 | elif self.distance_fx == "euclidean": 137 | similarities = 1 / (torch.cdist(features, train_features) + self.epsilon) 138 | else: 139 | raise NotImplementedError 140 | 141 | similarities, indices = similarities.topk(k, largest=True, sorted=True) 142 | candidates = train_targets.view(1, -1).expand(batch_size, -1) 143 | retrieved_neighbors = torch.gather(candidates, 1, indices) 144 | 145 | retrieval_one_hot.resize_(batch_size * k, num_classes).zero_() 146 | retrieval_one_hot.scatter_(1, retrieved_neighbors.view(-1, 1), 1) 147 | 148 | if self.distance_fx == "cosine": 149 | similarities = similarities.clone().div_(self.T).exp_() 150 | 151 | probs = torch.sum( 152 | torch.mul( 153 | retrieval_one_hot.view(batch_size, -1, num_classes), 154 | similarities.view(batch_size, -1, 1), 155 | ), 156 | 1, 157 | ) 158 | _, predictions = probs.sort(1, True) 159 | 160 | # find the predictions that match the target 161 | correct = predictions.eq(targets.data.view(-1, 1)) 162 | top1 = top1 + correct.narrow(1, 0, 1).sum().item() 163 | top5 = ( 164 | top5 + correct.narrow(1, 0, min(5, k, correct.size(-1))).sum().item() 165 | ) # top5 does not make sense if k < 5 166 | total += targets.size(0) 167 | 168 | top1 = top1 * 100.0 / total 169 | top5 = top5 * 100.0 / total 170 | 171 | self.reset() 172 | 173 | return top1, top5 174 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/lars.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Copied from Pytorch Lightning Bolts 21 | # (https://github.com/PyTorchLightning/lightning-bolts/blob/master/pl_bolts/optimizers/lars.py) 22 | 23 | import torch 24 | from torch.optim.optimizer import Optimizer, required 25 | 26 | 27 | class LARS(Optimizer): 28 | """Extends SGD in PyTorch with LARS scaling from the paper 29 | `Large batch training of Convolutional Networks `_. 30 | Args: 31 | params (iterable): iterable of parameters to optimize or dicts defining 32 | parameter groups 33 | lr (float): learning rate 34 | momentum (float, optional): momentum factor (default: 0) 35 | weight_decay (float, optional): weight decay (L2 penalty) (default: 0) 36 | dampening (float, optional): dampening for momentum (default: 0) 37 | nesterov (bool, optional): enables Nesterov momentum (default: False) 38 | eta (float, optional): trust coefficient for computing LR (default: 0.001) 39 | eps (float, optional): eps for division denominator (default: 1e-8) 40 | Example: 41 | >>> model = torch.nn.Linear(10, 1) 42 | >>> input = torch.Tensor(10) 43 | >>> target = torch.Tensor([1.]) 44 | >>> loss_fn = lambda input, target: (input - target) ** 2 45 | >>> # 46 | >>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9) 47 | >>> optimizer.zero_grad() 48 | >>> loss_fn(model(input), target).backward() 49 | >>> optimizer.step() 50 | .. note:: 51 | The application of momentum in the SGD part is modified according to 52 | the PyTorch standards. LARS scaling fits into the equation in the 53 | following fashion. 54 | .. math:: 55 | \begin{aligned} 56 | g_{t+1} & = \text{lars_lr} * (\beta * p_{t} + g_{t+1}), \\ 57 | v_{t+1} & = \\mu * v_{t} + g_{t+1}, \\ 58 | p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, 59 | \\end{aligned} 60 | where :math:`p`, :math:`g`, :math:`v`, :math:`\\mu` and :math:`\beta` denote the 61 | parameters, gradient, velocity, momentum, and weight decay respectively. 62 | The :math:`lars_lr` is defined by Eq. 6 in the paper. 63 | The Nesterov version is analogously modified. 64 | .. warning:: 65 | Parameters with weight decay set to 0 will automatically be excluded from 66 | layer-wise LR scaling. This is to ensure consistency with papers like SimCLR 67 | and BYOL. 68 | """ 69 | 70 | def __init__( 71 | self, 72 | params, 73 | lr=required, 74 | momentum=0, 75 | dampening=0, 76 | weight_decay=0, 77 | nesterov=False, 78 | eta=1e-3, 79 | eps=1e-8, 80 | clip_lars_lr=False, 81 | exclude_bias_n_norm=False, 82 | ): 83 | if lr is not required and lr < 0.0: 84 | raise ValueError(f"Invalid learning rate: {lr}") 85 | if momentum < 0.0: 86 | raise ValueError(f"Invalid momentum value: {momentum}") 87 | if weight_decay < 0.0: 88 | raise ValueError(f"Invalid weight_decay value: {weight_decay}") 89 | 90 | defaults = dict( 91 | lr=lr, 92 | momentum=momentum, 93 | dampening=dampening, 94 | weight_decay=weight_decay, 95 | nesterov=nesterov, 96 | eta=eta, 97 | eps=eps, 98 | clip_lars_lr=clip_lars_lr, 99 | exclude_bias_n_norm=exclude_bias_n_norm, 100 | ) 101 | if nesterov and (momentum <= 0 or dampening != 0): 102 | raise ValueError("Nesterov momentum requires a momentum and zero dampening") 103 | 104 | super().__init__(params, defaults) 105 | 106 | def __setstate__(self, state): 107 | super().__setstate__(state) 108 | 109 | for group in self.param_groups: 110 | group.setdefault("nesterov", False) 111 | 112 | @torch.no_grad() 113 | def step(self, closure=None): 114 | """Performs a single optimization step. 115 | Args: 116 | closure (callable, optional): A closure that reevaluates the model 117 | and returns the loss. 118 | """ 119 | loss = None 120 | if closure is not None: 121 | with torch.enable_grad(): 122 | loss = closure() 123 | 124 | # exclude scaling for params with 0 weight decay 125 | for group in self.param_groups: 126 | weight_decay = group["weight_decay"] 127 | momentum = group["momentum"] 128 | dampening = group["dampening"] 129 | nesterov = group["nesterov"] 130 | 131 | for p in group["params"]: 132 | if p.grad is None: 133 | continue 134 | 135 | d_p = p.grad 136 | p_norm = torch.norm(p.data) 137 | g_norm = torch.norm(p.grad.data) 138 | 139 | # lars scaling + weight decay part 140 | if p.ndim != 1 or not group["exclude_bias_n_norm"]: 141 | if p_norm != 0 and g_norm != 0: 142 | lars_lr = p_norm / (g_norm + p_norm * weight_decay + group["eps"]) 143 | lars_lr *= group["eta"] 144 | 145 | # clip lr 146 | if group["clip_lars_lr"]: 147 | lars_lr = min(lars_lr / group["lr"], 1) 148 | 149 | d_p = d_p.add(p, alpha=weight_decay) 150 | d_p *= lars_lr 151 | 152 | # sgd part 153 | if momentum != 0: 154 | param_state = self.state[p] 155 | if "momentum_buffer" not in param_state: 156 | buf = param_state["momentum_buffer"] = torch.clone(d_p).detach() 157 | else: 158 | buf = param_state["momentum_buffer"] 159 | buf.mul_(momentum).add_(d_p, alpha=1 - dampening) 160 | if nesterov: 161 | d_p = d_p.add(buf, alpha=momentum) 162 | else: 163 | d_p = buf 164 | 165 | p.add_(d_p, alpha=-group["lr"]) 166 | 167 | return loss 168 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/metrics.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | from typing import Dict, List, Sequence 21 | 22 | import torch 23 | 24 | 25 | def accuracy_at_k( 26 | outputs: torch.Tensor, targets: torch.Tensor, top_k: Sequence[int] = (1, 5) 27 | ) -> Sequence[int]: 28 | """Computes the accuracy over the k top predictions for the specified values of k. 29 | 30 | Args: 31 | outputs (torch.Tensor): output of a classifier (logits or probabilities). 32 | targets (torch.Tensor): ground truth labels. 33 | top_k (Sequence[int], optional): sequence of top k values to compute the accuracy over. 34 | Defaults to (1, 5). 35 | 36 | Returns: 37 | Sequence[int]: accuracies at the desired k. 38 | """ 39 | 40 | with torch.no_grad(): 41 | maxk = max(top_k) 42 | batch_size = targets.size(0) 43 | 44 | _, pred = outputs.topk(maxk, 1, True, True) 45 | pred = pred.t() 46 | correct = pred.eq(targets.view(1, -1).expand_as(pred)) 47 | 48 | res = [] 49 | for k in top_k: 50 | correct_k = correct[:k].contiguous().view(-1).float().sum(0, keepdim=True) 51 | res.append(correct_k.mul_(100.0 / batch_size)) 52 | return res 53 | 54 | 55 | def weighted_mean(outputs: List[Dict], key: str, batch_size_key: str) -> float: 56 | """Computes the mean of the values of a key weighted by the batch size. 57 | 58 | Args: 59 | outputs (List[Dict]): list of dicts containing the outputs of a validation step. 60 | key (str): key of the metric of interest. 61 | batch_size_key (str): key of batch size values. 62 | 63 | Returns: 64 | float: weighted mean of the values of a key 65 | """ 66 | 67 | value = 0 68 | n = 0 69 | for out in outputs: 70 | value += out[batch_size_key] * out[key] 71 | n += out[batch_size_key] 72 | value = value / n 73 | return value.squeeze(0) 74 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/momentum.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | import math 21 | 22 | import torch 23 | from torch import nn 24 | 25 | 26 | @torch.no_grad() 27 | def initialize_momentum_params(online_net: nn.Module, momentum_net: nn.Module): 28 | """Copies the parameters of the online network to the momentum network. 29 | 30 | Args: 31 | online_net (nn.Module): online network (e.g. online backbone, online projection, etc...). 32 | momentum_net (nn.Module): momentum network (e.g. momentum backbone, 33 | momentum projection, etc...). 34 | """ 35 | 36 | params_online = online_net.parameters() 37 | params_momentum = momentum_net.parameters() 38 | for po, pm in zip(params_online, params_momentum): 39 | pm.data.copy_(po.data) 40 | pm.requires_grad = False 41 | 42 | 43 | class MomentumUpdater: 44 | def __init__(self, base_tau: float = 0.996, final_tau: float = 1.0): 45 | """Updates momentum parameters using exponential moving average. 46 | 47 | Args: 48 | base_tau (float, optional): base value of the weight decrease coefficient 49 | (should be in [0,1]). Defaults to 0.996. 50 | final_tau (float, optional): final value of the weight decrease coefficient 51 | (should be in [0,1]). Defaults to 1.0. 52 | """ 53 | 54 | super().__init__() 55 | 56 | assert 0 <= base_tau <= 1 57 | assert 0 <= final_tau <= 1 and base_tau <= final_tau 58 | 59 | self.base_tau = base_tau 60 | self.cur_tau = base_tau 61 | self.final_tau = final_tau 62 | 63 | @torch.no_grad() 64 | def update(self, online_net: nn.Module, momentum_net: nn.Module): 65 | """Performs the momentum update for each param group. 66 | 67 | Args: 68 | online_net (nn.Module): online network (e.g. online backbone, online projection, etc...). 69 | momentum_net (nn.Module): momentum network (e.g. momentum backbone, 70 | momentum projection, etc...). 71 | """ 72 | 73 | for op, mp in zip(online_net.parameters(), momentum_net.parameters()): 74 | mp.data = self.cur_tau * mp.data + (1 - self.cur_tau) * op.data 75 | 76 | def update_tau(self, cur_step: int, max_steps: int): 77 | """Computes the next value for the weighting decrease coefficient tau using cosine annealing. 78 | 79 | Args: 80 | cur_step (int): number of gradient steps so far. 81 | max_steps (int): overall number of gradient steps in the whole training. 82 | """ 83 | 84 | self.cur_tau = ( 85 | self.final_tau 86 | - (self.final_tau - self.base_tau) * (math.cos(math.pi * cur_step / max_steps) + 1) / 2 87 | ) 88 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/sinkhorn_knopp.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | # Adapted from https://github.com/facebookresearch/swav. 21 | 22 | import torch 23 | import torch.distributed as dist 24 | 25 | 26 | class SinkhornKnopp(torch.nn.Module): 27 | def __init__(self, num_iters: int = 3, epsilon: float = 0.05, world_size: int = 1): 28 | """Approximates optimal transport using the Sinkhorn-Knopp algorithm. 29 | 30 | A simple iterative method to approach the double stochastic matrix is to alternately rescale 31 | rows and columns of the matrix to sum to 1. 32 | 33 | Args: 34 | num_iters (int, optional): number of times to perform row and column normalization. 35 | Defaults to 3. 36 | epsilon (float, optional): weight for the entropy regularization term. Defaults to 0.05. 37 | world_size (int, optional): number of nodes for distributed training. Defaults to 1. 38 | """ 39 | 40 | super().__init__() 41 | self.num_iters = num_iters 42 | self.epsilon = epsilon 43 | self.world_size = world_size 44 | 45 | @torch.no_grad() 46 | def forward(self, Q: torch.Tensor) -> torch.Tensor: 47 | """Produces assignments using Sinkhorn-Knopp algorithm. 48 | 49 | Applies the entropy regularization, normalizes the Q matrix and then normalizes rows and 50 | columns in an alternating fashion for num_iter times. Before returning it normalizes again 51 | the columns in order for the output to be an assignment of samples to prototypes. 52 | 53 | Args: 54 | Q (torch.Tensor): cosine similarities between the features of the 55 | samples and the prototypes. 56 | 57 | Returns: 58 | torch.Tensor: assignment of samples to prototypes according to optimal transport. 59 | """ 60 | 61 | Q = torch.exp(Q / self.epsilon).t() 62 | B = Q.shape[1] * self.world_size 63 | K = Q.shape[0] # num prototypes 64 | 65 | # make the matrix sums to 1 66 | sum_Q = torch.sum(Q) 67 | if dist.is_available() and dist.is_initialized(): 68 | dist.all_reduce(sum_Q) 69 | Q /= sum_Q 70 | 71 | for _ in range(self.num_iters): 72 | # normalize each row: total weight per prototype must be 1/K 73 | sum_of_rows = torch.sum(Q, dim=1, keepdim=True) 74 | if dist.is_available() and dist.is_initialized(): 75 | dist.all_reduce(sum_of_rows) 76 | Q /= sum_of_rows 77 | Q /= K 78 | 79 | # normalize each column: total weight per sample must be 1/B 80 | Q /= torch.sum(Q, dim=0, keepdim=True) 81 | Q /= B 82 | 83 | Q *= B # the colomns must sum to 1 so that Q is an assignment 84 | return Q.t() 85 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/solo/utils/whitening.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 solo-learn development team. 2 | 3 | # Permission is hereby granted, free of charge, to any person obtaining a copy of 4 | # this software and associated documentation files (the "Software"), to deal in 5 | # the Software without restriction, including without limitation the rights to use, 6 | # copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the 7 | # Software, and to permit persons to whom the Software is furnished to do so, 8 | # subject to the following conditions: 9 | 10 | # The above copyright notice and this permission notice shall be included in all copies 11 | # or substantial portions of the Software. 12 | 13 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, 14 | # INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR 15 | # PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE 16 | # FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR 17 | # OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER 18 | # DEALINGS IN THE SOFTWARE. 19 | 20 | 21 | from typing import Optional 22 | 23 | import torch 24 | import torch.nn as nn 25 | from torch.cuda.amp import custom_fwd 26 | from torch.nn.functional import conv2d 27 | 28 | 29 | class Whitening2d(nn.Module): 30 | def __init__(self, output_dim: int, eps: float = 0.0): 31 | """Layer that computes hard whitening for W-MSE using the Cholesky decomposition. 32 | 33 | Args: 34 | output_dim (int): number of dimension of projected features. 35 | eps (float, optional): eps for numerical stability in Cholesky decomposition. Defaults 36 | to 0.0. 37 | """ 38 | 39 | super(Whitening2d, self).__init__() 40 | self.output_dim = output_dim 41 | self.eps = eps 42 | 43 | @custom_fwd(cast_inputs=torch.float32) 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | """Performs whitening using the Cholesky decomposition. 46 | 47 | Args: 48 | x (torch.Tensor): a batch or slice of projected features. 49 | 50 | Returns: 51 | torch.Tensor: a batch or slice of whitened features. 52 | """ 53 | 54 | x = x.unsqueeze(2).unsqueeze(3) 55 | m = x.mean(0).view(self.output_dim, -1).mean(-1).view(1, -1, 1, 1) 56 | xn = x - m 57 | 58 | T = xn.permute(1, 0, 2, 3).contiguous().view(self.output_dim, -1) 59 | f_cov = torch.mm(T, T.permute(1, 0)) / (T.shape[-1] - 1) 60 | 61 | eye = torch.eye(self.output_dim).type(f_cov.type()) 62 | 63 | f_cov_shrinked = (1 - self.eps) * f_cov + self.eps * eye 64 | 65 | inv_sqrt = torch.triangular_solve(eye, torch.linalg.cholesky(f_cov_shrinked), upper=False)[ 66 | 0 67 | ] 68 | inv_sqrt = inv_sqrt.contiguous().view(self.output_dim, self.output_dim, 1, 1) 69 | 70 | decorrelated = conv2d(xn, inv_sqrt) 71 | 72 | return decorrelated.squeeze(2).squeeze(2) 73 | 74 | 75 | class iterative_normalization_py(torch.autograd.Function): 76 | @staticmethod 77 | def forward(ctx, *args) -> torch.Tensor: 78 | X, running_mean, running_wmat, nc, ctx.T, eps, momentum, training = args 79 | 80 | # change NxCxHxW to (G x D) x(NxHxW), i.e., g*d*m 81 | ctx.g = X.size(1) // nc 82 | x = X.transpose(0, 1).contiguous().view(ctx.g, nc, -1) 83 | _, d, m = x.size() 84 | saved = [] 85 | if training: 86 | # calculate centered activation by subtracted mini-batch mean 87 | mean = x.mean(-1, keepdim=True) 88 | xc = x - mean 89 | saved.append(xc) 90 | # calculate covariance matrix 91 | P = [None] * (ctx.T + 1) 92 | P[0] = torch.eye(d).to(X).expand(ctx.g, d, d) 93 | Sigma = torch.baddbmm( 94 | beta=eps, 95 | input=P[0], 96 | alpha=1.0 / m, 97 | batch1=xc, 98 | batch2=xc.transpose(1, 2), 99 | ) 100 | # reciprocal of trace of Sigma: shape [g, 1, 1] 101 | rTr = (Sigma * P[0]).sum((1, 2), keepdim=True).reciprocal_() 102 | saved.append(rTr) 103 | Sigma_N = Sigma * rTr 104 | saved.append(Sigma_N) 105 | for k in range(ctx.T): 106 | P[k + 1] = torch.baddbmm( 107 | beta=1.5, 108 | input=P[k], 109 | alpha=-0.5, 110 | batch1=torch.matrix_power(P[k], 3), 111 | batch2=Sigma_N, 112 | ) 113 | saved.extend(P) 114 | wm = P[ctx.T].mul_( 115 | rTr.sqrt() 116 | ) # whiten matrix: the matrix inverse of Sigma, i.e., Sigma^{-1/2} 117 | 118 | running_mean.copy_(momentum * mean + (1.0 - momentum) * running_mean) 119 | running_wmat.copy_(momentum * wm + (1.0 - momentum) * running_wmat) 120 | else: 121 | xc = x - running_mean 122 | wm = running_wmat 123 | xn = wm.matmul(xc) 124 | Xn = xn.view(X.size(1), X.size(0), *X.size()[2:]).transpose(0, 1).contiguous() 125 | ctx.save_for_backward(*saved) 126 | return Xn 127 | 128 | @staticmethod 129 | def backward(ctx, *grad_outputs): 130 | (grad,) = grad_outputs 131 | saved = ctx.saved_tensors 132 | if len(saved) == 0: 133 | return None, None, None, None, None, None, None, None 134 | 135 | xc = saved[0] # centered input 136 | rTr = saved[1] # trace of Sigma 137 | sn = saved[2].transpose(-2, -1) # normalized Sigma 138 | P = saved[3:] # middle result matrix, 139 | g, d, m = xc.size() 140 | 141 | g_ = grad.transpose(0, 1).contiguous().view_as(xc) 142 | g_wm = g_.matmul(xc.transpose(-2, -1)) 143 | g_P = g_wm * rTr.sqrt() 144 | wm = P[ctx.T] 145 | g_sn = 0 146 | for k in range(ctx.T, 1, -1): 147 | P[k - 1].transpose_(-2, -1) 148 | P2 = P[k - 1].matmul(P[k - 1]) 149 | g_sn += P2.matmul(P[k - 1]).matmul(g_P) 150 | g_tmp = g_P.matmul(sn) 151 | g_P.baddbmm_(beta=1.5, alpha=-0.5, batch1=g_tmp, batch2=P2) 152 | g_P.baddbmm_(beta=1, alpha=-0.5, batch1=P2, batch2=g_tmp) 153 | g_P.baddbmm_(beta=1, alpha=-0.5, batch1=P[k - 1].matmul(g_tmp), batch2=P[k - 1]) 154 | g_sn += g_P 155 | g_tr = ((-sn.matmul(g_sn) + g_wm.transpose(-2, -1).matmul(wm)) * P[0]).sum( 156 | (1, 2), keepdim=True 157 | ) * P[0] 158 | g_sigma = (g_sn + g_sn.transpose(-2, -1) + 2.0 * g_tr) * (-0.5 / m * rTr) 159 | g_x = torch.baddbmm(wm.matmul(g_ - g_.mean(-1, keepdim=True)), g_sigma, xc) 160 | grad_input = ( 161 | g_x.view(grad.size(1), grad.size(0), *grad.size()[2:]).transpose(0, 1).contiguous() 162 | ) 163 | return grad_input, None, None, None, None, None, None, None 164 | 165 | 166 | class IterNorm(torch.nn.Module): 167 | def __init__( 168 | self, 169 | num_features: int, 170 | num_groups: int = 64, 171 | num_channels: Optional[int] = None, 172 | T: int = 5, 173 | dim: int = 2, 174 | eps: float = 1.0e-5, 175 | momentum: float = 0.1, 176 | affine: bool = True, 177 | ): 178 | super(IterNorm, self).__init__() 179 | # assert dim == 4, 'IterNorm does not support 2D' 180 | self.T = T 181 | self.eps = eps 182 | self.momentum = momentum 183 | self.num_features = num_features 184 | self.affine = affine 185 | self.dim = dim 186 | if num_channels is None: 187 | num_channels = (num_features - 1) // num_groups + 1 188 | num_groups = num_features // num_channels 189 | while num_features % num_channels != 0: 190 | num_channels //= 2 191 | num_groups = num_features // num_channels 192 | assert ( 193 | num_groups > 0 and num_features % num_groups == 0 194 | ), f"num features={num_features}, num groups={num_groups}" 195 | self.num_groups = num_groups 196 | self.num_channels = num_channels 197 | shape = [1] * dim 198 | shape[1] = self.num_features 199 | if self.affine: 200 | self.weight = nn.Parameter(torch.Tensor(*shape)) 201 | self.bias = nn.Parameter(torch.Tensor(*shape)) 202 | else: 203 | self.register_parameter("weight", None) 204 | self.register_parameter("bias", None) 205 | 206 | self.register_buffer("running_mean", torch.zeros(num_groups, num_channels, 1)) 207 | # running whiten matrix 208 | self.register_buffer( 209 | "running_wm", 210 | torch.eye(num_channels).expand(num_groups, num_channels, num_channels).clone(), 211 | ) 212 | 213 | self.reset_parameters() 214 | 215 | def reset_parameters(self): 216 | if self.affine: 217 | torch.nn.init.ones_(self.weight) 218 | torch.nn.init.zeros_(self.bias) 219 | 220 | @custom_fwd(cast_inputs=torch.float32) 221 | def forward(self, X: torch.Tensor) -> torch.Tensor: 222 | X_hat = iterative_normalization_py.apply( 223 | X, 224 | self.running_mean, 225 | self.running_wm, 226 | self.num_channels, 227 | self.T, 228 | self.eps, 229 | self.momentum, 230 | self.training, 231 | ) 232 | # affine 233 | if self.affine: 234 | return X_hat * self.weight + self.bias 235 | 236 | return X_hat 237 | 238 | def extra_repr(self): 239 | return ( 240 | f"{self.num_features}, num_channels={self.num_channels}, T={self.T}, eps={self.eps}, " 241 | "momentum={momentum}, affine={affine}" 242 | ) 243 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/zoo/cifar10.sh: -------------------------------------------------------------------------------- 1 | mkdir trained_models 2 | cd trained_models 3 | mkdir cifar10 4 | cd cifar10 5 | 6 | # Barlow Twins 7 | mkdir barlow_twins 8 | cd barlow_twins 9 | gdown https://drive.google.com/uc?id=1x7y44E05vuobibfObT4n3jqLI8QNVESV 10 | gdown https://drive.google.com/uc?id=1Mxfq2YGQ53bNRV2fNYzvYIneM5ZGeb2h 11 | cd .. 12 | 13 | # BYOL 14 | mkdir byol 15 | cd byol 16 | gdown https://drive.google.com/uc?id=1zOE8O2yPyhE23LMoesMoDPdLyh1qbI8k 17 | gdown https://drive.google.com/uc?id=1l1XIWE1ailKzsQnUPGDgyvK0escOsta6 18 | cd .. 19 | 20 | # DeepCluster V2 21 | mkdir deepclusterv2 22 | cd deepclusterv2 23 | gdown https://drive.google.com/uc?id=13L_QlwrBRJhdeCaVdgkRYWfvoh4PIWwj 24 | gdown https://drive.google.com/uc?id=17jRJ-LC56uWRuNluWXecXHjTxomuGs_T 25 | cd .. 26 | 27 | # DINO 28 | mkdir dino 29 | cd dino 30 | gdown https://drive.google.com/uc?id=1Wv9w5j22YitGAWi4p3IJYzLVo4fQkpSu 31 | gdown https://drive.google.com/uc?id=1PBElgMN5gjZsK3o1L55jNnb5A1ebbOvu 32 | cd .. 33 | 34 | # MoCo V2+ 35 | mkdir mocov2plus 36 | cd mocov2plus 37 | gdown https://drive.google.com/uc?id=1viIUTHmLdozDWtzMicV4oOyC50iL2QDU 38 | gdown https://drive.google.com/uc?id=1ZLpgK13N8rgBxvqRbyGFd_8mF03pStIx 39 | cd .. 40 | 41 | # MoCo V3 42 | mkdir mocov3 43 | cd mocov3 44 | gdown https://drive.google.com/uc?id=1EFHWBLYFsglZYPYsBc0YrtihrzBZRe7h 45 | gdown https://drive.google.com/uc?id=1Gb6TCWoY2aN8AK3UnuZu4IxIktqbDybP 46 | cd .. 47 | 48 | # NNCLR 49 | mkdir nnclr 50 | cd nnclr 51 | gdown https://drive.google.com/uc?id=1zKReUmJ35vRnQxfSxn7yRVRW_oy3LUDF 52 | gdown https://drive.google.com/uc?id=1UyI9r19PoFGqHjd5r1UEpstCSTkleja7 53 | cd .. 54 | 55 | # ReSSL 56 | mkdir ressl 57 | cd ressl 58 | gdown https://drive.google.com/uc?id=1UdDWvgpyvj3VFVm0lq-WrGj0-GTcEpHq 59 | gdown https://drive.google.com/uc?id=1XkkYUuEI79__4GpCCDhuFEbv0BbRdCBh 60 | cd .. 61 | 62 | # SimCLR 63 | mkdir simclr 64 | cd simclr 65 | gdown https://drive.google.com/uc?id=15fI7gb9M92jZWBZoGLvarYDiNYK3RN2O 66 | gdown https://drive.google.com/uc?id=1HMJof4v2B5S-khepI_x8bgFv72I5KMc9 67 | cd .. 68 | 69 | # Simsiam 70 | mkdir simsiam 71 | cd simsiam 72 | gdown https://drive.google.com/uc?id=1ZMGGTziK0DbCP43fDx2rPFrtJxCLJDmb 73 | gdown https://drive.google.com/uc?id=1hh1QrQiWfRej-8D6L67T_F7Je9-EUUg2 74 | cd .. 75 | 76 | # SupCon 77 | mkdir supcon 78 | cd supcon 79 | gdown https://drive.google.com/uc?id=1tkk_r7tYozLgf9khW6LiGxaTvJQ4c5sA 80 | gdown https://drive.google.com/uc?id=1OhZul-rtBVUOqvOkORk8HgOLIXGCNEzB 81 | cd .. 82 | 83 | # SwAV 84 | mkdir swav 85 | cd swav 86 | gdown https://drive.google.com/uc?id=1CPok55wwN_4QecEjubdLeBo_9qWSJTHw 87 | gdown https://drive.google.com/uc?id=1t59f1Q8ifx8tAySGpD2pmvogNcR1USEo 88 | cd .. 89 | 90 | # VIbCReg 91 | mkdir vibcreg 92 | cd vibcreg 93 | gdown https://drive.google.com/uc?id=1dHsKrhCcwWIXFwQJ4oVPgLcEcT3SecQV 94 | gdown https://drive.google.com/uc?id=1OPsUf8VnKo5w6T8-rEQFaodUNxvQ8CTT 95 | cd .. 96 | 97 | # VICReg 98 | mkdir vicreg 99 | cd vicreg 100 | gdown https://drive.google.com/uc?id=1TeliMNt5bOchqJj2u_JjB0_ahKB5LKi5 101 | gdown https://drive.google.com/uc?id=1dsdPL-5QNS9LyHypYN6VQfEuiNWLKJqN 102 | cd .. 103 | 104 | # W-MSE 105 | mkdir wmse 106 | cd wmse 107 | gdown https://drive.google.com/uc?id=1jTjpmVTi9rtzy3NPEEp_61py-jeHy5fi 108 | gdown https://drive.google.com/uc?id=1YLuqazfSDOruSiu4Kl6OAexDnt5LKEIT 109 | cd .. 110 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/zoo/cifar100.sh: -------------------------------------------------------------------------------- 1 | mkdir trained_models 2 | cd trained_models 3 | mkdir cifar100 4 | cd cifar100 5 | 6 | # Barlow Twins 7 | mkdir barlow_twins 8 | cd barlow_twins 9 | gdown https://drive.google.com/uc?id=17cZt3DorfiCYb0ZauLHv0iM-YDGYa-mE 10 | gdown https://drive.google.com/uc?id=17Me99dh-XfTV-fniXn0Cy-ZcwGa9dRZe 11 | cd .. 12 | 13 | # BYOL 14 | mkdir byol 15 | cd byol 16 | gdown https://drive.google.com/uc?id=1fE7TdRboFJnYXr8JSY_tGmuFGitI8l23 17 | gdown https://drive.google.com/uc?id=1qsBJoO1ROAEUeQtvl8hOBDnLXZKY8Ziy 18 | cd .. 19 | 20 | # DeepCluster V2 21 | mkdir deepclusterv2 22 | cd deepclusterv2 23 | gdown https://drive.google.com/uc?id=1grFfh0aaVYpeuYbgFYB4rmfj9uvXhYSd 24 | gdown https://drive.google.com/uc?id=12jBsv8Fd2vk6OD5khbl4qp7szfSCiERD 25 | cd .. 26 | 27 | # DINO 28 | mkdir dino 29 | cd dino 30 | gdown https://drive.google.com/uc?id=16gdp5L_a9BVcRvcU4f-NUJCsIpX3Oecr 31 | gdown https://drive.google.com/uc?id=1M4UVug_ARfNW_sjnRbc0KBceBXVKVVxH 32 | cd .. 33 | 34 | # MoCo V2+ 35 | mkdir mocov2plus 36 | cd mocov2plus 37 | gdown https://drive.google.com/uc?id=1KNkCA2Hr70QsmOSif9_UUndFerOb7Jft 38 | gdown https://drive.google.com/uc?id=1T_SpFAEhZap2fvKnUvk8hzL-C7Nzad93 39 | cd .. 40 | 41 | # MoCo V3 42 | mkdir mocov3 43 | cd mocov3 44 | gdown https://drive.google.com/uc?id=1QAuKJmegGCJrntAL80tfTrbi2fI4sPl- 45 | gdown https://drive.google.com/uc?id=1jtJEi66g5z7dBn0FDcSL7zoU4ArEEyqU 46 | cd .. 47 | 48 | # NNCLR 49 | mkdir nnclr 50 | cd nnclr 51 | gdown https://drive.google.com/uc?id=1aodwBlGK6EqrC_kthk8JcuxVcY4S5CF9 52 | gdown https://drive.google.com/uc?id=14Z8REvCrdW8eZ0kwxmNioIPneyCSAk0E 53 | cd .. 54 | 55 | # ReSSL 56 | mkdir ressl 57 | cd ressl 58 | gdown https://drive.google.com/uc?id=16sKNdpScv5FckpC02W41mjETXL6T5u2S 59 | gdown https://drive.google.com/uc?id=1niA588wO6KX1dcbhfelb_vumByHgDfVV 60 | cd .. 61 | 62 | # SimCLR 63 | mkdir simclr 64 | cd simclr 65 | gdown https://drive.google.com/uc?id=17YGC7y4yxkVAF8ZNezdtmN-uc70jz3zq 66 | gdown https://drive.google.com/uc?id=1bmrfJxEK505_ky0m7q7ZJSDpFfgqIuQ6 67 | cd .. 68 | 69 | # Simsiam 70 | mkdir simsiam 71 | cd simsiam 72 | gdown https://drive.google.com/uc?id=1DStn9PAEMJtzh1Mxb3NjfTtm5vaNgRM5 73 | gdown https://drive.google.com/uc?id=1y03EtFuMi5fZGPJZfN3hkONe99WsFBOJ 74 | cd .. 75 | 76 | # SupCon 77 | mkdir supcon 78 | cd supcon 79 | gdown https://drive.google.com/uc?id=1QhPHENtgYttIF1Dn1srA4dAkIiC_5P7W 80 | gdown https://drive.google.com/uc?id=1QsZs9TfWoycrHBBUrliWe-cqkGQ9epAD 81 | cd .. 82 | 83 | # SwAV 84 | mkdir swav 85 | cd swav 86 | gdown https://drive.google.com/uc?id=1oJzFfayNpcShK1bZtDK58HthcKY2bpns 87 | gdown https://drive.google.com/uc?id=14ed_7MG_pg-G_qjQcxVc8MUZWwFcz3mF 88 | cd .. 89 | 90 | # VIbCReg 91 | mkdir vibcreg 92 | cd vibcreg 93 | gdown https://drive.google.com/uc?id=1akNcewHzh4ideoQPWakaXWGDxfGoxkNu 94 | gdown https://drive.google.com/uc?id=1cdvZXUmmDptSe-RkYyiQXyREwthvMuxW 95 | cd .. 96 | 97 | # VICReg 98 | mkdir vicreg 99 | cd vicreg 100 | gdown https://drive.google.com/uc?id=1kH78BUBKprrsxL2KRKmorVQ9vJHsMsID 101 | gdown https://drive.google.com/uc?id=1TJk6G6KY1URPpruhKIDuovv66U-mnQHo 102 | cd .. 103 | 104 | # W-MSE 105 | mkdir wmse 106 | cd wmse 107 | gdown https://drive.google.com/uc?id=1_6EmYFqAW_U8DFv72KUaAe-BV8xkRxsp 108 | gdown https://drive.google.com/uc?id=1uIeg5EKEMefeBIyYFm9SBmChJPBc-0g_ 109 | cd .. 110 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/zoo/imagenet.sh: -------------------------------------------------------------------------------- 1 | mkdir trained_models 2 | cd trained_models 3 | mkdir imagenet 4 | cd imagenet 5 | 6 | # Barlow Twins 7 | mkdir barlow_twins 8 | cd barlow_twins 9 | gdown https://drive.google.com/uc?id=1u0NH9L-p1Y99KVFxjBmZ_wlY-l6YJFtP 10 | gdown https://drive.google.com/uc?id=1EKdbR72-gtNE782254tjXi9UR2NiwEWh 11 | cd .. 12 | 13 | # BYOL 14 | mkdir byol 15 | cd byol 16 | gdown https://drive.google.com/uc?id=1TheL_4tmDWByCxg8XHke5VEz_lcYHH64 17 | gdown https://drive.google.com/uc?id=18gG0Jo59cFVX4qNUO119jIhzHcJkAmGz 18 | cd .. 19 | 20 | # MoCo V2+ 21 | mkdir mocov2plus 22 | cd mocov2plus 23 | gdown https://drive.google.com/uc?id=1BBauwWTJV38BCf56KtOK9TJWLyjNH-mP 24 | gdown https://drive.google.com/uc?id=1JMpGSYjefFzxT5GTbEc_2d4THxOxC3Ca 25 | cd .. 26 | -------------------------------------------------------------------------------- /MULTI-STAGE-GRAPH-AGGREGATION/zoo/imagenet100.sh: -------------------------------------------------------------------------------- 1 | mkdir trained_models 2 | cd trained_models 3 | mkdir imagenet100 4 | cd imagenet100 5 | 6 | # Barlow Twins 7 | mkdir barlow_twins 8 | cd barlow_twins 9 | gdown https://drive.google.com/uc?id=1C2qQSqp8cXvfrwHVG9MuGTPT2TOTsGla # checkpoint 10 | gdown https://drive.google.com/uc?id=1TY10aa97P4Fl7EgSjTy_u_QME9tkcU4r # args 11 | cd .. 12 | 13 | # BYOL 14 | mkdir byol 15 | cd byol 16 | gdown https://drive.google.com/uc?id=1cgJaSRr3HPZRNMwzYwwS5Vwtkna3LgGs # checkpoint 17 | gdown https://drive.google.com/uc?id=1EIluSRGaV0Ft1UQecGhpkFUCKVwMmtv9 # args 18 | cd .. 19 | 20 | # DeepCluster V2 21 | mkdir deepclusterv2 22 | cd deepclusterv2 23 | gdown https://drive.google.com/uc?id=1ANWOVMFMa-9eRWTKRGiUkNenJYD-McjT # checkpoint 24 | gdown https://drive.google.com/uc?id=18oOypleOOHQ7z9XL9zUTgDB7zpRdbMti # args 25 | cd .. 26 | 27 | # DINO 28 | mkdir dino 29 | cd dino 30 | gdown https://drive.google.com/uc?id=1MkuNjlIMqzuRwdG_K6NoDrGQH2GtssXV # checkpoint 31 | gdown https://drive.google.com/uc?id=1MlYaqPsp_pEaDR7nDRbxv3oOsMTVBHg9 # args 32 | cd .. 33 | 34 | # DINO (vit tiny) 35 | mkdir dino-vit 36 | cd dino-vit 37 | gdown https://drive.google.com/uc?id=11rHOKD4EQB2AJ1C2tLHz0pjai6MqwT9v # checkpoint 38 | gdown https://drive.google.com/uc?id=15pQbMd0xiLZNsmozBmsKA3_HVEdpSqmy # args 39 | cd .. 40 | 41 | # MoCo V2+ 42 | mkdir mocov2plus 43 | cd mocov2plus 44 | gdown https://drive.google.com/uc?id=1aXGypKbIqV8BqtVOzpk2lRJWXb--XejO # checkpoint 45 | gdown https://drive.google.com/uc?id=1s5rzHSqAMRKaUR4ZP3HWbCm8QLXU6JQ8 # args 46 | cd .. 47 | 48 | # MoCo V3 49 | mkdir mocov3 50 | cd mocov3 51 | gdown https://drive.google.com/uc?id=1cUaAdx-6NXCkeSMo-mQtpPnYk7zA4Gg4 # checkpoint 52 | gdown https://drive.google.com/uc?id=1mb6ZRKF1CdGP0rdJI2yjyStZ-FCFjsi4 # args 53 | cd .. 54 | 55 | # MoCo V3 R50 56 | mkdir mocov3-r50 57 | cd mocov3-r50 58 | gdown https://drive.google.com/uc?id=1KiwHisYRmzYjLYDm1zQxZlUKe8BkI2i8 # checkpoint 59 | gdown https://drive.google.com/uc?id=16pix6gNybXnssMpXlzjKnMWl9lRfdv20 # args 60 | cd .. 61 | 62 | # NNCLR 63 | mkdir nnclr 64 | cd nnclr 65 | gdown https://drive.google.com/uc?id=1rj9-YBUNX0wHVLjQuksOOubfEZJsrsjF # checkpoint 66 | gdown https://drive.google.com/uc?id=1GBT6-QkhuDLexfVgwM0SWbzuXJ6QrF9o # args 67 | cd .. 68 | 69 | # ReSSL 70 | mkdir ressl 71 | cd ressl 72 | gdown https://drive.google.com/uc?id=1AH3hFcakrGKXzxmzO2LBHWjk-Mgu5PUN # checkpoint 73 | gdown https://drive.google.com/uc?id=1XWKERLv_YgFQ_Oy33TD8DhTfH9qSoVQa # args 74 | cd .. 75 | 76 | # SimCLR 77 | mkdir simclr 78 | cd simclr 79 | gdown https://drive.google.com/uc?id=1dU88Sh5F_8J_UXXEQ8FOWS85g8eFEZVa # checkpoint 80 | gdown https://drive.google.com/uc?id=1865vcQhuvGeNm0iQ9g87APLwLvYNuqcn # args 81 | cd .. 82 | 83 | # Simsiam 84 | mkdir simsiam 85 | cd simsiam 86 | gdown https://drive.google.com/uc?id=1cwAyDCpU36zmQ6-r4Ww7YqiZ41vbjckQ # checkpoint 87 | gdown https://drive.google.com/uc?id=1EU43HZKrLu_ZTV3CVAkjkFS6HORhtmR9 # args 88 | cd .. 89 | 90 | # SupCon 91 | mkdir supcon 92 | cd supcon 93 | gdown https://drive.google.com/uc?id=1-NRvw7J9WrQKBvDhmuirQmTklMlQasxI # checkpoint 94 | gdown https://drive.google.com/uc?id=1IKTW20UTWlHSO4RsgO1QxakYs5ZscY26 # args 95 | cd .. 96 | 97 | # SwAV 98 | mkdir swav 99 | cd swav 100 | gdown https://drive.google.com/uc?id=1nDiXHb8ce6_qDyZ8EcqDXi6ptI4A_t6B # checkpoint 101 | gdown https://drive.google.com/uc?id=1h1-YEqEw5Zj7wl0Gkxiz6WC3qpwa2FgL # args 102 | cd .. 103 | 104 | # VIbCReg 105 | mkdir vibcreg 106 | cd vibcreg 107 | gdown https://drive.google.com/uc?id=1VDUvp0zghvnUgwhWS-s7PuCA1KTPEPPX # checkpoint 108 | gdown https://drive.google.com/uc?id=14rEyW3cZyUxctjLQunIjyuJMI3DbUQ-b # args 109 | cd .. 110 | 111 | # VICReg 112 | mkdir vicreg 113 | cd vicreg 114 | gdown https://drive.google.com/uc?id=1yAxL-NTOYN6kGi2VtPeo7cPKFxXFKYyP # checkpoint 115 | gdown https://drive.google.com/uc?id=1A5QaOlUGaId3qECmQusoPDiYx8tjfKDz # args 116 | cd .. 117 | 118 | # W-MSE 119 | mkdir wmse 120 | cd wmse 121 | gdown https://drive.google.com/uc?id=1yYhOsIpbHqJGhqlbMTYwBxMOJkz7rSwo # checkpoint 122 | gdown https://drive.google.com/uc?id=1Q88g4Rtz_k4FR9QvXwFqAZ-kYuL4dXl- # args 123 | cd .. 124 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Message-Passing-Contrastive-Learning 2 | 3 | This repository includes a PyTorch implementation of the ICLR 2023 paper [A Message Passing Perspective on Learning Dynamics of Contrastive Learning](https://openreview.net/pdf?id=VBTJqqWjxMv) authored by [Yifei Wang*](https://yifeiwang.me), Qi Zhang*, Tianqi Du, Jiansheng Yang, Zhouchen Lin, [Yisen Wang](https://yisenwang.github.io/). 4 | 5 | Multi-stage Graph Aggregation and Graph-Attention are two methods inspired by the connection between message passing and contrastive learning and they can siginificantly improve the performance of sefl-supervised paradigms. 6 | 7 | 8 | 9 | | Backbone | Method | CIFAR-10 | CIFAR-100 | ImageNet-100 | 10 | |-----------|---------------|:--------:|:---------:|:------------:| 11 | | ResNet-18 | SimSiam | 83.8 | 56.3 | 68.8 | 12 | | | SimSiam-Multi | 84.8 | 58.9 | 70.5 | 13 | | ResNet-50 | SimSiam | 85.9 | 58.4 | 70.9 | 14 | | | SimSiam-Multi | 87.0 | 59.8 | 72.3 | 15 | 16 | 17 | 18 | 19 | 20 | | Backbone | Method | CIFAR-10 | CIFAR-100 | ImageNet-100 | 21 | |-----------|-------------|----------|-----------|--------------| 22 | | ResNet-18 | SimCLR | 84.5 | 56.1 | 62.3 | 23 | | | SimCLR-Attn | 85.4 | 56.9 | 63.1 | 24 | | ResNet-50 | SimCLR | 88.2 | 59.8 | 66.0 | 25 | | | SimCLR-Attn | 89.4 | 60.7 | 66.7 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | ## Instructions 36 | 37 | ### Environment Setup 38 | 39 | To install the environment for Multi-stage Graph Aggregation with the following commands 40 | ``` 41 | cd MULTI-STAGE-GRAPH-AGGREGATION 42 | pip3 install .[dali,umap,h5] --extra-index-url https://developer.download.nvidia.com/compute/redist 43 | ``` 44 | 45 | To install the environment for Graph-Attention with the following commands 46 | ``` 47 | cd GRAPH-ATTENTION 48 | conda env create -f environment.yml 49 | conda activate simclr_pytorch 50 | ``` 51 | 52 | When pretraining the model with the proposed methods, please first enter the corresponding directory (``MULTI-STAGE-GRAPH-AGGREGATION``/``GRAPH-ATTENTION``). 53 | 54 | ### Pretraining with Multi-stage Graph Aggregation 55 | 56 | 57 | 58 | Taking Simsiam on CIFAR-10 as an example, we pretrain the model with Multi-stage Graph Aggregation technique with following commands 59 | 60 | 61 | ``` 62 | ./scripts/pretrain/cifar/simsiam.sh 63 | ``` 64 | 65 | The codes provide an online linear classifier. And the offline downstream linear performance can be evaluated with 66 | 67 | ``` 68 | ./scripts/linear/simsiam_linear.sh 69 | ``` 70 | 71 | 72 | ### Pretraining with Graph-Attention 73 | 74 | 75 | 76 | Taking SimCLR on CIFAR-10 as an example, we pretrain the model with Graph-Attention technique with following commands 77 | 78 | ``` 79 | python train.py --config configs/cifar_train_epochs200_bs512.yaml 80 | ``` 81 | 82 | And the downstream linear performance can be evaluated with 83 | 84 | ``` 85 | python train.py --config configs/cifar_eval.yaml --encoder_ckpt 86 | ``` 87 | 88 | More running details can be found in [MULTI-STAGE-GRAPH-AGGREGATION/README_simsiam.md](MULTI-STAGE-GRAPH-AGGREGATION/README_simsiam.md) and [GRAPH-ATTENTION/README_simclr.md](GRAPH-ATTENTION/README_simclr.md). 89 | 90 | 91 | ## Modifications 92 | 93 | We follow the default settings of SimSiam and SimCLR, and the main modifications are: 94 | 95 | In [MULTI-STAGE-GRAPH-AGGREGATION/solo/method/simsiam.py](MULTI-STAGE-GRAPH-AGGREGATION/solo/methods/simsiam.py), to implement the Multi-stage Graph Aggregation, we add a memory bank to store the latest features and modify the loss function by combing the latest features of the last epoch . 96 | 97 | In [GRAPH-ATTENTION/models/losses.py](GRAPH-ATTENTION/models/losses.py), to implement the Graph Attention, we evaluate the similarity between the features in the same batch and reweight the InfoNCE loss with that. 98 | 99 | 100 | 101 | ## Citing this work 102 | 103 | 104 | If you find our code useful, please cite 105 | ``` 106 | @inproceedings{ 107 | wang2023message, 108 | title={A Message Passing Perspective on Learning Dynamics of Contrastive Learning}, 109 | author={Yifei Wang and Qi Zhang and Tianqi Du and Jiansheng Yang and Zhouchen Lin and Yisen Wang}, 110 | booktitle={International Conference on Learning Representations}, 111 | year={2023}, 112 | } 113 | ``` 114 | 115 | 116 | ## Acknowledgement 117 | 118 | Our codes borrows the implementations of SimSiam and SimCLR in these respositories: 119 | 120 | https://github.com/vturrisi/solo-learn 121 | 122 | https://github.com/google-research/simclr 123 | 124 | 125 | --------------------------------------------------------------------------------