├── unagi ├── utils │ ├── __init__.py │ ├── file_utils.py │ └── image_utils.py ├── data │ ├── data_utils │ │ ├── __init__.py │ │ ├── collate_fns.py │ │ ├── meerkat_processors.py │ │ └── transform_util.py │ ├── transforms │ │ ├── image │ │ │ ├── utils.py │ │ │ ├── identity.py │ │ │ ├── invert.py │ │ │ ├── blur.py │ │ │ ├── equalize.py │ │ │ ├── smooth.py │ │ │ ├── auto_contrast.py │ │ │ ├── horizontal_filp.py │ │ │ ├── vertical_flip.py │ │ │ ├── to_tensor.py │ │ │ ├── reshape2d.py │ │ │ ├── normalize.py │ │ │ ├── posterize.py │ │ │ ├── solarize.py │ │ │ ├── color.py │ │ │ ├── contrast.py │ │ │ ├── sharpness.py │ │ │ ├── brightness.py │ │ │ ├── gaussian_blur.py │ │ │ ├── rotate.py │ │ │ ├── center_crop.py │ │ │ ├── random_grayscale.py │ │ │ ├── random_horizontal_flip.py │ │ │ ├── grayscale.py │ │ │ ├── shear_x.py │ │ │ ├── shear_y.py │ │ │ ├── resize.py │ │ │ ├── compose.py │ │ │ ├── transform.py │ │ │ ├── translate_x.py │ │ │ ├── translate_y.py │ │ │ ├── color_distortion.py │ │ │ ├── random_resize_crop.py │ │ │ ├── color_jitter.py │ │ │ ├── random_crop.py │ │ │ ├── cutout.py │ │ │ ├── resize_and_pad.py │ │ │ └── __init__.py │ │ ├── text │ │ │ ├── identity.py │ │ │ ├── __init__.py │ │ │ ├── compose.py │ │ │ ├── pretrained_lm_tokenize.py │ │ │ ├── transform.py │ │ │ └── back_translate.py │ │ ├── task │ │ │ ├── __init__.py │ │ │ └── transform.py │ │ └── __init__.py │ └── augmentations │ │ ├── __init__.py │ │ ├── cutout.py │ │ ├── brightness.py │ │ └── mixup.py ├── _version.py ├── tasks │ ├── loss_fns │ │ ├── base_loss.py │ │ ├── mask_loss.py │ │ └── ce_loss.py │ ├── output_layer_modules.py │ └── unagi_task_template.py ├── trainer │ └── __init__.py ├── models │ ├── ops │ │ ├── view_concat.py │ │ ├── sequence_concat.py │ │ ├── linear_proj.py │ │ ├── pool.py │ │ ├── view_select.py │ │ ├── image_reshape.py │ │ ├── einsum_reduce.py │ │ └── grayscale.py │ ├── decoders │ │ ├── view_concat.py │ │ ├── classifier.py │ │ ├── sequence │ │ │ ├── transformer.py │ │ │ └── mixer.py │ │ └── image │ │ │ └── resnet.py │ ├── embeddings │ │ └── base_embedding.py │ ├── encoders │ │ ├── sequence │ │ │ ├── mixer │ │ │ │ ├── mixer_modules.py │ │ │ │ └── mixer.py │ │ │ ├── bert │ │ │ │ └── bert.py │ │ │ └── transformer │ │ │ │ └── transformer.py │ │ ├── image │ │ │ └── resnet │ │ │ │ └── resnet.py │ │ └── base_sequence.py │ ├── __init__.py │ └── layers │ │ └── blocks.py ├── constants.py ├── datasets │ ├── __init__.py │ ├── mnist │ │ ├── utils.py │ │ └── mnist_dataset.py │ ├── meerkat_dataset.py │ ├── tiny_imagenet │ │ └── tinyimagenet_dataset.py │ └── cifar │ │ └── utils.py ├── unagi.py └── data_driver.py ├── config ├── loss │ ├── mse_loss.yaml │ ├── label_smoothing.yaml │ └── contrastive_loss.yaml ├── optimizer │ ├── adam.yaml │ └── adamw.yaml ├── encoder │ ├── resnet.yaml │ ├── bert.yaml │ ├── mixer.yaml │ └── transformer.yaml ├── dataflow │ ├── loader │ │ └── default.yaml │ ├── dataset │ │ ├── celeba.yaml │ │ ├── mnist_coarse.yaml │ │ ├── mnist.yaml │ │ ├── cifar10.yaml │ │ ├── cifar10_coarse.yaml │ │ ├── cifar10_train_as_test.yaml │ │ ├── cifar_samples.yaml │ │ ├── cifar10_subset.yaml │ │ ├── tinyimagenet_coarse.yaml │ │ ├── cifar100_coarse.yaml │ │ ├── cifar100.yaml │ │ └── tinyimagenet.yaml │ ├── transforms │ │ ├── cifar10_notransform.yaml │ │ ├── celeba.yaml │ │ ├── mnist.yaml │ │ ├── mnist_resnet.yaml │ │ ├── cifar10_resnet.yaml │ │ ├── cifar100.yaml │ │ ├── tinyimagenet.yaml │ │ ├── cifar10.yaml │ │ ├── nlvr2.yaml │ │ ├── cub200.yaml │ │ └── upmcfood101.yaml │ ├── celeba.yaml │ ├── mnist.yaml │ ├── cifar10.yaml │ ├── cifar100.yaml │ ├── cifar100_coarse.yaml │ ├── cifar10_coarse.yaml │ ├── cifar10_train_as_test.yaml │ ├── tinyimagenet_coarse.yaml │ └── cifar10_subset.yaml ├── pipeline │ ├── cep_image.yaml │ ├── cep_image_log.yaml │ ├── cep_image_supcon.yaml │ ├── cep_image_autoencoder.yaml │ ├── cep_image_supcon_simclr.yaml │ ├── cep_image_supcon_autoencoder.yaml │ ├── cep_image_supcon_autoencoder_joint.yaml │ └── cep_image_supcon_autoencoder_joint_grayscale.yaml ├── decoder │ ├── pool.yaml │ ├── grayscale.yaml │ ├── sequence_concat.yaml │ ├── linear_projection.yaml │ ├── view_select.yaml │ ├── classifier.yaml │ ├── resnet18decoder.yaml │ └── image_reshape.yaml ├── task │ ├── metric │ │ └── accuracy.yaml │ ├── loss │ │ ├── mse_loss.yaml │ │ ├── classifier_loss.yaml │ │ ├── lspread_contrastive_loss.yaml │ │ ├── simclr_contrastive_loss.yaml │ │ └── supcon_contrastive_loss.yaml │ ├── classification_image.yaml │ ├── classification_image_joint.yaml │ ├── callback │ │ ├── log_image.yaml │ │ └── log_embedding.yaml │ ├── classification_image_clip.yaml │ ├── classification_joint_clip.yaml │ ├── task_flow │ │ ├── classification_text.yaml │ │ ├── classification_image.yaml │ │ ├── supcon_image.yaml │ │ ├── autoencoder_image.yaml │ │ ├── autoencoder_image_grayscale.yaml │ │ └── supcon_autoencoder_joint_classification.yaml │ ├── classification_image_log.yaml │ ├── supcon_image.yaml │ ├── simclr_image.yaml │ ├── classification_text_clip.yaml │ ├── contrastive_image_clip.yaml │ ├── contrastive_joint_clip.yaml │ ├── autoencoder_image.yaml │ ├── contrastive_text_clip.yaml │ ├── autoencoder_image_grayscale.yaml │ └── entitymatching_pretrain.yaml ├── scheduler │ ├── cosine.yaml │ ├── multistep.yaml │ └── plateau.yaml ├── preprocessor │ └── supervised_task_preprocessing.yaml ├── embedding │ ├── identity.yaml │ ├── pretrained_lm.yaml │ └── square_patch.yaml ├── wandb │ └── default.yaml ├── config.yaml ├── experiment │ ├── test.yaml │ ├── mnist_cep_xfer_resnet18.yaml │ ├── mnist_supcon_cep_coarselabels_resnet18.yaml │ ├── mnist_lspread_cep_coarselabels_resnet18.yaml │ ├── cifar10_cep_autoencoder.yaml │ ├── mnist_cep_autoencoder_coarselabels_resnet18.yaml │ ├── mnist_supcon_simclr_cep_coarselabels_resnet18.yaml │ ├── cifar100_supcon_cep.yaml │ ├── cifar10_supcon_cep_autoencoder.yaml │ ├── cifar10_supcon_cep.yaml │ ├── cifar10_lspread_cep.yaml │ ├── cifar100_supcon_cep_autoencoder.yaml │ ├── cifar10_cep.yaml │ ├── cifar10_supcon_cep_autoencoder_joint.yaml │ ├── cifar100_cep.yaml │ ├── cifar100_cep_coarselabels.yaml │ ├── cifar10_cep_coarselabels.yaml │ ├── cifar10_supcon_cep_coarselabels.yaml │ ├── cifar10_cep_autoencoder_coarselabels.yaml │ ├── cifar100_supcon_cep_coarselabels.yaml │ ├── cifar10_supcon_cep_autoencoder_coarselabels.yaml │ ├── cifar100_supcon_cep_autoencoder_coarselabels.yaml │ ├── cifar10_supcon_cep_autoencoder_twobackbones_coarselabels.yaml │ ├── cifar10_supcon_cep_coarselabels_clip_pos.yaml │ ├── cifar10_lspread_cep_coarselabels.yaml │ ├── cifar10_lspread_cep_autoencoder_coarselabels.yaml │ ├── cifar10_supcon_simclr_cep_coarselabels.yaml │ ├── cifar10_supcon_cep_coarselabels_weighted_pos_in_denom.yaml │ ├── cifar10_lspread_cep_autoencoder_twobackbones_coarselabels.yaml │ ├── cifar10_cep_xfer.yaml │ ├── cifar100_cep_xfer.yaml │ ├── cifar10_cep_xfer_train_as_test.yaml │ ├── cifar10_cep_xfer_train_test_embs.yaml │ ├── cifar10_supcon_cep_coarselabels_resnet.yaml │ ├── cifar10_supcon_cep_coarselabels_resnet18.yaml │ ├── cifar10_cep_xfer_resnet.yaml │ ├── cifar10_cep_xfer_resnet18.yaml │ ├── cifar10_lspread_cep_coarselabels_resnet.yaml │ ├── cifar10_lspread_cep_coarselabels_resnet18.yaml │ ├── cifar10_supcon_simclr_cep_coarselabels_resnet.yaml │ ├── cifar10_supcon_simclr_cep_coarselabels_resnet18.yaml │ ├── cifar10_cep_twobackbones_xfer.yaml │ ├── cifar100_cep_twobackbones_xfer.yaml │ ├── cifar10_cep_autoencoder_coarselabels_resnet.yaml │ ├── cifar10_cep_autoencoder_coarselabels_resnet18.yaml │ └── mnist_cep_twobackbones_xfer.yaml ├── learner │ └── default.yaml ├── trainer │ ├── default.yaml │ └── full.yaml └── modules │ ├── cep_image.yaml │ ├── cep_image_autoencoder.yaml │ ├── cep_image_supcon.yaml │ ├── cep_image_supcon_simclr.yaml │ ├── cep_image_supcon_autoencoder.yaml │ ├── cep_image_supcon_autoencoder_joint.yaml │ └── cep_image_supcon_autoencoder_joint_grayscale.yaml ├── Makefile ├── static └── banner.png ├── requirements-dev.txt ├── setup.py ├── bin └── unagi └── .gitignore /unagi/utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /unagi/data/data_utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /config/loss/mse_loss.yaml: -------------------------------------------------------------------------------- 1 | module: mse_loss 2 | _target_: null 3 | -------------------------------------------------------------------------------- /config/optimizer/adam.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.Adam 2 | lr: 0.001 -------------------------------------------------------------------------------- /unagi/_version.py: -------------------------------------------------------------------------------- 1 | """Unagi version.""" 2 | __version__ = "0.0.1+dev" 3 | -------------------------------------------------------------------------------- /Makefile: -------------------------------------------------------------------------------- 1 | dev: 2 | pip install -r requirements-dev.txt 3 | pip install -e . 4 | -------------------------------------------------------------------------------- /config/encoder/resnet.yaml: -------------------------------------------------------------------------------- 1 | module: resnet 2 | model : resnet18 3 | _target_: null -------------------------------------------------------------------------------- /config/loss/label_smoothing.yaml: -------------------------------------------------------------------------------- 1 | module: label_smoothing 2 | _target_: null 3 | -------------------------------------------------------------------------------- /config/optimizer/adamw.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.AdamW 2 | lr: 0.0003 3 | # weight_decay: 0.0 -------------------------------------------------------------------------------- /config/dataflow/loader/default.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 128 2 | eval_batch_size: 128 3 | num_workers: 16 -------------------------------------------------------------------------------- /config/pipeline/cep_image.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image 3 | -------------------------------------------------------------------------------- /static/banner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HazyResearch/thanos-code/HEAD/static/banner.png -------------------------------------------------------------------------------- /config/decoder/pool.yaml: -------------------------------------------------------------------------------- 1 | module: pool 2 | d_input: null 3 | d_output: null 4 | _target_: null 5 | 6 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_log.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image_log 3 | -------------------------------------------------------------------------------- /config/task/metric/accuracy.yaml: -------------------------------------------------------------------------------- 1 | module: accuracy 2 | inputs: [[classifier, 0], [_output_, label]] 3 | -------------------------------------------------------------------------------- /config/dataflow/dataset/celeba.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | root_dir: None 5 | -------------------------------------------------------------------------------- /config/decoder/grayscale.yaml: -------------------------------------------------------------------------------- 1 | module: grayscale 2 | d_input: null 3 | d_output: null 4 | _target_: null 5 | dim: 1 6 | -------------------------------------------------------------------------------- /config/loss/contrastive_loss.yaml: -------------------------------------------------------------------------------- 1 | module: l_spread 2 | views: 2 3 | temp: 0.5 4 | lc_norm: True 5 | _target_: null 6 | -------------------------------------------------------------------------------- /config/scheduler/cosine.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.CosineAnnealingLR 2 | T_max: 100 3 | eta_min: 0 4 | -------------------------------------------------------------------------------- /config/scheduler/multistep.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.MultiStepLR 2 | milestones: [10, 20] 3 | gamma: 0.1 -------------------------------------------------------------------------------- /config/decoder/sequence_concat.yaml: -------------------------------------------------------------------------------- 1 | module: sequence_concat 2 | d_input: null 3 | d_output: null 4 | _target_: null 5 | -------------------------------------------------------------------------------- /config/preprocessor/supervised_task_preprocessing.yaml: -------------------------------------------------------------------------------- 1 | module: supervised 2 | _target_: null 3 | path_to_checkpoint: null 4 | -------------------------------------------------------------------------------- /config/decoder/linear_projection.yaml: -------------------------------------------------------------------------------- 1 | module: linear_proj 2 | d_input: null #embedding dimension 3 | d_output: null 4 | _target_: null -------------------------------------------------------------------------------- /config/scheduler/plateau.yaml: -------------------------------------------------------------------------------- 1 | _target_: torch.optim.lr_scheduler.ReduceLROnPlateau 2 | mode: min 3 | factor: 0.1 4 | patience: 10 5 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_supcon.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image 3 | - /task@SupCon: supcon_image 4 | -------------------------------------------------------------------------------- /config/decoder/view_select.yaml: -------------------------------------------------------------------------------- 1 | module: view_select 2 | n_views: null 3 | d_input: null 4 | d_output: null 5 | view_idx: null 6 | _target_: null 7 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image 3 | - /task@Autoencoder: autoencoder_image 4 | -------------------------------------------------------------------------------- /config/dataflow/dataset/mnist_coarse.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | coarse_labels: True 5 | class_names: ['<5','>=5'] 6 | -------------------------------------------------------------------------------- /config/task/loss/mse_loss.yaml: -------------------------------------------------------------------------------- 1 | module: MSELoss #uid of loss function in loss section 2 | inputs: [[image_decoder, 0], [image_unflatten, 0]] 3 | weight: 1.0 4 | -------------------------------------------------------------------------------- /config/task/loss/classifier_loss.yaml: -------------------------------------------------------------------------------- 1 | module: ClassifierLoss #uid of loss function in loss section 2 | inputs: [[classifier, 0], [_output_, label]] 3 | weight: 1.0 4 | -------------------------------------------------------------------------------- /config/dataflow/dataset/mnist.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | coarse_labels: False 5 | class_names: ['0','1','2','3','4','5','6','7','8','9'] 6 | -------------------------------------------------------------------------------- /config/embedding/identity.yaml: -------------------------------------------------------------------------------- 1 | module: identity 2 | d_input: null 3 | d_model: null #256 4 | patch_size: 32 5 | path_to_checkpoint: null 6 | source_module: null 7 | _target_: null 8 | -------------------------------------------------------------------------------- /config/embedding/pretrained_lm.yaml: -------------------------------------------------------------------------------- 1 | module: pretrained_lm 2 | d_input: null 3 | d_model: null #256 4 | patch_size: 4 5 | pretrained_lm_name: bert-base-uncased 6 | _target_: null 7 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_supcon_simclr.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image 3 | - /task@SupCon: supcon_image 4 | - /task@SimClr: simclr_image 5 | -------------------------------------------------------------------------------- /requirements-dev.txt: -------------------------------------------------------------------------------- 1 | hydra-core 2 | mypy 3 | types-PyYAML 4 | pytest 5 | torch 6 | torchmetrics==0.6.0 7 | pytorch-lightning==1.4.8 8 | rich 9 | wandb 10 | matplotlib 11 | sklearn -------------------------------------------------------------------------------- /config/embedding/square_patch.yaml: -------------------------------------------------------------------------------- 1 | module: square_patch 2 | d_input: null 3 | d_model: null #256 4 | patch_size: 4 5 | path_to_checkpoint: null 6 | source_module: null 7 | _target_: null 8 | -------------------------------------------------------------------------------- /config/wandb/default.yaml: -------------------------------------------------------------------------------- 1 | project: unagi 2 | group: "" 3 | job_type: training 4 | mode: disabled # choices=['online', 'offline', 'disabled'] 5 | id: null # pass correct id to resume experiment! -------------------------------------------------------------------------------- /config/dataflow/dataset/cifar10.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | class_names: ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 5 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_supcon_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image 3 | - /task@SupCon: supcon_image 4 | - /task@Autoencoder: autoencoder_image 5 | -------------------------------------------------------------------------------- /config/dataflow/dataset/cifar10_coarse.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | coarse_labels: True 5 | class_names: ['vehicles', 'animals'] 6 | return_train_as_test: False 7 | -------------------------------------------------------------------------------- /config/encoder/bert.yaml: -------------------------------------------------------------------------------- 1 | module: bert 2 | freeze_layers: True 3 | pretrained_lm_name: bert-base-uncased 4 | use_cls_token: True 5 | use_all_tokens: False 6 | pretrained_weights: None 7 | _target_: null 8 | -------------------------------------------------------------------------------- /config/task/loss/lspread_contrastive_loss.yaml: -------------------------------------------------------------------------------- 1 | module: ContrastiveLoss #uid of loss function in loss section 2 | inputs: [[image_view_select_0, 0], [image_view_select_1, 0], [_output_, label]] 3 | weight: 1.0 4 | -------------------------------------------------------------------------------- /config/task/loss/simclr_contrastive_loss.yaml: -------------------------------------------------------------------------------- 1 | module: ContrastiveLoss #uid of loss function in loss section 2 | inputs: [[image_view_select_0, 0], [image_view_select_1, 0], [_output_, label]] 3 | weight: 1.0 4 | -------------------------------------------------------------------------------- /config/task/loss/supcon_contrastive_loss.yaml: -------------------------------------------------------------------------------- 1 | module: ContrastiveLoss #uid of loss function in loss section 2 | inputs: [[image_view_select_0, 0], [image_view_select_1, 0], [_output_, label]] 3 | weight: 1.0 4 | -------------------------------------------------------------------------------- /config/decoder/classifier.yaml: -------------------------------------------------------------------------------- 1 | module: classifier 2 | d_input: null #256 -- fill this inside the Python 3 | d_output: null #10 -- fill this inside the Python program 4 | _target_: null 5 | path_to_checkpoint: null 6 | -------------------------------------------------------------------------------- /config/encoder/mixer.yaml: -------------------------------------------------------------------------------- 1 | module: mixer 2 | d_model: 256 3 | n_heads: 8 4 | l_max: 1024 # can be computed based on embedding 5 | n_layers: 4 6 | dropout: 0.1 7 | head_dropout: 0.1, 8 | mlp_dim: None 9 | tie_them_all: False -------------------------------------------------------------------------------- /config/decoder/resnet18decoder.yaml: -------------------------------------------------------------------------------- 1 | module: resnet18decoder 2 | d_input: null #embedding dimension 3 | input_height: null #height of the image to generate 4 | _target_: null 5 | d_output: null 6 | path_to_checkpoint: null 7 | -------------------------------------------------------------------------------- /config/dataflow/dataset/cifar10_train_as_test.yaml: -------------------------------------------------------------------------------- 1 | cifar10: 2 | val_split: 0.1 3 | seed: 42 4 | class_names: ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 5 | return_train_as_test: True 6 | -------------------------------------------------------------------------------- /config/dataflow/dataset/cifar_samples.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | class_names: ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 5 | path_to_numpy_file: None 6 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/utils.py: -------------------------------------------------------------------------------- 1 | def categorize_value(level, value_range, type="int"): 2 | val = value_range[0] + level * (value_range[1] - value_range[0]) 3 | 4 | return int(val) if type == "int" else float(val) 5 | -------------------------------------------------------------------------------- /unagi/tasks/loss_fns/base_loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | 3 | 4 | class UnagiLoss(nn.Module): 5 | def __init__(self): 6 | super().__init__() 7 | 8 | def forward(self): 9 | raise NotImplementedError 10 | -------------------------------------------------------------------------------- /config/decoder/image_reshape.yaml: -------------------------------------------------------------------------------- 1 | module: image_reshape 2 | d_input: null #embedding dimension 3 | output_height: null #height of the image to generate 4 | output_width: null #height of the image to generate 5 | d_output: null 6 | _target_: null -------------------------------------------------------------------------------- /config/config.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - /modules@model: cep_image 5 | - /pipeline@tasks: cep_image 6 | - /dataflow: cifar10 7 | - /trainer: default 8 | - /learner: default 9 | - /wandb: default 10 | -------------------------------------------------------------------------------- /config/task/classification_image.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: classification_image 3 | - loss@losses.classifier_loss: classifier_loss 4 | - metric@metrics.accuracy: accuracy 5 | 6 | task_weight: 1.0 7 | torchmetrics: {} 8 | callbacks: {} 9 | -------------------------------------------------------------------------------- /config/task/classification_image_joint.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: supcon_autoencoder_joint_classification 3 | - loss@losses.classifier_loss: classifier_loss 4 | - metric@metrics.accuracy: accuracy 5 | 6 | task_weight: 1.0 7 | torchmetrics: {} 8 | callbacks: {} 9 | -------------------------------------------------------------------------------- /config/dataflow/dataset/cifar10_subset.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | class_names: ['plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck'] 5 | subset_split_seed: 42 6 | subset_split_percent: 0.5 7 | coarse_labels: False -------------------------------------------------------------------------------- /config/experiment/test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image 5 | - override /pipeline@tasks: cep_image 6 | 7 | trainer: 8 | max_epochs: 5 9 | limit_train_batches: 5 10 | 11 | wandb: 12 | mode: disabled 13 | -------------------------------------------------------------------------------- /unagi/trainer/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.models import MODULE_DICTS 2 | from unagi.tasks import LOSS_MODULE_REGISTRY, TASK_PREPROCESSING_LAYER 3 | 4 | MODULE_REGISTRY = { 5 | "preprocessors": TASK_PREPROCESSING_LAYER, 6 | "losses": LOSS_MODULE_REGISTRY, 7 | **MODULE_DICTS, 8 | } 9 | -------------------------------------------------------------------------------- /config/experiment/mnist_cep_xfer_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/cifar10_xfer_resnet18 4 | - override /dataflow: mnist_coarse 5 | - override /dataflow/transforms@dataflow.transforms: mnist_resnet 6 | 7 | wandb: 8 | group: cifar10_cep_xfer_resnet18 9 | -------------------------------------------------------------------------------- /config/task/callback/log_image.yaml: -------------------------------------------------------------------------------- 1 | module: log_image 2 | logging_batch_idx: null # which batch to log; if -1, log all batches 3 | inputs: null # which input to log 4 | input_names: null # names of each input 5 | log_every_n_epochs: 1 6 | max_images: 16 # maximum number of images to log per input 7 | -------------------------------------------------------------------------------- /unagi/models/ops/view_concat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ViewConcat(nn.Module): 6 | def __init__(self, **kwargs): 7 | super().__init__() 8 | self.name = "view_concat" 9 | 10 | def forward(self, *args): 11 | return torch.stack(args, dim=1) 12 | -------------------------------------------------------------------------------- /unagi/models/decoders/view_concat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ViewConcat(nn.Module): 6 | def __init__(self, **kwargs): 7 | super().__init__() 8 | self.name = "view_concat" 9 | 10 | def forward(self, *args): 11 | return torch.stack(args, dim=1) 12 | -------------------------------------------------------------------------------- /unagi/models/ops/sequence_concat.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class SequenceConcat(nn.Module): 6 | def __init__(self, **kwargs): 7 | super().__init__() 8 | self.name = "sequence_concat" 9 | 10 | def forward(self, *args): 11 | return torch.cat(args, dim=1) 12 | -------------------------------------------------------------------------------- /config/experiment/mnist_supcon_cep_coarselabels_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/cifar10_supcon_cep_coarselabels_resnet18 4 | - override /dataflow: mnist_coarse 5 | - override /dataflow/transforms@dataflow.transforms: mnist_resnet 6 | 7 | wandb: 8 | group: mnist_supcon_cep_resnet18_cl 9 | -------------------------------------------------------------------------------- /unagi/constants.py: -------------------------------------------------------------------------------- 1 | TEXT = "text" 2 | TIMESERIES = "timeseries" 3 | IMAGE = "image" 4 | TYPE = "type" 5 | AUGMENTATIONS = "augmentations" 6 | CONTRASTIVE = "contrastive" 7 | MASKED = "masked" 8 | NAME = "name" 9 | DATASET = "dataset" 10 | TASKS = "tasks" 11 | RAW = "raw" 12 | PATCH = "patch" 13 | FEATURE = "feature" 14 | -------------------------------------------------------------------------------- /unagi/models/embeddings/base_embedding.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class EmbeddingModule(nn.Module): 5 | def __init__( 6 | self, 7 | d_input: int, 8 | d_model: int, 9 | ): 10 | super().__init__() 11 | self.d_input = d_input 12 | self.d_model = d_model 13 | -------------------------------------------------------------------------------- /unagi/models/ops/linear_proj.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class LinearProj(nn.Module): 5 | def __init__(self, d_input, d_output, **kwargs): 6 | super().__init__() 7 | self.linear_proj = nn.Linear(d_input, d_output) 8 | 9 | def forward(self, x): 10 | return self.linear_proj(x) 11 | -------------------------------------------------------------------------------- /config/experiment/mnist_lspread_cep_coarselabels_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/cifar10_lspread_cep_coarselabels_resnet18 4 | - override /dataflow: mnist_coarse 5 | - override /dataflow/transforms@dataflow.transforms: mnist_resnet 6 | 7 | wandb: 8 | group: mnist_lspread_cep_resnet18_cl 9 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image_autoencoder 5 | - override /pipeline@tasks: cep_image_autoencoder 6 | 7 | dataflow: 8 | x: 9 | image: 10 | views: 2 11 | 12 | wandb: 13 | group: cifar10_cep_autoencoder 14 | -------------------------------------------------------------------------------- /config/experiment/mnist_cep_autoencoder_coarselabels_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/cifar10_cep_autoencoder_coarselabels_resnet18 4 | - override /dataflow: mnist 5 | - override /dataflow/transforms@dataflow.transforms: mnist_resnet 6 | 7 | wandb: 8 | group: mnist_cep_autoencoder_resnet18_cl 9 | -------------------------------------------------------------------------------- /config/experiment/mnist_supcon_simclr_cep_coarselabels_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - /experiment/cifar10_supcon_simclr_cep_coarselabels_resnet18 4 | - override /dataflow: mnist_coarse 5 | - override /dataflow/transforms@dataflow.transforms: mnist_resnet 6 | 7 | wandb: 8 | group: mnist_supcon_simclr_cep_resnet18_cl 9 | -------------------------------------------------------------------------------- /config/task/classification_image_clip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: classification_image 3 | - loss@losses.classifier_loss_image: classifier_loss 4 | - metric@metrics.accuracy: accuracy 5 | 6 | task_weight: 1.0 7 | torchmetrics: {} 8 | callbacks: {} 9 | 10 | 11 | task_flow: 12 | classifier: 13 | module: ImageClassifier 14 | 15 | -------------------------------------------------------------------------------- /config/experiment/cifar100_supcon_cep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | 8 | dataflow: 9 | x: 10 | image: 11 | views: 2 12 | 13 | wandb: 14 | group: cifar100_supcon_cep 15 | -------------------------------------------------------------------------------- /unagi/data/transforms/text/identity.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.image.transform import UnagiTransform 2 | 3 | 4 | class Identity(UnagiTransform): 5 | def __init__(self, name=None, prob=1.0, level=0): 6 | super().__init__(name, prob, level) 7 | 8 | def transform(self, pil_img, label, **kwargs): 9 | return pil_img, label 10 | -------------------------------------------------------------------------------- /config/task/classification_joint_clip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: clip_joint_classifier 3 | - loss@losses.classifier_loss_joint: classifier_loss 4 | - metric@metrics.accuracy: accuracy 5 | 6 | task_weight: 1.0 7 | torchmetrics: {} 8 | callbacks: {} 9 | 10 | 11 | task_flow: 12 | classifier: 13 | module: JointClassifier 14 | 15 | 16 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/identity.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.image.transform import UnagiTransform 2 | 3 | 4 | class Identity(UnagiTransform): 5 | def __init__(self, name=None, prob=1.0, level=0): 6 | super().__init__(name, prob, level) 7 | 8 | def transform(self, pil_img, label, **kwargs): 9 | return pil_img, label 10 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image_supcon_autoencoder 5 | - override /pipeline@tasks: cep_image_supcon_autoencoder 6 | 7 | 8 | dataflow: 9 | x: 10 | image: 11 | views: 2 12 | 13 | wandb: 14 | group: cifar10_supcon_cep_autoencoder 15 | -------------------------------------------------------------------------------- /unagi/data/transforms/task/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.task.transform import ( 2 | GroupTransform, 3 | IdentityTransform, 4 | MaskGen, 5 | TupleTransform, 6 | ) 7 | 8 | ALL_TRANSFORMS = { 9 | "Contrastive": GroupTransform, 10 | "MaskGenerator": MaskGen, 11 | "Mask": TupleTransform, 12 | "Identity": IdentityTransform, 13 | } 14 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image_supcon 5 | - override /pipeline@tasks: cep_image_supcon 6 | 7 | model: 8 | losses: 9 | ContrastiveLoss: 10 | type: sup_con 11 | 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | 18 | wandb: 19 | group: cifar10_supcon_cep 20 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/invert.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageOps 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class Invert(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return ImageOps.invert(pil_img), label 12 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/blur.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class Blur(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return pil_img.filter(ImageFilter.BLUR), label 12 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/equalize.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageOps 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class Equalize(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return ImageOps.equalize(pil_img), label 12 | -------------------------------------------------------------------------------- /unagi/data/augmentations/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.data.augmentations.brightness import Brightness 2 | from unagi.data.augmentations.cutout import Cutout 3 | from unagi.data.augmentations.mixup import Mixup 4 | 5 | AUGMENTATIONS = { 6 | "mixup": Mixup, 7 | # "invert": Invert, 8 | "cutout": Cutout, 9 | # "solarize": Solarize, 10 | "brightness": Brightness, 11 | # "rotate": Rotate, 12 | } 13 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/smooth.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageFilter 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class Smooth(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return pil_img.filter(ImageFilter.SMOOTH), label 12 | -------------------------------------------------------------------------------- /unagi/data/transforms/text/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.text.back_translate import BackTranslate 2 | from unagi.data.transforms.text.identity import Identity 3 | from unagi.data.transforms.text.pretrained_lm_tokenize import PretrainedLMTokenize 4 | 5 | ALL_TRANSFORMS = { 6 | "PretrainedLMTokenize": PretrainedLMTokenize, 7 | "BackTranslate": BackTranslate, 8 | "Identity": Identity, 9 | } 10 | -------------------------------------------------------------------------------- /config/dataflow/transforms/cifar10_notransform.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: Identity 3 | image_pil_default_transform: 4 | - type: ToTensor 5 | - type: Normalize 6 | mean: [0.49139968, 0.48215841, 0.44653091] 7 | std: [0.24703223, 0.24348513, 0.26158784] 8 | - type: Reshape2D 9 | h_dim: 3 # num_channels 10 | w_dim: 1024 #32 x 32 x 3 11 | 12 | image_pil_no_transform: 13 | - type: Identity 14 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/auto_contrast.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageOps 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class AutoContrast(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return ImageOps.autocontrast(pil_img), label 12 | -------------------------------------------------------------------------------- /config/encoder/transformer.yaml: -------------------------------------------------------------------------------- 1 | module: transformer 2 | d_model: null 3 | learn_pos: True # has to be false for mixer - args.model is not "mixer", 4 | n_heads: 8 5 | head_dropout: 0.1 6 | n_layers: 7 7 | patch_size: 4 8 | dropout: 0.05 9 | use_cls_token: True 10 | ret_cls_token: True # if True, only return embedding of CLS token 11 | l_max: 65 12 | _target_: null 13 | path_to_checkpoint: null 14 | source_module: null 15 | -------------------------------------------------------------------------------- /unagi/data/transforms/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.image import ALL_TRANSFORMS as ALL_IMAGE_TRANSFORMS 2 | from unagi.data.transforms.task import ALL_TRANSFORMS as ALL_TASK_TRANSFORMS 3 | from unagi.data.transforms.text import ALL_TRANSFORMS as ALL_TEXT_TRANSFORMS 4 | 5 | ALL_TRANSFORMS = { 6 | "text": ALL_TEXT_TRANSFORMS, 7 | "image": ALL_IMAGE_TRANSFORMS, 8 | "task": ALL_TASK_TRANSFORMS, 9 | } 10 | -------------------------------------------------------------------------------- /config/experiment/cifar10_lspread_cep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image_supcon 5 | - override /pipeline@tasks: cep_image_supcon 6 | 7 | model: 8 | losses: 9 | ContrastiveLoss: 10 | module: l_spread 11 | type: l_spread 12 | 13 | 14 | dataflow: 15 | x: 16 | image: 17 | views: 2 18 | 19 | wandb: 20 | group: cifar10_lspread_cep 21 | -------------------------------------------------------------------------------- /config/task/callback/log_embedding.yaml: -------------------------------------------------------------------------------- 1 | module: log_embedding 2 | logging_batch_idx: null # which batch to log; if -1, log all batches 3 | inputs: null # which input to log 4 | input_names: null # names of each input 5 | log_every_n_epochs: 1 6 | batch_size: ${dataflow.batch_size} 7 | eval_batch_size: ${dataflow.eval_batch_size} 8 | plot_embeddings: True 9 | plot_embeddings_stride: 1 10 | class_names: ${dataflow.class_names} 11 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/horizontal_filp.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class HorizontalFlip(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return pil_img.transpose(Image.FLIP_LEFT_RIGHT), label 12 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/vertical_flip.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class VerticalFlip(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return pil_img.transpose(Image.FLIP_TOP_BOTTOM), label 12 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/to_tensor.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class ToTensor(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0): 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return transforms.ToTensor()(pil_img), label 12 | -------------------------------------------------------------------------------- /config/experiment/cifar100_supcon_cep_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100 5 | - override /modules@model: cep_image_supcon_autoencoder 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder 7 | 8 | dataflow: 9 | x: 10 | image: 11 | views: 2 12 | 13 | trainer: 14 | max_epochs: 600 15 | 16 | wandb: 17 | group: cifar100_supcon_cep_autoencoder 18 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image 5 | - override /pipeline@tasks: cep_image 6 | 7 | learner: 8 | checkpoint_scheduler: 9 | monitor: val/Classification_accuracy 10 | mode: max 11 | 12 | dataflow: 13 | x: 14 | image: 15 | views: 2 16 | 17 | trainer: 18 | max_epochs: 600 19 | 20 | wandb: 21 | group: cifar10_cep 22 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/reshape2d.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.image.transform import UnagiTransform 2 | 3 | 4 | class Reshape2D(UnagiTransform): 5 | def __init__(self, h_dim, w_dim, name=None, prob=1.0, level=0): 6 | self.h_dim = h_dim 7 | self.w_dim = w_dim 8 | super().__init__(name, prob, level) 9 | 10 | def transform(self, pil_img, label, **kwargs): 11 | return pil_img.view(self.h_dim, self.w_dim), label 12 | -------------------------------------------------------------------------------- /unagi/models/ops/pool.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class PoolDecoder(nn.Module): 5 | def __init__(self, **kwargs): 6 | super().__init__() 7 | # NOTE: compute d_input as module instantiation time 8 | # d_input = sum(d_model of all encoders being fed to Classifier) 9 | 10 | def forward(self, x): 11 | """ 12 | x: intermediate outpus from encoder. shape: (B, S, H) 13 | """ 14 | return x.mean(-2) 15 | -------------------------------------------------------------------------------- /config/dataflow/transforms/celeba.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: CenterCrop 3 | size: [178, 218] 4 | - type: Resize 5 | prob: 1.0 6 | size: 224 7 | 8 | image_pil_default_transform: 9 | - type: ToTensor 10 | - type: Normalize 11 | mean: [0.485, 0.456, 0.406] 12 | std: [0.229, 0.224, 0.225] 13 | # - type: Reshape2D 14 | # h_dim: 3 # num_channels 15 | # w_dim: 1024 #32 x 32 x 3 16 | 17 | image_pil_no_transform: 18 | - type: Identity -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_autoencoder_joint.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image_supcon_autoencoder_joint 5 | - override /pipeline@tasks: cep_image_supcon_autoencoder_joint 6 | 7 | 8 | dataflow: 9 | x: 10 | image: 11 | views: 2 12 | 13 | learner: 14 | checkpoint_scheduler: 15 | monitor: val/Classification_accuracy 16 | mode: max 17 | 18 | wandb: 19 | group: cifar10_supcon_cep_autoencoder 20 | -------------------------------------------------------------------------------- /config/task/task_flow/classification_text.yaml: -------------------------------------------------------------------------------- 1 | supervised_task_preprocessing: 2 | module: SupervisedTaskPreprocessing # this is the UID 3 | inputs: [[_input_, inputs]] 4 | text_pre_encoder: 5 | module: TextPreEncoder 6 | inputs: [[supervised_task_preprocessing, text]] 7 | text_encoder: 8 | module: TextEncoder 9 | inputs: [[text_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 10 | classifier: 11 | module: Classifier 12 | inputs: [[text_encoder, 0]] 13 | -------------------------------------------------------------------------------- /config/task/task_flow/classification_image.yaml: -------------------------------------------------------------------------------- 1 | supervised_task_preprocessing: 2 | module: SupervisedTaskPreprocessing # this is the UID 3 | inputs: [[_input_, inputs]] 4 | image_pre_encoder: 5 | module: ImagePreEncoder 6 | inputs: [[supervised_task_preprocessing, image]] 7 | image_encoder: 8 | module: ImageEncoder 9 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 10 | classifier: 11 | module: Classifier 12 | inputs: [[image_encoder, 0]] 13 | -------------------------------------------------------------------------------- /unagi/models/ops/view_select.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch import nn 3 | 4 | 5 | class ViewSelect(nn.Module): 6 | def __init__(self, view_idx, n_views, **kwargs): 7 | super().__init__() 8 | self.name = "view_select" 9 | self.view_idx = view_idx 10 | self.n_views = n_views 11 | 12 | def forward(self, input): 13 | embs = rearrange(input, "(b v) ... -> b v ...", v=self.n_views) 14 | embs = embs[:, self.view_idx, ...] 15 | return embs 16 | -------------------------------------------------------------------------------- /unagi/tasks/output_layer_modules.py: -------------------------------------------------------------------------------- 1 | from torch.nn import functional as F 2 | 3 | 4 | def multiclass_classification(module_name, immediate_output_dict): 5 | return F.softmax( 6 | immediate_output_dict[module_name][len(immediate_output_dict[module_name]) - 1], 7 | dim=1, 8 | ) 9 | 10 | 11 | def multilabel_classification(module_name, immediate_output_dict): 12 | return F.sigmoid( 13 | immediate_output_dict[module_name][len(immediate_output_dict[module_name]) - 1] 14 | ) 15 | -------------------------------------------------------------------------------- /config/dataflow/dataset/tinyimagenet_coarse.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | coarse_labels: True 5 | class_names: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66'] -------------------------------------------------------------------------------- /config/experiment/cifar100_cep.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100 5 | - override /modules@model: cep_image 6 | - override /pipeline@tasks: cep_image_log 7 | 8 | 9 | learner: 10 | monitor: val/Classification_accuracy 11 | checkpoint_scheduler: 12 | monitor: val/Classification_accuracy 13 | mode: max 14 | 15 | dataflow: 16 | x: 17 | image: 18 | views: 2 19 | 20 | trainer: 21 | max_epochs: 600 22 | 23 | wandb: 24 | group: cifar100_cep 25 | -------------------------------------------------------------------------------- /config/experiment/cifar100_cep_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100_coarse 5 | - override /modules@model: cep_image 6 | - override /pipeline@tasks: cep_image 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | num_classes: 20 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | y: 18 | coarse_label: 19 | transform_with: image 20 | 21 | 22 | trainer: 23 | max_epochs: 600 24 | 25 | wandb: 26 | group: cifar100_cep_cl 27 | 28 | -------------------------------------------------------------------------------- /config/learner/default.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /optimizer: adamw 3 | - /scheduler: cosine 4 | 5 | interval: epoch 6 | monitor: val/Classification_accuracy 7 | name: trainer/lr 8 | modules_to_freeze: [] 9 | task_scheduler: round_robin # [sequential, round_robin, mixed] 10 | checkpoint_scheduler: 11 | dirpath: null 12 | monitor: loss 13 | mode: min 14 | filename: best 15 | save_last: True 16 | sequential_scheduler_config: 17 | fillup: False 18 | round_robin_scheduler_config: 19 | fillup: False 20 | mixed_scheduler_config: 21 | fillup: False 22 | -------------------------------------------------------------------------------- /config/dataflow/dataset/cifar100_coarse.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_._name_ 2 | val_split: 0.1 3 | seed: 42 4 | coarse_labels: True 5 | coarse_labels_u: False 6 | class_names: ['aquatic_mammals', 'fish', 'flowers', 'food_containers', 'fruit_and_vegetables', 'household_electrical_devices', 'household_furniture', 'insects', 'large_carnivores', 'large_man-made_outdoor_things', 'large_natural_outdoor_scenes', 'large_omnivores_and_herbivores', 'medium_mammals', 'non-insect_invertebrates', 'people', 'reptiles', 'small_mammals', 'trees', 'vehicles_1', 'vehicles_2'] 7 | -------------------------------------------------------------------------------- /config/task/classification_image_log.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: classification_image 3 | - loss@losses.classifier_loss: classifier_loss 4 | - metric@metrics.accuracy: accuracy 5 | - callback@callbacks.encoder_embedding: log_embedding 6 | 7 | task_weight: 1.0 8 | torchmetrics: {} 9 | callbacks: 10 | encoder_embedding: 11 | logging_batch_idx: 0 12 | inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 13 | input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 14 | log_every_n_epochs: 50 15 | plot_embeddings: True 16 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image 6 | - override /pipeline@tasks: cep_image 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output : 2 12 | 13 | learner: 14 | checkpoint_scheduler: 15 | monitor: val/Classification_accuracy 16 | mode: max 17 | 18 | dataflow: 19 | x: 20 | image: 21 | views: 2 22 | 23 | trainer: 24 | max_epochs: 600 25 | 26 | wandb: 27 | group: cifar10_cep_cl 28 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/normalize.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class Normalize(UnagiTransform): 7 | def __init__(self, mean, std, name=None, prob=1.0, level=0): 8 | self.mean = mean 9 | self.std = std 10 | self.transform_func = transforms.Normalize(mean, std) 11 | 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | return self.transform_func(pil_img), label 16 | -------------------------------------------------------------------------------- /config/dataflow/transforms/mnist.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 28 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | image_pil_default_transform: 17 | - type: ToTensor 18 | - type: Normalize 19 | mean: [0.1307] 20 | std: [0.3081] 21 | - type: Reshape2D 22 | h_dim: 1 # num_channels 23 | w_dim: 784 #32 x 32 x 3 24 | -------------------------------------------------------------------------------- /config/task/task_flow/supcon_image.yaml: -------------------------------------------------------------------------------- 1 | supervised_task_preprocessing: 2 | module: SupervisedTaskPreprocessing # this is the UID 3 | inputs: [[_input_, inputs]] 4 | image_pre_encoder: 5 | module: ImagePreEncoder 6 | inputs: [[supervised_task_preprocessing, image]] 7 | image_encoder: 8 | module: ImageEncoder 9 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 10 | image_view_select_0: 11 | module: ViewSelect0 12 | inputs: [[image_encoder, 0]] 13 | image_view_select_1: 14 | module: ViewSelect1 15 | inputs: [[image_encoder, 0]] 16 | -------------------------------------------------------------------------------- /unagi/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.datasets.celeba.celeba_dataset import CelebA 2 | from unagi.datasets.cifar.cifar_dataset import CIFAR10, CIFAR100 3 | from unagi.datasets.mnist.mnist_dataset import MNIST 4 | from unagi.datasets.tiny_imagenet.tinyimagenet_dataset import TinyImageNet 5 | 6 | DATASET_CLASSES = { 7 | "cifar10": CIFAR10, 8 | "cifar100": CIFAR100, 9 | "cifar10_coarse": CIFAR10, 10 | "cifar100_coarse": CIFAR100, 11 | "tinyimagenet": TinyImageNet, 12 | "tinyimagenet_coarse": TinyImageNet, 13 | "mnist": MNIST, 14 | "celeba": CelebA, 15 | } 16 | -------------------------------------------------------------------------------- /config/dataflow/transforms/mnist_resnet.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 28 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | image_pil_default_transform: 17 | - type: Grayscale 18 | num_output_channels: 3 19 | - type: ToTensor 20 | - type: Normalize 21 | mean: [0.1307, 0.1307, 0.1307] 22 | std: [0.3081, 0.3081, 0.3081] 23 | -------------------------------------------------------------------------------- /config/trainer/default.yaml: -------------------------------------------------------------------------------- 1 | gpus: 1 # set `1` to train on GPU, `0` to train on CPU only 2 | accumulate_grad_batches: 1 3 | max_epochs: 200 4 | gradient_clip_val: 0.0 5 | log_every_n_steps: 10 6 | limit_train_batches: 1.0 # train on full dataset, can be used to toggle quick run 7 | limit_val_batches: 1.0 # train on full dataset, can be used to toggle quick run 8 | weights_summary: top # Set to 'full' to see every layer 9 | progress_bar_refresh_rate: 1 10 | track_grad_norm: -1 # Set to 2 to track norms of gradients 11 | resume_from_checkpoint: null 12 | reload_dataloaders_every_n_epochs: 1 13 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/posterize.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageOps 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Posterize(UnagiTransform): 8 | 9 | value_range = (0, 4) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "int") 16 | return ImageOps.posterize(pil_img, degree), label 17 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/solarize.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageOps 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Solarize(UnagiTransform): 8 | 9 | value_range = (0, 256) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "float") 16 | return ImageOps.solarize(pil_img, degree), label 17 | -------------------------------------------------------------------------------- /config/task/task_flow/autoencoder_image.yaml: -------------------------------------------------------------------------------- 1 | supervised_task_preprocessing: 2 | module: SupervisedTaskPreprocessing # this is the UID 3 | inputs: [[_input_, inputs]] 4 | image_pre_encoder: 5 | module: ImagePreEncoder 6 | inputs: [[supervised_task_preprocessing, image]] 7 | image_encoder: 8 | module: ImageEncoder 9 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 10 | image_decoder: 11 | module: ImageDecoder 12 | inputs: [[image_encoder, 0]] 13 | image_unflatten: 14 | module: ImageReshape 15 | inputs: [[supervised_task_preprocessing, image]] 16 | -------------------------------------------------------------------------------- /unagi/models/ops/image_reshape.py: -------------------------------------------------------------------------------- 1 | from einops import rearrange 2 | from torch import nn 3 | 4 | 5 | class ImageReshape(nn.Module): 6 | def __init__(self, d_input, output_height, output_width, **kwargs): 7 | super().__init__() 8 | self.name = "view_select" 9 | self.d_input = d_input 10 | self.output_height = output_height 11 | self.output_width = output_width 12 | 13 | def forward(self, input): 14 | embs = rearrange( 15 | input, "... (h w) -> ... h w", h=self.output_height, w=self.output_width 16 | ) 17 | return embs 18 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | 8 | 9 | model: 10 | decoders: 11 | Classifier: 12 | d_output : 2 13 | 14 | dataflow: 15 | x: 16 | image: 17 | views: 2 18 | 19 | learner: 20 | checkpoint_scheduler: 21 | monitor: val/Classification_accuracy 22 | mode: max 23 | 24 | trainer: 25 | max_epochs: 600 26 | 27 | wandb: 28 | group: cifar10_supcon_cep_cl 29 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/color.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageEnhance 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Color(UnagiTransform): 8 | 9 | value_range = (0.1, 1.9) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "float") 16 | return ImageEnhance.Color(pil_img).enhance(degree), label 17 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_autoencoder_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_autoencoder 6 | - override /pipeline@tasks: cep_image_autoencoder 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output: 2 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | 18 | trainer: 19 | max_epochs: 600 20 | 21 | learner: 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | 26 | wandb: 27 | group: cifar10_cep_autoencoder_cl 28 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/contrast.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageEnhance 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Contrast(UnagiTransform): 8 | 9 | value_range = (0.1, 1.9) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "float") 16 | return ImageEnhance.Contrast(pil_img).enhance(degree), label 17 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/sharpness.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageEnhance 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Sharpness(UnagiTransform): 8 | 9 | value_range = (0.1, 1.9) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "float") 16 | return ImageEnhance.Sharpness(pil_img).enhance(degree), label 17 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/brightness.py: -------------------------------------------------------------------------------- 1 | from PIL import ImageEnhance 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Brightness(UnagiTransform): 8 | 9 | value_range = (0.1, 1.9) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "float") 16 | return ImageEnhance.Brightness(pil_img).enhance(degree), label 17 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/gaussian_blur.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class GaussianBlur(UnagiTransform): 7 | def __init__(self, kernel_size, sigma=(0.1, 2.0), name=None, prob=1.0, level=0): 8 | self.kernel_size = kernel_size 9 | self.sigma = sigma 10 | self.transform_func = transforms.GaussianBlur(self.kernel_size, self.sigma) 11 | 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | return self.transform_func(pil_img), label 16 | -------------------------------------------------------------------------------- /config/modules/cep_image.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /preprocessor@preprocessors.SupervisedTaskPreprocessing: supervised_task_preprocessing 3 | - /embedding@embeddings.ImagePreEncoder: square_patch 4 | - /encoder@encoders.ImageEncoder: transformer 5 | - /decoder@decoders.Classifier: classifier 6 | - /loss@losses.ClassifierLoss: label_smoothing 7 | 8 | train: True 9 | label_smoothing: True 10 | 11 | 12 | embeddings: 13 | ImagePreEncoder: 14 | d_model: 256 15 | 16 | # encoders: 17 | # ImageEncoder: 18 | # d_model: 256 19 | 20 | # decoders: 21 | # Classifier: 22 | # d_input: 256 23 | # d_output: 10 24 | -------------------------------------------------------------------------------- /config/task/supcon_image.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: supcon_image 3 | - loss@losses.contrastive_loss: supcon_contrastive_loss 4 | #- callback@callbacks.encoder_embedding: log_embedding 5 | 6 | task_weight: 1.0 7 | metrics: {} 8 | torchmetrics: {} 9 | callbacks: {} 10 | #losses: 11 | # supcon_contrastive_loss: 12 | # module: ContrastiveLossSup 13 | #callbacks: 14 | # encoder_embedding: 15 | # logging_batch_idx: 0 16 | # inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 17 | # input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 18 | # log_every_n_epochs: 50 19 | # plot_embeddings: True 20 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/rotate.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | from unagi.data.transforms.image.utils import categorize_value 5 | 6 | 7 | class Rotate(UnagiTransform): 8 | 9 | value_range = (0, 30) 10 | 11 | def __init__(self, name=None, prob=1.0, level=0): 12 | super().__init__(name, prob, level) 13 | 14 | def transform(self, pil_img, label, **kwargs): 15 | degree = categorize_value(self.level, self.value_range, "float") 16 | if random.random() > 0.5: 17 | degree = -degree 18 | return pil_img.rotate(degree), label 19 | -------------------------------------------------------------------------------- /unagi/models/ops/einsum_reduce.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class EinsumReduceDecoder(nn.Module): 6 | def __init__(self, d_model, **kwargs): 7 | super().__init__() 8 | # NOTE: compute d_input as module instantiation time 9 | # d_input = sum(d_model of all encoders being fed to Classifier) 10 | self.attend = nn.Linear(d_model, 1) 11 | 12 | def forward(self, x): 13 | """ 14 | x: intermediate outpus from encoder. shape: (B, S, H) 15 | """ 16 | x = torch.einsum("b s o, b s d -> b d", self.attend(x).softmax(-1), x) 17 | return x.mean(-2) 18 | -------------------------------------------------------------------------------- /config/experiment/cifar100_supcon_cep_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output: 20 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | 18 | trainer: 19 | max_epochs: 600 20 | 21 | learner: 22 | monitor: val/Classification_accuracy 23 | checkpoint_scheduler: 24 | monitor: val/Classification_accuracy 25 | mode: max 26 | 27 | wandb: 28 | group: cifar100_supcon_cep_cl 29 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_autoencoder_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon_autoencoder 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output: 2 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | 18 | trainer: 19 | max_epochs: 600 20 | 21 | learner: 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | 26 | wandb: 27 | group: cifar10_supcon_cep_autoencoder_cl 28 | -------------------------------------------------------------------------------- /config/experiment/cifar100_supcon_cep_autoencoder_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100_coarse 5 | - override /modules@model: cep_image_supcon_autoencoder 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output: 20 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | 18 | trainer: 19 | max_epochs: 600 20 | 21 | learner: 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | 26 | wandb: 27 | group: cifar100_supcon_cep_autoencoder_cl 28 | -------------------------------------------------------------------------------- /config/dataflow/transforms/cifar10_resnet.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 32 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | - type: RandomGrayscale 17 | p: 0.2 18 | image_pil_default_transform: 19 | - type: ToTensor 20 | - type: Normalize 21 | mean: [0.49139968, 0.48215841, 0.44653091] 22 | std: [0.24703223, 0.24348513, 0.26158784] 23 | 24 | image_pil_no_transform: 25 | - type: Identity 26 | -------------------------------------------------------------------------------- /config/task/simclr_image.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: supcon_image 3 | - loss@losses.simclr_contrastive_loss: simclr_contrastive_loss 4 | # - callback@callbacks.encoder_embedding: log_embedding 5 | 6 | task_weight: 1.0 7 | metrics: {} 8 | torchmetrics: {} 9 | callbacks: {} 10 | # callbacks: 11 | # encoder_embedding: 12 | # logging_batch_idx: 0 13 | # inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 14 | # input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 15 | # log_every_n_epochs: 50 16 | # plot_embeddings: True 17 | 18 | losses: 19 | simclr_contrastive_loss: 20 | module: ContrastiveLossSim 21 | 22 | -------------------------------------------------------------------------------- /unagi/datasets/mnist/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def sparse2coarse(targets, scramble=False, dataset="mnist"): 5 | """Convert Pytorch MNIST sparse targets. 6 | trainset = torchvision.datasets.CIFAR100(path) 7 | trainset.targets = sparse2coarse(trainset.targets) 8 | """ 9 | if dataset == "mnist": 10 | sparse_coarse_array = [ 11 | 0, 12 | 0, 13 | 0, 14 | 0, 15 | 0, 16 | 1, 17 | 1, 18 | 1, 19 | 1, 20 | 1, 21 | ] 22 | 23 | targets = np.array(sparse_coarse_array)[targets] 24 | return targets.tolist() 25 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_autoencoder_twobackbones_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon_autoencoder_joint 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder_joint 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output: 2 12 | 13 | dataflow: 14 | x: 15 | image: 16 | views: 2 17 | 18 | trainer: 19 | max_epochs: 600 20 | 21 | learner: 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | 26 | wandb: 27 | group: cifar10_supcon_cep_autoencoder_twobackbones_cl 28 | -------------------------------------------------------------------------------- /config/task/task_flow/autoencoder_image_grayscale.yaml: -------------------------------------------------------------------------------- 1 | supervised_task_preprocessing: 2 | module: SupervisedTaskPreprocessing # this is the UID 3 | inputs: [[_input_, inputs]] 4 | image_pre_encoder: 5 | module: ImagePreEncoder 6 | inputs: [[supervised_task_preprocessing, image]] 7 | image_encoder: 8 | module: ImageEncoder 9 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 10 | image_decoder: 11 | module: ImageDecoder 12 | inputs: [[image_encoder, 0]] 13 | gray_scale: 14 | module: Grayscale 15 | inputs: [[image_decoder, 0]] 16 | image_unflatten: 17 | module: ImageReshape 18 | inputs: [[supervised_task_preprocessing, image]] 19 | -------------------------------------------------------------------------------- /config/task/classification_text_clip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: classification_text 3 | - loss@losses.classifier_loss_text: classifier_loss 4 | - metric@metrics.accuracy: accuracy 5 | # - callback@callbacks.classifier_embedding: log_embedding 6 | 7 | task_weight: 1.0 8 | torchmetrics: {} 9 | callbacks: {} 10 | # callbacks: 11 | # classifier_embedding: 12 | # logging_batch_idx: 0 13 | # inputs: [[classifier, 0], [_input_, index], [_output_, label]] 14 | # input_names: ['TextClassifierEmbeedding', 'sample_uid', 'labels'] 15 | # log_every_n_epochs: 1 16 | # plot_embeddings: True 17 | 18 | 19 | task_flow: 20 | classifier: 21 | module: TextClassifier 22 | -------------------------------------------------------------------------------- /config/dataflow/transforms/cifar100.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 32 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | - type: RandomGrayscale 17 | p: 0.2 18 | image_pil_default_transform: 19 | - type: ToTensor 20 | - type: Normalize 21 | mean: [0.50707516, 0.48654887, 0.44091784] 22 | std: [0.26733429, 0.25643846, 0.27615047] 23 | - type: Reshape2D 24 | h_dim: 3 # num_channels 25 | w_dim: 1024 #32 x 32 x 3 26 | 27 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_coarselabels_clip_pos.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | 8 | 9 | model: 10 | decoders: 11 | Classifier: 12 | d_output : 2 13 | losses: 14 | ContrastiveLoss: 15 | lc_norm: True 16 | clip_pos: 0.95 17 | 18 | dataflow: 19 | x: 20 | image: 21 | views: 2 22 | 23 | learner: 24 | checkpoint_scheduler: 25 | monitor: val/Classification_accuracy 26 | mode: max 27 | 28 | trainer: 29 | max_epochs: 600 30 | 31 | wandb: 32 | group: cifar10_supcon_cep_cl_clip_pos 33 | -------------------------------------------------------------------------------- /config/task/task_flow/supcon_autoencoder_joint_classification.yaml: -------------------------------------------------------------------------------- 1 | supervised_task_preprocessing: 2 | module: SupervisedTaskPreprocessing # this is the UID 3 | inputs: [[_input_, inputs]] 4 | image_pre_encoder_1: 5 | module: ImagePreEncoderSup 6 | inputs: [[supervised_task_preprocessing, image]] 7 | image_pre_encoder_2: 8 | module: ImagePreEncoderAuto 9 | inputs: [[supervised_task_preprocessing, image]] 10 | image_encoder_sup: 11 | module: ImageEncoderSup 12 | inputs: [[image_pre_encoder_1, 0]] 13 | image_encoder_auto: 14 | module: ImageEncoderAuto 15 | inputs: [[image_pre_encoder_2, 0]] 16 | classifier: 17 | module: Classifier 18 | inputs: [[image_encoder_sup, 0], [image_encoder_auto, 0]] 19 | -------------------------------------------------------------------------------- /config/dataflow/transforms/tinyimagenet.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 64 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | - type: RandomGrayscale 17 | p: 0.2 18 | image_pil_default_transform: 19 | - type: ToTensor 20 | - type: Normalize 21 | mean: [0.485, 0.456, 0.406] 22 | std: [0.229, 0.224, 0.225] 23 | - type: Reshape2D 24 | h_dim: 3 # num_channels 25 | w_dim: 4096 #32 x 32 x 3 26 | 27 | image_pil_no_transform: 28 | - type: Identity -------------------------------------------------------------------------------- /unagi/models/ops/grayscale.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torchvision import transforms as transforms 4 | 5 | 6 | class Grayscale(nn.Module): 7 | def __init__(self, dim=1, resize=None, **kwargs): 8 | super().__init__() 9 | self.dim = dim 10 | self.resize = resize 11 | if self.resize: 12 | self.resize_func = transforms.Resize( 13 | self.resize, transforms.InterpolationMode.BILINEAR 14 | ) 15 | 16 | def forward(self, x): 17 | grayscale_image = torch.mean(x, dim=self.dim, keepdim=True) 18 | if self.resize: 19 | return self.resize_func(grayscale_image) 20 | 21 | return grayscale_image 22 | -------------------------------------------------------------------------------- /config/experiment/cifar10_lspread_cep_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | 8 | 9 | model: 10 | decoders: 11 | Classifier: 12 | d_output : 2 13 | losses: 14 | ContrastiveLoss: 15 | module: l_spread 16 | type: l_spread 17 | 18 | dataflow: 19 | x: 20 | image: 21 | views: 2 22 | 23 | learner: 24 | checkpoint_scheduler: 25 | monitor: val/Classification_accuracy 26 | mode: max 27 | save_last: True 28 | 29 | trainer: 30 | max_epochs: 600 31 | 32 | wandb: 33 | group: cifar10_lspread_cep_cl 34 | -------------------------------------------------------------------------------- /unagi/models/decoders/classifier.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ClassificationDecoder(nn.Module): 6 | def __init__(self, d_input, d_output, **kwargs): 7 | super().__init__() 8 | # NOTE: compute d_input as module instantiation time 9 | # d_input = sum(d_model of all encoders being fed to Classifier) 10 | self.classification_layer = nn.Linear(d_input, d_output) 11 | 12 | def forward( 13 | self, 14 | *final_outs, 15 | ): 16 | """ 17 | final_outs List[Tensor]: intermediate outputs from encoder. shape: (B, S, H) 18 | """ 19 | fx = torch.cat(final_outs, dim=-1) 20 | return self.classification_layer(fx) 21 | -------------------------------------------------------------------------------- /config/dataflow/celeba.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: celeba 3 | - transforms: celeba 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | 14 | x: 15 | image: 16 | transform: image_pil 17 | default_transform: image_pil_default_transform 18 | views: 1 19 | mask: null # TODO: Should be added as a default 20 | type: image # TODO: currently needs to be a mandatory param 21 | 22 | y: 23 | label: 24 | transform_with: image 25 | 26 | data_dir: None 27 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/center_crop.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class CenterCrop(UnagiTransform): 7 | def __init__(self, size, name=None, prob=1.0, level=0): 8 | self.size = size 9 | self.transform_func = transforms.CenterCrop(self.size) 10 | super().__init__(name, prob, level) 11 | 12 | def transform(self, pil_img, label, **kwargs): 13 | return self.transform_func(pil_img), label 14 | 15 | def __repr__(self): 16 | return ( 17 | f"" 19 | ) 20 | -------------------------------------------------------------------------------- /config/task/contrastive_image_clip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: clip_image 3 | - loss@losses.contrastive_loss_image: lspread_contrastive_loss 4 | # - callback@callbacks.encoder_embedding: log_embedding 5 | 6 | 7 | task_weight: 1.0 8 | metrics: {} 9 | torchmetrics: {} 10 | callbacks: {} 11 | # callbacks: 12 | # encoder_embedding: 13 | # logging_batch_idx: 0 14 | # inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 15 | # input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 16 | # log_every_n_epochs: 50 17 | # plot_embeddings: True 18 | 19 | losses: 20 | contrastive_loss_image: 21 | inputs: [[image_view_select_0_proj, 0], [image_view_select_1_proj, 0], [_output_, label]] 22 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/random_grayscale.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class RandomGrayscale(UnagiTransform): 7 | def __init__(self, p=0.1, name=None, prob=1.0, level=0): 8 | self.p = p 9 | self.transform_func = transforms.RandomGrayscale(self.p) 10 | 11 | super().__init__(name, prob, level) 12 | 13 | def transform(self, pil_img, label, **kwargs): 14 | return self.transform_func(pil_img), label 15 | 16 | def __repr__(self): 17 | return ( 18 | f"" 19 | ) 20 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_coarselabels_weighted_pos_in_denom.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | 8 | 9 | model: 10 | decoders: 11 | Classifier: 12 | d_output : 2 13 | losses: 14 | ContrastiveLoss: 15 | pos_in_denom: True 16 | pos_in_denom_weight: 2.0 17 | 18 | dataflow: 19 | x: 20 | image: 21 | views: 2 22 | 23 | learner: 24 | checkpoint_scheduler: 25 | monitor: val/Classification_accuracy 26 | mode: max 27 | 28 | trainer: 29 | max_epochs: 600 30 | 31 | wandb: 32 | group: cifar10_supcon_cep_cl_weighted_pos_in_denom 33 | -------------------------------------------------------------------------------- /unagi/models/decoders/sequence/transformer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from unagi.models.encoders.base_sequence import SequenceModule 4 | from unagi.models.encoders.sequence.transformer.transformer_modules import MHA_Decoder 5 | 6 | 7 | class TransformerDecoder(SequenceModule): 8 | def __init__(self, d_model, n_heads, dropout=0.1, head_dropout=0.1, **kwargs): 9 | super().__init__() 10 | self.blocks = nn.ModuleList( 11 | [MHA_Decoder(d_model, n_heads, dropout=dropout, head_dropout=head_dropout)] 12 | ) 13 | 14 | def forward(self, x, target, state=None, mask=None, *args, **kwargs): 15 | for b in self.blocks: 16 | tgt = b(target, x, src_mask=mask, tgt_mask=mask) 17 | return tgt 18 | -------------------------------------------------------------------------------- /config/dataflow/mnist.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: mnist 3 | - transforms: mnist 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | class_names: ${dataflow.dataset.mnist.class_names} 14 | x: 15 | image: 16 | transform: image_pil 17 | default_transform: image_pil_default_transform 18 | views: 1 19 | mask: null # TODO: Should be added as a default 20 | type: image # TODO: currently needs to be a mandatory param 21 | 22 | y: 23 | label: 24 | transform_with: image 25 | 26 | data_dir: None 27 | -------------------------------------------------------------------------------- /config/experiment/cifar10_lspread_cep_autoencoder_twobackbones_coarselabels.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon_autoencoder_joint 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder_joint 7 | 8 | model: 9 | decoders: 10 | Classifier: 11 | d_output: 2 12 | losses: 13 | ContrastiveLoss: 14 | module: l_spread 15 | type: l_spread 16 | 17 | dataflow: 18 | x: 19 | image: 20 | views: 2 21 | 22 | trainer: 23 | max_epochs: 600 24 | 25 | learner: 26 | checkpoint_scheduler: 27 | monitor: val/Classification_accuracy 28 | mode: max 29 | 30 | wandb: 31 | group: cifar10_lspread_cep_autoencoder_twobackbones_cl 32 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/shear_x.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | 5 | from unagi.data.transforms.image.transform import UnagiTransform 6 | from unagi.data.transforms.image.utils import categorize_value 7 | 8 | 9 | class ShearX(UnagiTransform): 10 | 11 | value_range = (0.0, 0.3) 12 | 13 | def __init__(self, name=None, prob=1.0, level=0): 14 | super().__init__(name, prob, level) 15 | 16 | def transform(self, pil_img, label, **kwargs): 17 | degree = categorize_value(self.level, self.value_range, "float") 18 | if random.random() > 0.5: 19 | degree = -degree 20 | return ( 21 | pil_img.transform(pil_img.size, Image.AFFINE, (1, degree, 0, 0, 1, 0)), 22 | label, 23 | ) 24 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/shear_y.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | 5 | from unagi.data.transforms.image.transform import UnagiTransform 6 | from unagi.data.transforms.image.utils import categorize_value 7 | 8 | 9 | class ShearY(UnagiTransform): 10 | 11 | value_range = (0.0, 0.3) 12 | 13 | def __init__(self, name=None, prob=1.0, level=0): 14 | super().__init__(name, prob, level) 15 | 16 | def transform(self, pil_img, label, **kwargs): 17 | degree = categorize_value(self.level, self.value_range, "float") 18 | if random.random() > 0.5: 19 | degree = -degree 20 | return ( 21 | pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, degree, 1, 0)), 22 | label, 23 | ) 24 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_xfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image 5 | - override /pipeline@tasks: cep_image 6 | 7 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 8 | model: 9 | encoders: 10 | ImageEncoder: 11 | path_to_checkpoint: None 12 | embeddings: 13 | ImagePreEncoder: 14 | path_to_checkpoint: None 15 | 16 | # Note: customize dirpath 17 | learner: 18 | modules_to_freeze: [ImageEncoder, ImagePreEncoder] 19 | optimizer: 20 | lr: .001 21 | checkpoint_scheduler: 22 | monitor: val/Classification_accuracy 23 | mode: max 24 | dirpath: None 25 | 26 | trainer: 27 | max_epochs: 100 28 | 29 | wandb: 30 | group: cifar10_cep_xfer 31 | -------------------------------------------------------------------------------- /config/dataflow/cifar10.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar10 3 | - transforms: cifar10 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | class_names: ${dataflow.dataset.cifar10.class_names} 14 | 15 | x: 16 | image: 17 | transform: image_pil 18 | default_transform: image_pil_default_transform 19 | views: 1 20 | mask: null # TODO: Should be added as a default 21 | type: image # TODO: currently needs to be a mandatory param 22 | 23 | y: 24 | label: 25 | transform_with: image 26 | 27 | data_dir: None 28 | -------------------------------------------------------------------------------- /config/dataflow/cifar100.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar100 3 | - transforms: cifar100 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | class_names: ${dataflow.dataset.cifar100.class_names} 13 | 14 | x: 15 | image: 16 | transform: image_pil 17 | default_transform: image_pil_default_transform 18 | views: 1 19 | mask: null # TODO: Should be added as a default 20 | type: image # TODO: currently needs to be a mandatory param 21 | 22 | y: 23 | label: 24 | transform_with: image 25 | 26 | data_dir: None 27 | -------------------------------------------------------------------------------- /config/dataflow/cifar100_coarse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar100_coarse 3 | - transforms: cifar100 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | class_names: ${dataflow.dataset.cifar100_coarse.class_names} 13 | 14 | x: 15 | image: 16 | transform: image_pil 17 | default_transform: image_pil_default_transform 18 | views: 1 19 | mask: null # TODO: Should be added as a default 20 | type: image # TODO: currently needs to be a mandatory param 21 | 22 | y: 23 | label: 24 | transform_with: image 25 | 26 | data_dir: None 27 | -------------------------------------------------------------------------------- /config/dataflow/cifar10_coarse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar10_coarse 3 | - transforms: cifar10 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | class_names: ${dataflow.dataset.cifar10_coarse.class_names} 14 | 15 | x: 16 | image: 17 | transform: image_pil 18 | default_transform: image_pil_default_transform 19 | views: 1 20 | mask: null # TODO: Should be added as a default 21 | type: image # TODO: currently needs to be a mandatory param 22 | 23 | y: 24 | label: 25 | transform_with: image 26 | 27 | data_dir: None 28 | -------------------------------------------------------------------------------- /config/task/autoencoder_image.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: autoencoder_image 3 | - loss@losses.mse_loss: mse_loss 4 | #- callback@callbacks.autoencoder_image: log_image 5 | #- callback@callbacks.encoder_embedding: log_embedding 6 | 7 | task_weight: 1.0 8 | metrics: {} 9 | torchmetrics: {} 10 | callbacks: {} 11 | #callbacks: 12 | #autoencoder_image: 13 | # logging_batch_idx: 0 14 | # inputs: [[image_unflatten, 0], [image_decoder, 0]] 15 | # log_every_n_epochs: 50 16 | # input_names: ['Original', 'Reconstruction'] 17 | #encoder_embedding: 18 | # logging_batch_idx: 0 19 | # inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 20 | # input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 21 | # log_every_n_epochs: 50 22 | # plot_embeddings: True 23 | -------------------------------------------------------------------------------- /config/dataflow/cifar10_train_as_test.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar10_train_as_test 3 | - transforms: cifar10 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | class_names: ${dataflow.dataset.cifar10.class_names} 14 | 15 | x: 16 | image: 17 | transform: image_pil 18 | default_transform: image_pil_default_transform 19 | views: 1 20 | mask: null # TODO: Should be added as a default 21 | type: image # TODO: currently needs to be a mandatory param 22 | 23 | y: 24 | label: 25 | transform_with: image 26 | 27 | data_dir: None 28 | -------------------------------------------------------------------------------- /config/experiment/cifar100_cep_xfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100 5 | - override /modules@model: cep_image 6 | - override /pipeline@tasks: cep_image_log 7 | 8 | # Note: modify path_to_checkpoint to point to checkpoint of coarse training run 9 | model: 10 | encoders: 11 | ImageEncoder: 12 | path_to_checkpoint: None 13 | embeddings: 14 | ImagePreEncoder: 15 | path_to_checkpoint: None 16 | 17 | # Note: customize dirpath 18 | learner: 19 | modules_to_freeze: [ImageEncoder, ImagePreEncoder] 20 | optimizer: 21 | lr: .001 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | dirpath: None 26 | 27 | trainer: 28 | max_epochs: 100 29 | 30 | wandb: 31 | group: cifar100_cep 32 | -------------------------------------------------------------------------------- /config/dataflow/tinyimagenet_coarse.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: tinyimagenet_coarse 3 | - transforms: tinyimagenet 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | class_names: ${dataflow.dataset.tinyimagenet_coarse.class_names} 14 | 15 | x: 16 | image: 17 | transform: image_pil 18 | default_transform: image_pil_default_transform 19 | views: 1 20 | mask: null # TODO: Should be added as a default 21 | type: image # TODO: currently needs to be a mandatory param 22 | 23 | y: 24 | label: 25 | transform_with: image 26 | 27 | data_dir: None 28 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_supcon_autoencoder_joint.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image_joint 3 | - /task@SupCon: supcon_image 4 | - /task@Autoencoder: autoencoder_image 5 | 6 | 7 | SupCon: 8 | task_flow: 9 | image_pre_encoder: 10 | module: ImagePreEncoderSup 11 | inputs: [[supervised_task_preprocessing, image]] 12 | image_encoder: 13 | module: ImageEncoderSup 14 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 15 | Autoencoder: 16 | task_flow: 17 | image_pre_encoder: 18 | module: ImagePreEncoderAuto 19 | inputs: [[supervised_task_preprocessing, image]] 20 | image_encoder: 21 | module: ImageEncoderAuto 22 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 23 | -------------------------------------------------------------------------------- /config/task/contrastive_text_clip.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: clip_text 3 | - loss@losses.contrastive_loss_text: lspread_contrastive_loss 4 | # - callback@callbacks.encoder_embedding: log_embedding 5 | # - callback@callbacks.classifier_embedding: log_embedding 6 | 7 | 8 | task_weight: 1.0 9 | metrics: {} 10 | torchmetrics: {} 11 | callbacks: {} 12 | # callbacks: 13 | # encoder_embedding: 14 | # logging_batch_idx: 0 15 | # inputs: [[text_encoder, 0], [_input_, index], [_output_, label]] 16 | # input_names: ['TextEncoderEmbedding', 'sample_uid', 'labels'] 17 | # log_every_n_epochs: 1 18 | # plot_embeddings: True 19 | 20 | losses: 21 | contrastive_loss_text: 22 | module: ContrastiveLossText 23 | inputs: [[text_view_select_0_proj, 0], [text_view_select_1_proj, 0], [_output_, label]] 24 | 25 | -------------------------------------------------------------------------------- /config/pipeline/cep_image_supcon_autoencoder_joint_grayscale.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /task@Classification: classification_image_joint 3 | - /task@SupCon: supcon_image 4 | - /task@Autoencoder: autoencoder_image_grayscale 5 | 6 | 7 | SupCon: 8 | task_flow: 9 | image_pre_encoder: 10 | module: ImagePreEncoderSup 11 | inputs: [[supervised_task_preprocessing, image]] 12 | image_encoder: 13 | module: ImageEncoderSup 14 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 15 | Autoencoder: 16 | task_flow: 17 | image_pre_encoder: 18 | module: ImagePreEncoderAuto 19 | inputs: [[supervised_task_preprocessing, image]] 20 | image_encoder: 21 | module: ImageEncoderAuto 22 | inputs: [[image_pre_encoder, 0]] # for pre_encoders :=> 0: input, 1: target pre encoding 23 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_xfer_train_as_test.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_train_as_test 5 | - override /modules@model: cep_image 6 | - override /pipeline@tasks: cep_image_log 7 | 8 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 9 | model: 10 | encoders: 11 | ImageEncoder: 12 | path_to_checkpoint: None 13 | embeddings: 14 | ImagePreEncoder: 15 | path_to_checkpoint: None 16 | 17 | # Note: customize dirpath 18 | learner: 19 | modules_to_freeze: [ImageEncoder, ImagePreEncoder] 20 | optimizer: 21 | lr: .001 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | dirpath: None 26 | 27 | trainer: 28 | max_epochs: 1 29 | 30 | wandb: 31 | group: cifar10_cep_xfer_train_as_test 32 | -------------------------------------------------------------------------------- /unagi/models/decoders/image/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | from torchvision.models import resnet18, resnet34, resnet50 # noqa: F401 3 | 4 | 5 | class ResnetDecoder(nn.Module): 6 | def __init__( 7 | self, 8 | decoder_hidden_dim, 9 | decoder_projection_dim, 10 | model="resnet18", 11 | d_model=None, 12 | **kwargs, 13 | ): 14 | super().__init__() 15 | # self.d_model = model 16 | 17 | if not self.d_model: 18 | encoder = eval(model)() 19 | self.d_model = encoder.fc.in_features 20 | 21 | self.decoder = nn.Sequential( 22 | nn.Linear(self.d_model, decoder_hidden_dim), 23 | nn.ReLU(), 24 | nn.Linear(decoder_hidden_dim, decoder_projection_dim), 25 | ) 26 | 27 | def forward(self, x): 28 | return self.decoder(x) 29 | -------------------------------------------------------------------------------- /config/dataflow/cifar10_subset.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - dataset: cifar10 3 | - transforms: cifar10 4 | - loader@_here_: default 5 | 6 | # train: # the split for training, accepts str or list of strs 7 | # splits: train 8 | # val: # the split for validation, accepts str or list of strs 9 | # splits: val 10 | # test: # the split for testing, accepts str or list of strs 11 | # splits: test 12 | 13 | class_names: ${dataflow.dataset.cifar10.class_names} 14 | 15 | dataset: 16 | cifar10: 17 | subset_split_percent: 0.5 18 | subset_split_seed: 42 19 | 20 | x: 21 | image: 22 | transform: image_pil 23 | default_transform: image_pil_default_transform 24 | views: 1 25 | mask: null # TODO: Should be added as a default 26 | type: image # TODO: currently needs to be a mandatory param 27 | 28 | y: 29 | label: 30 | transform_with: image 31 | 32 | data_dir: None 33 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_xfer_train_test_embs.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image 5 | - override /pipeline@tasks: cep_image_log 6 | - override /dataflow/transforms@dataflow.transforms: cifar10_notransform 7 | 8 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 9 | model: 10 | encoders: 11 | ImageEncoder: 12 | path_to_checkpoint: None 13 | embeddings: 14 | ImagePreEncoder: 15 | path_to_checkpoint: None 16 | 17 | # Note: customize dirpath 18 | learner: 19 | modules_to_freeze: [ImageEncoder, ImagePreEncoder] 20 | optimizer: 21 | lr: .001 22 | checkpoint_scheduler: 23 | monitor: val/Classification_accuracy 24 | mode: max 25 | dirpath: None 26 | 27 | trainer: 28 | max_epochs: 1 29 | 30 | wandb: 31 | group: cifar10_cep_xfer_train_test_embs 32 | -------------------------------------------------------------------------------- /config/dataflow/transforms/nlvr2.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: ResizeAndCrop 5 | prob: 1.0 6 | resized_width: 530 7 | resized_height: 416 8 | #- type: RandomResizedCrop 9 | # prob: 1.0 10 | # size: 224 11 | # scale: [0.2, 1.] 12 | - type: HorizontalFlip 13 | prob: 0.5 14 | - type: ColorJitter 15 | brightness: 0.4 16 | contrast: 0.4 17 | saturation: 0.4 18 | hue: 0.1 19 | prob: 0.8 20 | - type: RandomGrayscale 21 | p: 0.2 22 | text_default_transforms: 23 | - type: Identity 24 | 25 | image_pil_default_transform: 26 | - type: ResizeAndCrop 27 | prob: 1.0 28 | resized_width: 530 29 | resized_height: 416 30 | - type: Resize 31 | prob: 1.0 32 | size: [224, 224] 33 | - type: ToTensor 34 | - type: Reshape2D 35 | h_dim: 3 # num_channels 36 | w_dim: 50176 #224 * 224 37 | -------------------------------------------------------------------------------- /config/task/autoencoder_image_grayscale.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: autoencoder_image_grayscale 3 | - loss@losses.mse_loss: mse_loss 4 | #- callback@callbacks.autoencoder_image: log_image 5 | #- callback@callbacks.encoder_embedding: log_embedding 6 | 7 | task_weight: 1.0 8 | metrics: {} 9 | torchmetrics: {} 10 | callbacks: {} 11 | losses: 12 | mse_loss: 13 | inputs: [[gray_scale, 0], [image_unflatten, 0]] 14 | 15 | #callbacks: 16 | #autoencoder_image: 17 | # logging_batch_idx: 0 18 | # inputs: [[image_unflatten, 0], [image_decoder, 0]] 19 | # log_every_n_epochs: 50 20 | # input_names: ['Original', 'Reconstruction'] 21 | #encoder_embedding: 22 | # logging_batch_idx: 0 23 | # inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 24 | # input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 25 | # log_every_n_epochs: 50 26 | # plot_embeddings: True 27 | -------------------------------------------------------------------------------- /unagi/datasets/meerkat_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | 3 | 4 | class MeerkatDataset(Dataset): 5 | """Torch dataset wrapper around meerkat dp""" 6 | 7 | def __init__(self, datapanel, xs, ys): 8 | self.dataset = datapanel 9 | self.x_names = xs 10 | self.y_names = ys 11 | 12 | def __len__(self): 13 | return len(self.dataset) 14 | 15 | def __getitem__(self, idx): 16 | # if self.x_names is single element, return single element 17 | if len(self.x_names) > 1: 18 | x = [self.dataset[idx][input_feat] for input_feat in self.x_names] 19 | else: 20 | x = self.dataset[idx][self.x_names[0]] 21 | if len(self.y_names) > 1: 22 | y = [self.dataset[idx][output_feat] for output_feat in self.y_names] 23 | else: 24 | y = self.dataset[idx][self.y_names[0]] 25 | return (x, y) 26 | -------------------------------------------------------------------------------- /config/dataflow/transforms/cub200.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 144 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | - type: RandomGrayscale 17 | p: 0.2 18 | 19 | image_pil_default_transform: 20 | - type: Resize 21 | size: [144, 144] 22 | - type: ToTensor 23 | - type: Normalize 24 | mean: [0.485, 0.456, 0.406] 25 | std: [0.229, 0.224, 0.225] 26 | - type: Reshape2D 27 | h_dim: 3 # num_channels 28 | w_dim: 20736 # 112 x 112 29 | 30 | image_pil_default_transform_pretrained: 31 | - type: Resize 32 | size: [144, 144] 33 | - type: ToTensor 34 | - type: Normalize 35 | mean: [0.485, 0.456, 0.406] 36 | std: [0.229, 0.224, 0.225] 37 | 38 | -------------------------------------------------------------------------------- /unagi/tasks/unagi_task_template.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from functools import partial 3 | 4 | from emmental.scorer import Scorer 5 | from emmental.task import EmmentalTask 6 | from torch import nn 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def create_unagi_task( 12 | model_name, 13 | model, 14 | dataset_name, 15 | task_flow, 16 | loss_module, 17 | loss_fns, 18 | output_classification, 19 | task_metric, 20 | n_views, 21 | ): 22 | loss = loss_module 23 | output = output_classification 24 | 25 | logger.info(f"Built model: {model_name}") 26 | 27 | return EmmentalTask( 28 | name=dataset_name, 29 | module_pool=nn.ModuleDict({"base_model": model}), 30 | task_flow=task_flow, 31 | loss_func=partial(loss, "base_model", model, loss_fns, n_views), 32 | output_func=partial(output, "base_model"), 33 | scorer=Scorer(metrics=task_metric), 34 | ) 35 | -------------------------------------------------------------------------------- /config/modules/cep_image_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /preprocessor@preprocessors.SupervisedTaskPreprocessing: supervised_task_preprocessing 3 | - /embedding@embeddings.ImagePreEncoder: square_patch 4 | - /encoder@encoders.ImageEncoder: transformer 5 | - /decoder@decoders.Classifier: classifier 6 | - /decoder@decoders.Grayscale: grayscale 7 | - /decoder@decoders.ImageDecoder: resnet18decoder 8 | - /decoder@decoders.ImageReshape: image_reshape 9 | - /loss@losses.ClassifierLoss: label_smoothing 10 | - /loss@losses.MSELoss: mse_loss 11 | 12 | train: True 13 | label_smoothing: True 14 | 15 | embeddings: 16 | ImagePreEncoder: 17 | d_model: 256 18 | 19 | encoders: 20 | ImageEncoder: 21 | d_model: 256 22 | 23 | decoders: 24 | Classifier: 25 | d_input: 256 26 | ImageDecoder: 27 | d_input: 256 28 | input_height: 32 29 | ImageReshape: 30 | d_input: 1024 31 | output_height: 32 32 | output_width: 32 33 | 34 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/resize.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class Resize(UnagiTransform): 7 | def __init__( 8 | self, 9 | size, 10 | name=None, 11 | prob=1.0, 12 | level=0, 13 | interpolation=transforms.InterpolationMode.BILINEAR, 14 | ): 15 | self.size = size 16 | self.interpolation = interpolation 17 | self.transform_func = transforms.Resize(self.size, self.interpolation) 18 | 19 | super().__init__(name, prob, level) 20 | 21 | def transform(self, pil_img, label, **kwargs): 22 | return self.transform_func(pil_img), label 23 | 24 | def __repr__(self): 25 | return ( 26 | f"" 28 | ) 29 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/compose.py: -------------------------------------------------------------------------------- 1 | class Compose(object): 2 | """Composes several transforms together. 3 | 4 | Originally from: 5 | https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#Compose 6 | 7 | Args: 8 | transforms (list of ``Transform`` objects): list of transforms to compose. 9 | """ 10 | 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, img, label, **kwargs): 15 | for idx, t in enumerate(self.transforms): 16 | kwargs["idx"] = idx 17 | img, label = t(img, label, **kwargs) 18 | return img, label 19 | 20 | def __repr__(self): 21 | format_string = self.__class__.__name__ + "(" 22 | for t in self.transforms: 23 | format_string += "\n" 24 | format_string += " {0}".format(t) 25 | format_string += "\n)" 26 | 27 | return format_string 28 | -------------------------------------------------------------------------------- /unagi/data/transforms/text/compose.py: -------------------------------------------------------------------------------- 1 | class Compose(object): 2 | """Composes several transforms together. 3 | 4 | Originally from: 5 | https://pytorch.org/docs/stable/_modules/torchvision/transforms/transforms.html#Compose 6 | 7 | Args: 8 | transforms (list of ``Transform`` objects): list of transforms to compose. 9 | """ 10 | 11 | def __init__(self, transforms): 12 | self.transforms = transforms 13 | 14 | def __call__(self, text, label, **kwargs): 15 | for idx, t in enumerate(self.transforms): 16 | kwargs["idx"] = idx 17 | text, label = t(text, label, **kwargs) 18 | return text, label 19 | 20 | def __repr__(self): 21 | format_string = self.__class__.__name__ + "(" 22 | for t in self.transforms: 23 | format_string += "\n" 24 | format_string += " {0}".format(t) 25 | format_string += "\n)" 26 | 27 | return format_string 28 | -------------------------------------------------------------------------------- /unagi/data/augmentations/cutout.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from unagi.data.transforms.image.cutout import Cutout as CutoutTransform 4 | from unagi.data.transforms.image.transform import UnagiTransform 5 | 6 | 7 | class Cutout(UnagiTransform): 8 | def __init__( 9 | self, 10 | name=None, 11 | prob=1.0, 12 | level=0, 13 | alpha=1.0, 14 | same_class_ratio=-1.0, 15 | prob_label=False, 16 | ): 17 | self.alpha = alpha 18 | self.same_class_ratio = same_class_ratio 19 | self.prob_label = prob_label 20 | self.cutout = CutoutTransform(prob=prob, level=level) 21 | 22 | super().__init__(name, prob, level) 23 | 24 | def transform(self, pil_img, label, dp_x, dp_y): 25 | if random.random() < self.prob: 26 | cutout_img, cutout_label = self.cutout(pil_img, label) 27 | return cutout_img, cutout_label 28 | else: 29 | return pil_img, label 30 | -------------------------------------------------------------------------------- /config/modules/cep_image_supcon.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /preprocessor@preprocessors.SupervisedTaskPreprocessing: supervised_task_preprocessing 3 | - /embedding@embeddings.ImagePreEncoder: square_patch 4 | - /encoder@encoders.ImageEncoder: transformer 5 | - /decoder@decoders.Classifier: classifier 6 | - /decoder@decoders.ViewSelect0: view_select 7 | - /decoder@decoders.ViewSelect1: view_select 8 | - /loss@losses.ClassifierLoss: label_smoothing 9 | - /loss@losses.ContrastiveLoss: contrastive_loss 10 | 11 | train: True 12 | label_smoothing: True 13 | 14 | losses: 15 | ContrastiveLoss: 16 | module: sup_con 17 | type: sup_con 18 | 19 | embeddings: 20 | ImagePreEncoder: 21 | d_model: 256 22 | 23 | # encoders: 24 | # ImageEncoder: 25 | # d_model: 256 26 | 27 | decoders: 28 | ViewSelect0: 29 | n_views: 2 30 | view_idx: 0 31 | ViewSelect1: 32 | n_views: 2 33 | view_idx: 1 34 | # Classifier: 35 | # d_input: 256 36 | # num_classes: 10 37 | -------------------------------------------------------------------------------- /config/dataflow/transforms/upmcfood101.yaml: -------------------------------------------------------------------------------- 1 | image_pil: 2 | - type: HorizontalFlip 3 | prob: 0.5 4 | - type: RandomResizedCrop 5 | prob: 1.0 6 | size: 224 7 | scale: [0.2, 1.] 8 | - type: HorizontalFlip 9 | prob: 0.5 10 | - type: ColorJitter 11 | brightness: 0.4 12 | contrast: 0.4 13 | saturation: 0.4 14 | hue: 0.1 15 | prob: 0.8 16 | - type: RandomGrayscale 17 | p: 0.2 18 | 19 | text_default_transforms: 20 | - type: Identity 21 | 22 | image_pil_default_transform: 23 | - type: Resize 24 | size: [224, 224] 25 | - type: ToTensor 26 | - type: Normalize 27 | mean: [0.485, 0.456, 0.406] 28 | std: [0.229, 0.224, 0.225] 29 | - type: Reshape2D 30 | h_dim: 3 # num_channelss 31 | w_dim: 12544 #32 x 32 x 3 32 | 33 | image_pil_default_transform_pretrained: 34 | - type: Resize 35 | size: [224, 224] 36 | - type: ToTensor 37 | - type: Normalize 38 | mean: [0.485, 0.456, 0.406] 39 | std: [0.229, 0.224, 0.225] 40 | 41 | -------------------------------------------------------------------------------- /config/task/entitymatching_pretrain.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - task_flow: entitymatching_pretrain 3 | - loss@losses.contrastive_loss_clip: lspread_contrastive_loss 4 | - loss@losses.contrastive_loss_lspread: lspread_contrastive_loss 5 | # - callback@callbacks.encoder_embedding: log_embedding 6 | 7 | 8 | task_weight: 1.0 9 | metrics: {} 10 | torchmetrics: {} 11 | callbacks: {} 12 | # callbacks: 13 | # encoder_embedding: 14 | # logging_batch_idx: 0 15 | # inputs: [[image_encoder, 0], [_input_, index], [_output_, label]] 16 | # input_names: ['ImageEncoderEmbedding', 'sample_uid', 'labels'] 17 | # log_every_n_epochs: 50 18 | # plot_embeddings: True 19 | 20 | losses: 21 | contrastive_loss_clip: 22 | module: ContrastiveLossClip 23 | inputs: [[text_view_select_0_left, 0], [text_view_select_0_right, 0], [_output_, label]] 24 | contrastive_loss_lspread: 25 | module: ContrastiveLossLspread 26 | inputs: [[text_view_select_0_left, 0], [text_view_select_0_right, 0], [_input_, price]] -------------------------------------------------------------------------------- /unagi/data/augmentations/brightness.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from unagi.data.transforms.image.brightness import Brightness as BrightnessTransform 4 | from unagi.data.transforms.image.transform import UnagiTransform 5 | 6 | 7 | class Brightness(UnagiTransform): 8 | def __init__( 9 | self, 10 | name=None, 11 | prob=1.0, 12 | level=0, 13 | alpha=1.0, 14 | same_class_ratio=-1.0, 15 | prob_label=False, 16 | ): 17 | self.alpha = alpha 18 | self.same_class_ratio = same_class_ratio 19 | self.prob_label = prob_label 20 | self.brightness = BrightnessTransform(prob=prob, level=level) 21 | 22 | super().__init__(name, prob, level) 23 | 24 | def transform(self, pil_img, label, dp_x, dp_y): 25 | if random.random() < self.prob: 26 | cutout_img, cutout_label = self.brightness(pil_img, label) 27 | return cutout_img, cutout_label 28 | else: 29 | return pil_img, label 30 | -------------------------------------------------------------------------------- /unagi/data/data_utils/collate_fns.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List, Tuple, Union 2 | 3 | from einops import rearrange 4 | from torch import Tensor 5 | 6 | from unagi.trainer.data import default_unagi_collate_fn 7 | 8 | 9 | def unagi_collate_fn( 10 | # is_train, 11 | # feature_type_map, 12 | # feature_view_map, 13 | batch: Union[List[Tuple[Dict[str, Any], Dict[str, Tensor]]], List[Dict[str, Any]]], 14 | ): 15 | (x_dict, y_dict) = default_unagi_collate_fn(batch) 16 | # x_dict["is_train"] = is_train 17 | # x_dict["feature_type_map"] = feature_type_map 18 | # x_dict["labels"] = y_dict["labels"] 19 | """x_dict.update( 20 | y_dict 21 | ) # ADD THIS LINE, AND IN YOUR DATALOADER ADD MORE LABELES 22 | """ 23 | new_x_dict = {} 24 | new_x_dict["index"] = x_dict["index"] 25 | del x_dict["index"] 26 | new_x_dict["inputs"] = x_dict 27 | new_y_dict = {k: rearrange(v, "b v ... -> (b v) ...") for k, v in y_dict.items()} 28 | return (new_x_dict, new_y_dict) 29 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_coarselabels_resnet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | - override /embedding@model.embeddings.ImagePreEncoder: identity 8 | - override /encoder@model.encoders.ImageEncoder: resnet 9 | - override /dataflow/transforms@dataflow.transforms: cifar10_resnet 10 | 11 | 12 | model: 13 | embeddings: 14 | ImagePreEncoder: 15 | d_model: 2048 16 | encoders: 17 | ImageEncoder: 18 | model: resnet50 19 | d_model: 2048 20 | use_pretrained: False 21 | decoders: 22 | Classifier: 23 | d_input: 2048 24 | d_output: 2 25 | 26 | dataflow: 27 | x: 28 | image: 29 | views: 2 30 | 31 | learner: 32 | checkpoint_scheduler: 33 | monitor: val/Classification_accuracy 34 | mode: max 35 | 36 | trainer: 37 | max_epochs: 200 38 | 39 | wandb: 40 | group: cifar10_supcon_cep_resnet_cl 41 | -------------------------------------------------------------------------------- /config/experiment/cifar10_supcon_cep_coarselabels_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_supcon 6 | - override /pipeline@tasks: cep_image_supcon 7 | - override /embedding@model.embeddings.ImagePreEncoder: identity 8 | - override /encoder@model.encoders.ImageEncoder: resnet 9 | - override /dataflow/transforms@dataflow.transforms: cifar10_resnet 10 | 11 | 12 | model: 13 | embeddings: 14 | ImagePreEncoder: 15 | d_model: 512 16 | encoders: 17 | ImageEncoder: 18 | model: resnet18 19 | d_model: 512 20 | use_pretrained: False 21 | decoders: 22 | Classifier: 23 | d_input: 512 24 | d_output: 2 25 | 26 | dataflow: 27 | x: 28 | image: 29 | views: 2 30 | 31 | learner: 32 | checkpoint_scheduler: 33 | monitor: val/Classification_accuracy 34 | mode: max 35 | 36 | trainer: 37 | max_epochs: 200 38 | 39 | wandb: 40 | group: cifar10_supcon_cep_resnet18_cl 41 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/transform.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | 4 | class UnagiTransform(object): 5 | """Base UnagiTransform transfrom class. 6 | 7 | Args: 8 | name(str): Transformation name. 9 | prob(float): Transformation probability. 10 | level(int): Transformation level. 11 | """ 12 | 13 | def __init__(self, name=None, prob=1.0, level=0): 14 | self.name = name if name is not None else type(self).__name__ 15 | self.prob = prob 16 | 17 | assert 0 <= level <= 1.0, "Invalid level, level must be in [0, 1.0]." 18 | 19 | self.level = level 20 | 21 | def transform(self, pil_img, label, **kwargs): 22 | return pil_img, label 23 | 24 | def __call__(self, pil_img, label, **kwargs): 25 | if random.random() <= self.prob: 26 | return self.transform(pil_img, label, **kwargs) 27 | else: 28 | return pil_img, label 29 | 30 | def __repr__(self): 31 | return f"" 32 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/translate_x.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | 5 | from unagi.data.transforms.image.transform import UnagiTransform 6 | from unagi.data.transforms.image.utils import categorize_value 7 | 8 | 9 | class TranslateX(UnagiTransform): 10 | def __init__(self, name=None, prob=1.0, level=0, max_degree=10): 11 | self.max_degree = max_degree 12 | self.value_range = (0, self.max_degree) 13 | super().__init__(name, prob, level) 14 | 15 | def transform(self, pil_img, label, **kwargs): 16 | degree = categorize_value(self.level, self.value_range, "float") 17 | if random.random() > 0.5: 18 | degree = -degree 19 | return ( 20 | pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, degree, 0, 1, 0)), 21 | label, 22 | ) 23 | 24 | def __repr__(self): 25 | return ( 26 | f"" 28 | ) 29 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/translate_y.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | from PIL import Image 4 | 5 | from unagi.data.transforms.image.transform import UnagiTransform 6 | from unagi.data.transforms.image.utils import categorize_value 7 | 8 | 9 | class TranslateY(UnagiTransform): 10 | def __init__(self, name=None, prob=1.0, level=0, max_degree=10): 11 | self.max_degree = max_degree 12 | self.value_range = (0, self.max_degree) 13 | super().__init__(name, prob, level) 14 | 15 | def transform(self, pil_img, label, **kwargs): 16 | degree = categorize_value(self.level, self.value_range, "float") 17 | if random.random() > 0.5: 18 | degree = -degree 19 | return ( 20 | pil_img.transform(pil_img.size, Image.AFFINE, (1, 0, 0, 0, 1, degree)), 21 | label, 22 | ) 23 | 24 | def __repr__(self): 25 | return ( 26 | f"" 28 | ) 29 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_xfer_resnet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image 5 | - override /pipeline@tasks: cep_image 6 | - override /embedding@model.embeddings.ImagePreEncoder: identity 7 | - override /encoder@model.encoders.ImageEncoder: resnet 8 | - override /dataflow/transforms@dataflow.transforms: cifar10_resnet 9 | 10 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 11 | model: 12 | embeddings: 13 | ImagePreEncoder: 14 | d_model: 2048 15 | encoders: 16 | ImageEncoder: 17 | path_to_checkpoint: None 18 | model: resnet50 19 | d_model: 2048 20 | use_pretrained: False 21 | decoders: 22 | Classifier: 23 | d_input: 2048 24 | 25 | # Note: customize dirpath 26 | learner: 27 | modules_to_freeze: [ImageEncoder, ImagePreEncoder] 28 | optimizer: 29 | lr: .001 30 | checkpoint_scheduler: 31 | monitor: val/Classification_accuracy 32 | mode: max 33 | dirpath: None 34 | 35 | trainer: 36 | max_epochs: 100 37 | 38 | wandb: 39 | group: cifar10_cep_xfer_resnet 40 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_xfer_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image 5 | - override /pipeline@tasks: cep_image 6 | - override /embedding@model.embeddings.ImagePreEncoder: identity 7 | - override /encoder@model.encoders.ImageEncoder: resnet 8 | - override /dataflow/transforms@dataflow.transforms: cifar10_resnet 9 | 10 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 11 | model: 12 | embeddings: 13 | ImagePreEncoder: 14 | d_model: 512 15 | encoders: 16 | ImageEncoder: 17 | path_to_checkpoint: None 18 | model: resnet18 19 | d_model: 512 20 | use_pretrained: False 21 | decoders: 22 | Classifier: 23 | d_input: 512 24 | 25 | # Note: customize dirpath 26 | learner: 27 | modules_to_freeze: [ImageEncoder, ImagePreEncoder] 28 | optimizer: 29 | lr: .001 30 | checkpoint_scheduler: 31 | monitor: val/Classification_accuracy 32 | mode: max 33 | dirpath: None 34 | 35 | trainer: 36 | max_epochs: 100 37 | 38 | wandb: 39 | group: cifar10_cep_xfer_resnet18 40 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/color_distortion.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class ColorDistortion(UnagiTransform): 7 | def __init__(self, name=None, prob=1.0, level=0, strength=0.5): 8 | super().__init__(name, prob, level) 9 | self.strength = strength 10 | self.color_jitter = transforms.ColorJitter( 11 | 0.8 * self.strength, 12 | 0.8 * self.strength, 13 | 0.8 * self.strength, 14 | 0.2 * self.strength, 15 | ) 16 | self.rnd_color_jitter = transforms.RandomApply([self.color_jitter], p=0.8) 17 | self.rnd_gray = transforms.RandomGrayscale(p=0.2) 18 | self.color_distort = transforms.Compose([self.rnd_color_jitter, self.rnd_gray]) 19 | 20 | def transform(self, pil_img, label, **kwargs): 21 | return self.color_distort(pil_img), label 22 | 23 | def __repr__(self): 24 | return ( 25 | f"= delta 10 | X = X[idx] 11 | Y = Y[idx] 12 | return X, ((Y < 0).astype(int)).reshape(-1) 13 | else: 14 | return X, Y 15 | 16 | 17 | def pil_loader(path): 18 | # open path as file to avoid ResourceWarning 19 | # (https://github.com/python-pillow/Pillow/issues/835) 20 | with open(path, "rb") as f: 21 | img = Image.open(f) 22 | return img.convert("RGB") 23 | 24 | 25 | def accimage_loader(path): 26 | import accimage 27 | 28 | try: 29 | return accimage.Image(path) 30 | except IOError: 31 | # Potentially a decoding problem, fall back to PIL.Image 32 | return pil_loader(path) 33 | 34 | 35 | def default_loader(path): 36 | from torchvision import get_image_backend 37 | 38 | if get_image_backend() == "accimage": 39 | return accimage_loader(path) 40 | else: 41 | return pil_loader(path) 42 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_twobackbones_xfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /modules@model: cep_image_supcon_autoencoder_joint 5 | - override /pipeline@tasks: cep_image_supcon_autoencoder_joint 6 | 7 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 8 | model: 9 | encoders: 10 | ImageEncoderSup: 11 | source_module: ImageEncoder 12 | path_to_checkpoint: None 13 | ImageEncoderAuto: 14 | source_module: ImageEncoder 15 | path_to_checkpoint: None 16 | embeddings: 17 | ImagePreEncoderSup: 18 | source_module: ImagePreEncoder 19 | path_to_checkpoint: None 20 | ImagePreEncoderAuto: 21 | source_module: ImagePreEncoder 22 | path_to_checkpoint: None 23 | 24 | # Note: customize dirpath 25 | learner: 26 | modules_to_freeze: [ImageEncoderSup, ImagePreEncoderSup, ImageEncoderAuto, ImagePreEncoderAuto, ImageDecoder] 27 | optimizer: 28 | lr: .001 29 | checkpoint_scheduler: 30 | monitor: val/Classification_accuracy 31 | mode: max 32 | dirpath: None 33 | 34 | trainer: 35 | max_epochs: 100 36 | 37 | wandb: 38 | group: cifar10_cep_twobackbones_xfer 39 | -------------------------------------------------------------------------------- /unagi/data/transforms/text/pretrained_lm_tokenize.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer 3 | 4 | from unagi.data.transforms.text.transform import UnagiTransform 5 | 6 | 7 | class PretrainedLMTokenize(UnagiTransform): 8 | def __init__( 9 | self, 10 | name=None, 11 | prob=1.0, 12 | level=0, 13 | model="bert-base-uncased", 14 | padding="max_length", 15 | truncation=True, 16 | max_length=128, 17 | ): 18 | super().__init__(name, prob, level) 19 | self.tokenizer = AutoTokenizer.from_pretrained(model) 20 | self.padding = padding 21 | self.truncation = truncation 22 | self.max_length = max_length 23 | 24 | def transform(self, text, label, **kwargs): 25 | if isinstance(text, str): 26 | tokens = torch.LongTensor( 27 | self.tokenizer( 28 | text, 29 | padding=self.padding, 30 | truncation=self.truncation, 31 | max_length=self.max_length, 32 | )["input_ids"] 33 | ) 34 | else: 35 | tokens = text 36 | return tokens, label 37 | -------------------------------------------------------------------------------- /config/experiment/cifar100_cep_twobackbones_xfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar100 5 | - override /modules@model: cep_image_supcon_autoencoder_joint 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder_joint 7 | 8 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 9 | model: 10 | encoders: 11 | ImageEncoderSup: 12 | source_module: ImageEncoder 13 | path_to_checkpoint: None 14 | ImageEncoderAuto: 15 | source_module: ImageEncoder 16 | path_to_checkpoint: None 17 | embeddings: 18 | ImagePreEncoderSup: 19 | source_module: ImagePreEncoder 20 | path_to_checkpoint: None 21 | ImagePreEncoderAuto: 22 | source_module: ImagePreEncoder 23 | path_to_checkpoint: None 24 | 25 | # Note: customize dirpath 26 | learner: 27 | modules_to_freeze: [ImageEncoderSup, ImagePreEncoderSup, ImageEncoderAuto, ImagePreEncoderAuto, ImageDecoder] 28 | optimizer: 29 | lr: .001 30 | checkpoint_scheduler: 31 | monitor: val/Classification_accuracy 32 | mode: max 33 | dirpath: None 34 | 35 | trainer: 36 | max_epochs: 100 37 | 38 | wandb: 39 | group: cifar100_supauto_twobackbones_xfer 40 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/random_resize_crop.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class RandomResizedCrop(UnagiTransform): 7 | def __init__( 8 | self, 9 | size, 10 | name=None, 11 | prob=1.0, 12 | level=0, 13 | scale=(0.08, 1.0), 14 | ratio=(0.75, 1.333_333_333_333_333_3), 15 | interpolation=transforms.InterpolationMode.BILINEAR, 16 | ): 17 | self.size = size 18 | self.scale = scale 19 | self.ratio = ratio 20 | self.interpolation = interpolation 21 | self.transform_func = transforms.RandomResizedCrop( 22 | self.size, self.scale, self.ratio, self.interpolation 23 | ) 24 | 25 | super().__init__(name, prob, level) 26 | 27 | def transform(self, pil_img, label, **kwargs): 28 | return self.transform_func(pil_img), label 29 | 30 | def __repr__(self): 31 | return ( 32 | f"" 35 | ) 36 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_autoencoder_coarselabels_resnet.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_autoencoder 6 | - override /pipeline@tasks: cep_image_autoencoder 7 | - override /embedding@model.embeddings.ImagePreEncoder: identity 8 | - override /encoder@model.encoders.ImageEncoder: resnet 9 | - override /dataflow/transforms@dataflow.transforms: cifar10_resnet 10 | 11 | model: 12 | embeddings: 13 | ImagePreEncoder: 14 | d_model: 2048 15 | encoders: 16 | ImageEncoder: 17 | model: resnet50 18 | d_model: 2048 19 | use_pretrained: False 20 | decoders: 21 | ImageDecoder: 22 | d_input: 2048 23 | Classifier: 24 | d_input: 2048 25 | d_output: 2 26 | 27 | tasks: 28 | Autoencoder: 29 | task_flow: 30 | image_unflatten: 31 | module: ImagePreEncoder 32 | inputs: [[supervised_task_preprocessing, image]] 33 | 34 | dataflow: 35 | x: 36 | image: 37 | views: 2 38 | 39 | trainer: 40 | max_epochs: 200 41 | 42 | learner: 43 | checkpoint_scheduler: 44 | monitor: val/Classification_accuracy 45 | mode: max 46 | 47 | wandb: 48 | group: cifar10_cep_autoencoder_resnet_cl 49 | -------------------------------------------------------------------------------- /config/experiment/cifar10_cep_autoencoder_coarselabels_resnet18.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: cifar10_coarse 5 | - override /modules@model: cep_image_autoencoder 6 | - override /pipeline@tasks: cep_image_autoencoder 7 | - override /embedding@model.embeddings.ImagePreEncoder: identity 8 | - override /encoder@model.encoders.ImageEncoder: resnet 9 | - override /dataflow/transforms@dataflow.transforms: cifar10_resnet 10 | 11 | model: 12 | embeddings: 13 | ImagePreEncoder: 14 | d_model: 512 15 | encoders: 16 | ImageEncoder: 17 | model: resnet18 18 | d_model: 512 19 | use_pretrained: False 20 | decoders: 21 | ImageDecoder: 22 | d_input: 512 23 | Classifier: 24 | d_input: 512 25 | d_output: 2 26 | 27 | tasks: 28 | Autoencoder: 29 | task_flow: 30 | image_unflatten: 31 | module: ImagePreEncoder 32 | inputs: [[supervised_task_preprocessing, image]] 33 | 34 | dataflow: 35 | x: 36 | image: 37 | views: 2 38 | 39 | trainer: 40 | max_epochs: 200 41 | 42 | learner: 43 | checkpoint_scheduler: 44 | monitor: val/Classification_accuracy 45 | mode: max 46 | 47 | wandb: 48 | group: cifar10_cep_autoencoder_resnet18_cl 49 | -------------------------------------------------------------------------------- /config/modules/cep_image_supcon_autoencoder.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /preprocessor@preprocessors.SupervisedTaskPreprocessing: supervised_task_preprocessing 3 | - /embedding@embeddings.ImagePreEncoder: square_patch 4 | - /encoder@encoders.ImageEncoder: transformer 5 | - /decoder@decoders.Classifier: classifier 6 | - /decoder@decoders.ViewSelect0: view_select 7 | - /decoder@decoders.ViewSelect1: view_select 8 | - /decoder@decoders.ImageDecoder: resnet18decoder 9 | - /decoder@decoders.ImageReshape: image_reshape 10 | - /loss@losses.ClassifierLoss: label_smoothing 11 | - /loss@losses.ContrastiveLoss: contrastive_loss 12 | - /loss@losses.MSELoss: mse_loss 13 | 14 | train: True 15 | label_smoothing: True 16 | 17 | embeddings: 18 | ImagePreEncoder: 19 | d_model: 256 20 | 21 | encoders: 22 | ImageEncoder: 23 | d_model: 256 24 | 25 | decoders: 26 | ViewSelect0: 27 | n_views: 2 28 | view_idx: 0 29 | ViewSelect1: 30 | n_views: 2 31 | view_idx: 1 32 | Classifier: 33 | d_input: 256 34 | ImageDecoder: 35 | d_input: 256 36 | input_height: 32 37 | ImageReshape: 38 | d_input: 1024 39 | output_height: 32 40 | output_width: 32 41 | 42 | losses: 43 | ContrastiveLoss: 44 | module: sup_con 45 | type: sup_con 46 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/color_jitter.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class ColorJitter(UnagiTransform): 7 | def __init__( 8 | self, 9 | brightness=0.0, 10 | contrast=0.0, 11 | saturation=0.0, 12 | hue=0.0, 13 | name=None, 14 | prob=1.0, 15 | level=0, 16 | ): 17 | self.brightness = brightness 18 | self.contrast = contrast 19 | self.saturation = saturation 20 | self.hue = hue 21 | self.transform_func = transforms.ColorJitter( 22 | brightness=self.brightness, 23 | contrast=self.contrast, 24 | saturation=self.saturation, 25 | hue=self.hue, 26 | ) 27 | 28 | super().__init__(name, prob, level) 29 | 30 | def transform(self, pil_img, label, **kwargs): 31 | return self.transform_func(pil_img), label 32 | 33 | def __repr__(self): 34 | return ( 35 | f"" 34 | 35 | def get_prob(self): 36 | if self.prob == 1: 37 | return self.prob 38 | return random.random() 39 | 40 | def get_level(self): 41 | return random.randint(0, 10 ** PRECISION) / float(10 ** PRECISION) 42 | -------------------------------------------------------------------------------- /unagi/models/encoders/sequence/mixer/mixer_modules.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from unagi.models.layers.blocks import FFN, PreNorm, Residual 4 | 5 | 6 | class mixer(nn.Module): 7 | def __init__(self, d, n=64, dropout=0.0): 8 | super().__init__() 9 | self.f = FFN(n, n << 1) 10 | 11 | def forward(self, x): 12 | # b x p x c 13 | return self.f(x.transpose(1, 2)).transpose(1, 2) 14 | 15 | 16 | # Encoder and decoder blocks 17 | 18 | 19 | def _prenorm(d, x, drop_path=0.0): 20 | return Residual(d, PreNorm(d, x), drop_path=drop_path) 21 | 22 | 23 | class mixer_encoder(nn.Module): 24 | def __init__( 25 | self, 26 | d, 27 | num_heads, 28 | l_max, # should be equal to the sequence length 29 | mlp_dim=None, 30 | dropout=0.1, 31 | drop_path=0.0, 32 | head_dropout=None, 33 | ): 34 | super().__init__() 35 | 36 | def _pre(x): 37 | return _prenorm(d, x, drop_path=drop_path) 38 | 39 | self.mlp = _pre(mixer(d, n=l_max, dropout=dropout)) 40 | mlp_dim = d << 1 if mlp_dim is None else mlp_dim 41 | self.ffn = _pre(FFN(d, mlp_dim, dropout=dropout)) 42 | 43 | def forward(self, x, mask=None): 44 | x = self.mlp(x) 45 | return self.ffn(x) 46 | -------------------------------------------------------------------------------- /unagi/models/encoders/image/resnet/resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torchvision.models import resnet18, resnet34, resnet50 # noqa: F401 4 | 5 | 6 | class ResnetEncoder(nn.Module): 7 | def __init__( 8 | self, 9 | model="resnet18", 10 | use_pretrained=True, 11 | **kwargs, 12 | ): 13 | super().__init__() 14 | 15 | encoder = eval(model)(pretrained=use_pretrained) 16 | self.f = [] 17 | """for name, module in encoder.named_children(): 18 | if name == "conv1": 19 | module = nn.Conv2d( 20 | 3, 64, kernel_size=3, stride=1, padding=1, bias=False 21 | ) 22 | if not isinstance(module, nn.Linear) and not isinstance( 23 | module, nn.MaxPool2d 24 | ): 25 | self.f.append(module)""" 26 | for name, module in encoder.named_children(): 27 | if not isinstance(module, nn.Linear): 28 | self.f.append(module) 29 | self.f = nn.Sequential(*self.f) 30 | self.feature_size = encoder.fc.in_features 31 | self.d_model = encoder.fc.in_features 32 | 33 | def forward(self, x): 34 | x = self.f(x) 35 | x = torch.flatten(x, start_dim=1) 36 | return x 37 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/random_crop.py: -------------------------------------------------------------------------------- 1 | from torchvision import transforms as transforms 2 | 3 | from unagi.data.transforms.image.transform import UnagiTransform 4 | 5 | 6 | class RandomCrop(UnagiTransform): 7 | def __init__( 8 | self, 9 | size, 10 | padding=None, 11 | pad_if_needed=False, 12 | fill=0, 13 | padding_mode="constant", 14 | name=None, 15 | prob=1.0, 16 | level=0, 17 | ): 18 | self.size = size 19 | self.padding = padding 20 | self.pad_if_needed = pad_if_needed 21 | self.fill = fill 22 | self.padding_mode = padding_mode 23 | 24 | self.transform_func = transforms.RandomCrop( 25 | self.size, self.padding, self.pad_if_needed, self.fill, self.padding_mode 26 | ) 27 | 28 | super().__init__(name, prob, level) 29 | 30 | def transform(self, pil_img, label, **kwargs): 31 | return self.transform_func(pil_img), label 32 | 33 | def __repr__(self): 34 | return ( 35 | f"" 39 | ) 40 | -------------------------------------------------------------------------------- /config/dataflow/dataset/tinyimagenet.yaml: -------------------------------------------------------------------------------- 1 | val_split: 0.1 2 | seed: 42 3 | class_names: ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '83', '84', '85', '86', '87', '88', '89', '90', '91', '92', '93', '94', '95', '96', '97', '98', '99', '100', '101', '102', '103', '104', '105', '106', '107', '108', '109', '110', '111', '112', '113', '114', '115', '116', '117', '118', '119', '120', '121', '122', '123', '124', '125', '126', '127', '128', '129', '130', '131', '132', '133', '134', '135', '136', '137', '138', '139', '140', '141', '142', '143', '144', '145', '146', '147', '148', '149', '150', '151', '152', '153', '154', '155', '156', '157', '158', '159', '160', '161', '162', '163', '164', '165', '166', '167', '168', '169', '170', '171', '172', '173', '174', '175', '176', '177', '178', '179', '180', '181', '182', '183', '184', '185', '186', '187', '188', '189', '190', '191', '192', '193', '194', '195', '196', '197', '198', '199'] -------------------------------------------------------------------------------- /unagi/unagi.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | import pytorch_lightning as pl 4 | 5 | from unagi.data_driver import get_data 6 | from unagi.trainer.trainer import UnagiModule 7 | 8 | logger = logging.getLogger(__name__) 9 | 10 | 11 | def main(config): 12 | # Create dataloaders 13 | data = get_data(config.dataflow) 14 | 15 | # set seed 16 | if ( 17 | "random_seed" in config.dataflow.keys() 18 | and config.dataflow.random_seed is not None 19 | ): 20 | pl.seed_everything(seed=config.dataflow.random_seed) 21 | 22 | if config.model.train: 23 | unagi_module = UnagiModule( 24 | config=config, 25 | dataset=data.dataset, 26 | train_dataloaders=data.train_dataloaders, 27 | val_dataloaders=data.val_dataloaders, 28 | test_dataloaders=data.test_dataloaders, 29 | ) 30 | 31 | if "wandb" in config.keys(): 32 | logger = pl.loggers.WandbLogger(**{**config.wandb, "config": config}) 33 | 34 | # Create trainer 35 | trainer = pl.Trainer( 36 | **{ 37 | **config.trainer, 38 | "logger": logger, 39 | "callbacks": unagi_module.configure_callbacks(), 40 | } 41 | ) 42 | trainer.fit(unagi_module) 43 | trainer.test(ckpt_path="best") 44 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | """For pip.""" 2 | 3 | from setuptools import setup 4 | 5 | exec(open("unagi/_version.py").read()) 6 | setup( 7 | name="unagi", 8 | version=__version__, 9 | description="Official repo for the paper 'Perfectly Balanced: Improving Transfer and Robustness of Supervised Contrastive Learning'", 10 | long_description=open("README.md").read(), 11 | packages=['unagi'], 12 | scripts=["bin/unagi"], 13 | install_requires=[ 14 | "cmake>=3.21.2, <4.0.0", 15 | "datasets>=1.11.0, <2.0.0", 16 | "einops>=0.3.2, <1.0.0", 17 | "meerkat-ml", 18 | "opt-einsum>=3.3.0, <4.0.0", 19 | "pykeops>=1.5, <2.0", 20 | "pytorch-lightning>=1.4.5, <1.4.9", 21 | "torch", 22 | "torchvision>=0.10.0, <2.0.0", 23 | "transformers", 24 | ], 25 | include_package_data=True, 26 | url="https://github.com/HazyResearch/thanos-code", 27 | classifiers=[ # https://pypi.python.org/pypi?:action=list_classifiers 28 | "Development Status :: 3 - Alpha", 29 | "Intended Audience :: Developers", 30 | "Programming Language :: Python", 31 | "Programming Language :: Python :: 3", 32 | "Programming Language :: Python :: 3.8", 33 | "Programming Language :: Python :: 3.9", 34 | "Programming Language :: Python :: 3 :: Only", 35 | ], 36 | python_requires=">=3.8", 37 | author="HazyResearch Team", 38 | ) 39 | -------------------------------------------------------------------------------- /config/modules/cep_image_supcon_autoencoder_joint.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /preprocessor@preprocessors.SupervisedTaskPreprocessing: supervised_task_preprocessing 3 | - /embedding@embeddings.ImagePreEncoderSup: square_patch 4 | - /embedding@embeddings.ImagePreEncoderAuto: square_patch 5 | - /encoder@encoders.ImageEncoderSup: transformer 6 | - /encoder@encoders.ImageEncoderAuto: transformer 7 | - /decoder@decoders.Classifier: classifier 8 | - /decoder@decoders.ViewSelect0: view_select 9 | - /decoder@decoders.ViewSelect1: view_select 10 | - /decoder@decoders.ImageDecoder: resnet18decoder 11 | - /decoder@decoders.ImageReshape: image_reshape 12 | - /loss@losses.ClassifierLoss: label_smoothing 13 | - /loss@losses.ContrastiveLoss: contrastive_loss 14 | - /loss@losses.MSELoss: mse_loss 15 | 16 | train: True 17 | label_smoothing: True 18 | 19 | embeddings: 20 | ImagePreEncoderSup: 21 | d_model: 256 22 | ImagePreEncoderAuto: 23 | d_model: 256 24 | 25 | encoders: 26 | ImageEncoderSup: 27 | d_model: 256 28 | ImageEncoderAuto: 29 | d_model: 256 30 | 31 | decoders: 32 | ViewSelect0: 33 | n_views: 2 34 | view_idx: 0 35 | ViewSelect1: 36 | n_views: 2 37 | view_idx: 1 38 | Classifier: 39 | d_input: 512 40 | ImageDecoder: 41 | d_input: 256 42 | input_height: 32 43 | ImageReshape: 44 | d_input: 1024 45 | output_height: 32 46 | output_width: 32 47 | 48 | losses: 49 | ContrastiveLoss: 50 | module: sup_con 51 | type: sup_con 52 | -------------------------------------------------------------------------------- /config/modules/cep_image_supcon_autoencoder_joint_grayscale.yaml: -------------------------------------------------------------------------------- 1 | defaults: 2 | - /preprocessor@preprocessors.SupervisedTaskPreprocessing: supervised_task_preprocessing 3 | - /embedding@embeddings.ImagePreEncoderSup: square_patch 4 | - /embedding@embeddings.ImagePreEncoderAuto: square_patch 5 | - /encoder@encoders.ImageEncoderSup: transformer 6 | - /encoder@encoders.ImageEncoderAuto: transformer 7 | - /decoder@decoders.Grayscale: grayscale 8 | - /decoder@decoders.Classifier: classifier 9 | - /decoder@decoders.ViewSelect0: view_select 10 | - /decoder@decoders.ViewSelect1: view_select 11 | - /decoder@decoders.ImageDecoder: resnet18decoder 12 | - /decoder@decoders.ImageReshape: image_reshape 13 | - /loss@losses.ClassifierLoss: label_smoothing 14 | - /loss@losses.ContrastiveLoss: contrastive_loss 15 | - /loss@losses.MSELoss: mse_loss 16 | 17 | train: True 18 | label_smoothing: True 19 | 20 | embeddings: 21 | ImagePreEncoderSup: 22 | d_model: 256 23 | ImagePreEncoderAuto: 24 | d_model: 256 25 | 26 | encoders: 27 | ImageEncoderSup: 28 | d_model: 256 29 | ImageEncoderAuto: 30 | d_model: 256 31 | 32 | decoders: 33 | ViewSelect0: 34 | n_views: 2 35 | view_idx: 0 36 | ViewSelect1: 37 | n_views: 2 38 | view_idx: 1 39 | Classifier: 40 | d_input: 512 41 | ImageDecoder: 42 | d_input: 256 43 | input_height: 32 44 | ImageReshape: 45 | d_input: 1024 46 | output_height: 32 47 | output_width: 32 48 | 49 | losses: 50 | ContrastiveLoss: 51 | module: sup_con 52 | type: sup_con 53 | -------------------------------------------------------------------------------- /config/trainer/full.yaml: -------------------------------------------------------------------------------- 1 | logger: true 2 | checkpoint_callback: null 3 | enable_checkpointing: true 4 | callbacks: null 5 | default_root_dir: null 6 | gradient_clip_val: null 7 | gradient_clip_algorithm: null 8 | process_position: 0 9 | num_nodes: 1 10 | num_processes: 1 11 | devices: null 12 | gpus: -1 13 | auto_select_gpus: false 14 | tpu_cores: null 15 | ipus: null 16 | log_gpu_memory: null 17 | progress_bar_refresh_rate: null 18 | enable_progress_bar: true 19 | overfit_batches: 0.0 20 | track_grad_norm: -1 21 | check_val_every_n_epoch: 1 22 | fast_dev_run: false 23 | accumulate_grad_batches: null 24 | max_epochs: null 25 | min_epochs: null 26 | max_steps: -1 27 | min_steps: null 28 | max_time: null 29 | limit_train_batches: 1.0 30 | limit_val_batches: 1.0 31 | limit_test_batches: 1.0 32 | limit_predict_batches: 1.0 33 | val_check_interval: 1.0 34 | flush_logs_every_n_steps: 1 35 | log_every_n_steps: 50 36 | accelerator: null 37 | strategy: dp 38 | sync_batchnorm: false 39 | precision: 32 40 | enable_model_summary: true 41 | weights_summary: top 42 | weights_save_path: null 43 | num_sanity_val_steps: 2 44 | resume_from_checkpoint: null 45 | profiler: null 46 | benchmark: false 47 | deterministic: false 48 | reload_dataloaders_every_n_epochs: 1 49 | reload_dataloaders_every_epoch: True 50 | auto_lr_find: false 51 | replace_sampler_ddp: true 52 | detect_anomaly: false 53 | auto_scale_batch_size: false 54 | prepare_data_per_node: null 55 | plugins: null 56 | amp_backend: native 57 | amp_level: null 58 | move_metrics_to_cpu: false 59 | multiple_trainloader_mode: max_size_cycle 60 | stochastic_weight_avg: false 61 | terminate_on_nan: null 62 | -------------------------------------------------------------------------------- /config/experiment/mnist_cep_twobackbones_xfer.yaml: -------------------------------------------------------------------------------- 1 | # @package _global_ 2 | defaults: 3 | - _self_ 4 | - override /dataflow: mnist 5 | - override /modules@model: cep_image_supcon_autoencoder_joint_grayscale 6 | - override /pipeline@tasks: cep_image_supcon_autoencoder_joint_grayscale 7 | 8 | dataflow: 9 | dataset: 10 | mnist: 11 | coarse_labels: False 12 | # Note: modify path_to_checkpoint to point to the checkpoint of coarse training run 13 | model: 14 | encoders: 15 | ImageEncoderSup: 16 | source_module: ImageEncoder 17 | path_to_checkpoint: None 18 | l_max: 50 19 | ImageEncoderAuto: 20 | source_module: ImageEncoder 21 | path_to_checkpoint: None 22 | l_max: 50 23 | embeddings: 24 | ImagePreEncoderSup: 25 | source_module: ImagePreEncoder 26 | path_to_checkpoint: None 27 | ImagePreEncoderAuto: 28 | source_module: ImagePreEncoder 29 | path_to_checkpoint: None 30 | decoders: 31 | Grayscale: 32 | d_output: 2 33 | resize: 28 34 | ImageDecoder: 35 | d_output: 2 36 | input_height: 28 37 | ImageReshape: 38 | d_output: 2 39 | d_input: 784 40 | output_height: 28 41 | output_width: 28 42 | 43 | # Note: customize dirpath 44 | learner: 45 | modules_to_freeze: [ImageEncoderSup, ImagePreEncoderSup, ImageEncoderAuto, ImagePreEncoderAuto, ImageDecoder] 46 | optimizer: 47 | lr: .001 48 | checkpoint_scheduler: 49 | monitor: val/Classification_accuracy 50 | mode: max 51 | dirpath: None 52 | 53 | trainer: 54 | max_epochs: 100 55 | 56 | wandb: 57 | group: mnist_twobackbone_xfer 58 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/cutout.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from PIL import ImageDraw 3 | 4 | from unagi.data.transforms.image.transform import UnagiTransform 5 | from unagi.data.transforms.image.utils import categorize_value 6 | 7 | 8 | class Cutout(UnagiTransform): 9 | def __init__(self, name=None, prob=1.0, level=0, max_pixel=20, color=None): 10 | self.max_pixel = max_pixel 11 | self.value_range = (0, self.max_pixel) 12 | self.color = color 13 | super().__init__(name, prob, level) 14 | 15 | def transform(self, pil_img, label, **kwargs): 16 | pil_img = pil_img.copy() 17 | degree = categorize_value(self.level, self.value_range, "int") 18 | width, height = pil_img.size 19 | 20 | x0 = np.random.uniform(width) 21 | y0 = np.random.uniform(height) 22 | 23 | x0 = int(max(0, x0 - degree / 2.0)) 24 | y0 = int(max(0, y0 - degree / 2.0)) 25 | x1 = min(width, x0 + degree) 26 | y1 = min(height, y0 + degree) 27 | 28 | xy = (x0, y0, x1, y1) 29 | 30 | if self.color is not None: 31 | color = self.color 32 | elif pil_img.mode == "RGB": 33 | color = (125, 123, 114) 34 | elif pil_img.mode == "L": 35 | color = 121 36 | else: 37 | raise ValueError(f"Unspported image mode {pil_img.mode}") 38 | 39 | ImageDraw.Draw(pil_img).rectangle(xy, color) 40 | 41 | return pil_img, label 42 | 43 | def __repr__(self): 44 | return ( 45 | f"" 47 | ) 48 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/resize_and_pad.py: -------------------------------------------------------------------------------- 1 | from PIL import Image, ImageOps 2 | from torchvision import transforms as transforms 3 | 4 | from unagi.data.transforms.image.transform import UnagiTransform 5 | 6 | 7 | class ResizeAndPad(UnagiTransform): 8 | def __init__( 9 | self, 10 | resized_width, 11 | resized_height, 12 | name=None, 13 | prob=1.0, 14 | level=0, 15 | ratio=(0.75, 1.333_333_333_333_333_3), 16 | interpolation=transforms.InterpolationMode.BILINEAR, 17 | ): 18 | self.resized_height = resized_height 19 | self.resized_width = resized_width 20 | 21 | super().__init__(name, prob, level) 22 | 23 | def transform(self, pil_img, label, **kwargs): 24 | original_size = pil_img.size 25 | ratio = float(self.resized_width) / max(original_size) 26 | new_size = (int(self.resized_width * ratio), int(self.resized_height * ratio)) 27 | pil_img = pil_img.resize(new_size, Image.ANTIALIAS) 28 | delta_w = self.resized_width - new_size[0] 29 | delta_h = self.resized_height - new_size[1] 30 | padding = ( 31 | delta_w // 2, 32 | delta_h // 2, 33 | delta_w - (delta_w // 2), 34 | delta_h - (delta_h // 2), 35 | ) 36 | resized_img = ImageOps.expand(pil_img, padding) 37 | return resized_img 38 | 39 | def __repr__(self): 40 | return ( 41 | f"" 44 | ) 45 | -------------------------------------------------------------------------------- /unagi/models/encoders/sequence/mixer/mixer.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | from unagi.models.encoders.base_sequence import SequenceModule 4 | from unagi.models.encoders.sequence.mixer.mixer_modules import mixer_encoder 5 | 6 | 7 | class MixerEncoder(SequenceModule): 8 | def __init__( 9 | self, 10 | d_model, 11 | n_heads, 12 | l_max, # can be computed based on embedding 13 | n_layers=4, 14 | dropout=0.1, 15 | head_dropout=0.1, 16 | mlp_dim=None, 17 | tie_them_all=False, 18 | **kwargs, 19 | ): 20 | super().__init__() 21 | 22 | def _block(): 23 | return mixer_encoder( 24 | d_model, 25 | n_heads, 26 | l_max=l_max, 27 | mlp_dim=mlp_dim, 28 | head_dropout=head_dropout, 29 | dropout=dropout, 30 | ) 31 | 32 | _f = ( 33 | [_block()] * n_layers 34 | if tie_them_all 35 | else [_block() for k in range(n_layers)] 36 | ) 37 | 38 | _f += [nn.LayerNorm(d_model)] 39 | self.f = nn.Sequential(*_f) 40 | 41 | self.d_model = d_model 42 | self.n_heads = n_heads 43 | self.mlp_dim = mlp_dim 44 | self.head_dropout = head_dropout 45 | self.dropout = dropout 46 | self.n_layers = n_layers 47 | self.tie_them_all = tie_them_all 48 | 49 | def forward(self, x, state=None, mask=None, *args, **kwargs): 50 | # print(f"px={px.size()} mask={mask.size()}") 51 | if mask is not None: 52 | mask = self.truncate(mask) 53 | x = x.masked_fill(~mask.unsqueeze(-1), 0) if mask is not None else x 54 | x = self.f(x) 55 | return x, state 56 | -------------------------------------------------------------------------------- /unagi/data/transforms/task/transform.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from unagi.data.transforms.image.compose import Compose 4 | 5 | 6 | class IdentityTransform: 7 | def __init__(self): 8 | pass 9 | 10 | def __call__(self, x, label): 11 | return x, label 12 | 13 | 14 | class GroupTransform: 15 | def __init__(self, transform, views=2): 16 | self.t = transform 17 | self.views = views 18 | self.squeeze = False 19 | 20 | def __call__(self, x, label, **kwargs): 21 | grouped_contrastive_transforms_lst, label_lst = zip( 22 | *[self.t(x, label, **kwargs) for i in range(self.views)] 23 | ) 24 | grouped_contrastive_transforms = torch.stack(grouped_contrastive_transforms_lst) 25 | if label is not None: 26 | label = torch.stack(label_lst, dim=1) 27 | 28 | return grouped_contrastive_transforms, label 29 | 30 | 31 | class MaskGen: 32 | def __init__(self, views, mask_length, mask_prob=0.05): 33 | self.mask_length = mask_length 34 | self.views = views 35 | self.mask_prob = mask_prob 36 | 37 | def __call__(self, x, label, **kwargs): 38 | return torch.rand(self.views, self.mask_length) < self.mask_prob 39 | 40 | 41 | class TupleTransform: 42 | def __init__(self, *args): 43 | self.fs = args 44 | 45 | def __call__(self, x, label, **kwargs): 46 | input = [ 47 | f(x, label, **kwargs)[0] 48 | if isinstance(f, Compose) 49 | else f(x, label, **kwargs) 50 | for f in self.fs 51 | ] 52 | for f in self.fs: 53 | if isinstance(f, Compose): 54 | label = f(x, label, **kwargs)[1] 55 | return input, label 56 | # return tuple([f(x, label)[0] for f in self.fs]) 57 | -------------------------------------------------------------------------------- /bin/unagi: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | import hydra 4 | from omegaconf import DictConfig, OmegaConf 5 | import rich 6 | import rich.tree 7 | import rich.syntax 8 | 9 | from unagi.unagi import main as unagi_main 10 | 11 | def print_config( 12 | config: DictConfig, 13 | resolve: bool = True, 14 | ) -> None: 15 | """ 16 | Prints content of DictConfig using Rich library and its tree structure. 17 | Args: 18 | config (DictConfig): Configuration composed by Hydra. 19 | resolve (bool, optional): Whether to resolve reference fields of DictConfig. 20 | """ 21 | 22 | style = "dim" 23 | tree = rich.tree.Tree("CONFIG", style=style, guide_style=style) 24 | 25 | fields = config.keys() 26 | for field in fields: 27 | branch = tree.add(field, style=style, guide_style=style) 28 | 29 | config_section = config.get(field) 30 | branch_content = str(config_section) 31 | if isinstance(config_section, DictConfig): 32 | branch_content = OmegaConf.to_yaml(config_section, resolve=resolve) 33 | 34 | branch.add(rich.syntax.Syntax(branch_content, "yaml")) 35 | 36 | rich.print(tree) 37 | 38 | with open("config_tree.txt", "w") as fp: 39 | rich.print(tree, file=fp) 40 | 41 | 42 | @hydra.main(config_path="../config", config_name="config.yaml") 43 | def main(args): 44 | args = process_config(args) 45 | print_config(args) 46 | unagi_main(args) 47 | 48 | 49 | def process_config(config): 50 | # Add a resolver to evaluate arbitrary Python expressions 51 | OmegaConf.register_new_resolver("eval", eval) 52 | 53 | # Enable adding new keys to config 54 | OmegaConf.set_struct(config, False) 55 | 56 | return config 57 | 58 | 59 | if __name__ == "__main__": 60 | 61 | # Call main without any arguments! 62 | main() 63 | -------------------------------------------------------------------------------- /unagi/data_driver.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from types import SimpleNamespace 3 | 4 | from unagi.datasets import DATASET_CLASSES 5 | 6 | logger = logging.getLogger(__name__) 7 | 8 | 9 | def get_data(dataflow_config): 10 | """ 11 | Builds datasets and dataloaders from config file. 12 | 13 | # Inputs 14 | :param config: (dict) dictionary representation of experiment config file 15 | 16 | # Returns 17 | :return: SimpleNamespace containing datasets and dataloaders (train, val, test). 18 | A dataloader is built for every task / dataset type split. 19 | """ 20 | 21 | datasets = list(dataflow_config.dataset.keys()) 22 | assert len(datasets) == 1, "Only one dataset is supported." 23 | dataset_name = datasets[0] 24 | 25 | dataset = DATASET_CLASSES[dataset_name]( 26 | data_dir=dataflow_config.data_dir, 27 | x_transforms=dataflow_config.x, 28 | y_transforms=dataflow_config.y, 29 | transform_pool=dataflow_config.transforms, 30 | **dataflow_config.dataset[dataset_name], 31 | ) 32 | train_dataloaders = dataset.train_dataloader( 33 | batch_size=dataflow_config.batch_size, 34 | num_workers=dataflow_config.num_workers, 35 | drop_last=True, 36 | ) 37 | val_dataloaders = dataset.val_dataloader( 38 | batch_size=dataflow_config.batch_size, 39 | num_workers=dataflow_config.num_workers, 40 | drop_last=True, 41 | ) 42 | test_dataloaders = dataset.test_dataloader( 43 | batch_size=dataflow_config.batch_size, 44 | num_workers=dataflow_config.num_workers, 45 | drop_last=True, 46 | ) 47 | 48 | return SimpleNamespace( 49 | dataset=dataset, 50 | train_dataloaders=train_dataloaders, 51 | val_dataloaders=val_dataloaders, 52 | test_dataloaders=test_dataloaders, 53 | ) 54 | -------------------------------------------------------------------------------- /unagi/data/transforms/text/back_translate.py: -------------------------------------------------------------------------------- 1 | import math 2 | import os 3 | import pickle 4 | from zipfile import ZipFile 5 | 6 | import numpy as np 7 | 8 | from unagi.data.transforms.text.transform import UnagiTransform 9 | 10 | 11 | class BackTranslate(UnagiTransform): 12 | def __init__(self, name=None, prob=1.0, level=0, select_prob=0.5): 13 | super().__init__(name, prob, level) 14 | self.select_prob = select_prob 15 | # Check data exists or not. 16 | if not os.path.exists("data/seq2seq.pkl"): 17 | with ZipFile("data/seq2seq.pkl.zip", "r") as zip: 18 | zip.extractall("data/") 19 | data = open("data/seq2seq.pkl", "rb") 20 | self.seq2seq = pickle.load(data) 21 | 22 | def transform(self, text, label, **kwargs): 23 | ori_text = " ".join(text).strip() 24 | num_sents = len(self.seq2seq[ori_text][0]) 25 | select_prob = np.random.random(size=(num_sents,)) 26 | 27 | new_sents = [ 28 | self.seq2seq[ori_text][0][i] 29 | if select_prob[i] > self.select_prob 30 | else self.seq2seq[ori_text][1][i] 31 | for i in range(num_sents) 32 | ] 33 | new_text = " ".join(new_sents).strip() 34 | 35 | return (self.replace_with_length_check(ori_text, new_text).split(" "), label) 36 | 37 | def replace_with_length_check( 38 | self, ori_text, new_text, use_min_length=10, use_max_length_diff_ratio=0.5 39 | ): 40 | """Use new_text if the text length satisfies several constraints.""" 41 | if len(ori_text) < use_min_length or len(new_text) < use_min_length: 42 | return ori_text 43 | length_diff_ratio = 1.0 * (len(new_text) - len(ori_text)) / len(ori_text) 44 | if math.fabs(length_diff_ratio) > use_max_length_diff_ratio: 45 | return ori_text 46 | return new_text 47 | -------------------------------------------------------------------------------- /unagi/tasks/loss_fns/mask_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | from unagi.tasks.loss_fns.base_loss import UnagiLoss 5 | 6 | 7 | class BatchMask(UnagiLoss): 8 | def __init__(self): 9 | super().__init__() 10 | 11 | def forward(self, last_layer, embs): 12 | # embs == output of embedding layers 13 | # last_layer == output of the decoder 14 | # Both embs and last_layer 15 | # batch x sentence x dims 16 | # for each prediction in the last layer, we assume no duplicate ftm. 17 | a = torch.einsum("b s d, b t d -> b s t", last_layer, embs) 18 | return -torch.diagonal(a.log_softmax(-1), dim1=1, dim2=2).mean() 19 | 20 | 21 | class BatchMaskDup(UnagiLoss): 22 | def __init__(self, eps=1e-5): 23 | super().__init__() 24 | print("Using Batchmask") 25 | self.eps = eps 26 | 27 | def forward(self, last_layer, embs): 28 | # embs == output of embedding layers 29 | # last_layer == output of the decoder 30 | # 31 | # Both embs and last_layer 32 | # batch x sentence x dims 33 | # for each prediction in the last layer, we assume no duplicate ftm. 34 | def _g(x, y): 35 | return torch.einsum("b s d, b t d -> b s t", x, y) 36 | 37 | def _dupe_check(x): 38 | b, s, _ = x.size() 39 | x = F.normalize(x, dim=-1) 40 | mask = _g(x, x) > 1 - self.eps 41 | mask = mask.masked_fill( 42 | torch.triu(torch.ones(b, s, s, device=x.device)) > 0, False 43 | ) 44 | # The mask is true, if there is a duplicate that comes before it in order. 45 | # As a result, only the first duplicate is counted. 46 | return mask.any(-1) 47 | 48 | a = _g(last_layer, embs) 49 | a = a.masked_fill(_dupe_check(embs).unsqueeze(1), -1e9) 50 | return -torch.diagonal(a.log_softmax(-1), dim1=1, dim2=2).mean() 51 | -------------------------------------------------------------------------------- /unagi/models/encoders/sequence/bert/bert.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from transformers import AutoTokenizer, BertModel 3 | 4 | from unagi.models.encoders.base_sequence import SequenceModule 5 | 6 | 7 | class BertEncoder(SequenceModule): 8 | def __init__( 9 | self, 10 | freeze_layers=True, 11 | pretrained_lm_name="bert-base-uncased", 12 | use_cls_token=True, 13 | use_all_tokens=False, 14 | pretrained_weights=None, 15 | **kwargs, 16 | ): 17 | super().__init__() 18 | self.f = BertModel.from_pretrained(pretrained_lm_name) 19 | self.tokenizer = AutoTokenizer.from_pretrained(pretrained_lm_name) 20 | self.f = self.f.train() 21 | self.use_cls_token = use_cls_token 22 | self.use_all_tokens = use_all_tokens 23 | 24 | """if freeze_layers: 25 | for param in self.f.parameters(): 26 | param.requires_grad = False""" 27 | 28 | self.d_model = self.f.encoder.layer[-1].output.dense.out_features 29 | self.padding = "max_length" 30 | self.truncation = True 31 | self.max_length = 128 32 | 33 | def forward(self, x): 34 | # tok_out = self.tokenizer( 35 | # x, 36 | # padding=self.padding, 37 | # truncation=self.truncation, 38 | # max_length=self.max_length, 39 | # ) 40 | # input_ids = torch.LongTensor(tok_out["input_ids"]) 41 | # attention_mask = torch.LongTensor(tok_out["attention_mask"]) 42 | input_ids = x 43 | attention_mask = (x != 0).long() 44 | token_type_ids = torch.zeros_like(input_ids) 45 | 46 | # output = self.f(inputs_embeds=x, return_dict=True) 47 | output = self.f( 48 | input_ids, 49 | attention_mask=attention_mask, 50 | token_type_ids=token_type_ids, 51 | return_dict=True, 52 | ) 53 | if self.use_cls_token: 54 | # return output["pooler_output"] 55 | return output["last_hidden_state"][:, 0, :].squeeze(dim=1) 56 | else: 57 | return output["last_hidden_state"] 58 | -------------------------------------------------------------------------------- /unagi/models/decoders/sequence/mixer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from einops import rearrange 3 | from torch import nn 4 | 5 | from unagi.models.encoders.base_sequence import SequenceModule 6 | from unagi.models.encoders.sequence.mixer.mixer_modules import mixer_encoder 7 | 8 | 9 | class MixerDecoder(SequenceModule): 10 | def __init__( 11 | self, 12 | d_model, 13 | n_heads, 14 | l_max, # can be computed based on embedding 15 | n_layers=4, 16 | dropout=0.1, 17 | head_dropout=0.1, 18 | mlp_dim=None, 19 | tie_them_all=False, 20 | **kwargs, 21 | ): 22 | super().__init__() 23 | 24 | def _block(): 25 | return mixer_encoder( 26 | d_model, 27 | n_heads, 28 | l_max=l_max, 29 | mlp_dim=mlp_dim, 30 | head_dropout=head_dropout, 31 | dropout=dropout, 32 | ) 33 | 34 | _f = ( 35 | [_block()] * n_layers 36 | if tie_them_all 37 | else [_block() for k in range(n_layers)] 38 | ) 39 | 40 | _f += [nn.LayerNorm(d_model)] 41 | self.decoder = nn.Sequential(*_f) 42 | 43 | self.d_model = d_model 44 | self.n_heads = n_heads 45 | self.mlp_dim = mlp_dim 46 | self.head_dropout = head_dropout 47 | self.dropout = dropout 48 | self.n_layers = n_layers 49 | self.tie_them_all = tie_them_all 50 | 51 | self.device = torch.device("cpu") 52 | if torch.cuda.is_available(): 53 | self.device = torch.device("cuda") 54 | 55 | def forward(self, x, state=None, mask=None, *args, **kwargs): 56 | # print(f"px={px.size()} mask={mask.size()}") 57 | pooled_output = x.mean(-2) 58 | if not self.expand: 59 | self.expand = nn.Linear(self.d_model, self.d_model * x.size(1)).to( 60 | self.device 61 | ) 62 | 63 | temp_enc = rearrange(self.expand(pooled_output), "b (s d) -> b s d", d=self.d) 64 | decoded_outputs = self.decoder(temp_enc, encoding=True) 65 | return decoded_outputs 66 | -------------------------------------------------------------------------------- /unagi/datasets/mnist/mnist_dataset.py: -------------------------------------------------------------------------------- 1 | import meerkat as mk 2 | import torchvision 3 | 4 | from unagi.datasets.base_dataset import UnagiDatasetBuilder 5 | from unagi.datasets.meerkat_dataset import MeerkatDataset 6 | from unagi.datasets.mnist.utils import sparse2coarse 7 | 8 | 9 | class MNIST(UnagiDatasetBuilder): 10 | """Dataset to load MNIST dataset.""" 11 | 12 | _name_ = "mnist" 13 | # TODO: these can be modified by the transforms (e.g. grayscale) 14 | # and need to be up to date 15 | input_shapes = { 16 | "image": (1, 28, 28), 17 | } 18 | output_shapes = { 19 | "label": (10,), 20 | } 21 | 22 | @property 23 | def init_defaults(self): 24 | return { 25 | "val_split": 0.1, 26 | "seed": 42, # For validation split 27 | "coarse_labels": False, 28 | } 29 | 30 | def setup(self): 31 | self.dataset_train = torchvision.datasets.MNIST( 32 | root=self.data_dir, 33 | train=True, 34 | download=True, 35 | ) 36 | self.dataset_train, self.dataset_val = self.split_train_val( 37 | val_split=self.val_split 38 | ) 39 | self.dataset_test = torchvision.datasets.MNIST( 40 | root=self.data_dir, 41 | train=False, 42 | download=True, 43 | ) 44 | self.dataset_train = self.to_meerkat(self.dataset_train) 45 | self.dataset_val = self.to_meerkat(self.dataset_val) 46 | self.dataset_test = self.to_meerkat(self.dataset_test) 47 | 48 | def to_meerkat(self, dataset): 49 | if self.coarse_labels: 50 | # TODO: split train and val 51 | img_pil, label = [], [] 52 | 53 | for _, (x, y) in enumerate(dataset): 54 | img_pil.append(x) 55 | label.append(y) 56 | coarse_label = sparse2coarse(label, dataset="mnist") 57 | obj = { 58 | "image": mk.ListColumn(img_pil), 59 | "label": mk.TensorColumn(coarse_label), 60 | } 61 | dp = mk.DataPanel(obj) 62 | 63 | # TODO: combine this with the UnagiDataset as an option 64 | dataset = MeerkatDataset(dp, xs=["image"], ys=["label"]) 65 | self.output_shapes["label"] = (2,) 66 | return dataset 67 | -------------------------------------------------------------------------------- /unagi/models/encoders/base_sequence.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | class SequenceModule(nn.Module): 5 | """Abstract sequence model class. All layers that the backbones 6 | use must adhere to this 7 | 8 | A sequence model is a layer that transforms an input of shape 9 | (n_batch, l_sequence, d_input) to (n_batch, l_sequence, d_output) 10 | Additionally, it returns a "state" which can be any additional information 11 | For example, RNN and SSM layers may return their hidden state, 12 | while some types of transformer layers (e.g. Transformer-XL) may want to pass 13 | through state as well 14 | - default_state receives a batch_shape with device and returns an initial state 15 | - step simulates a single step of the sequence (e.g. one unroll for an RNN). 16 | It receives a state and single input (n_batch, d_input) and returns a state 17 | and output (n_batch, d_output) 18 | - forward is a sequence-to-sequence transformation that receives an optional state 19 | """ 20 | 21 | # def __init__(self, transposed=False, *args, **kwargs): 22 | # """ model should support regular (B, L, H) and transposed (B, H, L) 23 | # axes ordering """ 24 | # self.transposed = transposed 25 | 26 | @property 27 | def d_output(self): 28 | return self._d_output 29 | 30 | @d_output.setter 31 | def d_output(self, d): 32 | self._d_output = d 33 | 34 | @property 35 | def state_to_tensor(self): 36 | """Returns a function mapping a state to a single tensor, 37 | in case one wants to use the hidden state instead of the output 38 | for final prediction""" 39 | return lambda _: None 40 | 41 | @property 42 | def d_state(self): 43 | """Returns dimension of output of self.state_to_tensor""" 44 | return None 45 | 46 | @property 47 | def transposed(self): 48 | return self._transposed 49 | 50 | @transposed.setter 51 | def transposed(self, x): 52 | self._transposed = x 53 | 54 | def default_state(self, *batch_shape, device=None): 55 | # TODO device shouldn't be needed; models should store their own 56 | # initial state at initialization 57 | return None 58 | 59 | def step(self, x, state=None, *args, **kwargs): 60 | return x, state 61 | 62 | def forward(self, x, state=None, *args, **kwargs): 63 | return x, state 64 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Local files 2 | wandb/ 3 | logs/ 4 | train/ 5 | 6 | !config/* 7 | 8 | # Byte-compiled / optimized / DLL files 9 | __pycache__/ 10 | *.py[cod] 11 | *$py.class 12 | 13 | # C extensions 14 | *.so 15 | 16 | # Distribution / packaging 17 | .Python 18 | build/ 19 | develop-eggs/ 20 | dist/ 21 | downloads/ 22 | eggs/ 23 | .eggs/ 24 | lib/ 25 | lib64/ 26 | parts/ 27 | sdist/ 28 | var/ 29 | wheels/ 30 | pip-wheel-metadata/ 31 | share/python-wheels/ 32 | *.egg-info/ 33 | .installed.cfg 34 | *.egg 35 | MANIFEST 36 | 37 | # PyInstaller 38 | # Usually these files are written by a python script from a template 39 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 40 | *.manifest 41 | *.spec 42 | 43 | # Installer logs 44 | pip-log.txt 45 | pip-delete-this-directory.txt 46 | 47 | # Unit test / coverage reports 48 | htmlcov/ 49 | .tox/ 50 | .nox/ 51 | .coverage 52 | .coverage.* 53 | .cache 54 | nosetests.xml 55 | coverage.xml 56 | *.cover 57 | *.py,cover 58 | .hypothesis/ 59 | .pytest_cache/ 60 | 61 | # Translations 62 | *.mo 63 | *.pot 64 | 65 | # Django stuff: 66 | *.log 67 | local_settings.py 68 | db.sqlite3 69 | db.sqlite3-journal 70 | 71 | # Flask stuff: 72 | instance/ 73 | .webassets-cache 74 | 75 | # Scrapy stuff: 76 | .scrapy 77 | 78 | # Sphinx documentation 79 | docs/_build/ 80 | 81 | # PyBuilder 82 | target/ 83 | 84 | # Jupyter Notebook 85 | .ipynb_checkpoints 86 | 87 | # IPython 88 | profile_default/ 89 | ipython_config.py 90 | 91 | # pyenv 92 | .python-version 93 | 94 | # pipenv 95 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 96 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 97 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 98 | # install all needed dependencies. 99 | #Pipfile.lock 100 | 101 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 102 | __pypackages__/ 103 | 104 | # Celery stuff 105 | celerybeat-schedule 106 | celerybeat.pid 107 | 108 | # SageMath parsed files 109 | *.sage.py 110 | 111 | # Environments 112 | .env 113 | .venv 114 | env/ 115 | venv/ 116 | ENV/ 117 | env.bak/ 118 | venv.bak/ 119 | 120 | # Spyder project settings 121 | .spyderproject 122 | .spyproject 123 | 124 | # Rope project settings 125 | .ropeproject 126 | 127 | # mkdocs documentation 128 | /site 129 | 130 | # mypy 131 | .mypy_cache/ 132 | .dmypy.json 133 | dmypy.json 134 | 135 | # vim 136 | **.swp 137 | -------------------------------------------------------------------------------- /unagi/tasks/loss_fns/ce_loss.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import Tensor 6 | 7 | from unagi.tasks.loss_fns.base_loss import UnagiLoss 8 | 9 | 10 | class SoftCrossEntropyLoss(UnagiLoss): 11 | """Calculate the CrossEntropyLoss with soft targets. 12 | :param weight: Weight to assign to each of the classes. Default: None 13 | :type weight: list of float 14 | :param reduction: The way to reduce the losses: 'none' | 'mean' | 'sum'. 15 | 'none': no reduction, 16 | 'mean': the mean of the losses, 17 | 'sum': the sum of the losses. 18 | :type reduction: str 19 | """ 20 | 21 | def __init__(self, weight: List[float] = None, reduction: str = "mean"): 22 | super().__init__() 23 | if weight is None: 24 | self.weight = None 25 | else: 26 | self.register_buffer("weight", torch.Tensor(weight)) 27 | 28 | self.reduction = reduction 29 | 30 | def forward(self, input: Tensor, target: Tensor) -> Tensor: # type:ignore 31 | """Calculate the loss. 32 | :param input: prediction logits 33 | :param target: target probabilities 34 | :return: loss 35 | """ 36 | n, k = input.shape 37 | losses = input.new_zeros(n) 38 | 39 | for i in range(k): 40 | cls_idx = input.new_full((n,), i, dtype=torch.long) 41 | loss = F.cross_entropy(input, cls_idx, reduction="none") 42 | if self.weight is not None: 43 | loss = loss * self.weight[i] 44 | losses += target[:, i].float() * loss 45 | 46 | if self.reduction == "mean": 47 | losses = losses.mean() 48 | elif self.reduction == "sum": 49 | losses = losses.sum() 50 | elif self.reduction != "none": 51 | raise ValueError(f"Unrecognized reduction: {self.reduction}") 52 | 53 | return losses 54 | 55 | 56 | class LabelSmoothing(UnagiLoss): 57 | """NLL loss with label smoothing.""" 58 | 59 | def __init__(self, smoothing=0.0): 60 | """Constructor for the LabelSmoothing module. 61 | :param smoothing: label smoothing factor 62 | """ 63 | super().__init__() 64 | self.confidence = 1.0 - smoothing 65 | self.smoothing = smoothing 66 | 67 | def forward(self, x, target): 68 | logprobs = torch.nn.functional.log_softmax(x, dim=-1) 69 | nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1)) 70 | # nll_loss = -logprobs.gather(dim=-1, index=target) 71 | nll_loss = nll_loss.squeeze(1) 72 | smooth_loss = -logprobs.mean(dim=-1) 73 | # smooth_loss = smooth_loss.unsqueeze(-1) # added 74 | loss = self.confidence * nll_loss + self.smoothing * smooth_loss 75 | return loss.mean() 76 | -------------------------------------------------------------------------------- /unagi/data/augmentations/mixup.py: -------------------------------------------------------------------------------- 1 | import random 2 | from typing import Tuple 3 | 4 | import numpy as np 5 | import torch 6 | 7 | from unagi.data.transforms.image.transform import UnagiTransform 8 | 9 | 10 | class Mixup(UnagiTransform): 11 | def __init__( 12 | self, 13 | name=None, 14 | prob=1.0, 15 | level=0, 16 | alpha=1.0, 17 | same_class_ratio=-1.0, 18 | prob_label=False, 19 | ): 20 | self.alpha = alpha 21 | self.same_class_ratio = same_class_ratio 22 | self.prob_label = prob_label 23 | 24 | super().__init__(name, prob, level) 25 | 26 | def transform(self, pil_img, label, dp_x, dp_y): 27 | """ 28 | Note: for Mixup apply all transforms (including converting to Tesnsor) 29 | before applying mixup 30 | 31 | pil_img (Tensor) 32 | dp_x: input column 33 | dp_y: label column 34 | """ 35 | if random.random() < self.prob: 36 | if self.alpha > 0.0: 37 | mix_ratio = np.random.beta(self.alpha, self.alpha) 38 | else: 39 | mix_ratio = 1.0 40 | 41 | idx = np.random.randint(len(dp_x)) 42 | tot_cnt = len(dp_x) 43 | 44 | if self.same_class_ratio >= 0: # get idx 45 | same_class = ( 46 | True if np.random.rand() <= self.same_class_ratio else False 47 | ) 48 | for i in np.random.permutation(tot_cnt): 49 | if same_class == torch.equal(dp_y["labels"][i], label): 50 | idx = i 51 | break 52 | 53 | cand_img = dp_x[idx] 54 | cand_label = dp_y[idx] 55 | # Calc all transforms before mixup 56 | 57 | if isinstance(cand_img, Tuple): 58 | cand_img = cand_img[0] 59 | if isinstance(pil_img, Tuple): 60 | cand_img = pil_img[0] 61 | 62 | mixup_img = mix_ratio * pil_img + (1 - mix_ratio) * cand_img 63 | 64 | if label is not None: 65 | if self.prob_label: 66 | mixup_label = mix_ratio * label + (1 - mix_ratio) * cand_label 67 | else: 68 | mixup_label = ( 69 | label if np.random.random() < mix_ratio else cand_label 70 | ) 71 | else: 72 | mixup_label = label 73 | 74 | return mixup_img, mixup_label 75 | 76 | else: 77 | 78 | return pil_img, label 79 | 80 | def __repr__(self): 81 | return ( 82 | f"" 85 | ) 86 | -------------------------------------------------------------------------------- /unagi/models/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.models.decoders.classifier import ClassificationDecoder 2 | from unagi.models.decoders.image.resnet import ResnetDecoder 3 | from unagi.models.decoders.image.resnet_autoencoder import ( 4 | Resnet18Decoder, 5 | Resnet50Decoder, 6 | ) 7 | from unagi.models.decoders.sequence.mixer import MixerDecoder 8 | from unagi.models.decoders.sequence.transformer import TransformerDecoder 9 | from unagi.models.embeddings.embeddings import ( 10 | CategoricalEmbed, 11 | Conv2DEmbed, 12 | ConvEmbed, 13 | IdentityEmbed, 14 | LinearPatchEmbed, 15 | PretrainedLMEmbed, 16 | SquarePatchEmbed, 17 | ) 18 | from unagi.models.encoders.image.resnet.resnet import ResnetEncoder 19 | from unagi.models.encoders.sequence.bert.bert import BertEncoder 20 | from unagi.models.encoders.sequence.mixer.mixer import MixerEncoder 21 | from unagi.models.encoders.sequence.transformer.transformer import TransformerEncoder 22 | from unagi.models.layers.patch_augmentations import ( 23 | BrightnessLayer, 24 | CutoutLayer, 25 | InvertLayer, 26 | MixUpLayer, 27 | RotateLayer, 28 | SolarizeLayer, 29 | ) 30 | from unagi.models.ops.grayscale import Grayscale 31 | from unagi.models.ops.image_reshape import ImageReshape 32 | from unagi.models.ops.linear_proj import LinearProj 33 | from unagi.models.ops.pool import PoolDecoder 34 | from unagi.models.ops.sequence_concat import SequenceConcat 35 | from unagi.models.ops.view_concat import ViewConcat 36 | from unagi.models.ops.view_select import ViewSelect 37 | 38 | MODULE_DICTS = { 39 | "embeddings": { 40 | "square_patch": SquarePatchEmbed, 41 | "linear_patch": LinearPatchEmbed, 42 | "categorical": CategoricalEmbed, 43 | "conv2d": Conv2DEmbed, 44 | "conv1d": ConvEmbed, 45 | "pretrained_lm": PretrainedLMEmbed, 46 | "identity": IdentityEmbed, 47 | "sequence_concat": SequenceConcat, 48 | }, 49 | "encoders": { 50 | "mixer": MixerEncoder, 51 | "transformer": TransformerEncoder, 52 | "resnet": ResnetEncoder, 53 | "bert": BertEncoder, 54 | }, 55 | "decoders": { 56 | "classifier": ClassificationDecoder, 57 | "pool": PoolDecoder, 58 | "view_select": ViewSelect, 59 | "view_concat": ViewConcat, 60 | "transformer": TransformerDecoder, 61 | "mixer": MixerDecoder, 62 | "resnet": ResnetDecoder, 63 | "resnet18decoder": Resnet18Decoder, 64 | "resnet50decoder": Resnet50Decoder, 65 | "image_reshape": ImageReshape, 66 | "sequence_concat": SequenceConcat, 67 | "linear_proj": LinearProj, 68 | "grayscale": Grayscale, 69 | }, 70 | } 71 | 72 | AUGMENTATION_LAYERS = { 73 | "patch": { 74 | "mixup": MixUpLayer, 75 | "invert": InvertLayer, 76 | "cutout": CutoutLayer, 77 | "solarize": SolarizeLayer, 78 | "brightness": BrightnessLayer, 79 | "rotate": RotateLayer, 80 | }, 81 | "feature": { 82 | "mixup": MixUpLayer, 83 | "invert": InvertLayer, 84 | "cutout": CutoutLayer, 85 | "solarize": SolarizeLayer, 86 | "brightness": BrightnessLayer, 87 | "rotate": RotateLayer, 88 | }, 89 | } 90 | -------------------------------------------------------------------------------- /unagi/datasets/tiny_imagenet/tinyimagenet_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | import meerkat as mk 4 | import torchvision 5 | 6 | from unagi.datasets.base_dataset import UnagiDatasetBuilder 7 | from unagi.datasets.meerkat_dataset import MeerkatDataset 8 | from unagi.datasets.tiny_imagenet.utils import create_val_img_folder, sparse2coarse 9 | 10 | 11 | class TinyImageNet(UnagiDatasetBuilder): 12 | """Dataset to load TinyImageNet dataset.""" 13 | 14 | _name_ = "tinyimagenet" 15 | # TODO: these can be modified by the transforms (e.g. grayscale) 16 | # and need to be up to date 17 | input_shapes = { 18 | "image": (3, 64, 64), 19 | } 20 | output_shapes = { 21 | "label": (200,), 22 | } 23 | 24 | @property 25 | def init_defaults(self): 26 | return { 27 | "val_split": 0.1, 28 | "seed": 42, # For validation split 29 | "coarse_labels": False, 30 | "root_folder": None, 31 | } 32 | 33 | def setup(self): 34 | if self.root_folder is None: 35 | raise Exception( 36 | "Please specify the path to root folder containing " "TinyImageNet" 37 | ) 38 | 39 | dp = {"train": {}, "val": {}} 40 | 41 | for split in ["train", "val"]: 42 | if split in ["val"]: 43 | folder_path = os.path.join(self.root_folder, split, "images") 44 | # make val image folder 45 | if ( 46 | sum( 47 | [ 48 | os.path.isdir(os.path.join(self.root_folder, x)) 49 | for x in os.listdir(folder_path) 50 | ] 51 | ) 52 | == 0 53 | ): 54 | print("create val folder") 55 | create_val_img_folder(self.root_folder) 56 | 57 | else: 58 | folder_path = os.path.join(self.root_folder, split) 59 | 60 | labels = sorted(os.listdir(folder_path)) 61 | 62 | class_to_idx = {cls: i for i, cls in enumerate(labels)} 63 | 64 | # get image paths 65 | img_paths, classes = zip( 66 | *torchvision.datasets.DatasetFolder.make_dataset( 67 | folder_path, 68 | class_to_idx=class_to_idx, 69 | extensions="jpeg", 70 | ) 71 | ) 72 | 73 | if self.coarse_labels: 74 | classes = sparse2coarse(list(classes)) 75 | self.output_shapes["label"] = (67,) 76 | 77 | split_dp = mk.DataPanel( 78 | { 79 | "image": mk.ImageColumn.from_filepaths(list(img_paths)), 80 | "label": mk.TensorColumn(classes), 81 | } 82 | ) 83 | 84 | dp[split] = split_dp 85 | 86 | self.dataset_train = MeerkatDataset( 87 | dp["train"], xs=list(self.input_shapes.keys()), ys=["label"] 88 | ) 89 | self.dataset_val = MeerkatDataset( 90 | dp["val"], xs=list(self.input_shapes.keys()), ys=["label"] 91 | ) 92 | self.dataset_test = MeerkatDataset( 93 | dp["val"], xs=list(self.input_shapes.keys()), ys=["label"] 94 | ) 95 | -------------------------------------------------------------------------------- /unagi/data/transforms/image/__init__.py: -------------------------------------------------------------------------------- 1 | from unagi.data.transforms.image.auto_contrast import AutoContrast 2 | from unagi.data.transforms.image.blur import Blur 3 | from unagi.data.transforms.image.brightness import Brightness 4 | from unagi.data.transforms.image.center_crop import CenterCrop 5 | from unagi.data.transforms.image.color import Color 6 | from unagi.data.transforms.image.color_distortion import ColorDistortion 7 | from unagi.data.transforms.image.color_jitter import ColorJitter 8 | from unagi.data.transforms.image.contrast import Contrast 9 | from unagi.data.transforms.image.cutout import Cutout 10 | from unagi.data.transforms.image.equalize import Equalize 11 | from unagi.data.transforms.image.gaussian_blur import GaussianBlur 12 | from unagi.data.transforms.image.grayscale import Grayscale 13 | from unagi.data.transforms.image.horizontal_filp import HorizontalFlip 14 | from unagi.data.transforms.image.identity import Identity 15 | from unagi.data.transforms.image.invert import Invert 16 | from unagi.data.transforms.image.normalize import Normalize 17 | from unagi.data.transforms.image.posterize import Posterize 18 | from unagi.data.transforms.image.random_crop import RandomCrop 19 | from unagi.data.transforms.image.random_grayscale import RandomGrayscale 20 | from unagi.data.transforms.image.random_horizontal_flip import RandomHorizontalFlip 21 | from unagi.data.transforms.image.random_resize_crop import RandomResizedCrop 22 | from unagi.data.transforms.image.reshape2d import Reshape2D 23 | from unagi.data.transforms.image.resize import Resize 24 | from unagi.data.transforms.image.resize_and_pad import ResizeAndPad 25 | from unagi.data.transforms.image.rotate import Rotate 26 | from unagi.data.transforms.image.sharpness import Sharpness 27 | from unagi.data.transforms.image.shear_x import ShearX 28 | from unagi.data.transforms.image.shear_y import ShearY 29 | from unagi.data.transforms.image.smooth import Smooth 30 | from unagi.data.transforms.image.solarize import Solarize 31 | from unagi.data.transforms.image.to_tensor import ToTensor 32 | from unagi.data.transforms.image.translate_x import TranslateX 33 | from unagi.data.transforms.image.translate_y import TranslateY 34 | from unagi.data.transforms.image.vertical_flip import VerticalFlip 35 | 36 | ALL_TRANSFORMS = { 37 | "AutoContrast": AutoContrast, 38 | "Blur": Blur, 39 | "Brightness": Brightness, 40 | "GaussianBlur": GaussianBlur, 41 | "CenterCrop": CenterCrop, 42 | "Color": Color, 43 | "Contrast": Contrast, 44 | "Cutout": Cutout, 45 | "Equalize": Equalize, 46 | "GaussianBlur": GaussianBlur, 47 | "Grayscale": Grayscale, 48 | "ColorDistortion": ColorDistortion, 49 | "HorizontalFlip": HorizontalFlip, 50 | "Identity": Identity, 51 | "Invert": Invert, 52 | "Posterize": Posterize, 53 | "RandomCrop": RandomCrop, 54 | "RandomResizedCrop": RandomResizedCrop, 55 | "Resize": Resize, 56 | "Rotate": Rotate, 57 | "Sharpness": Sharpness, 58 | "ShearX": ShearX, 59 | "ShearY": ShearY, 60 | "Smooth": Smooth, 61 | "Solarize": Solarize, 62 | "TranslateX": TranslateX, 63 | "TranslateY": TranslateY, 64 | "VerticalFlip": VerticalFlip, 65 | "ToTensor": ToTensor, 66 | "Normalize": Normalize, 67 | "Reshape2D": Reshape2D, 68 | "RandomHorizontalFlip": RandomHorizontalFlip, 69 | "ResizeAndPad": ResizeAndPad, 70 | "ColorJitter": ColorJitter, 71 | "RandomGrayscale": RandomGrayscale, 72 | } 73 | -------------------------------------------------------------------------------- /unagi/models/encoders/sequence/transformer/transformer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | from unagi.models.encoders.base_sequence import SequenceModule 5 | from unagi.models.encoders.sequence.transformer.transformer_modules import ( 6 | MHA_Encoder, 7 | MHA_Encoder_Cat, 8 | ) 9 | 10 | 11 | class TransformerEncoder(SequenceModule): 12 | def __init__( 13 | self, 14 | d_model, 15 | n_heads, 16 | l_max=512, 17 | n_layers=4, 18 | dropout=0.1, 19 | head_dropout=0.1, 20 | mlp_dim=None, 21 | tie_them_all=False, 22 | cat=False, 23 | d_cat=None, 24 | att_dim=None, 25 | learn_pos=True, 26 | use_cls_token=True, 27 | use_all_tokens=False, 28 | ret_cls_token=True, 29 | **kwargs 30 | ): 31 | super().__init__() 32 | 33 | if cat: 34 | assert d_cat is not None 35 | _f = [] 36 | dim = d_model 37 | for _ in range(n_layers): 38 | layer = MHA_Encoder_Cat( 39 | dim, 40 | n_heads, 41 | mlp_dim=mlp_dim, 42 | att_dim=att_dim, 43 | d_out=d_cat, 44 | head_dropout=head_dropout, 45 | drop_path=dropout, 46 | dropout=dropout, 47 | ) 48 | _f += [layer] 49 | dim += 2 * d_cat 50 | _f += [nn.LayerNorm(dim)] 51 | _f += [nn.Linear(dim, d_model)] 52 | else: 53 | 54 | def _block(drop_path=0.0): 55 | return MHA_Encoder( 56 | d_model, 57 | n_heads, 58 | mlp_dim=mlp_dim, 59 | head_dropout=head_dropout, 60 | drop_path=drop_path, 61 | dropout=dropout, 62 | ) 63 | 64 | if tie_them_all: 65 | _f = [_block()] * n_layers 66 | else: 67 | _f = [ 68 | _block( 69 | drop_path=k * dropout / (n_layers - 1) if n_layers > 1 else 0 70 | ) 71 | for k in range(n_layers) 72 | ] 73 | _f += [nn.LayerNorm(d_model)] 74 | self.f = nn.Sequential(*_f) 75 | self.use_cls_token = use_cls_token 76 | self.use_all_tokens = use_all_tokens 77 | self.ret_cls_token = ret_cls_token 78 | self.learn_pos = learn_pos 79 | self.pe = nn.Parameter(1e-1 * torch.randn(l_max + 2, d_model).clamp(-1, 1)) 80 | self.cls_token = ( 81 | nn.Parameter(1e-1 * torch.randn(1, 1, d_model).clamp(-1, 1)) 82 | if self.use_cls_token 83 | else None 84 | ) 85 | 86 | def add_tokens(self, x): 87 | b, _, d = x.size() 88 | if self.use_cls_token: 89 | x = torch.cat( 90 | [self.cls_token.expand(b, self.cls_token.size(1), d), x], 91 | dim=1, 92 | ) 93 | if self.learn_pos: 94 | x += self.pe[0 : x.size(1)] 95 | return x 96 | 97 | def forward(self, x, state=None, *args, **kwargs): 98 | x = self.add_tokens(x) 99 | x = self.f(x) 100 | if self.ret_cls_token: 101 | x = x[:, 0] 102 | return x 103 | -------------------------------------------------------------------------------- /unagi/data/data_utils/meerkat_processors.py: -------------------------------------------------------------------------------- 1 | # flake8: noqa 2 | # from __future__ import annotation 3 | 4 | import logging 5 | from typing import Collection, List, Sequence, Tuple 6 | 7 | import meerkat as mk 8 | import pandas as pd 9 | import torch 10 | from meerkat.columns.lambda_column import LambdaColumn 11 | from meerkat.tools.lazy_loader import LazyLoader 12 | 13 | from unagi.data.transforms.task import GroupTransform, TupleTransform 14 | 15 | folder = LazyLoader("torchvision.datasets.folder") 16 | 17 | logger = logging.getLogger(__name__) 18 | 19 | """ 20 | class MultiImageColumn(LambdaColumn): 21 | def __init__( 22 | self, 23 | data: Sequence[Tuple(str, str)] = None, 24 | transform: List[callable] = None, 25 | loader: callable = None, 26 | *args, 27 | **kwargs, 28 | ): 29 | super(MultiImageColumn, self).__init__( 30 | mk.PandasSeriesColumn.from_data(data), *args, **kwargs 31 | ) 32 | self.loader = self.default_loader if loader is None else loader 33 | self.transform = transform 34 | 35 | def fn(self, filepaths: Tuple(str, str)): 36 | image_0, image_1 = self.loader(filepaths[0]), self.loader(filepaths[1]) 37 | image_0, image_1 = self.transform[0](image_0), self.transform[1](image_1) 38 | image_cat = torch.cat((image_0, image_1), 1) 39 | return self.transform[2](image_cat) 40 | 41 | @classmethod 42 | def from_filepaths( 43 | cls, 44 | filepaths: List[Sequence[str]], 45 | loader: callable = None, 46 | transform: List[callable] = None, 47 | *args, 48 | **kwargs, 49 | ): 50 | return cls(data=filepaths, loader=loader, transform=transform, *args, **kwargs) 51 | 52 | @classmethod 53 | def default_loader(cls, *args, **kwargs): 54 | return folder.default_loader(*args, **kwargs) 55 | 56 | @classmethod 57 | def _state_keys(cls) -> Collection: 58 | return (super()._state_keys() | {"transform", "loader"}) - {"fn"} 59 | 60 | def _repr_pandas_(self) -> pd.Series: 61 | return "ImageCell(" + self.data.data.reset_index(drop=True) + ")" 62 | """ 63 | 64 | 65 | class TextTransformCell(mk.AbstractCell): 66 | def __init__(self, input_text: str, transforms): 67 | self.input = input_text 68 | self.transforms = transforms 69 | self._token_ids = None 70 | 71 | def get(self): 72 | if self._token_ids is None: 73 | token_ids = self.transforms(self.input, None)[0] 74 | self._token_ids = token_ids 75 | return self._token_ids 76 | 77 | def data(self): 78 | return self.input 79 | 80 | def __repr__(self): 81 | return "TextTransformCell" 82 | 83 | 84 | class PILImgTransformCell(mk.AbstractCell): 85 | def __init__(self, pil_image, transforms): 86 | self.pil_image = pil_image 87 | self.transforms = transforms 88 | 89 | def get(self): 90 | if self.transforms is None: 91 | return self.pil_image 92 | else: 93 | transformed_img = self.transforms(self.pil_image, None) 94 | if not isinstance(self.transforms, TupleTransform) and not isinstance( 95 | self.transforms, GroupTransform 96 | ): 97 | transformed_img = transformed_img[0] 98 | return transformed_img 99 | 100 | def data(self): 101 | return self.pil_image 102 | 103 | def transforms(self): 104 | return self.transforms 105 | 106 | def __repr__(self): 107 | return "PILImgTransformCell" 108 | -------------------------------------------------------------------------------- /unagi/datasets/cifar/utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | # https://github.com/ryanchankh/cifar100coarse/blob/master/sparse2coarse.py 5 | def sparse2coarse(targets, scramble=False, dataset="cifar10"): 6 | """Convert Pytorch CIFAR100 sparse targets to coarse targets. 7 | Usage: 8 | trainset = torchvision.datasets.CIFAR100(path) 9 | trainset.targets = sparse2coarse(trainset.targets) 10 | """ 11 | if dataset == "cifar100": 12 | sparse_coarse_array = [ 13 | 4, 14 | 1, 15 | 14, 16 | 8, 17 | 0, 18 | 6, 19 | 7, 20 | 7, 21 | 18, 22 | 3, 23 | 3, 24 | 14, 25 | 9, 26 | 18, 27 | 7, 28 | 11, 29 | 3, 30 | 9, 31 | 7, 32 | 11, 33 | 6, 34 | 11, 35 | 5, 36 | 10, 37 | 7, 38 | 6, 39 | 13, 40 | 15, 41 | 3, 42 | 15, 43 | 0, 44 | 11, 45 | 1, 46 | 10, 47 | 12, 48 | 14, 49 | 16, 50 | 9, 51 | 11, 52 | 5, 53 | 5, 54 | 19, 55 | 8, 56 | 8, 57 | 15, 58 | 13, 59 | 14, 60 | 17, 61 | 18, 62 | 10, 63 | 16, 64 | 4, 65 | 17, 66 | 4, 67 | 2, 68 | 0, 69 | 17, 70 | 4, 71 | 18, 72 | 17, 73 | 10, 74 | 3, 75 | 2, 76 | 12, 77 | 12, 78 | 16, 79 | 12, 80 | 1, 81 | 9, 82 | 19, 83 | 2, 84 | 10, 85 | 0, 86 | 1, 87 | 16, 88 | 12, 89 | 9, 90 | 13, 91 | 15, 92 | 13, 93 | 16, 94 | 19, 95 | 2, 96 | 4, 97 | 6, 98 | 19, 99 | 5, 100 | 5, 101 | 8, 102 | 19, 103 | 18, 104 | 1, 105 | 2, 106 | 15, 107 | 6, 108 | 0, 109 | 17, 110 | 8, 111 | 14, 112 | 13, 113 | ] 114 | else: 115 | # index of original labels: 116 | # [b'airplane', b'automobile', b'bird', b'cat', b'deer', 117 | # b'dog', b'frog', b'horse', b'ship', b'truck'] 118 | sparse_coarse_array = [1, 1, 0, 0, 0, 0, 0, 0, 1, 1] 119 | 120 | targets = np.array(sparse_coarse_array)[targets] 121 | return targets.tolist() 122 | 123 | 124 | def get_superclass_subclass_mapping(): 125 | return { 126 | 0: [4, 30, 55, 72, 95], 127 | 1: [1, 32, 67, 73, 91], 128 | 2: [54, 62, 70, 82, 92], 129 | 3: [9, 10, 16, 28, 61], 130 | 4: [0, 51, 53, 57, 83], 131 | 5: [22, 39, 40, 86, 87], 132 | 6: [5, 20, 25, 84, 94], 133 | 7: [6, 7, 14, 18, 24], 134 | 8: [3, 42, 43, 88, 97], 135 | 9: [12, 17, 37, 68, 76], 136 | 10: [23, 33, 49, 60, 71], 137 | 11: [15, 19, 21, 31, 38], 138 | 12: [34, 63, 64, 66, 75], 139 | 13: [26, 45, 77, 79, 99], 140 | 14: [2, 11, 35, 46, 98], 141 | 15: [27, 29, 44, 78, 93], 142 | 16: [36, 50, 65, 74, 80], 143 | 17: [47, 52, 56, 59, 96], 144 | 18: [8, 13, 48, 58, 90], 145 | 19: [41, 69, 81, 85, 89], 146 | } 147 | -------------------------------------------------------------------------------- /unagi/data/data_utils/transform_util.py: -------------------------------------------------------------------------------- 1 | from copy import deepcopy 2 | 3 | from unagi.data.transforms import ALL_TRANSFORMS 4 | from unagi.data.transforms.image.compose import Compose 5 | 6 | 7 | def get_transforms( 8 | input_features: dict, 9 | dataset_split: str, 10 | augmentations: dict, 11 | default_transforms: dict = {}, 12 | ): 13 | """ 14 | Gets list of transforms for each input feature. 15 | 16 | # Inputs 17 | :param input_features: (dict) contains all imput feature metadata including 18 | transform information used by this module. 19 | 20 | # Returns 21 | :return: returns a dict mapping input feature name to relevamt transforms. 22 | """ 23 | ifeat_to_transforms = {} 24 | for name, inpt_feat in input_features.items(): 25 | transforms_list = [] 26 | feat_type = inpt_feat["type"] 27 | # key that has corresponding mapping in augmentations.raw section 28 | if dataset_split == "train": 29 | augmentation_key = ( 30 | inpt_feat["transform"] if "transform" in inpt_feat.keys() else None 31 | ) 32 | if augmentation_key is not None and augmentations is not None: 33 | augmentation_list = augmentations[augmentation_key] 34 | for aug in augmentation_list: 35 | type = aug["type"] 36 | aug = deepcopy(aug) 37 | del aug["type"] 38 | if type in ALL_TRANSFORMS[feat_type]: 39 | transforms_list.append(ALL_TRANSFORMS[feat_type][type](**aug)) 40 | else: 41 | raise ValueError( 42 | f"Unknown transform type: {type} for feature: {name}" 43 | ) 44 | # check if default transformation is specified in experiment config 45 | # if yes, overwrite preset dataset transformation 46 | default_transforms_key = ( 47 | inpt_feat["default_transform"] 48 | if "default_transform" in inpt_feat.keys() 49 | else None 50 | ) 51 | if default_transforms_key is not None and augmentations is not None: 52 | augmentation_list = augmentations[default_transforms_key] 53 | for aug in augmentation_list: 54 | type = "" + aug["type"] 55 | aug = deepcopy(aug) 56 | del aug["type"] 57 | if type in ALL_TRANSFORMS[feat_type]: 58 | transforms_list.append(ALL_TRANSFORMS[feat_type][type](**aug)) 59 | else: 60 | raise ValueError( 61 | f"Unknown transform type: {type} for feature: {name}" 62 | ) 63 | else: 64 | # use dataset preset transform 65 | if feat_type in default_transforms: 66 | transforms_list.extend(default_transforms[feat_type]) 67 | composed_transforms = Compose(transforms_list) 68 | if inpt_feat["views"] >= 1: 69 | contrastive_transform = ALL_TRANSFORMS["task"]["Contrastive"] 70 | composed_transforms = contrastive_transform( 71 | composed_transforms, 72 | inpt_feat["views"] if dataset_split == "train" else 1, 73 | ) 74 | if inpt_feat["mask"]: 75 | tuple_transform = ALL_TRANSFORMS["task"]["Mask"] 76 | mask_gen = ALL_TRANSFORMS["task"]["MaskGenerator"] 77 | composed_transforms = tuple_transform( 78 | composed_transforms, 79 | mask_gen( 80 | 1, # task_config["contrastive"]["contrastive_views"], 81 | inpt_feat["mask_length"], 82 | ), 83 | ) 84 | 85 | ifeat_to_transforms[name] = composed_transforms 86 | 87 | return ifeat_to_transforms 88 | -------------------------------------------------------------------------------- /unagi/models/layers/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch import nn 4 | 5 | 6 | class Transpose(nn.Module): 7 | def __init__(self, i, j): 8 | super().__init__() 9 | self.i = i 10 | self.j = j 11 | 12 | def forward(self, x): 13 | return x.transpose(self.i, self.j) 14 | 15 | 16 | class Truncate(nn.Module): 17 | def __init__(self, max_sequence_length): 18 | super().__init__() 19 | self.max_sequence_length = max_sequence_length 20 | 21 | def forward(self, input): 22 | if self.max_sequence_length is not None: 23 | # NOTE: assumes input is of form (batch, seq_length, hidden_dim) 24 | # or (batch, seq_length) 25 | if input.size(1) > self.max_sequence_length: 26 | input = input[:, 0 : self.max_sequence_length, :] 27 | elif input.size(1) < self.max_sequence_length: 28 | if len(input.size()) == 2: 29 | pad = (0, self.max_sequence_length - input.size(1)) 30 | elif len(input.size()) == 3: 31 | pad = (0, 0, self.max_sequence_length - input.size(1), 0) 32 | input = F.pad(input, pad, mode="constant", value=0) 33 | return input 34 | 35 | 36 | class PreNorm(nn.Module): 37 | def __init__(self, d, f): 38 | super().__init__() 39 | self.f = f 40 | self.norm = nn.LayerNorm(d) 41 | 42 | def forward(self, x, **kwargs): 43 | return self.f(self.norm(x), **kwargs) 44 | 45 | 46 | class FFN(nn.Module): 47 | def __init__(self, d, mlp_dim, out_dim=None, dropout=0.1): 48 | super().__init__() 49 | out_dim = d if out_dim is None else out_dim 50 | self.f = nn.Sequential( 51 | nn.Linear(d, mlp_dim), 52 | nn.GELU(), 53 | nn.Dropout(dropout), 54 | nn.Linear(mlp_dim, out_dim), 55 | nn.Dropout(dropout), 56 | ) 57 | 58 | def forward(self, x, **kwargs): 59 | return self.f(x, **kwargs) 60 | 61 | 62 | # https://github.com/SHI-Labs/Compact-Transformers/blob/f6d43e50ece006b933eeb27b087a0c3cad3bc635/src/transformers.py#L90 63 | 64 | 65 | class DropPath(nn.Module): 66 | def __init__(self, drop_prob): 67 | super().__init__() 68 | self.keep_prob = 1 - drop_prob 69 | 70 | def forward(self, x): 71 | if self.keep_prob >= 1.0 or not self.training: 72 | return x 73 | # work with diff dim tensors, not just 2D ConvNets 74 | shape = (x.shape[0],) + (1,) * (x.ndim - 1) 75 | random_tensor = self.keep_prob + torch.rand( 76 | shape, dtype=x.dtype, device=x.device 77 | ) 78 | random_tensor.floor_() # binarize 79 | return x.div(self.keep_prob) * random_tensor 80 | 81 | 82 | class Residual(nn.Module): 83 | def __init__(self, d, f, trainable=False, per_channel=False, drop_path=0.0): 84 | super().__init__() 85 | _init = [1.0] * d if per_channel else [1.0] 86 | self.scalar = nn.Parameter(torch.tensor(_init)) if trainable else 1.0 87 | self.f = f 88 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 89 | 90 | def forward(self, x, **kwargs): 91 | return self.drop_path(self.f(x, **kwargs)) + x * self.scalar 92 | 93 | 94 | class Cat(nn.Module): 95 | def __init__(self, f, drop_path=0.0): 96 | super().__init__() 97 | self.f = f 98 | self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() 99 | 100 | def forward(self, x, **kwargs): 101 | y = self.drop_path(self.f(x, **kwargs)) 102 | return torch.cat([x, y], dim=-1) 103 | 104 | 105 | class Classifier(nn.Module): 106 | def __init__(self, input_dim, target_dim): 107 | super().__init__() 108 | self.classification_layer = nn.Linear(input_dim, target_dim) 109 | 110 | def forward(self, x, **kwargs): 111 | return self.classification_layer(x) 112 | --------------------------------------------------------------------------------