├── ldm ├── data │ ├── __init__.py │ ├── __pycache__ │ │ ├── vg.cpython-37.pyc │ │ ├── base.cpython-37.pyc │ │ ├── coco.cpython-37.pyc │ │ └── __init__.cpython-37.pyc │ ├── base.py │ └── vg.py ├── modules │ ├── cgip │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── cgip.cpython-37.pyc │ │ │ ├── tools.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ │ └── tools.py │ ├── encoders │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── modules.cpython-37.pyc │ │ └── modules.py │ ├── distributions │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── distributions.cpython-37.pyc │ │ └── distributions.py │ ├── diffusionmodules │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── model.cpython-37.pyc │ │ │ ├── util.cpython-37.pyc │ │ │ ├── __init__.cpython-37.pyc │ │ │ └── openaimodel.cpython-37.pyc │ ├── losses │ │ ├── __init__.py │ │ ├── contperceptual.py │ │ └── vqperceptual.py │ ├── __pycache__ │ │ ├── ema.cpython-37.pyc │ │ ├── attention.cpython-37.pyc │ │ └── x_transformer.cpython-37.pyc │ ├── image_degradation │ │ ├── utils │ │ │ └── test.png │ │ └── __init__.py │ ├── ema.py │ └── attention.py ├── models │ ├── diffusion │ │ ├── __init__.py │ │ └── __pycache__ │ │ │ ├── ddim.cpython-37.pyc │ │ │ ├── ddpm.cpython-37.pyc │ │ │ └── __init__.cpython-37.pyc │ └── __pycache__ │ │ └── autoencoder.cpython-37.pyc ├── __pycache__ │ └── util.cpython-37.pyc ├── lr_scheduler.py └── util.py ├── sg_image_pretraining ├── clip │ ├── __init__.py │ ├── bpe_simple_vocab_16e6.txt.gz │ └── simple_tokenizer.py ├── .DS_Store ├── global_var.py ├── sgCLIP │ ├── constants.py │ ├── contrastive_losses.py │ ├── create_model.py │ ├── create_model_coco.py │ ├── generative_loss.py │ ├── mm_transformer_module.py │ ├── model.py │ └── masked_loss.py ├── ReadMe.md ├── model_configs │ └── RN50.json ├── training │ ├── precision.py │ ├── scheduler.py │ ├── logger.py │ ├── distributed.py │ ├── configs.py │ ├── configs_coco.py │ └── train_mim.py ├── datasets │ ├── dataloader_builder.py │ ├── dataloader_builder_coco.py │ ├── transform.py │ └── vg_dataset.py ├── trainer.py ├── trainer_coco.py └── utils.py ├── setup.py ├── sg2im ├── __init__.py ├── data │ ├── __init__.py │ ├── utils.py │ └── vg.py └── vis.py ├── scripts ├── download_coco.sh ├── download_vg.sh ├── download_first_stages.sh ├── download_models.sh ├── datamodule.py └── train_searcher.py ├── DATA.md ├── config_vg.yaml ├── config_coco.yaml ├── README.md └── sgdiff.yaml /ldm/data/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/cgip/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /sg_image_pretraining/clip/__init__.py: -------------------------------------------------------------------------------- 1 | from .clip import * 2 | -------------------------------------------------------------------------------- /ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /sg_image_pretraining/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/sg_image_pretraining/.DS_Store -------------------------------------------------------------------------------- /sg_image_pretraining/global_var.py: -------------------------------------------------------------------------------- 1 | import torch 2 | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") -------------------------------------------------------------------------------- /ldm/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/vg.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/data/__pycache__/vg.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/base.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/data/__pycache__/base.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/coco.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/data/__pycache__/coco.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/data/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/data/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/ema.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/__pycache__/ema.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /ldm/models/__pycache__/autoencoder.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/models/__pycache__/autoencoder.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/attention.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/__pycache__/attention.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/cgip/__pycache__/cgip.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/cgip/__pycache__/cgip.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/cgip/__pycache__/tools.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/cgip/__pycache__/tools.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/models/diffusion/__pycache__/ddim.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/ddpm.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/models/diffusion/__pycache__/ddpm.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/__pycache__/x_transformer.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/__pycache__/x_transformer.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/cgip/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/cgip/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /sg_image_pretraining/clip/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/sg_image_pretraining/clip/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/constants.py: -------------------------------------------------------------------------------- 1 | OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) 2 | OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) -------------------------------------------------------------------------------- /ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/models/diffusion/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/encoders/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/encoders/__pycache__/modules.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/encoders/__pycache__/modules.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/diffusionmodules/__pycache__/model.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/diffusionmodules/__pycache__/util.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/distributions/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/diffusionmodules/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/distributions/__pycache__/distributions.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YangLing0818/SGDiff/HEAD/ldm/modules/diffusionmodules/__pycache__/openaimodel.cpython-37.pyc -------------------------------------------------------------------------------- /ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | from ldm.modules.image_degradation.bsrgan import degradation_bsrgan_variant as degradation_fn_bsr 2 | from ldm.modules.image_degradation.bsrgan_light import degradation_bsrgan_variant as degradation_fn_bsr_light 3 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | setup( 4 | name='latent-diffusion', 5 | version='0.0.1', 6 | description='', 7 | packages=find_packages(), 8 | install_requires=[ 9 | 'torch', 10 | 'numpy', 11 | 'tqdm', 12 | ], 13 | ) -------------------------------------------------------------------------------- /sg_image_pretraining/ReadMe.md: -------------------------------------------------------------------------------- 1 | ### Masked Contrastive Pre-Training of Scene Graphs and Images 2 | 3 | 1. Modify the `````` in the ```training/configs.py``` and ```training/configs_coco.py``` 4 | 5 | 2. Run 6 | 7 | ```shell 8 | conda activate sgdiff 9 | # for vg 10 | python trainer.py 11 | # for coco 12 | python trainer_coco.py 13 | ``` 14 | 15 | -------------------------------------------------------------------------------- /sg_image_pretraining/model_configs/RN50.json: -------------------------------------------------------------------------------- 1 | { 2 | "vision_cfg": { 3 | "layers": [3,4,6,3], 4 | "image_size": 256, 5 | "width": 64, 6 | "head_width": 64, 7 | "mlp_ratio": 4.0, 8 | "patch_size": null, 9 | }, 10 | "graph_cfg": { 11 | "layers": 5, 12 | "width": 512, 13 | }, 14 | "embed_dim": 1024, 15 | } -------------------------------------------------------------------------------- /sg_image_pretraining/training/precision.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from contextlib import suppress 3 | 4 | # amp_bfloat16 is more stable than amp float16 for clip training 5 | def get_autocast(precision): 6 | if precision == 'amp': 7 | return torch.cuda.amp.autocast 8 | elif precision == 'amp_bfloat16': 9 | return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) 10 | else: 11 | return suppress -------------------------------------------------------------------------------- /sg2im/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | -------------------------------------------------------------------------------- /scripts/download_coco.sh: -------------------------------------------------------------------------------- 1 | COCO_DIR=datasets/coco 2 | mkdir -p $COCO_DIR 3 | 4 | wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip -O $COCO_DIR/annotations_trainval2017.zip 5 | wget http://images.cocodataset.org/annotations/stuff_annotations_trainval2017.zip -O $COCO_DIR/stuff_annotations_trainval2017.zip 6 | wget http://images.cocodataset.org/zips/train2017.zip -O $COCO_DIR/train2017.zip 7 | wget http://images.cocodataset.org/zips/val2017.zip -O $COCO_DIR/val2017.zip 8 | 9 | unzip $COCO_DIR/annotations_trainval2017.zip -d $COCO_DIR 10 | unzip $COCO_DIR/stuff_annotations_trainval2017.zip -d $COCO_DIR 11 | unzip $COCO_DIR/train2017.zip -d $COCO_DIR/images 12 | unzip $COCO_DIR/val2017.zip -d $COCO_DIR/images 13 | -------------------------------------------------------------------------------- /sg2im/data/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | from .utils import imagenet_preprocess, imagenet_deprocess 18 | from .utils import imagenet_deprocess_batch 19 | -------------------------------------------------------------------------------- /ldm/data/base.py: -------------------------------------------------------------------------------- 1 | from abc import abstractmethod 2 | from torch.utils.data import Dataset, ConcatDataset, ChainDataset, IterableDataset 3 | 4 | 5 | class Txt2ImgIterableBaseDataset(IterableDataset): 6 | ''' 7 | Define an interface to make the IterableDatasets for text2img data chainable 8 | ''' 9 | def __init__(self, num_records=0, valid_ids=None, size=256): 10 | super().__init__() 11 | self.num_records = num_records 12 | self.valid_ids = valid_ids 13 | self.sample_ids = valid_ids 14 | self.size = size 15 | 16 | print(f'{self.__class__.__name__} dataset contains {self.__len__()} examples.') 17 | 18 | def __len__(self): 19 | return self.num_records 20 | 21 | @abstractmethod 22 | def __iter__(self): 23 | pass -------------------------------------------------------------------------------- /sg_image_pretraining/training/scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from global_var import * 3 | 4 | def assign_learning_rate(optimizer, new_lr): 5 | for param_group in optimizer.param_groups: 6 | param_group["lr"] = new_lr 7 | 8 | 9 | def _warmup_lr(base_lr, warmup_length, step): 10 | return base_lr * (step + 1) / warmup_length 11 | 12 | 13 | def cosine_lr(optimizer, base_lr, warmup_length, steps): 14 | def _lr_adjuster(step): 15 | if step < warmup_length: 16 | lr = _warmup_lr(base_lr, warmup_length, step) 17 | else: 18 | e = step - warmup_length 19 | es = steps - warmup_length 20 | lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr 21 | assign_learning_rate(optimizer, lr) 22 | return lr 23 | return _lr_adjuster -------------------------------------------------------------------------------- /sg_image_pretraining/datasets/dataloader_builder.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from datasets.vg_dataset import vg_collate_fn, build_vg_dsets 3 | 4 | def build_vg_loaders(args): 5 | 6 | vocab, train_dset, val_dset = build_vg_dsets(args) 7 | collate_fn = vg_collate_fn 8 | 9 | loader_kwargs = { 10 | 'batch_size': args.batch_size, 11 | 'num_workers': args.workers, 12 | 'shuffle': True, 13 | 'collate_fn': collate_fn, 14 | } 15 | train_loader = DataLoader(train_dset, **loader_kwargs) 16 | train_loader.num_samples = len(train_dset) 17 | 18 | loader_kwargs['batch_size'] = args.val_batch_size 19 | loader_kwargs['shuffle'] = False 20 | val_loader = DataLoader(val_dset, **loader_kwargs) 21 | val_loader.num_samples = len(val_dset) 22 | return vocab, train_loader, val_loader 23 | -------------------------------------------------------------------------------- /sg_image_pretraining/datasets/dataloader_builder_coco.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import DataLoader 2 | from datasets.coco_dataset import coco_collate_fn, build_coco_dsets 3 | 4 | def build_coco_loaders(args): 5 | 6 | vocab, train_dset, val_dset = build_coco_dsets(args) 7 | collate_fn = coco_collate_fn 8 | 9 | loader_kwargs = { 10 | 'batch_size': args.batch_size, 11 | 'num_workers': args.workers, 12 | 'shuffle': True, 13 | 'collate_fn': collate_fn, 14 | } 15 | train_loader = DataLoader(train_dset, **loader_kwargs) 16 | train_loader.num_samples = len(train_dset) 17 | 18 | loader_kwargs['batch_size'] = args.val_batch_size 19 | loader_kwargs['shuffle'] = False 20 | val_loader = DataLoader(val_dset, **loader_kwargs) 21 | val_loader.num_samples = len(val_dset) 22 | 23 | return vocab, train_loader, val_loader 24 | 25 | -------------------------------------------------------------------------------- /sg_image_pretraining/training/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def setup_logging(log_file, level, include_host=False): 4 | if include_host: 5 | import socket 6 | hostname = socket.gethostname() 7 | formatter = logging.Formatter( 8 | f'%(asctime)s | {hostname} | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 9 | else: 10 | formatter = logging.Formatter('%(asctime)s | %(levelname)s | %(message)s', datefmt='%Y-%m-%d,%H:%M:%S') 11 | 12 | logging.root.setLevel(level) 13 | loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] 14 | for logger in loggers: 15 | logger.setLevel(level) 16 | 17 | stream_handler = logging.StreamHandler() 18 | stream_handler.setFormatter(formatter) 19 | logging.root.addHandler(stream_handler) 20 | 21 | if log_file: 22 | file_handler = logging.FileHandler(filename=log_file) 23 | file_handler.setFormatter(formatter) 24 | logging.root.addHandler(file_handler) -------------------------------------------------------------------------------- /scripts/download_vg.sh: -------------------------------------------------------------------------------- 1 | VG_DIR=datasets/vg 2 | mkdir -p $VG_DIR 3 | 4 | wget https://visualgenome.org/static/data/dataset/objects.json.zip -O $VG_DIR/objects.json.zip 5 | wget https://visualgenome.org/static/data/dataset/attributes.json.zip -O $VG_DIR/attributes.json.zip 6 | wget https://visualgenome.org/static/data/dataset/relationships.json.zip -O $VG_DIR/relationships.json.zip 7 | wget https://visualgenome.org/static/data/dataset/object_alias.txt -O $VG_DIR/object_alias.txt 8 | wget https://visualgenome.org/static/data/dataset/relationship_alias.txt -O $VG_DIR/relationship_alias.txt 9 | wget https://visualgenome.org/static/data/dataset/image_data.json.zip -O $VG_DIR/image_data.json.zip 10 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images.zip -O $VG_DIR/images.zip 11 | wget https://cs.stanford.edu/people/rak248/VG_100K_2/images2.zip -O $VG_DIR/images2.zip 12 | 13 | unzip $VG_DIR/objects.json.zip -d $VG_DIR 14 | unzip $VG_DIR/attributes.json.zip -d $VG_DIR 15 | unzip $VG_DIR/relationships.json.zip -d $VG_DIR 16 | unzip $VG_DIR/image_data.json.zip -d $VG_DIR 17 | unzip $VG_DIR/images.zip -d $VG_DIR/images 18 | unzip $VG_DIR/images2.zip -d $VG_DIR/images 19 | -------------------------------------------------------------------------------- /scripts/download_first_stages.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/first_stage_models/kl-f4/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f4.zip 3 | wget -O models/first_stage_models/kl-f8/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f8.zip 4 | wget -O models/first_stage_models/kl-f16/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f16.zip 5 | wget -O models/first_stage_models/kl-f32/model.zip https://ommer-lab.com/files/latent-diffusion/kl-f32.zip 6 | wget -O models/first_stage_models/vq-f4/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4.zip 7 | wget -O models/first_stage_models/vq-f4-noattn/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f4-noattn.zip 8 | wget -O models/first_stage_models/vq-f8/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 9 | wget -O models/first_stage_models/vq-f8-n256/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f8-n256.zip 10 | wget -O models/first_stage_models/vq-f16/model.zip https://ommer-lab.com/files/latent-diffusion/vq-f16.zip 11 | 12 | 13 | 14 | cd models/first_stage_models/kl-f4 15 | unzip -o model.zip 16 | 17 | cd ../kl-f8 18 | unzip -o model.zip 19 | 20 | cd ../kl-f16 21 | unzip -o model.zip 22 | 23 | cd ../kl-f32 24 | unzip -o model.zip 25 | 26 | cd ../vq-f4 27 | unzip -o model.zip 28 | 29 | cd ../vq-f4-noattn 30 | unzip -o model.zip 31 | 32 | cd ../vq-f8 33 | unzip -o model.zip 34 | 35 | cd ../vq-f8-n256 36 | unzip -o model.zip 37 | 38 | cd ../vq-f16 39 | unzip -o model.zip 40 | 41 | cd ../.. -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/contrastive_losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn import functional as F 4 | import torch.distributed.nn 5 | has_distributed = True 6 | 7 | class ClipLoss(nn.Module): 8 | 9 | def __init__( 10 | self, 11 | local_loss=False, 12 | gather_with_grad=False, 13 | cache_labels=False, 14 | rank=0, 15 | world_size=1, 16 | ): 17 | super().__init__() 18 | self.local_loss = local_loss 19 | self.gather_with_grad = gather_with_grad 20 | self.cache_labels = cache_labels 21 | self.rank = rank 22 | self.world_size = world_size 23 | 24 | self.prev_num_logits = 0 25 | self.labels = {} 26 | 27 | def forward(self, image_features, graph_features, logit_scale): 28 | device = image_features.device 29 | 30 | logits_per_image = logit_scale * image_features @ graph_features.T 31 | logits_per_graph = logit_scale * graph_features @ image_features.T 32 | 33 | num_logits = logits_per_image.shape[0] 34 | if self.prev_num_logits != num_logits or device not in self.labels: 35 | labels = torch.arange(num_logits, device=device, dtype=torch.long) 36 | else: 37 | labels = self.labels[device] 38 | 39 | total_loss = ( 40 | F.cross_entropy(logits_per_image, labels) + 41 | F.cross_entropy(logits_per_graph, labels) 42 | ) / 2 43 | return total_loss -------------------------------------------------------------------------------- /DATA.md: -------------------------------------------------------------------------------- 1 | ## Step 1: Install COCO API 2 | To train new models you will need to install the [COCO Python API](https://github.com/cocodataset/cocoapi). Unfortunately installing this package via pip often leads to build errors, but you can install it from source like this: 3 | 4 | ```bash 5 | cd ~ 6 | git clone https://github.com/cocodataset/cocoapi.git 7 | cd cocoapi/PythonAPI/ 8 | python setup.py install 9 | ``` 10 | 11 | ## Step 2: Preparing the data 12 | 13 | ### Visual Genome 14 | Run the following script to download and unpack the relevant parts of the Visual Genome dataset: 15 | 16 | ```bash 17 | bash scripts/download_vg.sh 18 | ``` 19 | 20 | This will create the directory `datasets/vg` and will download about 15 GB of data to this directory; after unpacking it will take about 30 GB of disk space. 21 | 22 | After downloading the Visual Genome dataset, we need to preprocess it. This will split the data into train / val / test splits, consolidate all scene graphs into HDF5 files, and apply several heuristics to clean the data. In particular we ignore images that are too small, and only consider object and attribute categories that appear some number of times in the training set; we also igmore objects that are too small, and set minimum and maximum values on the number of objects and relationships that appear per image. 23 | 24 | ```bash 25 | python scripts/preprocess_vg.py 26 | ``` 27 | 28 | This will create files `train.h5`, `val.h5`, `test.h5`, and `vocab.json` in the directory `datasets/vg`. 29 | 30 | ### COCO 31 | Run the following script to download and unpack the relevant parts of the COCO dataset: 32 | 33 | ```bash 34 | bash scripts/download_coco.sh 35 | ``` 36 | 37 | This will create the directory `datasets/coco` and will download about 21 GB of data to this directory; after unpacking it will take about 60 GB of disk space. 38 | -------------------------------------------------------------------------------- /scripts/download_models.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | wget -O models/ldm/celeba256/celeba-256.zip https://ommer-lab.com/files/latent-diffusion/celeba.zip 3 | wget -O models/ldm/ffhq256/ffhq-256.zip https://ommer-lab.com/files/latent-diffusion/ffhq.zip 4 | wget -O models/ldm/lsun_churches256/lsun_churches-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_churches.zip 5 | wget -O models/ldm/lsun_beds256/lsun_beds-256.zip https://ommer-lab.com/files/latent-diffusion/lsun_bedrooms.zip 6 | wget -O models/ldm/text2img256/model.zip https://ommer-lab.com/files/latent-diffusion/text2img.zip 7 | wget -O models/ldm/cin256/model.zip https://ommer-lab.com/files/latent-diffusion/cin.zip 8 | wget -O models/ldm/semantic_synthesis512/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis.zip 9 | wget -O models/ldm/semantic_synthesis256/model.zip https://ommer-lab.com/files/latent-diffusion/semantic_synthesis256.zip 10 | wget -O models/ldm/bsr_sr/model.zip https://ommer-lab.com/files/latent-diffusion/sr_bsr.zip 11 | wget -O models/ldm/layout2img-openimages256/model.zip https://ommer-lab.com/files/latent-diffusion/layout2img_model.zip 12 | wget -O models/ldm/inpainting_big/model.zip https://ommer-lab.com/files/latent-diffusion/inpainting_big.zip 13 | 14 | 15 | 16 | cd models/ldm/celeba256 17 | unzip -o celeba-256.zip 18 | 19 | cd ../ffhq256 20 | unzip -o ffhq-256.zip 21 | 22 | cd ../lsun_churches256 23 | unzip -o lsun_churches-256.zip 24 | 25 | cd ../lsun_beds256 26 | unzip -o lsun_beds-256.zip 27 | 28 | cd ../text2img256 29 | unzip -o model.zip 30 | 31 | cd ../cin256 32 | unzip -o model.zip 33 | 34 | cd ../semantic_synthesis512 35 | unzip -o model.zip 36 | 37 | cd ../semantic_synthesis256 38 | unzip -o model.zip 39 | 40 | cd ../bsr_sr 41 | unzip -o model.zip 42 | 43 | cd ../layout2img-openimages256 44 | unzip -o model.zip 45 | 46 | cd ../inpainting_big 47 | unzip -o model.zip 48 | 49 | cd ../.. 50 | -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/create_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from typing import Optional, Tuple 4 | from sgCLIP.model import sgCLIP, convert_weights_to_fp16 5 | 6 | 7 | def create_model( 8 | args, 9 | graph_vocab: dict, 10 | model_config_json: str, 11 | precision: str = 'fp32', 12 | device: torch.device = torch.device('cpu'), 13 | force_quick_gelu: bool = False, 14 | pretrained_image: bool = False, 15 | ): 16 | if model_config_json != '': 17 | with open(model_config_json, 'r') as f: 18 | model_cfg = json.load(f) 19 | else: 20 | model_cfg = { 21 | "graph_cfg": { 22 | "layers": args.num_graph_layer, 23 | "width": args.graph_width, 24 | }, 25 | "embed_dim": args.embed_dim, 26 | } 27 | 28 | if force_quick_gelu: 29 | model_cfg["quick_gelu"] = True 30 | 31 | if pretrained_image: 32 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 33 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 34 | else: 35 | assert False, 'pretrained image towers currently only supported for timm models' 36 | 37 | model = sgCLIP(graph_vocab=graph_vocab, **model_cfg) 38 | 39 | model.to(device=device) 40 | if precision == "fp16": 41 | assert device.type != 'cpu' 42 | convert_weights_to_fp16(model) 43 | 44 | return model 45 | 46 | def create_model_and_transforms( 47 | args, 48 | graph_vocab: dict, 49 | model_config_json: str, 50 | precision: str = 'fp32', 51 | device: torch.device = torch.device('cpu'), 52 | force_quick_gelu: bool = False, 53 | pretrained_image: bool = False, 54 | ): 55 | model = create_model(args, graph_vocab, model_config_json, precision, device, force_quick_gelu=force_quick_gelu, pretrained_image=pretrained_image) 56 | 57 | return model -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/create_model_coco.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import json 3 | from typing import Optional, Tuple 4 | from sgCLIP.model import sgCLIP, convert_weights_to_fp16 5 | 6 | def create_model( 7 | args, 8 | graph_vocab: dict, 9 | model_config_json: str, 10 | precision: str = 'fp32', 11 | device: torch.device = torch.device('cpu'), 12 | force_quick_gelu: bool = False, 13 | pretrained_image: bool = False, 14 | ): 15 | if model_config_json != '': 16 | with open(model_config_json, 'r') as f: 17 | model_cfg = json.load(f) 18 | else: 19 | model_cfg = { 20 | "graph_cfg": { 21 | "layers": args.num_graph_layer, 22 | "width": args.graph_width, 23 | }, 24 | "embed_dim": args.embed_dim, 25 | } 26 | 27 | if force_quick_gelu: 28 | model_cfg["quick_gelu"] = True 29 | 30 | if pretrained_image: 31 | if 'timm_model_name' in model_cfg.get('vision_cfg', {}): 32 | model_cfg['vision_cfg']['timm_model_pretrained'] = True 33 | else: 34 | assert False, 'pretrained image towers currently only supported for timm models' 35 | 36 | model = sgCLIP(graph_vocab=graph_vocab, **model_cfg) 37 | 38 | model.to(device=device) 39 | if precision == "fp16": 40 | assert device.type != 'cpu' 41 | convert_weights_to_fp16(model) 42 | 43 | return model 44 | 45 | def create_model_and_transforms( 46 | args, 47 | graph_vocab: dict, 48 | model_config_json: str, 49 | precision: str = 'fp32', 50 | device: torch.device = torch.device('cpu'), 51 | force_quick_gelu: bool = False, 52 | pretrained_image: bool = False, 53 | ): 54 | model = create_model(args, graph_vocab, model_config_json, precision, device, force_quick_gelu=force_quick_gelu, pretrained_image=pretrained_image) 55 | 56 | return model -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/generative_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed.nn 4 | from sgCLIP.mm_transformer_module import BasicTransformerBlock 5 | from utils import boxes_to_mask 6 | has_distributed = True 7 | 8 | class ReconstractMaskedImageFromSceneGraphLoss(nn.Module): 9 | def __init__(self, triple_dim, image_dim, num_img_patches=50, num_triple=15, sg_only=False): 10 | super().__init__() 11 | 12 | self.image_dim = image_dim 13 | 14 | if sg_only: 15 | self.register_buffer('attn_mask', self.build_attention_mask(tri_length=num_triple, img_length=num_img_patches), persistent=False) 16 | else: 17 | self.attn_mask = None 18 | 19 | self.transformer = BasicTransformerBlock(dim=image_dim, n_heads=8, d_head=64, dropout=0., context_dim=triple_dim) 20 | 21 | self.criterion = nn.MSELoss() 22 | 23 | def forward(self, local_graph_fea, local_masked_image_fea, local_gt_image_fea): 24 | local_masked_image_fea = local_masked_image_fea.permute(1, 0, 2).contiguous() 25 | local_gt_image_fea = local_gt_image_fea.permute(1, 0, 2).contiguous() 26 | 27 | local_reconstructed_img_fea = self.transformer(local_masked_image_fea, context=local_graph_fea) 28 | 29 | rec_loss = self.criterion(local_reconstructed_img_fea, local_gt_image_fea) 30 | return rec_loss 31 | 32 | 33 | class ReconstractMaskedSceneGraphFromImageLoss(nn.Module): 34 | def __init__(self, triple_dim, image_dim, num_img_patches=50, num_triple=15, sg_only=False): 35 | super().__init__() 36 | 37 | self.triple_dim = triple_dim 38 | 39 | if sg_only: 40 | self.register_buffer('attn_mask', self.build_attention_mask(tri_length=num_triple, img_length=num_img_patches), persistent=False) 41 | else: 42 | self.attn_mask = None 43 | 44 | self.transformer = BasicTransformerBlock(dim=triple_dim, n_heads=8, d_head=64, dropout=0., context_dim=image_dim) 45 | 46 | self.criterion = nn.MSELoss() 47 | 48 | def forward(self, local_graph_fea, local_masked_graph_fea, local_gt_image_fea): 49 | local_gt_image_fea = local_gt_image_fea.permute(1, 0, 2).contiguous() 50 | 51 | local_reconstructed_graph_fea = self.transformer(local_masked_graph_fea, context=local_gt_image_fea) 52 | 53 | rec_loss = self.criterion(local_reconstructed_graph_fea, local_graph_fea) 54 | return rec_loss -------------------------------------------------------------------------------- /ldm/modules/cgip/tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import copy 4 | 5 | def create_tensor_by_assign_samples_to_img(samples, sample_to_img, max_sample_per_img, batch_size): 6 | dtype, device = samples.dtype, samples.device 7 | N = batch_size 8 | D = samples.shape[1] 9 | assert (sample_to_img.max() + 1) == N 10 | 11 | samples_per_img = [] 12 | for i in range(N): 13 | s_idxs = (sample_to_img == i).nonzero().view(-1) 14 | sub_sample = samples[s_idxs] 15 | len_cur = sub_sample.shape[0] 16 | if len_cur > max_sample_per_img: 17 | sub_sample = sub_sample[:max_sample_per_img, :] 18 | if len_cur < max_sample_per_img: 19 | zero_vector = torch.zeros([1, D]).to(device) 20 | padding_vectors = torch.cat([copy.deepcopy(zero_vector) for _ in range(max_sample_per_img - len_cur)], dim=0) # [res, D] 21 | sub_sample = torch.cat([sub_sample, padding_vectors], dim=0) 22 | sub_sample = sub_sample.unsqueeze(0) 23 | samples_per_img.append(sub_sample) 24 | samples_per_img = torch.cat(samples_per_img, dim=0).to(device) 25 | 26 | return samples_per_img 27 | 28 | def idx_to_one_hot(idx, num_classes): 29 | result = F.one_hot(idx, num_classes) 30 | result = result.float() 31 | return result 32 | 33 | 34 | def sample_json(vocab, scene_graphs): 35 | objs, triples, obj_to_img, triple_to_img = encode_scene_graphs(vocab, scene_graphs) 36 | return objs, triples, obj_to_img, triple_to_img 37 | 38 | 39 | def encode_scene_graphs(vocab, scene_graphs): 40 | if isinstance(scene_graphs, dict): 41 | scene_graphs = [scene_graphs] 42 | 43 | objs, triples, obj_to_img = [], [], [] 44 | obj_offset = 0 45 | for i, sg in enumerate(scene_graphs): 46 | sg['objects'].append('__image__') 47 | image_idx = len(sg['objects']) - 1 48 | for j in range(image_idx): 49 | sg['relationships'].append([j, '__in_image__', image_idx]) 50 | 51 | for obj in sg['objects']: 52 | obj_idx = vocab['object_name_to_idx'].get(obj, None) 53 | if obj_idx is None: 54 | raise ValueError('Object "%s" not in vocab' % obj) 55 | objs.append(obj_idx) 56 | obj_to_img.append(i) 57 | for s, p, o in sg['relationships']: 58 | pred_idx = vocab['pred_name_to_idx'].get(p, None) 59 | if pred_idx is None: 60 | raise ValueError('Relationship "%s" not in vocab' % p) 61 | triples.append([s + obj_offset, pred_idx, o + obj_offset]) 62 | obj_offset += len(sg['objects']) 63 | objs = torch.tensor(objs, dtype=torch.int64) 64 | triples = torch.tensor(triples, dtype=torch.int64) 65 | obj_to_img = torch.tensor(obj_to_img, dtype=torch.int64) 66 | 67 | T = triples.shape[0] 68 | triple_to_img = torch.zeros([T, ], dtype=torch.int64) 69 | return objs, triples, obj_to_img, triple_to_img -------------------------------------------------------------------------------- /config_vg.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | cond_stage_forward: encode_graph_local_global 17 | monitor: val/loss_simple_ema 18 | unet_config: 19 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 20 | params: 21 | image_size: 32 22 | in_channels: 4 23 | out_channels: 4 24 | model_channels: 256 25 | attention_resolutions: 26 | - 4 27 | - 2 28 | - 1 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 4 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_local_dim: 1536 38 | context_dim: 512 39 | first_stage_config: 40 | target: ldm.models.autoencoder.VQModelInterface 41 | params: 42 | embed_dim: 4 43 | n_embed: 16384 44 | ckpt_path: pretrained/vq-f8-model.ckpt 45 | ddconfig: 46 | double_z: false 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: 59 | - 32 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | cond_stage_config: 64 | target: ldm.modules.cgip.cgip.CGIPModel 65 | params: 66 | num_objs: 179 67 | num_preds: 46 68 | layers: 5 69 | width: 512 70 | embed_dim: 512 71 | ckpt_path: pretrained/sip_vg.pt 72 | data: 73 | target: scripts.datamodule.DataModuleFromConfig 74 | params: 75 | batch_size: 16 76 | num_workers: 4 77 | wrap: false 78 | train: 79 | target: ldm.data.vg.VGTrain 80 | params: 81 | vocab: ./datasets/vg/vocab.json 82 | h5_path: ./datasets/vg/train.h5 83 | image_dir: ./datasets/vg/images 84 | image_size: 256 85 | max_objects: 10 86 | validation: 87 | target: ldm.data.vg.VGTrain 88 | params: 89 | vocab: ./datasets/vg/vocab.json 90 | h5_path: ./datasets/vg/val.h5 91 | image_dir: ./datasets/vg/images 92 | image_size: 256 93 | max_objects: 10 94 | 95 | lightning: 96 | callbacks: 97 | image_logger: 98 | target: trainer.ImageLogger 99 | params: 100 | batch_frequency: 10000 101 | max_images: 16 102 | increase_log_steps: False 103 | trainer: 104 | benchmark: True -------------------------------------------------------------------------------- /sg_image_pretraining/training/distributed.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from global_var import * 4 | 5 | def is_global_master(args): 6 | return args.rank == 0 7 | 8 | 9 | def is_local_master(args): 10 | return args.local_rank == 0 11 | 12 | 13 | def is_master(args, local=False): 14 | return is_local_master(args) if local else is_global_master(args) 15 | 16 | 17 | def is_using_distributed(): 18 | if 'WORLD_SIZE' in os.environ: 19 | return int(os.environ['WORLD_SIZE']) > 1 20 | if 'SLURM_NTASKS' in os.environ: 21 | return int(os.environ['SLURM_NTASKS']) > 1 22 | return False 23 | 24 | 25 | def world_info_from_env(): 26 | local_rank = 0 27 | for v in ('LOCAL_RANK', 'MPI_LOCALRANKID', 'SLURM_LOCALID', 'OMPI_COMM_WORLD_LOCAL_RANK'): 28 | if v in os.environ: 29 | local_rank = int(os.environ[v]) 30 | break 31 | global_rank = 0 32 | for v in ('RANK', 'PMI_RANK', 'SLURM_PROCID', 'OMPI_COMM_WORLD_RANK'): 33 | if v in os.environ: 34 | global_rank = int(os.environ[v]) 35 | break 36 | world_size = 1 37 | for v in ('WORLD_SIZE', 'PMI_SIZE', 'SLURM_NTASKS', 'OMPI_COMM_WORLD_SIZE'): 38 | if v in os.environ: 39 | world_size = int(os.environ[v]) 40 | break 41 | 42 | return local_rank, global_rank, world_size 43 | 44 | 45 | def init_distributed_device(args): 46 | # Distributed training = training on more than one GPU. 47 | # Works in both single and multi-node scenarios. 48 | args.distributed = False 49 | args.world_size = 1 50 | args.rank = 0 # global rank 51 | args.local_rank = 0 52 | if is_using_distributed(): 53 | if 'SLURM_PROCID' in os.environ: 54 | # DDP via SLURM 55 | args.local_rank, args.rank, args.world_size = world_info_from_env() 56 | # SLURM var -> torch.distributed vars in case needed 57 | os.environ['LOCAL_RANK'] = str(args.local_rank) 58 | os.environ['RANK'] = str(args.rank) 59 | os.environ['WORLD_SIZE'] = str(args.world_size) 60 | torch.distributed.init_process_group( 61 | backend=args.dist_backend, 62 | init_method=args.dist_url, 63 | world_size=args.world_size, 64 | rank=args.rank, 65 | ) 66 | else: 67 | # DDP via torchrun, torch.distributed.launch 68 | args.local_rank, _, _ = world_info_from_env() 69 | torch.distributed.init_process_group( 70 | backend=args.dist_backend, 71 | init_method=args.dist_url) 72 | args.world_size = torch.distributed.get_world_size() 73 | args.rank = torch.distributed.get_rank() 74 | args.distributed = True 75 | 76 | if torch.cuda.is_available(): 77 | if args.distributed and not args.no_set_device_rank: 78 | device = 'cuda:%d' % args.local_rank 79 | else: 80 | device = 'cuda:0' 81 | torch.cuda.set_device(device) 82 | else: 83 | device = 'cpu' 84 | args.device = device 85 | device = torch.device(device) 86 | return device 87 | -------------------------------------------------------------------------------- /scripts/datamodule.py: -------------------------------------------------------------------------------- 1 | import pytorch_lightning as pl 2 | from torch.utils.data import DataLoader, Dataset 3 | from functools import partial 4 | from ldm.data.vg import vg_collate_fn 5 | from ldm.util import instantiate_from_config 6 | 7 | 8 | class WrappedDataset(Dataset): 9 | """Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset""" 10 | 11 | def __init__(self, dataset): 12 | self.data = dataset 13 | 14 | def __len__(self): 15 | return len(self.data) 16 | 17 | def __getitem__(self, idx): 18 | return self.data[idx] 19 | 20 | class DataModuleFromConfig(pl.LightningDataModule): 21 | def __init__(self, batch_size, train=None, validation=None, test=None, predict=None, 22 | wrap=False, num_workers=None, shuffle_test_loader=False, use_worker_init_fn=False, 23 | shuffle_val_dataloader=False): 24 | super().__init__() 25 | self.batch_size = batch_size 26 | self.dataset_configs = dict() 27 | self.num_workers = num_workers if num_workers is not None else batch_size * 2 28 | if train is not None: 29 | self.dataset_configs["train"] = train 30 | self.train_dataloader = self._train_dataloader 31 | if validation is not None: 32 | self.dataset_configs["validation"] = validation 33 | self.val_dataloader = partial(self._val_dataloader, shuffle=shuffle_val_dataloader) 34 | if test is not None: 35 | self.dataset_configs["test"] = test 36 | self.test_dataloader = partial(self._test_dataloader, shuffle=shuffle_test_loader) 37 | if predict is not None: 38 | self.dataset_configs["predict"] = predict 39 | self.predict_dataloader = self._predict_dataloader 40 | self.wrap = wrap 41 | 42 | def prepare_data(self): 43 | for data_cfg in self.dataset_configs.values(): 44 | instantiate_from_config(data_cfg) 45 | 46 | def setup(self, stage=None): 47 | self.datasets = dict( 48 | (k, instantiate_from_config(self.dataset_configs[k])) 49 | for k in self.dataset_configs) 50 | if self.wrap: 51 | for k in self.datasets: 52 | self.datasets[k] = WrappedDataset(self.datasets[k]) 53 | 54 | def _train_dataloader(self): 55 | return DataLoader(self.datasets["train"], batch_size=self.batch_size, 56 | num_workers=self.num_workers, shuffle=True, collate_fn=vg_collate_fn) 57 | 58 | def _val_dataloader(self, shuffle=False): 59 | return DataLoader(self.datasets["validation"], batch_size=self.batch_size, 60 | num_workers=self.num_workers, shuffle=shuffle, collate_fn=vg_collate_fn) 61 | 62 | def _test_dataloader(self, shuffle=False): 63 | return DataLoader(self.datasets["test"], batch_size=self.batch_size, 64 | num_workers=self.num_workers, shuffle=shuffle, collate_fn=vg_collate_fn) 65 | 66 | def _predict_dataloader(self, shuffle=False): 67 | return DataLoader(self.datasets["predict"], batch_size=self.batch_size, 68 | num_workers=self.num_workers, collate_fn=vg_collate_fn) 69 | -------------------------------------------------------------------------------- /config_coco.yaml: -------------------------------------------------------------------------------- 1 | model: 2 | base_learning_rate: 1.0e-06 3 | target: ldm.models.diffusion.ddpm.LatentDiffusion 4 | params: 5 | linear_start: 0.0015 6 | linear_end: 0.0195 7 | num_timesteps_cond: 1 8 | log_every_t: 200 9 | timesteps: 1000 10 | first_stage_key: image 11 | cond_stage_key: class_label 12 | image_size: 32 13 | channels: 4 14 | cond_stage_trainable: false 15 | conditioning_key: crossattn 16 | cond_stage_forward: encode_graph_local_global 17 | monitor: val/loss_simple_ema 18 | unet_config: 19 | target: ldm.modules.diffusionmodules.openaimodel.UNetModel 20 | params: 21 | image_size: 32 22 | in_channels: 4 23 | out_channels: 4 24 | model_channels: 256 25 | attention_resolutions: 26 | - 4 27 | - 2 28 | - 1 29 | num_res_blocks: 2 30 | channel_mult: 31 | - 1 32 | - 2 33 | - 4 34 | num_head_channels: 32 35 | use_spatial_transformer: true 36 | transformer_depth: 1 37 | context_local_dim: 1536 38 | context_dim: 512 39 | first_stage_config: 40 | target: ldm.models.autoencoder.VQModelInterface 41 | params: 42 | embed_dim: 4 43 | n_embed: 16384 44 | ckpt_path: pretrained/vq-f8-model.ckpt 45 | ddconfig: 46 | double_z: false 47 | z_channels: 4 48 | resolution: 256 49 | in_channels: 3 50 | out_ch: 3 51 | ch: 128 52 | ch_mult: 53 | - 1 54 | - 2 55 | - 2 56 | - 4 57 | num_res_blocks: 2 58 | attn_resolutions: 59 | - 32 60 | dropout: 0.0 61 | lossconfig: 62 | target: torch.nn.Identity 63 | cond_stage_config: 64 | target: ldm.modules.cgip.cgip.CGIPModel 65 | params: 66 | num_objs: 184 67 | num_preds: 7 68 | layers: 5 69 | width: 512 70 | embed_dim: 512 71 | ckpt_path: pretrained/sip_coco.pt 72 | data: 73 | target: scripts.datamodule.DataModuleFromConfig 74 | params: 75 | batch_size: 16 76 | num_workers: 4 77 | wrap: false 78 | train: 79 | target: ldm.data.coco.COCOTrain 80 | params: 81 | image_dir: ./datasets/coco/images/train2017 82 | instances_json: ./datasets/coco/annotations/instances_train2017.json 83 | stuff_json: ./datasets/coco/annotations/stuff_train2017.json 84 | stuff_only: True 85 | image_size: 256 86 | validation: 87 | target: ldm.data.coco.COCOValidation 88 | params: 89 | image_dir: ./datasets/coco/images/val2017 90 | instances_json: ./datasets/coco/annotations/instances_val2017.json 91 | stuff_json: ./datasets/coco/annotations/stuff_val2017.json 92 | stuff_only: True 93 | image_size: 256 94 | 95 | lightning: 96 | callbacks: 97 | image_logger: 98 | target: trainer.ImageLogger 99 | params: 100 | batch_frequency: 10000 101 | max_images: 16 102 | increase_log_steps: False 103 | trainer: 104 | benchmark: True -------------------------------------------------------------------------------- /ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Diffusion-Based Scene Graph to Image Generation with Masked Contrastive Pre-Training 2 | 3 | 4 | Official Implementation for [Diffusion-Based Scene Graph to Image Generation with Masked Contrastive Pre-Training](https://arxiv.org/abs/2211.11138). 5 | 6 | 🚩 New Updates : We release [LAION-SG](https://arxiv.org/abs/2412.08580), [a large-scale dataset](https://huggingface.co/datasets/mengcy/LAION-SG) with high-quality structural annotations of scene graphs (SG), which precisely describe attributes and relationships of multiple objects, effectively representing the semantic structure in complex scenes. Based on LAION-SG, we also provide a new foundation model [SDXL-SG](https://drive.google.com/file/d/1mdC3Np4KkV9V24K1gcyddsG5AIv5S0MT/view?usp=sharing) to incorporate structural annotation information into the generation process. 7 | ## Overview of The Proposed SGDiff 8 | 9 |
image
10 | 11 | 12 | 13 | 14 | ## Environment 15 | ``` 16 | git clone https://github.com/YangLing0818/SGDiff.git 17 | cd SGDiff 18 | 19 | conda env create -f sgdiff.yaml 20 | conda activate sgdiff 21 | mkdir pretrained 22 | ``` 23 | 24 | 25 | ## Data and Model Preparation 26 | 27 | The instructions of data pre-processing can be [found here](https://github.com/YangLing0818/SGDiff/blob/main/DATA.md). 28 | 29 | Our masked contrastive pre-trained models of SG-image pairs for COCO and VG datasets are provided in [here](https://www.dropbox.com/scl/fo/lccvtxuwxxblo3atnxlmg/h?rlkey=duy7dcwmy3a64auqoqiw8dv2e&dl=0), please download them and put them in the 'pretrained' directory. 30 | 31 | And the pretrained VQVAE for embedding image to latent can be obtained from https://ommer-lab.com/files/latent-diffusion/vq-f8.zip 32 | 33 | ## Masked Contrastive Pre-Training 34 | 35 | The instructions of SG-image pretraining can be found in the folder "sg_image_pretraining/" 36 | 37 | ## Diffusion Training 38 | Kindly note that one **should not skip the training stage** and test directly. For single gpu, one can use 39 | ```shell 40 | python trainer.py --base CONFIG_PATH -t --gpus 0, 41 | ``` 42 | 43 | ***NOT OFFICIAL:*** 44 | Alternatively, if you don't want to train the model from scratch you can download trained weights from the following link: 45 | [VG weight](https://drive.google.com/file/d/1bzYgv_lmCUL7wrh9G3t3169ITbAuMbYo/view?usp=sharing), [COCO weight](https://drive.google.com/file/d/1HAj2C3xHTrm-txVCq_cSSbr5NvFPnasR/view?usp=sharing) 46 | 47 | Checkpoint trained for only 150 epochs. 48 | 49 | ## Sampling 50 | 51 | ```shell 52 | python testset_ddim_sampler.py 53 | ``` 54 | 55 | ## Citation 56 | If you found the codes are useful, please cite our paper 57 | ``` 58 | @article{yang2022diffusionsg, 59 | title={Diffusion-based scene graph to image generation with masked contrastive pre-training}, 60 | author={Yang, Ling and Huang, Zhilin and Song, Yang and Hong, Shenda and Li, Guohao and Zhang, Wentao and Cui, Bin and Ghanem, Bernard and Yang, Ming-Hsuan}, 61 | journal={arXiv preprint arXiv:2211.11138}, 62 | year={2022} 63 | } 64 | 65 | @article{li2024laion, 66 | title={LAION-SG: An Enhanced Large-Scale Dataset for Training Complex Image-Text Models with Structural Annotations}, 67 | author={Li, Zejian and Meng, Chenye and Li, Yize and Yang, Ling and Zhang, Shengyuan and Ma, Jiarui and Li, Jiayi and Yang, Guang and Yang, Changyuan and Yang, Zhiyuan and others}, 68 | journal={arXiv preprint arXiv:2412.08580}, 69 | year={2024} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /sg2im/data/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import PIL 18 | import torch 19 | import torchvision.transforms as T 20 | 21 | 22 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 23 | IMAGENET_STD = [0.229, 0.224, 0.225] 24 | 25 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 26 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 27 | 28 | 29 | def imagenet_preprocess(): 30 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 31 | 32 | 33 | def rescale(x): 34 | lo, hi = x.min(), x.max() 35 | return x.sub(lo).div(hi - lo) 36 | 37 | 38 | def imagenet_deprocess(rescale_image=True): 39 | transforms = [ 40 | T.Normalize(mean=[0, 0, 0], std=INV_IMAGENET_STD), 41 | T.Normalize(mean=INV_IMAGENET_MEAN, std=[1.0, 1.0, 1.0]), 42 | ] 43 | if rescale_image: 44 | transforms.append(rescale) 45 | return T.Compose(transforms) 46 | 47 | 48 | def imagenet_deprocess_batch(imgs, rescale=True): 49 | """ 50 | Input: 51 | - imgs: FloatTensor of shape (N, C, H, W) giving preprocessed images 52 | 53 | Output: 54 | - imgs_de: ByteTensor of shape (N, C, H, W) giving deprocessed images 55 | in the range [0, 255] 56 | """ 57 | if isinstance(imgs, torch.autograd.Variable): 58 | imgs = imgs.data 59 | imgs = imgs.cpu().clone() 60 | deprocess_fn = imagenet_deprocess(rescale_image=rescale) 61 | imgs_de = [] 62 | for i in range(imgs.size(0)): 63 | img_de = deprocess_fn(imgs[i])[None] 64 | img_de = img_de.mul(255).clamp(0, 255).byte() 65 | imgs_de.append(img_de) 66 | imgs_de = torch.cat(imgs_de, dim=0) 67 | return imgs_de 68 | 69 | 70 | class Resize(object): 71 | def __init__(self, size, interp=PIL.Image.BILINEAR): 72 | if isinstance(size, tuple): 73 | H, W = size 74 | self.size = (W, H) 75 | else: 76 | self.size = (size, size) 77 | self.interp = interp 78 | 79 | def __call__(self, img): 80 | return img.resize(self.size, self.interp) 81 | 82 | 83 | def unpack_var(v): 84 | if isinstance(v, torch.autograd.Variable): 85 | return v.data 86 | return v 87 | 88 | 89 | def split_graph_batch(triples, obj_data, obj_to_img, triple_to_img): 90 | triples = unpack_var(triples) 91 | obj_data = [unpack_var(o) for o in obj_data] 92 | obj_to_img = unpack_var(obj_to_img) 93 | triple_to_img = unpack_var(triple_to_img) 94 | 95 | triples_out = [] 96 | obj_data_out = [[] for _ in obj_data] 97 | obj_offset = 0 98 | N = obj_to_img.max() + 1 99 | for i in range(N): 100 | o_idxs = (obj_to_img == i).nonzero().view(-1) 101 | t_idxs = (triple_to_img == i).nonzero().view(-1) 102 | 103 | cur_triples = triples[t_idxs].clone() 104 | cur_triples[:, 0] -= obj_offset 105 | cur_triples[:, 2] -= obj_offset 106 | triples_out.append(cur_triples) 107 | 108 | for j, o_data in enumerate(obj_data): 109 | cur_o_data = None 110 | if o_data is not None: 111 | cur_o_data = o_data[o_idxs] 112 | obj_data_out[j].append(cur_o_data) 113 | 114 | obj_offset += o_idxs.size(0) 115 | 116 | return triples_out, obj_data_out 117 | 118 | -------------------------------------------------------------------------------- /sg_image_pretraining/datasets/transform.py: -------------------------------------------------------------------------------- 1 | from typing import Optional, Sequence, Tuple 2 | import torch 3 | import torch.nn as nn 4 | import torchvision.transforms.functional as F 5 | from torchvision.transforms import Normalize, Compose, RandomResizedCrop, InterpolationMode, ToTensor, Resize, CenterCrop 6 | from sgCLIP.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD 7 | import PIL 8 | from global_var import * 9 | 10 | class VGResize(object): 11 | def __init__(self, size, interp=PIL.Image.BILINEAR): 12 | if isinstance(size, tuple): 13 | H, W = size 14 | self.size = (W, H) 15 | else: 16 | self.size = (size, size) 17 | self.interp = interp 18 | 19 | def __call__(self, img): 20 | return img.resize(self.size, self.interp) 21 | 22 | class COCOResize(object): 23 | def __init__(self, size, interp=PIL.Image.BILINEAR): 24 | if isinstance(size, tuple): 25 | H, W = size 26 | self.size = (W, H) 27 | else: 28 | self.size = (size, size) 29 | self.interp = interp 30 | 31 | def __call__(self, img): 32 | return img.resize(self.size, self.interp) 33 | 34 | class ResizeMaxSize(nn.Module): 35 | 36 | def __init__(self, max_size, interpolation=InterpolationMode.BICUBIC, fn='max', fill=0): 37 | super().__init__() 38 | if not isinstance(max_size, int): 39 | raise TypeError(f"Size should be int. Got {type(max_size)}") 40 | self.max_size = max_size 41 | self.interpolation = interpolation 42 | self.fn = min if fn == 'min' else min 43 | self.fill = fill 44 | 45 | def forward(self, img): 46 | if isinstance(img, torch.Tensor): 47 | height, width = img.shape[:2] 48 | else: 49 | width, height = img.size 50 | scale = self.max_size / float(max(height, width)) 51 | if scale != 1.0: 52 | new_size = tuple(round(dim * scale) for dim in (height, width)) 53 | img = F.resize(img, new_size, self.interpolation) 54 | pad_h = self.max_size - new_size[0] 55 | pad_w = self.max_size - new_size[1] 56 | img = F.pad(img, padding=[pad_w//2, pad_h//2, pad_w - pad_w//2, pad_h - pad_h//2], fill=self.fill) 57 | return img 58 | 59 | 60 | def _convert_to_rgb(image): 61 | return image.convert('RGB') 62 | 63 | 64 | def image_transform( 65 | image_size: int, 66 | is_train: bool, 67 | mean: Optional[Tuple[float, ...]] = None, 68 | std: Optional[Tuple[float, ...]] = None, 69 | resize_longest_max: bool = False, 70 | fill_color: int = 0, 71 | ): 72 | mean = mean or OPENAI_DATASET_MEAN 73 | if not isinstance(mean, (list, tuple)): 74 | mean = (mean,) * 3 75 | 76 | std = std or OPENAI_DATASET_STD 77 | if not isinstance(std, (list, tuple)): 78 | std = (std,) * 3 79 | 80 | if isinstance(image_size, (list, tuple)) and image_size[0] == image_size[1]: 81 | image_size = image_size[0] 82 | 83 | normalize = Normalize(mean=mean, std=std) 84 | if is_train: 85 | return Compose([ 86 | RandomResizedCrop(image_size, scale=(0.9, 1.0), interpolation=InterpolationMode.BICUBIC), 87 | _convert_to_rgb, 88 | ToTensor(), 89 | normalize, 90 | ]) 91 | else: 92 | if resize_longest_max: 93 | transforms = [ 94 | ResizeMaxSize(image_size, fill=fill_color) 95 | ] 96 | else: 97 | transforms = [ 98 | Resize(image_size, interpolation=InterpolationMode.BICUBIC), 99 | CenterCrop(image_size), 100 | ] 101 | transforms.extend([ 102 | _convert_to_rgb, 103 | ToTensor(), 104 | normalize, 105 | ]) 106 | return Compose(transforms) 107 | -------------------------------------------------------------------------------- /ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /sg_image_pretraining/clip/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import html 3 | import os 4 | from functools import lru_cache 5 | 6 | import ftfy 7 | import regex as re 8 | 9 | 10 | @lru_cache() 11 | def default_bpe(): 12 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 13 | 14 | 15 | @lru_cache() 16 | def bytes_to_unicode(): 17 | """ 18 | Returns list of utf-8 byte and a corresponding list of unicode strings. 19 | The reversible bpe codes work on unicode strings. 20 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 21 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 22 | This is a signficant percentage of your normal, say, 32K bpe vocab. 23 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 24 | And avoids mapping to whitespace/control characters the bpe code barfs on. 25 | """ 26 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 27 | cs = bs[:] 28 | n = 0 29 | for b in range(2**8): 30 | if b not in bs: 31 | bs.append(b) 32 | cs.append(2**8+n) 33 | n += 1 34 | cs = [chr(n) for n in cs] 35 | return dict(zip(bs, cs)) 36 | 37 | 38 | def get_pairs(word): 39 | """Return set of symbol pairs in a word. 40 | Word is represented as tuple of symbols (symbols being variable-length strings). 41 | """ 42 | pairs = set() 43 | prev_char = word[0] 44 | for char in word[1:]: 45 | pairs.add((prev_char, char)) 46 | prev_char = char 47 | return pairs 48 | 49 | 50 | def basic_clean(text): 51 | text = ftfy.fix_text(text) 52 | text = html.unescape(html.unescape(text)) 53 | return text.strip() 54 | 55 | 56 | def whitespace_clean(text): 57 | text = re.sub(r'\s+', ' ', text) 58 | text = text.strip() 59 | return text 60 | 61 | 62 | class SimpleTokenizer(object): 63 | def __init__(self, bpe_path: str = default_bpe()): 64 | self.byte_encoder = bytes_to_unicode() 65 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 66 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 67 | merges = merges[1:49152-256-2+1] 68 | merges = [tuple(merge.split()) for merge in merges] 69 | vocab = list(bytes_to_unicode().values()) 70 | vocab = vocab + [v+'' for v in vocab] 71 | for merge in merges: 72 | vocab.append(''.join(merge)) 73 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 74 | self.encoder = dict(zip(vocab, range(len(vocab)))) 75 | self.decoder = {v: k for k, v in self.encoder.items()} 76 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 77 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 78 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 79 | 80 | def bpe(self, token): 81 | if token in self.cache: 82 | return self.cache[token] 83 | word = tuple(token[:-1]) + ( token[-1] + '',) 84 | pairs = get_pairs(word) 85 | 86 | if not pairs: 87 | return token+'' 88 | 89 | while True: 90 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 91 | if bigram not in self.bpe_ranks: 92 | break 93 | first, second = bigram 94 | new_word = [] 95 | i = 0 96 | while i < len(word): 97 | try: 98 | j = word.index(first, i) 99 | new_word.extend(word[i:j]) 100 | i = j 101 | except: 102 | new_word.extend(word[i:]) 103 | break 104 | 105 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 106 | new_word.append(first+second) 107 | i += 2 108 | else: 109 | new_word.append(word[i]) 110 | i += 1 111 | new_word = tuple(new_word) 112 | word = new_word 113 | if len(word) == 1: 114 | break 115 | else: 116 | pairs = get_pairs(word) 117 | word = ' '.join(word) 118 | self.cache[token] = word 119 | return word 120 | 121 | def encode(self, text): 122 | bpe_tokens = [] 123 | text = whitespace_clean(basic_clean(text)).lower() 124 | for token in re.findall(self.pat, text): 125 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 126 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 127 | return bpe_tokens 128 | 129 | def decode(self, tokens): 130 | text = ''.join([self.decoder[token] for token in tokens]) 131 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 132 | return text 133 | -------------------------------------------------------------------------------- /ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /sgdiff.yaml: -------------------------------------------------------------------------------- 1 | name: sgdiff 2 | channels: 3 | - pytorch 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge 6 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/ 7 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 8 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=conda_forge 12 | - _openmp_mutex=4.5=2_gnu 13 | - aom=3.5.0=h27087fc_0 14 | - blas=1.0=mkl 15 | - brotlipy=0.7.0=py37h540881e_1004 16 | - bzip2=1.0.8=h7f98852_4 17 | - ca-certificates=2022.9.24=ha878542_0 18 | - certifi=2022.9.24=pyhd8ed1ab_0 19 | - cffi=1.15.1=py37h43b0acd_1 20 | - charset-normalizer=2.1.1=pyhd8ed1ab_0 21 | - cryptography=38.0.2=py37h5994e8b_1 22 | - cudatoolkit=11.3.1=h2bc3f7f_2 23 | - expat=2.5.0=h27087fc_0 24 | - ffmpeg=5.1.2=gpl_hc51e5dc_103 25 | - font-ttf-dejavu-sans-mono=2.37=hab24e00_0 26 | - font-ttf-inconsolata=3.000=h77eed37_0 27 | - font-ttf-source-code-pro=2.038=h77eed37_0 28 | - font-ttf-ubuntu=0.83=hab24e00_0 29 | - fontconfig=2.14.1=hc2a2eb6_0 30 | - fonts-conda-ecosystem=1=0 31 | - fonts-conda-forge=1=0 32 | - freetype=2.12.1=hca18f0e_0 33 | - gettext=0.21.1=h27087fc_0 34 | - gmp=6.2.1=h58526e2_0 35 | - gnutls=3.7.8=hf3e180e_0 36 | - icu=70.1=h27087fc_0 37 | - idna=3.4=pyhd8ed1ab_0 38 | - intel-openmp=2021.4.0=h06a4308_3561 39 | - jpeg=9e=h166bdaf_2 40 | - lame=3.100=h166bdaf_1003 41 | - lcms2=2.12=hddcbb42_0 42 | - ld_impl_linux-64=2.39=hc81fddc_0 43 | - lerc=4.0.0=h27087fc_0 44 | - libdeflate=1.14=h166bdaf_0 45 | - libdrm=2.4.113=h166bdaf_0 46 | - libffi=3.4.2=h7f98852_5 47 | - libgcc-ng=12.2.0=h65d4601_19 48 | - libgomp=12.2.0=h65d4601_19 49 | - libiconv=1.17=h166bdaf_0 50 | - libidn2=2.3.4=h166bdaf_0 51 | - libnsl=2.0.0=h7f98852_0 52 | - libpciaccess=0.16=h516909a_0 53 | - libpng=1.6.38=h753d276_0 54 | - libsqlite=3.39.4=h753d276_0 55 | - libstdcxx-ng=12.2.0=h46fd767_19 56 | - libtasn1=4.19.0=h166bdaf_0 57 | - libtiff=4.4.0=h55922b4_4 58 | - libunistring=0.9.10=h7f98852_0 59 | - libuuid=2.32.1=h7f98852_1000 60 | - libva=2.16.0=h166bdaf_0 61 | - libvpx=1.11.0=h9c3ff4c_3 62 | - libwebp-base=1.2.4=h166bdaf_0 63 | - libxcb=1.13=h7f98852_1004 64 | - libxml2=2.10.3=h7463322_0 65 | - libzlib=1.2.13=h166bdaf_4 66 | - mkl=2021.4.0=h06a4308_640 67 | - mkl-service=2.4.0=py37h402132d_0 68 | - mkl_fft=1.3.1=py37h3e078e5_1 69 | - mkl_random=1.2.2=py37h219a48f_0 70 | - ncurses=6.3=h27087fc_1 71 | - nettle=3.8.1=hc379101_1 72 | - numpy=1.21.5=py37h6c91a56_3 73 | - numpy-base=1.21.5=py37ha15fc14_3 74 | - openh264=2.3.1=h27087fc_1 75 | - openjpeg=2.5.0=h7d73246_1 76 | - openssl=3.0.7=h166bdaf_0 77 | - p11-kit=0.24.1=hc5aa10d_0 78 | - pillow=9.2.0=py37h850a105_2 79 | - pip=22.3=pyhd8ed1ab_0 80 | - pthread-stubs=0.4=h36c2ea0_1001 81 | - pycparser=2.21=pyhd8ed1ab_0 82 | - pyopenssl=22.1.0=pyhd8ed1ab_0 83 | - pysocks=1.7.1=py37h89c1867_5 84 | - python=3.7.12=hf930737_100_cpython 85 | - python_abi=3.7=2_cp37m 86 | - pytorch=1.12.1=py3.7_cuda11.3_cudnn8.3.2_0 87 | - pytorch-mutex=1.0=cuda 88 | - readline=8.1.2=h0f457ee_0 89 | - requests=2.28.1=pyhd8ed1ab_1 90 | - setuptools=65.5.0=pyhd8ed1ab_0 91 | - six=1.16.0=pyh6c4a22f_0 92 | - sqlite=3.39.4=h4ff8645_0 93 | - svt-av1=1.3.0=h27087fc_0 94 | - tk=8.6.12=h27826a3_0 95 | - torchaudio=0.12.1=py37_cu113 96 | - torchvision=0.13.1=py37_cu113 97 | - typing_extensions=4.4.0=pyha770c72_0 98 | - urllib3=1.26.11=pyhd8ed1ab_0 99 | - wheel=0.37.1=pyhd8ed1ab_0 100 | - x264=1!164.3095=h166bdaf_2 101 | - x265=3.5=h924138e_3 102 | - xorg-fixesproto=5.0=h7f98852_1002 103 | - xorg-kbproto=1.0.7=h7f98852_1002 104 | - xorg-libx11=1.7.2=h7f98852_0 105 | - xorg-libxau=1.0.9=h7f98852_0 106 | - xorg-libxdmcp=1.1.3=h7f98852_0 107 | - xorg-libxext=1.3.4=h7f98852_1 108 | - xorg-libxfixes=5.0.3=h7f98852_1004 109 | - xorg-xextproto=7.3.0=h7f98852_1002 110 | - xorg-xproto=7.0.31=h7f98852_1007 111 | - xz=5.2.6=h166bdaf_0 112 | - zstd=1.5.2=h6239696_4 113 | - pip: 114 | - absl-py==1.3.0 115 | - aiohttp==3.8.3 116 | - aiosignal==1.2.0 117 | - antlr4-python3-runtime==4.9.3 118 | - async-timeout==4.0.2 119 | - asynctest==0.13.0 120 | - attrs==22.1.0 121 | - blobfile==2.0.0 122 | - cachetools==5.2.0 123 | - cycler==0.11.0 124 | - einops==0.5.0 125 | - filelock==3.8.0 126 | - fonttools==4.38.0 127 | - frozenlist==1.3.1 128 | - fsspec==2022.8.2 129 | - ftfy==6.1.1 130 | - future==0.18.2 131 | - google-auth==2.13.0 132 | - google-auth-oauthlib==0.4.6 133 | - grpcio==1.49.1 134 | - h5py==3.7.0 135 | - huggingface-hub==0.10.1 136 | - imageio==2.22.2 137 | - importlib-metadata==5.0.0 138 | - kiwisolver==1.4.4 139 | - lxml==4.9.1 140 | - markdown==3.4.1 141 | - markupsafe==2.1.1 142 | - matplotlib==3.5.3 143 | - multidict==6.0.2 144 | - networkx==2.6.3 145 | - oauthlib==3.2.1 146 | - omegaconf==2.2.3 147 | - packaging==21.3 148 | - pandas==1.3.5 149 | - protobuf==3.19.6 150 | - pyasn1==0.4.8 151 | - pyasn1-modules==0.2.8 152 | - pycocotools==2.0.5 153 | - pycryptodomex==3.15.0 154 | - pydeprecate==0.3.1 155 | - pyparsing==3.0.9 156 | - python-dateutil==2.8.2 157 | - pytorch-lightning==1.4.2 158 | - pytz==2022.5 159 | - pywavelets==1.3.0 160 | - pyyaml==6.0 161 | - regex==2022.9.13 162 | - requests-oauthlib==1.3.1 163 | - rsa==4.9 164 | - scikit-image==0.19.3 165 | - scipy==1.7.3 166 | - tensorboard==2.10.1 167 | - tensorboard-data-server==0.6.1 168 | - tensorboard-plugin-wit==1.8.1 169 | - test-tube==0.7.5 170 | - tifffile==2021.11.2 171 | - timm==0.6.11 172 | - torchmetrics==0.5.0 173 | - tqdm==4.64.1 174 | - wcwidth==0.2.5 175 | - werkzeug==2.2.2 176 | - yarl==1.8.1 177 | - zipp==3.9.0 178 | - taming-transformers==0.0.1 179 | - transformers==4.3.1 180 | -------------------------------------------------------------------------------- /scripts/train_searcher.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | import scann 4 | import argparse 5 | import glob 6 | from multiprocessing import cpu_count 7 | from tqdm import tqdm 8 | 9 | from ldm.util import parallel_data_prefetch 10 | 11 | 12 | def search_bruteforce(searcher): 13 | return searcher.score_brute_force().build() 14 | 15 | 16 | def search_partioned_ah(searcher, dims_per_block, aiq_threshold, reorder_k, 17 | partioning_trainsize, num_leaves, num_leaves_to_search): 18 | return searcher.tree(num_leaves=num_leaves, 19 | num_leaves_to_search=num_leaves_to_search, 20 | training_sample_size=partioning_trainsize). \ 21 | score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder(reorder_k).build() 22 | 23 | 24 | def search_ah(searcher, dims_per_block, aiq_threshold, reorder_k): 25 | return searcher.score_ah(dims_per_block, anisotropic_quantization_threshold=aiq_threshold).reorder( 26 | reorder_k).build() 27 | 28 | def load_datapool(dpath): 29 | 30 | 31 | def load_single_file(saved_embeddings): 32 | compressed = np.load(saved_embeddings) 33 | database = {key: compressed[key] for key in compressed.files} 34 | return database 35 | 36 | def load_multi_files(data_archive): 37 | database = {key: [] for key in data_archive[0].files} 38 | for d in tqdm(data_archive, desc=f'Loading datapool from {len(data_archive)} individual files.'): 39 | for key in d.files: 40 | database[key].append(d[key]) 41 | 42 | return database 43 | 44 | print(f'Load saved patch embedding from "{dpath}"') 45 | file_content = glob.glob(os.path.join(dpath, '*.npz')) 46 | 47 | if len(file_content) == 1: 48 | data_pool = load_single_file(file_content[0]) 49 | elif len(file_content) > 1: 50 | data = [np.load(f) for f in file_content] 51 | prefetched_data = parallel_data_prefetch(load_multi_files, data, 52 | n_proc=min(len(data), cpu_count()), target_data_type='dict') 53 | 54 | data_pool = {key: np.concatenate([od[key] for od in prefetched_data], axis=1)[0] for key in prefetched_data[0].keys()} 55 | else: 56 | raise ValueError(f'No npz-files in specified path "{dpath}" is this directory existing?') 57 | 58 | print(f'Finished loading of retrieval database of length {data_pool["embedding"].shape[0]}.') 59 | return data_pool 60 | 61 | 62 | def train_searcher(opt, 63 | metric='dot_product', 64 | partioning_trainsize=None, 65 | reorder_k=None, 66 | # todo tune 67 | aiq_thld=0.2, 68 | dims_per_block=2, 69 | num_leaves=None, 70 | num_leaves_to_search=None,): 71 | 72 | data_pool = load_datapool(opt.database) 73 | k = opt.knn 74 | 75 | if not reorder_k: 76 | reorder_k = 2 * k 77 | 78 | # normalize 79 | # embeddings = 80 | searcher = scann.scann_ops_pybind.builder(data_pool['embedding'] / np.linalg.norm(data_pool['embedding'], axis=1)[:, np.newaxis], k, metric) 81 | pool_size = data_pool['embedding'].shape[0] 82 | 83 | print(*(['#'] * 100)) 84 | print('Initializing scaNN searcher with the following values:') 85 | print(f'k: {k}') 86 | print(f'metric: {metric}') 87 | print(f'reorder_k: {reorder_k}') 88 | print(f'anisotropic_quantization_threshold: {aiq_thld}') 89 | print(f'dims_per_block: {dims_per_block}') 90 | print(*(['#'] * 100)) 91 | print('Start training searcher....') 92 | print(f'N samples in pool is {pool_size}') 93 | 94 | # this reflects the recommended design choices proposed at 95 | # https://github.com/google-research/google-research/blob/aca5f2e44e301af172590bb8e65711f0c9ee0cfd/scann/docs/algorithms.md 96 | if pool_size < 2e4: 97 | print('Using brute force search.') 98 | searcher = search_bruteforce(searcher) 99 | elif 2e4 <= pool_size and pool_size < 1e5: 100 | print('Using asymmetric hashing search and reordering.') 101 | searcher = search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 102 | else: 103 | print('Using using partioning, asymmetric hashing search and reordering.') 104 | 105 | if not partioning_trainsize: 106 | partioning_trainsize = data_pool['embedding'].shape[0] // 10 107 | if not num_leaves: 108 | num_leaves = int(np.sqrt(pool_size)) 109 | 110 | if not num_leaves_to_search: 111 | num_leaves_to_search = max(num_leaves // 20, 1) 112 | 113 | print('Partitioning params:') 114 | print(f'num_leaves: {num_leaves}') 115 | print(f'num_leaves_to_search: {num_leaves_to_search}') 116 | # self.searcher = self.search_ah(searcher, dims_per_block, aiq_thld, reorder_k) 117 | searcher = search_partioned_ah(searcher, dims_per_block, aiq_thld, reorder_k, 118 | partioning_trainsize, num_leaves, num_leaves_to_search) 119 | 120 | print('Finish training searcher') 121 | searcher_savedir = opt.target_path 122 | os.makedirs(searcher_savedir, exist_ok=True) 123 | searcher.serialize(searcher_savedir) 124 | print(f'Saved trained searcher under "{searcher_savedir}"') 125 | 126 | if __name__ == '__main__': 127 | sys.path.append(os.getcwd()) 128 | parser = argparse.ArgumentParser() 129 | parser.add_argument('--database', 130 | '-d', 131 | default='data/rdm/retrieval_databases/openimages', 132 | type=str, 133 | help='path to folder containing the clip feature of the database') 134 | parser.add_argument('--target_path', 135 | '-t', 136 | default='data/rdm/searchers/openimages', 137 | type=str, 138 | help='path to the target folder where the searcher shall be stored.') 139 | parser.add_argument('--knn', 140 | '-k', 141 | default=20, 142 | type=int, 143 | help='number of nearest neighbors, for which the searcher shall be optimized') 144 | 145 | opt, _ = parser.parse_known_args() 146 | 147 | train_searcher(opt,) -------------------------------------------------------------------------------- /ldm/data/vg.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import PIL 4 | from PIL import Image 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | import torch 8 | import random 9 | import h5py 10 | import json 11 | 12 | class VGDatabase(Dataset): 13 | def __init__(self, vocab, h5_path, image_dir, image_size=256, max_objects=10, max_samples=None, 14 | include_relationships=True, use_orphaned_objects=True): 15 | with open(vocab, 'r') as f: 16 | vocab = json.load(f) 17 | self.image_dir = image_dir 18 | self.image_size = (image_size, image_size) 19 | self.vocab = vocab 20 | self.num_objects = len(vocab['object_idx_to_name']) 21 | self.use_orphaned_objects = use_orphaned_objects 22 | self.max_objects = max_objects 23 | self.max_samples = max_samples 24 | self.include_relationships = include_relationships 25 | 26 | transform = [Resize(self.image_size), transforms.ToTensor()] # augmentation 27 | self.transform = transforms.Compose(transform) 28 | 29 | self.data = {} 30 | with h5py.File(h5_path, 'r') as f: 31 | for k, v in f.items(): 32 | if k == 'image_paths': 33 | self.image_paths = list(v) 34 | else: 35 | self.data[k] = torch.IntTensor(np.asarray(v)) 36 | 37 | def __len__(self): 38 | num = self.data['object_names'].size(0) 39 | if self.max_samples is not None: 40 | return min(self.max_samples, num) 41 | return num 42 | 43 | def __getitem__(self, index): 44 | img_path = os.path.join(self.image_dir, str(self.image_paths[index], encoding="utf-8")) 45 | 46 | with open(img_path, 'rb') as f: 47 | with PIL.Image.open(f) as image: 48 | WW, HH = image.size 49 | image = self.transform(image.convert('RGB')) 50 | 51 | image = image * 2 - 1 52 | 53 | obj_idxs_with_rels = set() 54 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item())) 55 | for r_idx in range(self.data['relationships_per_image'][index]): 56 | s = self.data['relationship_subjects'][index, r_idx].item() 57 | o = self.data['relationship_objects'][index, r_idx].item() 58 | obj_idxs_with_rels.add(s) 59 | obj_idxs_with_rels.add(o) 60 | obj_idxs_without_rels.discard(s) 61 | obj_idxs_without_rels.discard(o) 62 | 63 | obj_idxs = list(obj_idxs_with_rels) 64 | obj_idxs_without_rels = list(obj_idxs_without_rels) 65 | if len(obj_idxs) > self.max_objects - 1: 66 | obj_idxs = random.sample(obj_idxs, self.max_objects) 67 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects: 68 | num_to_add = self.max_objects - 1 - len(obj_idxs) 69 | num_to_add = min(num_to_add, len(obj_idxs_without_rels)) 70 | obj_idxs += random.sample(obj_idxs_without_rels, num_to_add) 71 | O = len(obj_idxs) + 1 72 | 73 | objs = torch.LongTensor(O).fill_(-1) 74 | 75 | boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(O, 1) 76 | obj_idx_mapping = {} 77 | for i, obj_idx in enumerate(obj_idxs): 78 | objs[i] = self.data['object_names'][index, obj_idx].item() 79 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist() 80 | x0 = float(x) / WW 81 | y0 = float(y) / HH 82 | x1 = float(x + w) / WW 83 | y1 = float(y + h) / HH 84 | boxes[i] = torch.FloatTensor([x0, y0, x1, y1]) 85 | obj_idx_mapping[obj_idx] = i 86 | 87 | objs[O - 1] = self.vocab['object_name_to_idx']['__image__'] 88 | 89 | triples = [] 90 | for r_idx in range(self.data['relationships_per_image'][index].item()): 91 | if not self.include_relationships: 92 | break 93 | s = self.data['relationship_subjects'][index, r_idx].item() 94 | p = self.data['relationship_predicates'][index, r_idx].item() 95 | o = self.data['relationship_objects'][index, r_idx].item() 96 | s = obj_idx_mapping.get(s, None) 97 | o = obj_idx_mapping.get(o, None) 98 | if s is not None and o is not None: 99 | triples.append([s, p, o]) 100 | 101 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 102 | for i in range(O - 1): 103 | triples.append([i, in_image, O - 1]) 104 | 105 | triples = torch.LongTensor(triples) 106 | return image, objs, boxes, triples 107 | 108 | 109 | class VGTrain(VGDatabase): 110 | def __init__(self, vocab, h5_path, image_dir, **kwargs): 111 | super().__init__(vocab=vocab, h5_path=h5_path, image_dir=image_dir, **kwargs) 112 | 113 | class VGValidation(VGDatabase): 114 | def __init__(self, vocab, h5_path, image_dir, **kwargs): 115 | super().__init__(vocab=vocab, h5_path=h5_path, image_dir=image_dir, **kwargs) 116 | 117 | 118 | def vg_collate_fn(batch): 119 | all_imgs, all_objs, all_boxes, all_triples = [], [], [], [] 120 | all_obj_to_img, all_triple_to_img = [], [] 121 | obj_offset = 0 122 | for i, (img, objs, boxes, triples) in enumerate(batch): 123 | all_imgs.append(img[None]) 124 | O, T = objs.size(0), triples.size(0) 125 | all_objs.append(objs) 126 | all_boxes.append(boxes) 127 | triples = triples.clone() 128 | triples[:, 0] += obj_offset 129 | triples[:, 2] += obj_offset 130 | all_triples.append(triples) 131 | 132 | all_obj_to_img.append(torch.LongTensor(O).fill_(i)) 133 | all_triple_to_img.append(torch.LongTensor(T).fill_(i)) 134 | obj_offset += O 135 | 136 | all_imgs = torch.cat(all_imgs) 137 | all_objs = torch.cat(all_objs) 138 | all_boxes = torch.cat(all_boxes) 139 | all_triples = torch.cat(all_triples) 140 | all_obj_to_img = torch.cat(all_obj_to_img) 141 | all_triple_to_img = torch.cat(all_triple_to_img) 142 | 143 | out = (all_imgs, all_objs, all_boxes, all_triples, all_obj_to_img, all_triple_to_img) 144 | return out 145 | 146 | class Resize(object): 147 | def __init__(self, size, interp=PIL.Image.BILINEAR): 148 | if isinstance(size, tuple): 149 | H, W = size 150 | self.size = (W, H) 151 | else: 152 | self.size = (size, size) 153 | self.interp = interp 154 | 155 | def __call__(self, img): 156 | return img.resize(self.size, self.interp) -------------------------------------------------------------------------------- /ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /sg_image_pretraining/training/configs.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | from typing import Tuple, Union, List, Optional 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--train_data", type=str, default='/home/huangzl/Data/datasets/vg/images', help="Path to training dataset") 8 | 9 | parser.add_argument("--vocab_json", type=str, default='/home/huangzl/Data/datasets/vg/vocab.json', help="Path to json file of vocab.") 10 | parser.add_argument("--train_h5", type=str, default='/home/huangzl/Data/datasets/vg/train.h5', help="Path to h5 file of train dataset.") 11 | parser.add_argument("--val_h5", type=str, default='/home/huangzl/Data/datasets/vg/val.h5', help="Path to h5 file of validate dataset.") 12 | parser.add_argument("--max_objects_per_image", type=int, default=10, help="Max objects of each image.") 13 | parser.add_argument("--use_orphaned_objects", type=bool, default=True, help="Use orphaned objects or not in the image.") 14 | parser.add_argument("--include_relationships", type=bool, default=True, help="Obtain relationships annotations between objects in the dataset.") 15 | parser.add_argument("--model_config_json", type=str, default='', help="Path to json file of model configs.") 16 | parser.add_argument("--image_size", type=int, default=224, help="Image size for training.") 17 | 18 | # tower config 19 | parser.add_argument("--graph_width", type=int, default=512, help="Width of Graph Tower.") 20 | parser.add_argument("--num_graph_layer", type=int, default=5, help="Number of layers in Graph Tower.") 21 | parser.add_argument("--embed_dim", type=int, default=512, help="Dimension of embeddings.") 22 | 23 | # training config 24 | parser.add_argument("--name", type=str, default=None, help="Optional identifier for the experiment when storing logs. Otherwise use current time.") 25 | parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.") 26 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size per GPU.") 27 | parser.add_argument("--val_batch_size", type=int, default=128, help="Batch size per GPU for Validation.") 28 | parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train for.") 29 | parser.add_argument("--lr", type=float, default=5.0e-4, help="Learning rate.") 30 | parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.") 31 | parser.add_argument("--beta2", type=float, default=0.999, help="Adam beta 2.") 32 | parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.") 33 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 34 | parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.") 35 | parser.add_argument("--use_bn_sync", default=False, action="store_true", help="Whether to use batch norm sync.") 36 | parser.add_argument("--skip_scheduler", action="store_true", default=False, help="Use this flag to skip the learning rate decay.") 37 | parser.add_argument("--save_frequency", type=int, default=1, help="How often to save checkpoints. epoch level.") 38 | parser.add_argument("--save_most_recent", action="store_true", default=False, help="Always save the most recent model trained to epoch_latest.pt.") 39 | parser.add_argument("--logs", type=str, default="./logs/", help="Where to store tensorboard logs. Use None to avoid storing logs.") 40 | parser.add_argument("--log_local", action="store_true", default=False, help="log files on local master, otherwise global master only.") 41 | 42 | parser.add_argument("--val_frequency", type=int, default=1, help="How often to run evaluation with val data.") 43 | parser.add_argument("--precision", choices=["amp", "amp_bfloat16", "fp16", "fp32"], default="amp", help="Floating point precision.") 44 | parser.add_argument("--pretrained", default='', type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.") 45 | parser.add_argument("--pretrained-image", default=False, action='store_true', help="Load imagenet pretrained weights for image tower backbone if available.") 46 | 47 | parser.add_argument("--lock_image", default=False, action='store_true', help="Lock full image tower by disabling gradients.") 48 | parser.add_argument("--lock_image_unlocked_groups", type=int, default=0, help="Leave last n image tower layer groups unlocked.") 49 | parser.add_argument("--lock_image_freeze_bn_stats", default=False, action='store_true', help="Freeze BatchNorm running stats in image tower for any locked layers.") 50 | parser.add_argument('--image_mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override default image mean value of dataset') 51 | parser.add_argument('--image_std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of of dataset') 52 | parser.add_argument("--grad_checkpointing", default=False, action='store_true', help="Enable gradient checkpointing.") 53 | parser.add_argument("--local_loss", default=False, action="store_true", help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)") 54 | parser.add_argument("--gather_with_grad", default=False, action="store_true", help="enable full distributed gradient for feature gather") 55 | parser.add_argument("--force_quick_gelu", default=False, action='store_true', help="Force use of QuickGELU activation for non-OpenAI transformer models.") 56 | 57 | parser.add_argument("--dist_url", default="env://", type=str, help="url used to set up distributed training") 58 | parser.add_argument("--dist_backend", default="nccl", type=str, help="distributed backend") 59 | parser.add_argument("--report_to", default='', type=str, help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']") 60 | parser.add_argument("--debug", default=False, action="store_true", help="If true, more information is logged.") 61 | parser.add_argument("--ddp_static_graph", default=False, action='store_true', help="Enable static graph optimization for DDP in PyTorch >= 1.11.") 62 | parser.add_argument("--no_set_device_rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).") 63 | parser.add_argument("--seed", type=int, default=9768, help="Default random seed.") 64 | parser.add_argument("--norm_gradient_clip", type=float, default=None, help="Gradient clip.") 65 | args = parser.parse_args() 66 | 67 | return args 68 | 69 | @dataclass 70 | class CLIPVisionCfg: 71 | layers: Union[Tuple[int, int, int, int], int] 72 | width: int 73 | head_width: int 74 | image_size: int 75 | mlp_ratio: float 76 | patch_size: int = None 77 | timm_model_name: str = None 78 | timm_model_pretrained: bool = None 79 | timm_pool: str = None 80 | timm_proj: str = None 81 | 82 | 83 | @dataclass 84 | class CLIPGraphCfg: 85 | layers: int 86 | width: int -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/mm_transformer_module.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | 9 | def exists(val): 10 | return val is not None 11 | 12 | 13 | def uniq(arr): 14 | return{el: True for el in arr}.keys() 15 | 16 | 17 | def default(val, d): 18 | if exists(val): 19 | return val 20 | return d() if isfunction(d) else d 21 | 22 | 23 | def max_neg_value(t): 24 | return -torch.finfo(t.dtype).max 25 | 26 | 27 | def init_(tensor): 28 | dim = tensor.shape[-1] 29 | std = 1 / math.sqrt(dim) 30 | tensor.uniform_(-std, std) 31 | return tensor 32 | 33 | 34 | class GEGLU(nn.Module): 35 | def __init__(self, dim_in, dim_out): 36 | super().__init__() 37 | self.proj = nn.Linear(dim_in, dim_out * 2) 38 | 39 | def forward(self, x): 40 | x, gate = self.proj(x).chunk(2, dim=-1) 41 | return x * F.gelu(gate) 42 | 43 | 44 | class FeedForward(nn.Module): 45 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 46 | super().__init__() 47 | inner_dim = int(dim * mult) 48 | dim_out = default(dim_out, dim) 49 | project_in = nn.Sequential( 50 | nn.Linear(dim, inner_dim), 51 | nn.GELU() 52 | ) if not glu else GEGLU(dim, inner_dim) 53 | 54 | self.net = nn.Sequential( 55 | project_in, 56 | nn.Dropout(dropout), 57 | nn.Linear(inner_dim, dim_out) 58 | ) 59 | 60 | def forward(self, x): 61 | return self.net(x) 62 | 63 | 64 | def zero_module(module): 65 | for p in module.parameters(): 66 | p.detach().zero_() 67 | return module 68 | 69 | 70 | def Normalize(in_channels): 71 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 72 | 73 | 74 | class SpatialSelfAttention(nn.Module): 75 | def __init__(self, in_channels): 76 | super().__init__() 77 | self.in_channels = in_channels 78 | 79 | self.norm = Normalize(in_channels) 80 | self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 81 | self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 82 | self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 83 | self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) 84 | 85 | def forward(self, x): 86 | h_ = x 87 | h_ = self.norm(h_) 88 | q = self.q(h_) 89 | k = self.k(h_) 90 | v = self.v(h_) 91 | 92 | b,c,h,w = q.shape 93 | q = rearrange(q, 'b c h w -> b (h w) c') 94 | k = rearrange(k, 'b c h w -> b c (h w)') 95 | w_ = torch.einsum('bij,bjk->bik', q, k) 96 | 97 | w_ = w_ * (int(c)**(-0.5)) 98 | w_ = torch.nn.functional.softmax(w_, dim=2) 99 | 100 | v = rearrange(v, 'b c h w -> b c (h w)') 101 | w_ = rearrange(w_, 'b i j -> b j i') 102 | h_ = torch.einsum('bij,bjk->bik', v, w_) 103 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 104 | h_ = self.proj_out(h_) 105 | 106 | return x+h_ 107 | 108 | 109 | class CrossAttention(nn.Module): 110 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 111 | super().__init__() 112 | inner_dim = dim_head * heads 113 | context_dim = default(context_dim, query_dim) 114 | 115 | self.scale = dim_head ** -0.5 116 | self.heads = heads 117 | 118 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 119 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 120 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 121 | 122 | self.to_out = nn.Sequential( 123 | nn.Linear(inner_dim, query_dim), 124 | nn.Dropout(dropout) 125 | ) 126 | 127 | def forward(self, x, context=None, mask=None): 128 | h = self.heads 129 | 130 | q = self.to_q(x) 131 | context = default(context, x) 132 | k = self.to_k(context) 133 | v = self.to_v(context) 134 | 135 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 136 | 137 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 138 | 139 | if exists(mask): 140 | mask = rearrange(mask, 'b ... -> b (...)') 141 | max_neg_value = -torch.finfo(sim.dtype).max 142 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 143 | sim.masked_fill_(~mask, max_neg_value) 144 | 145 | attn = sim.softmax(dim=-1) 146 | 147 | out = einsum('b i j, b j d -> b i d', attn, v) 148 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 149 | return self.to_out(out) 150 | 151 | 152 | class BasicTransformerBlock(nn.Module): 153 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 154 | super().__init__() 155 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 156 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 157 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 158 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 159 | self.norm1 = nn.LayerNorm(dim) 160 | self.norm2 = nn.LayerNorm(dim) 161 | self.norm3 = nn.LayerNorm(dim) 162 | 163 | 164 | def forward(self, x, context=None): 165 | x = self.attn1(self.norm1(x)) + x 166 | x = self.attn2(self.norm2(x), context=context) + x 167 | x = self.ff(self.norm3(x)) + x 168 | return x 169 | 170 | 171 | class SpatialTransformer(nn.Module): 172 | def __init__(self, in_channels, n_heads, d_head, 173 | depth=1, dropout=0., context_dim=None): 174 | super().__init__() 175 | self.in_channels = in_channels 176 | inner_dim = n_heads * d_head 177 | self.norm = Normalize(in_channels) 178 | 179 | self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) 180 | 181 | self.transformer_blocks = nn.ModuleList( 182 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 183 | for d in range(depth)] 184 | ) 185 | 186 | self.proj_out = zero_module(nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)) 187 | 188 | def forward(self, x, context=None): 189 | b, c, h, w = x.shape 190 | x_in = x 191 | x = self.norm(x) 192 | x = self.proj_in(x) 193 | x = rearrange(x, 'b c h w -> b (h w) c') 194 | for block in self.transformer_blocks: 195 | x = block(x, context=context) 196 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 197 | x = self.proj_out(x) 198 | return x + x_in 199 | -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/model.py: -------------------------------------------------------------------------------- 1 | """ CLIP Model 2 | 3 | Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. 4 | """ 5 | from dataclasses import dataclass 6 | from typing import Tuple, Union 7 | import numpy as np 8 | import torch 9 | import torch.nn.functional as F 10 | from torch import nn 11 | from sgCLIP.module import ModifiedResNet, QuickGELU, VisualTransformer, GraphTripleConv, GraphTripleConvNet, GraphAggregationNetwork, Attention 12 | from training.configs import CLIPVisionCfg, CLIPGraphCfg 13 | from utils import create_tensor_by_assign_samples_to_img, get_linear_feas_by_hook 14 | import clip 15 | from global_var import * 16 | 17 | class sgCLIP(nn.Module): 18 | def __init__(self, 19 | graph_vocab: dict, 20 | graph_cfg: CLIPGraphCfg, 21 | embed_dim: int, 22 | max_sample_per_img: int=15, 23 | ): 24 | super().__init__() 25 | if isinstance(graph_cfg, dict): 26 | graph_cfg = CLIPGraphCfg(**graph_cfg) 27 | 28 | self.clip_model, preprocess = clip.load("ViT-B/32", device=device) 29 | self.clip_model.eval().requires_grad_(False).to(device) 30 | 31 | num_objs = len(graph_vocab['object_idx_to_name']) 32 | num_preds = len(graph_vocab['pred_idx_to_name']) 33 | self.num_objs = num_objs 34 | self.num_preds = num_preds 35 | self.max_sample_per_img = max_sample_per_img 36 | self.obj_embeddings = nn.Embedding(num_objs + 1, embed_dim) 37 | self.pred_embeddings = nn.Embedding(num_preds, embed_dim) 38 | 39 | self.graph_conv = GraphTripleConv(embed_dim, output_dim=embed_dim, hidden_dim=graph_cfg.width, pooling='avg', mlp_normalization='none') 40 | self.graph_net = GraphTripleConvNet(embed_dim, num_layers=graph_cfg.layers, hidden_dim=graph_cfg.width, pooling='avg', mlp_normalization='none') 41 | self.graph_projection = nn.Linear(embed_dim * 2, embed_dim) 42 | 43 | self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) 44 | 45 | 46 | def initialize_parameters(self): 47 | nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) 48 | 49 | if hasattr(self.graph_conv, 'init_parameters'): 50 | self.graph_conv.init_parameters() 51 | if hasattr(self.graph_net, 'init_parameters'): 52 | self.graph_net.init_parameters() 53 | if hasattr(self.graph_projection, 'init_parameters'): 54 | self.graph_projection.init_parameters() 55 | 56 | @torch.jit.ignore 57 | def set_grad_checkpointing(self, enable=False): 58 | self.visual.set_grad_checkpointing(enable) 59 | 60 | def encode_image_local_global(self, image): 61 | with torch.no_grad(): 62 | extract_linear_feas = get_linear_feas_by_hook(self.clip_model.visual) 63 | global_image_fea = self.clip_model.encode_image(image) 64 | local_image_fea = extract_linear_feas[-1].extract_fea 65 | return local_image_fea.detach(), global_image_fea.detach() 66 | 67 | def encode_graph_local_global(self, img, graph): 68 | batch_size, _, H, W = img.shape 69 | 70 | objs, boxes, triples, obj_to_img, triples_to_img = graph 71 | s, p, o = triples.chunk(3, dim=1) 72 | s, p, o = [x.squeeze(1) for x in [s, p, o]] 73 | edges = torch.stack([s, o], dim=1) 74 | 75 | obj_vecs = self.obj_embeddings(objs) 76 | pred_vecs = self.pred_embeddings(p) 77 | 78 | if isinstance(self.graph_conv, nn.Linear): 79 | obj_vecs = self.graph_conv(obj_vecs) 80 | else: 81 | obj_vecs, pred_vecs = self.graph_conv(obj_vecs, pred_vecs, edges) 82 | if self.graph_net is not None: 83 | obj_vecs, pred_vecs = self.graph_net(obj_vecs, pred_vecs, edges) 84 | 85 | # Global Branch 86 | obj_fea = self.pool_samples(obj_vecs, obj_to_img) 87 | pred_fea = self.pool_samples(pred_vecs, triples_to_img) 88 | graph_global_fea = self.graph_projection(torch.cat([obj_fea, pred_fea], dim=1)) 89 | 90 | # Local Branch 91 | s_obj_vec, o_obj_vec = obj_vecs[s], obj_vecs[o] 92 | triple_vec = torch.cat([s_obj_vec, pred_vecs, o_obj_vec], dim=1) 93 | graph_local_fea = create_tensor_by_assign_samples_to_img(samples=triple_vec, sample_to_img=triples_to_img, 94 | max_sample_per_img=self.max_sample_per_img, 95 | batch_size=batch_size) 96 | 97 | return graph_local_fea, graph_global_fea 98 | 99 | def forward(self, image, graph): 100 | local_image_feature, global_image_features = self.encode_image_local_global(image) 101 | norm_global_image_features = F.normalize(global_image_features, dim=-1) 102 | local_graph_features, global_graph_features = self.encode_graph_local_global(image, graph) 103 | norm_global_graph_features = F.normalize(global_graph_features, dim=-1) 104 | 105 | return local_image_feature, local_graph_features, norm_global_image_features, norm_global_graph_features, self.logit_scale.exp() 106 | 107 | def pool_samples(self, samples, obj_to_img, pooling='avg'): 108 | dtype, device = samples.dtype, samples.device 109 | O, D = samples.size() 110 | 111 | N = obj_to_img.data.max().item() + 1 112 | 113 | out = torch.zeros(N, D, dtype=dtype, device=device) 114 | idx = obj_to_img.view(O, 1).expand(O, D) 115 | out = out.scatter_add(0, idx, samples) 116 | 117 | if pooling == 'avg': 118 | ones = torch.ones(O, dtype=dtype, device=device) 119 | obj_counts = torch.zeros(N, dtype=dtype, device=device) 120 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) 121 | obj_counts = obj_counts.clamp(min=1) 122 | out = out / obj_counts.view(N, 1) 123 | elif pooling != 'sum': 124 | raise ValueError('Invalid pooling "%s"' % pooling) 125 | 126 | return out 127 | 128 | def convert_weights_to_fp16(model: nn.Module): 129 | """Convert applicable model parameters to fp16""" 130 | 131 | def _convert_weights_to_fp16(l): 132 | if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): 133 | l.weight.data = l.weight.data.half() 134 | if l.bias is not None: 135 | l.bias.data = l.bias.data.half() 136 | 137 | if isinstance(l, (nn.MultiheadAttention, Attention)): 138 | for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: 139 | tensor = getattr(l, attr) 140 | if tensor is not None: 141 | tensor.data = tensor.data.half() 142 | 143 | for name in ["text_projection", "proj"]: 144 | if hasattr(l, name): 145 | attr = getattr(l, name) 146 | if attr is not None: 147 | attr.data = attr.data.half() 148 | 149 | model.apply(_convert_weights_to_fp16) 150 | 151 | def idx_to_one_hot(idx, num_classes): 152 | result = F.one_hot(idx, num_classes) 153 | result = result.float() 154 | return result 155 | -------------------------------------------------------------------------------- /sg_image_pretraining/datasets/vg_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import torchvision.transforms as transforms 3 | import h5py 4 | import PIL 5 | import os 6 | import random 7 | import numpy as np 8 | import json 9 | import torch 10 | from datasets.transform import VGResize 11 | from utils import imagenet_preprocess 12 | from global_var import * 13 | 14 | class VgSceneGraphDataset(Dataset): 15 | def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256), normalize_images=True, max_objects=10, max_samples=None, 16 | include_relationships=True, use_orphaned_objects=True): 17 | super(VgSceneGraphDataset, self).__init__() 18 | 19 | self.image_dir = image_dir 20 | self.image_size = image_size 21 | self.vocab = vocab 22 | self.num_objects = len(vocab['object_idx_to_name']) 23 | self.num_attributes = len(vocab['attribute_idx_to_name']) 24 | self.use_orphaned_objects = use_orphaned_objects 25 | self.max_objects = max_objects 26 | self.max_samples = max_samples 27 | self.include_relationships = include_relationships 28 | 29 | transform = [VGResize(image_size), transforms.ToTensor()] 30 | 31 | transform.append(imagenet_preprocess()) 32 | self.transform = transforms.Compose(transform) 33 | 34 | self.data = {} 35 | with h5py.File(h5_path, 'r') as f: 36 | for k, v in f.items(): 37 | if k == 'image_paths': 38 | self.image_paths = list(v) 39 | else: 40 | self.data[k] = torch.IntTensor(np.asarray(v)) 41 | 42 | def __len__(self): 43 | num = self.data['object_names'].size(0) 44 | if self.max_samples is not None: 45 | return min(self.max_samples, num) 46 | return num 47 | 48 | def __getitem__(self, index): 49 | img_path = os.path.join(self.image_dir, str(self.image_paths[index], encoding="utf-8")) 50 | 51 | with open(img_path, 'rb') as f: 52 | with PIL.Image.open(f) as image: 53 | WW, HH = image.size 54 | image = self.transform(image.convert('RGB')) 55 | 56 | H, W = self.image_size 57 | 58 | obj_idxs_with_rels = set() 59 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item())) 60 | for r_idx in range(self.data['relationships_per_image'][index]): 61 | s = self.data['relationship_subjects'][index, r_idx].item() 62 | o = self.data['relationship_objects'][index, r_idx].item() 63 | obj_idxs_with_rels.add(s) 64 | obj_idxs_with_rels.add(o) 65 | obj_idxs_without_rels.discard(s) 66 | obj_idxs_without_rels.discard(o) 67 | 68 | obj_idxs = list(obj_idxs_with_rels) 69 | obj_idxs_without_rels = list(obj_idxs_without_rels) 70 | if len(obj_idxs) > self.max_objects - 1: 71 | obj_idxs = random.sample(obj_idxs, self.max_objects) 72 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects: 73 | num_to_add = self.max_objects - 1 - len(obj_idxs) 74 | num_to_add = min(num_to_add, len(obj_idxs_without_rels)) 75 | obj_idxs += random.sample(obj_idxs_without_rels, num_to_add) 76 | O = len(obj_idxs) + 1 77 | 78 | objs = torch.LongTensor(O).fill_(-1) 79 | 80 | attributes = torch.LongTensor(O).fill_(-1) 81 | boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(O, 1) 82 | obj_idx_mapping = {} 83 | for i, obj_idx in enumerate(obj_idxs): 84 | objs[i] = self.data['object_names'][index, obj_idx].item() 85 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist() 86 | x0 = float(x) / WW 87 | y0 = float(y) / HH 88 | x1 = float(x + w) / WW 89 | y1 = float(y + h) / HH 90 | boxes[i] = torch.FloatTensor([x0, y0, x1, y1]) 91 | obj_idx_mapping[obj_idx] = i 92 | attributes[i] = self.data['attributes_per_object'][index, obj_idx].item() 93 | 94 | objs[O - 1] = self.vocab['object_name_to_idx']['__image__'] 95 | 96 | triples = [] 97 | for r_idx in range(self.data['relationships_per_image'][index].item()): 98 | if not self.include_relationships: 99 | break 100 | s = self.data['relationship_subjects'][index, r_idx].item() 101 | p = self.data['relationship_predicates'][index, r_idx].item() 102 | o = self.data['relationship_objects'][index, r_idx].item() 103 | s = obj_idx_mapping.get(s, None) 104 | o = obj_idx_mapping.get(o, None) 105 | if s is not None and o is not None: 106 | triples.append([s, p, o]) 107 | 108 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 109 | for i in range(O - 1): 110 | triples.append([i, in_image, O - 1]) 111 | 112 | triples = torch.LongTensor(triples) 113 | return image, objs, boxes, triples, attributes 114 | 115 | 116 | def vg_collate_fn(batch): 117 | all_imgs, all_objs, all_boxes, all_triples, all_attributes = [], [], [], [], [] 118 | all_obj_to_img, all_triple_to_img = [], [] 119 | obj_offset = 0 120 | for i, (img, objs, boxes, triples, attributes) in enumerate(batch): 121 | all_imgs.append(img[None]) 122 | O, T = objs.size(0), triples.size(0) 123 | all_objs.append(objs) 124 | all_boxes.append(boxes) 125 | triples = triples.clone() 126 | triples[:, 0] += obj_offset 127 | triples[:, 2] += obj_offset 128 | all_triples.append(triples) 129 | 130 | all_obj_to_img.append(torch.LongTensor(O).fill_(i)) 131 | all_triple_to_img.append(torch.LongTensor(T).fill_(i)) 132 | obj_offset += O 133 | 134 | all_attributes.append(attributes) 135 | 136 | all_imgs = torch.cat(all_imgs) 137 | all_objs = torch.cat(all_objs) 138 | all_boxes = torch.cat(all_boxes) 139 | all_triples = torch.cat(all_triples) 140 | all_obj_to_img = torch.cat(all_obj_to_img) 141 | all_triple_to_img = torch.cat(all_triple_to_img) 142 | 143 | out = (all_imgs, all_objs, all_boxes, all_triples, all_obj_to_img, all_triple_to_img) 144 | return out 145 | 146 | 147 | def build_vg_dsets(args): 148 | with open(args.vocab_json, 'r') as f: 149 | vocab = json.load(f) 150 | dset_kwargs = { 151 | 'vocab': vocab, 152 | 'h5_path': args.train_h5, 153 | 'image_dir': args.train_data, 154 | 'image_size': (args.image_size, args.image_size), 155 | 'max_samples': None, 156 | 'max_objects': args.max_objects_per_image, 157 | 'use_orphaned_objects': args.use_orphaned_objects, 158 | 'include_relationships': args.include_relationships, 159 | } 160 | train_dset = VgSceneGraphDataset(**dset_kwargs) 161 | iter_per_epoch = len(train_dset) // args.batch_size 162 | print('There are %d iterations per epoch' % iter_per_epoch) 163 | 164 | dset_kwargs['h5_path'] = args.val_h5 165 | del dset_kwargs['max_samples'] 166 | val_dset = VgSceneGraphDataset(**dset_kwargs) 167 | 168 | return vocab, train_dset, val_dset 169 | -------------------------------------------------------------------------------- /sg2im/vis.py: -------------------------------------------------------------------------------- 1 | import tempfile, os 2 | import torch 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from matplotlib.patches import Rectangle 6 | from imageio import imread 7 | 8 | 9 | """ 10 | Utilities for making visualizations. 11 | """ 12 | 13 | 14 | def draw_layout(vocab, objs, boxes, masks=None, size=256, 15 | show_boxes=False, bgcolor=(0, 0, 0)): 16 | if bgcolor == 'white': 17 | bgcolor = (255, 255, 255) 18 | 19 | cmap = plt.get_cmap('rainbow') 20 | colors = cmap(np.linspace(0, 1, len(objs))) 21 | 22 | with torch.no_grad(): 23 | objs = objs.cpu().clone() 24 | boxes = boxes.cpu().clone() 25 | boxes *= size 26 | 27 | if masks is not None: 28 | masks = masks.cpu().clone() 29 | 30 | bgcolor = np.asarray(bgcolor) 31 | bg = np.ones((size, size, 1)) * bgcolor 32 | plt.imshow(bg.astype(np.uint8)) 33 | 34 | plt.gca().set_xlim(0, size) 35 | plt.gca().set_ylim(size, 0) 36 | plt.gca().set_aspect(1.0, adjustable='box') 37 | 38 | for i, obj in enumerate(objs): 39 | name = vocab['object_idx_to_name'][obj] 40 | if name == '__image__': 41 | continue 42 | box = boxes[i] 43 | 44 | if masks is None: 45 | continue 46 | mask = masks[i].numpy() 47 | mask /= mask.max() 48 | 49 | r, g, b, a = colors[i] 50 | colored_mask = mask[:, :, None] * np.asarray(colors[i]) 51 | 52 | x0, y0, x1, y1 = box 53 | plt.imshow(colored_mask, extent=(x0, x1, y1, y0), 54 | interpolation='bicubic', alpha=1.0) 55 | 56 | if show_boxes: 57 | for i, obj in enumerate(objs): 58 | name = vocab['object_idx_to_name'][obj] 59 | if name == '__image__': 60 | continue 61 | box = boxes[i] 62 | 63 | draw_box(box, colors[i], name) 64 | 65 | 66 | def draw_box(box, color, text=None): 67 | """ 68 | Draw a bounding box using pyplot, optionally with a text box label. 69 | 70 | Inputs: 71 | - box: Tensor or list with 4 elements: [x0, y0, x1, y1] in [0, W] x [0, H] 72 | coordinate system. 73 | - color: pyplot color to use for the box. 74 | - text: (Optional) String; if provided then draw a label for this box. 75 | """ 76 | TEXT_BOX_HEIGHT = 10 77 | if torch.is_tensor(box) and box.dim() == 2: 78 | box = box.view(-1) 79 | assert box.size(0) == 4 80 | x0, y0, x1, y1 = box 81 | assert y1 > y0, box 82 | assert x1 > x0, box 83 | w, h = x1 - x0, y1 - y0 84 | rect = Rectangle((x0, y0), w, h, fc='none', lw=2, ec=color) 85 | plt.gca().add_patch(rect) 86 | if text is not None: 87 | text_rect = Rectangle((x0, y0), w, TEXT_BOX_HEIGHT, fc=color, alpha=0.5) 88 | plt.gca().add_patch(text_rect) 89 | tx = 0.5 * (x0 + x1) 90 | ty = y0 + TEXT_BOX_HEIGHT / 2.0 91 | plt.text(tx, ty, text, va='center', ha='center') 92 | 93 | 94 | def draw_scene_graph(objs, triples, vocab=None, **kwargs): 95 | """ 96 | Use GraphViz to draw a scene graph. If vocab is not passed then we assume 97 | that objs and triples are python lists containing strings for object and 98 | relationship names. 99 | 100 | Using this requires that GraphViz is installed. On Ubuntu 16.04 this is easy: 101 | sudo apt-get install graphviz 102 | """ 103 | output_filename = kwargs.pop('output_filename', 'graph.png') 104 | orientation = kwargs.pop('orientation', 'V') 105 | edge_width = kwargs.pop('edge_width', 6) 106 | arrow_size = kwargs.pop('arrow_size', 1.5) 107 | binary_edge_weight = kwargs.pop('binary_edge_weight', 1.2) 108 | ignore_dummies = kwargs.pop('ignore_dummies', True) 109 | 110 | if orientation not in ['V', 'H']: 111 | raise ValueError('Invalid orientation "%s"' % orientation) 112 | rankdir = {'H': 'LR', 'V': 'TD'}[orientation] 113 | 114 | if vocab is not None: 115 | # Decode object and relationship names 116 | assert torch.is_tensor(objs) 117 | assert torch.is_tensor(triples) 118 | objs_list, triples_list = [], [] 119 | for i in range(objs.size(0)): 120 | objs_list.append(vocab['object_idx_to_name'][objs[i].item()]) 121 | for i in range(triples.size(0)): 122 | s = triples[i, 0].item() 123 | p = vocab['pred_name_to_idx'][triples[i, 1].item()] 124 | o = triples[i, 2].item() 125 | triples_list.append([s, p, o]) 126 | objs, triples = objs_list, triples_list 127 | 128 | # General setup, and style for object nodes 129 | lines = [ 130 | 'digraph{', 131 | 'graph [size="5,3",ratio="compress",dpi="300",bgcolor="transparent"]', 132 | 'rankdir=%s' % rankdir, 133 | 'nodesep="0.5"', 134 | 'ranksep="0.5"', 135 | 'node [shape="box",style="rounded,filled",fontsize="48",color="none"]', 136 | 'node [fillcolor="lightpink1"]', 137 | ] 138 | # Output nodes for objects 139 | for i, obj in enumerate(objs): 140 | if ignore_dummies and obj == '__image__': 141 | continue 142 | lines.append('%d [label="%s"]' % (i, obj)) 143 | 144 | # Output relationships 145 | next_node_id = len(objs) 146 | lines.append('node [fillcolor="lightblue1"]') 147 | for s, p, o in triples: 148 | if ignore_dummies and p == '__in_image__': 149 | continue 150 | lines += [ 151 | '%d [label="%s"]' % (next_node_id, p), 152 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 153 | s, next_node_id, edge_width, arrow_size, binary_edge_weight), 154 | '%d->%d [penwidth=%f,arrowsize=%f,weight=%f]' % ( 155 | next_node_id, o, edge_width, arrow_size, binary_edge_weight) 156 | ] 157 | next_node_id += 1 158 | lines.append('}') 159 | 160 | # Now it gets slightly hacky. Write the graphviz spec to a temporary 161 | # text file 162 | ff, dot_filename = tempfile.mkstemp() 163 | with open(dot_filename, 'w') as f: 164 | for line in lines: 165 | f.write('%s\n' % line) 166 | os.close(ff) 167 | 168 | # Shell out to invoke graphviz; this will save the resulting image to disk, 169 | # so we read it, delete it, then return it. 170 | output_format = os.path.splitext(output_filename)[1][1:] 171 | os.system('dot -T%s %s > %s' % (output_format, dot_filename, output_filename)) 172 | os.remove(dot_filename) 173 | img = imread(output_filename) 174 | os.remove(output_filename) 175 | 176 | return img 177 | 178 | 179 | if __name__ == '__main__': 180 | o_idx_to_name = ['cat', 'dog', 'hat', 'skateboard'] 181 | p_idx_to_name = ['riding', 'wearing', 'on', 'next to', 'above'] 182 | o_name_to_idx = {s: i for i, s in enumerate(o_idx_to_name)} 183 | p_name_to_idx = {s: i for i, s in enumerate(p_idx_to_name)} 184 | vocab = { 185 | 'object_idx_to_name': o_idx_to_name, 186 | 'object_name_to_idx': o_name_to_idx, 187 | 'pred_idx_to_name': p_idx_to_name, 188 | 'pred_name_to_idx': p_name_to_idx, 189 | } 190 | 191 | objs = [ 192 | 'cat', 193 | 'cat', 194 | 'skateboard', 195 | 'hat', 196 | ] 197 | objs = torch.LongTensor([o_name_to_idx[o] for o in objs]) 198 | triples = [ 199 | [0, 'next to', 1], 200 | [0, 'riding', 2], 201 | [1, 'wearing', 3], 202 | [3, 'above', 2], 203 | ] 204 | triples = [[s, p_name_to_idx[p], o] for s, p, o in triples] 205 | triples = torch.LongTensor(triples) 206 | 207 | draw_scene_graph(objs, triples, vocab, orientation='V') 208 | 209 | -------------------------------------------------------------------------------- /ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import rearrange, repeat 6 | import kornia 7 | 8 | 9 | from ldm.modules.x_transformer import Encoder, TransformerWrapper 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class TransformerEmbedder(AbstractEncoder): 37 | """Some transformer encoder layers""" 38 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=77, device="cuda"): 39 | super().__init__() 40 | self.device = device 41 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 42 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 43 | 44 | def forward(self, tokens): 45 | tokens = tokens.to(self.device) # meh 46 | z = self.transformer(tokens, return_embeddings=True) 47 | return z 48 | 49 | def encode(self, x): 50 | return self(x) 51 | 52 | 53 | class BERTTokenizer(AbstractEncoder): 54 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 55 | def __init__(self, device="cuda", vq_interface=True, max_length=77): 56 | super().__init__() 57 | from transformers import BertTokenizerFast 58 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 59 | self.device = device 60 | self.vq_interface = vq_interface 61 | self.max_length = max_length 62 | 63 | def forward(self, text): 64 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 65 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 66 | tokens = batch_encoding["input_ids"].to(self.device) 67 | return tokens 68 | 69 | @torch.no_grad() 70 | def encode(self, text): 71 | tokens = self(text) 72 | if not self.vq_interface: 73 | return tokens 74 | return None, None, [None, None, tokens] 75 | 76 | def decode(self, text): 77 | return text 78 | 79 | 80 | class BERTEmbedder(AbstractEncoder): 81 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 82 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=77, 83 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 84 | super().__init__() 85 | self.use_tknz_fn = use_tokenizer 86 | if self.use_tknz_fn: 87 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len) 88 | self.device = device 89 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 90 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 91 | emb_dropout=embedding_dropout) 92 | 93 | def forward(self, text): 94 | if self.use_tknz_fn: 95 | tokens = self.tknz_fn(text)#.to(self.device) 96 | else: 97 | tokens = text 98 | z = self.transformer(tokens, return_embeddings=True) 99 | return z 100 | 101 | def encode(self, text): 102 | # output of length 77 103 | return self(text) 104 | 105 | 106 | class SpatialRescaler(nn.Module): 107 | def __init__(self, 108 | n_stages=1, 109 | method='bilinear', 110 | multiplier=0.5, 111 | in_channels=3, 112 | out_channels=None, 113 | bias=False): 114 | super().__init__() 115 | self.n_stages = n_stages 116 | assert self.n_stages >= 0 117 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 118 | self.multiplier = multiplier 119 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 120 | self.remap_output = out_channels is not None 121 | if self.remap_output: 122 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 123 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 124 | 125 | def forward(self,x): 126 | for stage in range(self.n_stages): 127 | x = self.interpolator(x, scale_factor=self.multiplier) 128 | 129 | 130 | if self.remap_output: 131 | x = self.channel_mapper(x) 132 | return x 133 | 134 | def encode(self, x): 135 | return self(x) 136 | 137 | 138 | class FrozenCLIPTextEmbedder(nn.Module): 139 | """ 140 | Uses the CLIP transformer encoder for text. 141 | """ 142 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 143 | super().__init__() 144 | self.model, _ = clip.load(version, jit=False, device="cpu") 145 | self.device = device 146 | self.max_length = max_length 147 | self.n_repeat = n_repeat 148 | self.normalize = normalize 149 | 150 | def freeze(self): 151 | self.model = self.model.eval() 152 | for param in self.parameters(): 153 | param.requires_grad = False 154 | 155 | def forward(self, text): 156 | tokens = clip.tokenize(text).to(self.device) 157 | z = self.model.encode_text(tokens) 158 | if self.normalize: 159 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 160 | return z 161 | 162 | def encode(self, text): 163 | z = self(text) 164 | if z.ndim==2: 165 | z = z[:, None, :] 166 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 167 | return z 168 | 169 | 170 | class FrozenClipImageEmbedder(nn.Module): 171 | """ 172 | Uses the CLIP image encoder. 173 | """ 174 | def __init__( 175 | self, 176 | model, 177 | jit=False, 178 | device='cuda' if torch.cuda.is_available() else 'cpu', 179 | antialias=False, 180 | ): 181 | super().__init__() 182 | self.model, _ = clip.load(name=model, device=device, jit=jit) 183 | 184 | self.antialias = antialias 185 | 186 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 187 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 188 | 189 | def preprocess(self, x): 190 | # normalize to [0,1] 191 | x = kornia.geometry.resize(x, (224, 224), 192 | interpolation='bicubic',align_corners=True, 193 | antialias=self.antialias) 194 | x = (x + 1.) / 2. 195 | # renormalize according to clip 196 | x = kornia.enhance.normalize(x, self.mean, self.std) 197 | return x 198 | 199 | def forward(self, x): 200 | # x is assumed to be in range [-1,1] 201 | return self.model.encode_image(self.preprocess(x)) 202 | 203 | -------------------------------------------------------------------------------- /sg_image_pretraining/sgCLIP/masked_loss.py: -------------------------------------------------------------------------------- 1 | from global_var import * 2 | import torch 3 | import torch.nn as nn 4 | import torch.distributed.nn 5 | from sgCLIP.module import GraphTripleConv, GraphTripleConvNet, Attention, Transformer 6 | from utils import boxes_to_mask 7 | has_distributed = True 8 | 9 | class MaskedSceneGraphLoss(nn.Module): 10 | def __init__(self, triple_dim, max_relationships_per_image=15, threshold=0.3): 11 | super().__init__() 12 | self.transformer = Transformer(width=triple_dim, layers=2, heads=4) 13 | self.triple_mlp = nn.Sequential(nn.Linear(triple_dim, triple_dim)) 14 | self.criterion = nn.MSELoss(reduction='mean') 15 | 16 | self.threshold = threshold 17 | 18 | def forward(self, triple_per_img): 19 | triple_per_img = self.triple_mlp(triple_per_img) 20 | rec_loss = self.calculate_reconstruction_loss(triple_per_img) 21 | return rec_loss 22 | 23 | def calculate_reconstruction_loss(self, triple_fea): 24 | batch_size = triple_fea.shape[0] 25 | len_triple_fea = triple_fea.shape[1] 26 | 27 | triple_mask = (torch.rand([batch_size, len_triple_fea, 1]) > self.threshold).float().detach().to(device) # [0 for mask] 28 | masked_triple_fea = triple_mask * triple_fea 29 | 30 | rec_fea = self.reconstruct_missing_sg(masked_triple_fea) 31 | gt_fea = triple_fea.detach() 32 | 33 | valid_mask = (torch.mean(triple_fea, dim=2, keepdim=True)).float() 34 | loss_mask = (1 - triple_mask) * valid_mask.detach() 35 | loss_mask.detach() 36 | 37 | rec_loss = self.criterion(rec_fea * loss_mask, gt_fea * loss_mask) 38 | return rec_loss 39 | 40 | def reconstruct_missing_sg(self, triple_fea): 41 | 42 | triple_fea = triple_fea.permute(1, 0, 2).contiguous() 43 | rec_fea = self.transformer(triple_fea) 44 | rec_fea = rec_fea.permute(1, 0, 2).contiguous() 45 | return rec_fea 46 | 47 | class Img2MaskedSceneGraphLoss(nn.Module): 48 | def __init__(self, triple_dim, image_dim, max_relationships_per_image=30, image_size=32, threshold=0.3): 49 | super().__init__() 50 | self.transformer = Transformer(width=image_dim, layers=2, heads=4) 51 | self.triple_mlp = nn.Sequential(nn.Linear(triple_dim, image_dim)) 52 | self.criterion = nn.MSELoss(reduction='mean') 53 | 54 | self.threshold = threshold 55 | 56 | self.image_size = image_size 57 | self.image_dim = image_dim 58 | 59 | def forward(self, triple_per_img, image_feature): 60 | image_feature = image_feature.detach() 61 | triple_per_img = self.triple_mlp(triple_per_img) 62 | rec_loss = self.calculate_reconstruction_loss(triple_per_img, image_feature) 63 | return rec_loss 64 | 65 | def calculate_reconstruction_loss(self, triple_fea, img_fea): 66 | batch_size = triple_fea.shape[0] 67 | img_fea = img_fea.permute(0,2,3,1).contiguous() 68 | img_fea = img_fea.view(batch_size, self.image_size * self.image_size, self.image_dim) 69 | len_img_fea = img_fea.shape[1] 70 | assert len_img_fea == self.image_size * self.image_size 71 | len_triple_fea = triple_fea.shape[1] 72 | 73 | triple_mask = (torch.rand([batch_size, len_triple_fea, 1]) > self.threshold).float().to(device) 74 | masked_triple_fea = triple_mask * triple_fea 75 | 76 | rec_fea = self.reconstruct_missing_sg(masked_triple_fea, img_fea) 77 | gt_fea = torch.cat([triple_fea, img_fea], dim=1) 78 | 79 | valid_mask = torch.ones([batch_size, len_triple_fea + len_img_fea, 1]) * 1.0 80 | valid_mask = valid_mask.to(device) 81 | valid_mask[:, :len_triple_fea, :] = (torch.mean(triple_fea.detach(), dim=2, keepdim=True)).float() 82 | loss_mask = torch.ones([batch_size, len_triple_fea + len_img_fea, 1]) * 1.0 83 | loss_mask = loss_mask.to(device) 84 | loss_mask[:, :len_triple_fea, :] = triple_mask 85 | loss_mask = (1 - loss_mask) * valid_mask 86 | loss_mask.detach() 87 | 88 | rec_loss = self.criterion(rec_fea * loss_mask, gt_fea * loss_mask) 89 | return rec_loss 90 | 91 | def reconstruct_missing_sg(self, triple_fea, img_fea): 92 | 93 | input_fea = torch.cat([triple_fea, img_fea], dim=1) # [B, N+HW, C] 94 | input_fea = input_fea.permute(1, 0, 2).contiguous() 95 | rec_fea = self.transformer(input_fea) 96 | rec_fea = rec_fea.permute(1, 0, 2).contiguous() 97 | return rec_fea 98 | 99 | class SceneGraph2MakedImgLoss(nn.Module): 100 | def __init__(self, triple_dim, image_dim, sg_only, max_relationships_per_image=30, image_size=32, threshold=0.3): 101 | super().__init__() 102 | 103 | self.transformer = Transformer(width=image_dim, layers=2, heads=4) 104 | self.triple_mlp = nn.Sequential(nn.Linear(triple_dim, image_dim)) 105 | self.criterion = nn.MSELoss(reduction='mean') 106 | 107 | self.threshold = threshold 108 | 109 | self.image_size = image_size 110 | self.image_dim = image_dim 111 | 112 | if sg_only: 113 | self.register_buffer('attn_mask', self.build_attention_mask(tri_length=max_relationships_per_image, img_length=image_size * image_size), persistent=False) 114 | else: 115 | self.attn_mask = None 116 | 117 | def forward(self, triple_per_img, image_feature, gt_boxes, obj_to_img): 118 | triple_per_img = self.triple_mlp(triple_per_img) 119 | rec_loss = self.calculate_reconstruction_loss(triple_per_img, image_feature, gt_boxes, obj_to_img) 120 | return rec_loss 121 | 122 | def build_box_mask(self, boxes_gt, obj_to_img, H, W=None, threshold=0.2): 123 | bbox_mask = boxes_to_mask(boxes_gt, obj_to_img, H, W, threshold) 124 | return bbox_mask 125 | 126 | def build_attention_mask(self, tri_length, img_length): 127 | total_length = tri_length + img_length 128 | mask = torch.empty(total_length, total_length) 129 | mask.fill_(1) 130 | mask[tri_length:, tri_length:].fill_(float("-inf")) 131 | return mask 132 | 133 | def calculate_reconstruction_loss(self, triple_fea, img_fea, boxes_gt, obj_to_img): 134 | batch_size = triple_fea.shape[0] 135 | 136 | img_fea = img_fea.permute(0, 2, 3, 1).contiguous() 137 | img_fea = img_fea.view(batch_size, self.image_size * self.image_size, self.image_dim) 138 | 139 | len_img_fea = img_fea.shape[1] 140 | assert len_img_fea == self.image_size * self.image_size 141 | len_triple_fea = triple_fea.shape[1] 142 | 143 | image_mask = self.build_box_mask(boxes_gt, obj_to_img, H=self.image_size, W=self.image_size, threshold=self.threshold) 144 | image_mask = image_mask.view(batch_size, -1, 1) 145 | masked_img_fea = image_mask * img_fea 146 | 147 | rec_fea = self.reconstruct_missing_patches(triple_fea, masked_img_fea) 148 | gt_fea = torch.cat([triple_fea, img_fea], dim=1) 149 | 150 | loss_mask = torch.ones([batch_size, len_triple_fea + len_img_fea, 1]) * 1.0 151 | loss_mask = loss_mask.to(device) 152 | loss_mask[:, len_triple_fea:, :] = image_mask 153 | loss_mask = 1 - loss_mask 154 | loss_mask.detach() 155 | 156 | rec_loss = self.criterion(rec_fea * loss_mask, gt_fea * loss_mask) 157 | return rec_loss 158 | 159 | def reconstruct_missing_patches(self, triple_fea, img_fea): 160 | input_fea = torch.cat([triple_fea, img_fea], dim=1) # [B, N+HW, C] 161 | input_fea = input_fea.permute(1, 0, 2).contiguous() 162 | rec_fea = self.transformer(input_fea, attn_mask=self.attn_mask) 163 | rec_fea = rec_fea.permute(1, 0, 2).contiguous() 164 | return rec_fea -------------------------------------------------------------------------------- /sg_image_pretraining/trainer.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from datetime import datetime 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | from torch.cuda.amp import GradScaler 9 | import torch.utils.tensorboard as tensorboard 10 | from sgCLIP.create_model import create_model_and_transforms 11 | from training.distributed import is_master, init_distributed_device, world_info_from_env 12 | from training.logger import setup_logging 13 | from training.configs import parse_args 14 | from training.scheduler import cosine_lr 15 | from training.train_mim import train_one_epoch, validate_one_epoch#, evaluate 16 | from datasets.dataloader_builder import build_vg_loaders 17 | from global_var import * 18 | 19 | def random_seed(seed=42, rank=0): 20 | torch.manual_seed(seed + rank) 21 | np.random.seed(seed + rank) 22 | random.seed(seed + rank) 23 | 24 | 25 | def trainer(): 26 | args = parse_args() 27 | 28 | if torch.cuda.is_available(): 29 | torch.backends.cuda.matmul.allow_tf32 = True 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.deterministic = False 32 | 33 | if args.name is None: 34 | args.name = '-'.join([ 35 | datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), 36 | f"lr_{args.lr}", 37 | f"b_{args.batch_size}", 38 | f"j_{args.workers}", 39 | f"p_{args.precision}", 40 | ]) 41 | 42 | args.distributed = False 43 | args.local_rank, args.rank, args.world_size = world_info_from_env() 44 | 45 | args.log_path = None 46 | if is_master(args, local=args.log_local): 47 | log_base_path = os.path.join(args.logs, args.name) 48 | os.makedirs(log_base_path, exist_ok=True) 49 | log_filename = f'out-{args.rank}' if args.log_local else 'out.log' 50 | args.log_path = os.path.join(log_base_path, log_filename) 51 | if os.path.exists(args.log_path): 52 | print( 53 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 54 | ) 55 | return -1 56 | 57 | args.log_level = logging.DEBUG if args.debug else logging.INFO 58 | setup_logging(args.log_path, args.log_level) 59 | 60 | device = init_distributed_device(args) 61 | 62 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 63 | if is_master(args): 64 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 65 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 66 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 67 | if dirname: 68 | os.makedirs(dirname, exist_ok=True) 69 | else: 70 | args.tensorboard_path = '' 71 | args.checkpoint_path = '' 72 | 73 | assert args.precision in ['amp', 'amp_bfloat16', 'fp16', 'fp32'] 74 | if args.precision == 'fp16': 75 | logging.warning( 76 | 'It is recommended to use AMP mixed-precision instead of FP16. ' 77 | 'FP16 support needs further verification and tuning, especially for train.') 78 | 79 | if args.distributed: 80 | logging.info( 81 | f'Running in distributed mode with multiple processes. Device: {args.device}.' 82 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 83 | else: 84 | logging.info(f'Running with a single process. Device {args.device}.') 85 | 86 | random_seed(args.seed, 0) 87 | 88 | graph_vocab, train_dataloader, val_dataloader = build_vg_loaders(args) 89 | 90 | model = create_model_and_transforms( 91 | args, 92 | graph_vocab, 93 | args.model_config_json, 94 | args.precision, 95 | device=device, 96 | force_quick_gelu=args.force_quick_gelu, 97 | pretrained_image=args.pretrained_image, 98 | image_mean=args.image_mean, 99 | image_std=args.image_std, 100 | ) 101 | 102 | random_seed(args.seed, args.rank) 103 | if is_master(args): 104 | logging.info("Model:") 105 | logging.info(f"{str(model)}") 106 | logging.info("Params:") 107 | params_file = os.path.join(args.logs, args.name, "params.txt") 108 | with open(params_file, "w") as f: 109 | for name in sorted(vars(args)): 110 | val = getattr(args, name) 111 | logging.info(f" {name}: {val}") 112 | f.write(f"{name}: {val}\n") 113 | 114 | if args.distributed: 115 | if args.use_bn_sync: 116 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 117 | ddp_args = {} 118 | if args.ddp_static_graph: 119 | ddp_args['static_graph'] = True 120 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) 121 | 122 | optimizer = None 123 | scaler = None 124 | if args.train_data: 125 | 126 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 127 | include = lambda n, p: not exclude(n, p) 128 | 129 | named_parameters = list(model.named_parameters()) 130 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 131 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 132 | 133 | optimizer = optim.AdamW( 134 | [ 135 | {"params": gain_or_bias_params, "weight_decay": 0.}, 136 | {"params": rest_params, "weight_decay": args.wd}, 137 | ], 138 | lr=args.lr, 139 | betas=(args.beta1, args.beta2), 140 | eps=args.eps, 141 | ) 142 | 143 | scaler = GradScaler() if args.precision == "amp" else None 144 | 145 | start_epoch = 0 146 | total_steps = len(train_dataloader) * args.epochs 147 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) 148 | 149 | args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) 150 | writer = None 151 | if args.save_logs and args.tensorboard: 152 | assert tensorboard is not None, "Please install tensorboard." 153 | writer = tensorboard.SummaryWriter(args.tensorboard_path) 154 | 155 | logging.debug('Finished loading wandb.') 156 | 157 | for epoch in range(start_epoch, args.epochs): 158 | if is_master(args): 159 | logging.info(f'Start epoch {epoch}') 160 | 161 | train_one_epoch(model, train_dataloader, epoch, optimizer, scaler, scheduler, args, writer) 162 | with torch.no_grad(): 163 | validate_one_epoch(model, val_dataloader, epoch, args, writer) 164 | completed_epoch = epoch + 1 165 | 166 | if args.save_logs: 167 | checkpoint_dict = { 168 | "epoch": completed_epoch, 169 | "name": args.name, 170 | "state_dict": model.state_dict(), 171 | "optimizer": optimizer.state_dict(), 172 | } 173 | if scaler is not None: 174 | checkpoint_dict["scaler"] = scaler.state_dict() 175 | 176 | if completed_epoch == args.epochs or ( 177 | args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 178 | ): 179 | torch.save( 180 | checkpoint_dict, 181 | os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), 182 | ) 183 | if args.save_most_recent: 184 | torch.save( 185 | checkpoint_dict, 186 | os.path.join(args.checkpoint_path, f"epoch_latest.pt"), 187 | ) 188 | 189 | if __name__ == "__main__": 190 | trainer() 191 | -------------------------------------------------------------------------------- /sg_image_pretraining/training/configs_coco.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from dataclasses import dataclass 3 | from typing import Tuple, Union, List, Optional 4 | 5 | def parse_args(): 6 | parser = argparse.ArgumentParser() 7 | parser.add_argument("--coco_train_image_dir", type=str, default='/home/huangzl/Data/datasets/coco/images/train2017', help="Path to training dataset") 8 | parser.add_argument("--coco_val_image_dir", type=str, default='/home/huangzl/Data/datasets/coco/images/val2017', help="Path to training dataset") 9 | 10 | parser.add_argument("--coco_train_instances_json", type=str, default='/home/huangzl/Data/datasets/coco/annotations/instances_train2017.json', help="") 11 | parser.add_argument("--coco_train_stuff_json", type=str, default='/home/huangzl/Data/datasets/coco/annotations/stuff_train2017.json', help="") 12 | parser.add_argument("--coco_val_instances_json", type=str, default='/home/huangzl/Data/datasets/coco/annotations/instances_val2017.json', help="") 13 | parser.add_argument("--coco_val_stuff_json", type=str, default='/home/huangzl/Data/datasets/coco/annotations/stuff_val2017.json', help="") 14 | parser.add_argument("--coco_stuff_only", type=str, default='/home/huangzl/Data/datasets/vg/val.h5', help="") 15 | parser.add_argument("--image_size", type=int, default=224, help="Image size for training.") 16 | parser.add_argument("--mask_size", type=int, default=16, help="") 17 | parser.add_argument("--num_train_samples", type=int, default=None, help="") 18 | parser.add_argument("--num_val_samples", type=int, default=1024, help="") 19 | parser.add_argument("--include_relationships", type=bool, default=True, help="Use orphaned objects or not in the image.") 20 | parser.add_argument("--min_object_size", type=float, default=0.02, help="") 21 | parser.add_argument("--min_objects_per_image", type=int, default=3, help="") 22 | parser.add_argument("--max_objects_per_image", type=int, default=8, help="") 23 | parser.add_argument("--coco_include_other", type=bool, default=False, help="") 24 | parser.add_argument("--instance_whitelist", type=list, default=None, help="") 25 | parser.add_argument("--stuff_whitelist", type=list, default=None, help="") 26 | 27 | parser.add_argument("--batch_size", type=int, default=64, help="Batch size per GPU.") 28 | parser.add_argument("--val_batch_size", type=int, default=128, help="Batch size per GPU for Validation.") 29 | 30 | parser.add_argument("--model_config_json", type=str, default='', help="Path to json file of model configs.") 31 | parser.add_argument("--graph_width", type=int, default=512, help="Width of Graph Tower.") 32 | parser.add_argument("--num_graph_layer", type=int, default=5, help="Number of layers in Graph Tower.") 33 | parser.add_argument("--embed_dim", type=int, default=512, help="Dimension of embeddings.") 34 | 35 | parser.add_argument("--name", type=str, default=None, help="Optional identifier for the experiment when storing logs. Otherwise use current time.") 36 | parser.add_argument("--workers", type=int, default=1, help="Number of dataloader workers per GPU.") 37 | parser.add_argument("--epochs", type=int, default=100, help="Number of epochs to train for.") 38 | parser.add_argument("--lr", type=float, default=5.0e-4, help="Learning rate.") 39 | parser.add_argument("--beta1", type=float, default=0.9, help="Adam beta 1.") 40 | parser.add_argument("--beta2", type=float, default=0.999, help="Adam beta 2.") 41 | parser.add_argument("--eps", type=float, default=1.0e-8, help="Adam epsilon.") 42 | parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") 43 | parser.add_argument("--warmup", type=int, default=10000, help="Number of steps to warmup for.") 44 | parser.add_argument("--use_bn_sync", default=False, action="store_true", help="Whether to use batch norm sync.") 45 | parser.add_argument("--skip_scheduler", action="store_true", default=False, help="Use this flag to skip the learning rate decay.") 46 | parser.add_argument("--save_frequency", type=int, default=1, help="How often to save checkpoints. epoch level.") 47 | parser.add_argument("--save_most_recent", action="store_true", default=False, help="Always save the most recent model trained to epoch_latest.pt.") 48 | parser.add_argument("--logs", type=str, default="./logs/", help="Where to store tensorboard logs. Use None to avoid storing logs.") 49 | parser.add_argument("--log_local", action="store_true", default=False, help="log files on local master, otherwise global master only.") 50 | 51 | parser.add_argument("--val_frequency", type=int, default=1, help="How often to run evaluation with val data.") 52 | parser.add_argument("--precision", choices=["amp", "amp_bfloat16", "fp16", "fp32"], default="amp", help="Floating point precision.") 53 | parser.add_argument("--pretrained", default='', type=str, help="Use a pretrained CLIP model weights with the specified tag or file path.") 54 | parser.add_argument("--pretrained-image", default=False, action='store_true', help="Load imagenet pretrained weights for image tower backbone if available.") 55 | 56 | parser.add_argument("--lock_image", default=False, action='store_true', help="Lock full image tower by disabling gradients.") 57 | parser.add_argument("--lock_image_unlocked_groups", type=int, default=0, help="Leave last n image tower layer groups unlocked.") 58 | parser.add_argument("--lock_image_freeze_bn_stats", default=False, action='store_true', help="Freeze BatchNorm running stats in image tower for any locked layers.") 59 | parser.add_argument('--image_mean', type=float, nargs='+', default=None, metavar='MEAN', help='Override default image mean value of dataset') 60 | parser.add_argument('--image_std', type=float, nargs='+', default=None, metavar='STD', help='Override default image std deviation of of dataset') 61 | parser.add_argument("--grad_checkpointing", default=False, action='store_true', help="Enable gradient checkpointing.") 62 | parser.add_argument("--local_loss", default=False, action="store_true", help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)") 63 | parser.add_argument("--gather_with_grad", default=False, action="store_true", help="enable full distributed gradient for feature gather") 64 | parser.add_argument("--force_quick_gelu", default=False, action='store_true', help="Force use of QuickGELU activation for non-OpenAI transformer models.") 65 | 66 | parser.add_argument("--dist_url", default="env://", type=str, help="url used to set up distributed training") 67 | parser.add_argument("--dist_backend", default="nccl", type=str, help="distributed backend") 68 | parser.add_argument("--report_to", default='', type=str, help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']") 69 | parser.add_argument("--debug", default=False, action="store_true", help="If true, more information is logged.") 70 | parser.add_argument("--ddp_static_graph", default=False, action='store_true', help="Enable static graph optimization for DDP in PyTorch >= 1.11.") 71 | parser.add_argument("--no_set_device_rank", default=False, action="store_true", help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).") 72 | parser.add_argument("--seed", type=int, default=9768, help="Default random seed.") 73 | parser.add_argument("--norm_gradient_clip", type=float, default=None, help="Gradient clip.") 74 | args = parser.parse_args() 75 | 76 | return args 77 | 78 | @dataclass 79 | class CLIPVisionCfg: 80 | layers: Union[Tuple[int, int, int, int], int] 81 | width: int 82 | head_width: int 83 | image_size: int 84 | mlp_ratio: float 85 | patch_size: int = None 86 | timm_model_name: str = None 87 | timm_model_pretrained: bool = None 88 | timm_pool: str = None 89 | timm_proj: str = None 90 | 91 | 92 | @dataclass 93 | class CLIPGraphCfg: 94 | layers: int 95 | width: int -------------------------------------------------------------------------------- /sg_image_pretraining/trainer_coco.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | import random 4 | from datetime import datetime 5 | import numpy as np 6 | import torch 7 | from torch import optim 8 | from torch.cuda.amp import GradScaler 9 | import torch.utils.tensorboard as tensorboard 10 | from sgCLIP.create_model_coco import create_model_and_transforms 11 | from training.distributed import is_master, init_distributed_device, world_info_from_env 12 | from training.logger import setup_logging 13 | from training.configs_coco import parse_args 14 | from training.scheduler import cosine_lr 15 | from training.train_mim import train_one_epoch, validate_one_epoch#, evaluate 16 | from datasets.dataloader_builder_coco import build_coco_loaders 17 | from global_var import * 18 | 19 | def random_seed(seed=42, rank=0): 20 | torch.manual_seed(seed + rank) 21 | np.random.seed(seed + rank) 22 | random.seed(seed + rank) 23 | 24 | 25 | def trainer(): 26 | args = parse_args() 27 | 28 | if torch.cuda.is_available(): 29 | torch.backends.cuda.matmul.allow_tf32 = True 30 | torch.backends.cudnn.benchmark = True 31 | torch.backends.cudnn.deterministic = False 32 | 33 | if args.name is None: 34 | args.name = '-'.join([ 35 | datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), 36 | f"lr_{args.lr}", 37 | f"b_{args.batch_size}", 38 | f"j_{args.workers}", 39 | f"p_{args.precision}", 40 | ]) 41 | 42 | args.distributed = False 43 | args.local_rank, args.rank, args.world_size = world_info_from_env() 44 | 45 | args.log_path = None 46 | if is_master(args, local=args.log_local): 47 | log_base_path = os.path.join(args.logs, args.name) 48 | os.makedirs(log_base_path, exist_ok=True) 49 | log_filename = f'out-{args.rank}' if args.log_local else 'out.log' 50 | args.log_path = os.path.join(log_base_path, log_filename) 51 | if os.path.exists(args.log_path): 52 | print( 53 | "Error. Experiment already exists. Use --name {} to specify a new experiment." 54 | ) 55 | return -1 56 | 57 | args.log_level = logging.DEBUG if args.debug else logging.INFO 58 | setup_logging(args.log_path, args.log_level) 59 | 60 | device = init_distributed_device(args) 61 | 62 | args.tensorboard = 'tensorboard' in args.report_to or 'all' in args.report_to 63 | if is_master(args): 64 | args.tensorboard_path = os.path.join(args.logs, args.name, "tensorboard") if args.tensorboard else '' 65 | args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") 66 | for dirname in [args.tensorboard_path, args.checkpoint_path]: 67 | if dirname: 68 | os.makedirs(dirname, exist_ok=True) 69 | else: 70 | args.tensorboard_path = '' 71 | args.checkpoint_path = '' 72 | 73 | assert args.precision in ['amp', 'amp_bfloat16', 'fp16', 'fp32'] 74 | if args.precision == 'fp16': 75 | logging.warning( 76 | 'It is recommended to use AMP mixed-precision instead of FP16. ' 77 | 'FP16 support needs further verification and tuning, especially for train.') 78 | 79 | if args.distributed: 80 | logging.info( 81 | f'Running in distributed mode with multiple processes. Device: {args.device}.' 82 | f'Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}.') 83 | else: 84 | logging.info(f'Running with a single process. Device {args.device}.') 85 | 86 | random_seed(args.seed, 0) 87 | 88 | graph_vocab, train_dataloader, val_dataloader = build_coco_loaders(args) 89 | 90 | model = create_model_and_transforms( 91 | args, 92 | graph_vocab, 93 | args.model_config_json, 94 | args.precision, 95 | device=device, 96 | force_quick_gelu=args.force_quick_gelu, 97 | pretrained_image=args.pretrained_image, 98 | image_mean=args.image_mean, 99 | image_std=args.image_std, 100 | ) 101 | 102 | random_seed(args.seed, args.rank) 103 | if is_master(args): 104 | logging.info("Model:") 105 | logging.info(f"{str(model)}") 106 | logging.info("Params:") 107 | params_file = os.path.join(args.logs, args.name, "params.txt") 108 | with open(params_file, "w") as f: 109 | for name in sorted(vars(args)): 110 | val = getattr(args, name) 111 | logging.info(f" {name}: {val}") 112 | f.write(f"{name}: {val}\n") 113 | 114 | if args.distributed: 115 | if args.use_bn_sync: 116 | model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) 117 | ddp_args = {} 118 | if args.ddp_static_graph: 119 | ddp_args['static_graph'] = True 120 | model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], **ddp_args) 121 | 122 | optimizer = None 123 | scaler = None 124 | if args.coco_train_image_dir: 125 | 126 | exclude = lambda n, p: p.ndim < 2 or "bn" in n or "ln" in n or "bias" in n or 'logit_scale' in n 127 | include = lambda n, p: not exclude(n, p) 128 | 129 | named_parameters = list(model.named_parameters()) 130 | gain_or_bias_params = [p for n, p in named_parameters if exclude(n, p) and p.requires_grad] 131 | rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] 132 | 133 | optimizer = optim.AdamW( 134 | [ 135 | {"params": gain_or_bias_params, "weight_decay": 0.}, 136 | {"params": rest_params, "weight_decay": args.wd}, 137 | ], 138 | lr=args.lr, 139 | betas=(args.beta1, args.beta2), 140 | eps=args.eps, 141 | ) 142 | 143 | scaler = GradScaler() if args.precision == "amp" else None 144 | 145 | start_epoch = 0 146 | total_steps = len(train_dataloader) * args.epochs 147 | scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) 148 | 149 | args.save_logs = args.logs and args.logs.lower() != 'none' and is_master(args) 150 | writer = None 151 | if args.save_logs and args.tensorboard: 152 | assert tensorboard is not None, "Please install tensorboard." 153 | writer = tensorboard.SummaryWriter(args.tensorboard_path) 154 | 155 | logging.debug('Finished loading wandb.') 156 | 157 | for epoch in range(start_epoch, args.epochs): 158 | if is_master(args): 159 | logging.info(f'Start epoch {epoch}') 160 | 161 | train_one_epoch(model, train_dataloader, epoch, optimizer, scaler, scheduler, args, writer) 162 | with torch.no_grad(): 163 | validate_one_epoch(model, val_dataloader, epoch, args, writer) 164 | completed_epoch = epoch + 1 165 | 166 | if args.save_logs: 167 | checkpoint_dict = { 168 | "epoch": completed_epoch, 169 | "name": args.name, 170 | "state_dict": model.state_dict(), 171 | "optimizer": optimizer.state_dict(), 172 | } 173 | if scaler is not None: 174 | checkpoint_dict["scaler"] = scaler.state_dict() 175 | 176 | if completed_epoch == args.epochs or ( 177 | args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 178 | ): 179 | torch.save( 180 | checkpoint_dict, 181 | os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), 182 | ) 183 | if args.save_most_recent: 184 | torch.save( 185 | checkpoint_dict, 186 | os.path.join(args.checkpoint_path, f"epoch_latest.pt"), 187 | ) 188 | 189 | if __name__ == "__main__": 190 | trainer() 191 | -------------------------------------------------------------------------------- /sg2im/data/vg.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # 3 | # Copyright 2018 Google LLC 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | # limitations under the License. 16 | 17 | import os 18 | import random 19 | from collections import defaultdict 20 | 21 | import torch 22 | from torch.utils.data import Dataset 23 | import torchvision.transforms as T 24 | 25 | import numpy as np 26 | import h5py 27 | import PIL 28 | 29 | from .utils import imagenet_preprocess, Resize 30 | 31 | 32 | class VgSceneGraphDataset(Dataset): 33 | def __init__(self, vocab, h5_path, image_dir, image_size=(256, 256), 34 | normalize_images=True, max_objects=10, max_samples=None, 35 | include_relationships=True, use_orphaned_objects=True): 36 | super(VgSceneGraphDataset, self).__init__() 37 | 38 | self.image_dir = image_dir 39 | self.image_size = image_size 40 | self.vocab = vocab 41 | self.num_objects = len(vocab['object_idx_to_name']) 42 | self.use_orphaned_objects = use_orphaned_objects 43 | self.max_objects = max_objects 44 | self.max_samples = max_samples 45 | self.include_relationships = include_relationships 46 | 47 | transform = [Resize(image_size), T.ToTensor()] 48 | if normalize_images: 49 | transform.append(imagenet_preprocess()) 50 | self.transform = T.Compose(transform) 51 | 52 | self.data = {} 53 | with h5py.File(h5_path, 'r') as f: 54 | for k, v in f.items(): 55 | if k == 'image_paths': 56 | self.image_paths = list(v) 57 | else: 58 | self.data[k] = torch.IntTensor(np.asarray(v)) 59 | 60 | def __len__(self): 61 | num = self.data['object_names'].size(0) 62 | if self.max_samples is not None: 63 | return min(self.max_samples, num) 64 | return num 65 | 66 | def __getitem__(self, index): 67 | """ 68 | Returns a tuple of: 69 | - image: FloatTensor of shape (C, H, W) 70 | - objs: LongTensor of shape (O,) 71 | - boxes: FloatTensor of shape (O, 4) giving boxes for objects in 72 | (x0, y0, x1, y1) format, in a [0, 1] coordinate system. 73 | - triples: LongTensor of shape (T, 3) where triples[t] = [i, p, j] 74 | means that (objs[i], p, objs[j]) is a triple. 75 | """ 76 | img_path = os.path.join(self.image_dir, self.image_paths[index]) 77 | 78 | with open(img_path, 'rb') as f: 79 | with PIL.Image.open(f) as image: 80 | WW, HH = image.size 81 | image = self.transform(image.convert('RGB')) 82 | 83 | H, W = self.image_size 84 | 85 | # Figure out which objects appear in relationships and which don't 86 | obj_idxs_with_rels = set() 87 | obj_idxs_without_rels = set(range(self.data['objects_per_image'][index].item())) 88 | for r_idx in range(self.data['relationships_per_image'][index]): 89 | s = self.data['relationship_subjects'][index, r_idx].item() 90 | o = self.data['relationship_objects'][index, r_idx].item() 91 | obj_idxs_with_rels.add(s) 92 | obj_idxs_with_rels.add(o) 93 | obj_idxs_without_rels.discard(s) 94 | obj_idxs_without_rels.discard(o) 95 | 96 | obj_idxs = list(obj_idxs_with_rels) 97 | obj_idxs_without_rels = list(obj_idxs_without_rels) 98 | if len(obj_idxs) > self.max_objects - 1: 99 | obj_idxs = random.sample(obj_idxs, self.max_objects) 100 | if len(obj_idxs) < self.max_objects - 1 and self.use_orphaned_objects: 101 | num_to_add = self.max_objects - 1 - len(obj_idxs) 102 | num_to_add = min(num_to_add, len(obj_idxs_without_rels)) 103 | obj_idxs += random.sample(obj_idxs_without_rels, num_to_add) 104 | O = len(obj_idxs) + 1 105 | 106 | objs = torch.LongTensor(O).fill_(-1) 107 | 108 | boxes = torch.FloatTensor([[0, 0, 1, 1]]).repeat(O, 1) 109 | obj_idx_mapping = {} 110 | for i, obj_idx in enumerate(obj_idxs): 111 | objs[i] = self.data['object_names'][index, obj_idx].item() 112 | x, y, w, h = self.data['object_boxes'][index, obj_idx].tolist() 113 | x0 = float(x) / WW 114 | y0 = float(y) / HH 115 | x1 = float(x + w) / WW 116 | y1 = float(y + h) / HH 117 | boxes[i] = torch.FloatTensor([x0, y0, x1, y1]) 118 | obj_idx_mapping[obj_idx] = i 119 | 120 | # The last object will be the special __image__ object 121 | objs[O - 1] = self.vocab['object_name_to_idx']['__image__'] 122 | 123 | triples = [] 124 | for r_idx in range(self.data['relationships_per_image'][index].item()): 125 | if not self.include_relationships: 126 | break 127 | s = self.data['relationship_subjects'][index, r_idx].item() 128 | p = self.data['relationship_predicates'][index, r_idx].item() 129 | o = self.data['relationship_objects'][index, r_idx].item() 130 | s = obj_idx_mapping.get(s, None) 131 | o = obj_idx_mapping.get(o, None) 132 | if s is not None and o is not None: 133 | triples.append([s, p, o]) 134 | 135 | # Add dummy __in_image__ relationships for all objects 136 | in_image = self.vocab['pred_name_to_idx']['__in_image__'] 137 | for i in range(O - 1): 138 | triples.append([i, in_image, O - 1]) 139 | 140 | triples = torch.LongTensor(triples) 141 | return image, objs, boxes, triples 142 | 143 | 144 | def vg_collate_fn(batch): 145 | """ 146 | Collate function to be used when wrapping a VgSceneGraphDataset in a 147 | DataLoader. Returns a tuple of the following: 148 | 149 | - imgs: FloatTensor of shape (N, C, H, W) 150 | - objs: LongTensor of shape (O,) giving categories for all objects 151 | - boxes: FloatTensor of shape (O, 4) giving boxes for all objects 152 | - triples: FloatTensor of shape (T, 3) giving all triples, where 153 | triples[t] = [i, p, j] means that [objs[i], p, objs[j]] is a triple 154 | - obj_to_img: LongTensor of shape (O,) mapping objects to images; 155 | obj_to_img[i] = n means that objs[i] belongs to imgs[n] 156 | - triple_to_img: LongTensor of shape (T,) mapping triples to images; 157 | triple_to_img[t] = n means that triples[t] belongs to imgs[n]. 158 | """ 159 | # batch is a list, and each element is (image, objs, boxes, triples) 160 | all_imgs, all_objs, all_boxes, all_triples = [], [], [], [] 161 | all_obj_to_img, all_triple_to_img = [], [] 162 | obj_offset = 0 163 | for i, (img, objs, boxes, triples) in enumerate(batch): 164 | all_imgs.append(img[None]) 165 | O, T = objs.size(0), triples.size(0) 166 | all_objs.append(objs) 167 | all_boxes.append(boxes) 168 | triples = triples.clone() 169 | triples[:, 0] += obj_offset 170 | triples[:, 2] += obj_offset 171 | all_triples.append(triples) 172 | 173 | all_obj_to_img.append(torch.LongTensor(O).fill_(i)) 174 | all_triple_to_img.append(torch.LongTensor(T).fill_(i)) 175 | obj_offset += O 176 | 177 | all_imgs = torch.cat(all_imgs) 178 | all_objs = torch.cat(all_objs) 179 | all_boxes = torch.cat(all_boxes) 180 | all_triples = torch.cat(all_triples) 181 | all_obj_to_img = torch.cat(all_obj_to_img) 182 | all_triple_to_img = torch.cat(all_triple_to_img) 183 | 184 | out = (all_imgs, all_objs, all_boxes, all_triples, 185 | all_obj_to_img, all_triple_to_img) 186 | return out 187 | 188 | 189 | def vg_uncollate_fn(batch): 190 | """ 191 | Inverse operation to the above. 192 | """ 193 | imgs, objs, boxes, triples, obj_to_img, triple_to_img = batch 194 | out = [] 195 | obj_offset = 0 196 | for i in range(imgs.size(0)): 197 | cur_img = imgs[i] 198 | o_idxs = (obj_to_img == i).nonzero().view(-1) 199 | t_idxs = (triple_to_img == i).nonzero().view(-1) 200 | cur_objs = objs[o_idxs] 201 | cur_boxes = boxes[o_idxs] 202 | cur_triples = triples[t_idxs].clone() 203 | cur_triples[:, 0] -= obj_offset 204 | cur_triples[:, 2] -= obj_offset 205 | obj_offset += cur_objs.size(0) 206 | out.append((cur_img, cur_objs, cur_boxes, cur_triples)) 207 | return out 208 | 209 | -------------------------------------------------------------------------------- /sg_image_pretraining/utils.py: -------------------------------------------------------------------------------- 1 | from itertools import repeat 2 | import collections.abc 3 | import torchvision.transforms as T 4 | from torch import nn as nn 5 | from torchvision.ops.misc import FrozenBatchNorm2d 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | import copy 10 | from global_var import * 11 | 12 | IMAGENET_MEAN = [0.485, 0.456, 0.406] 13 | IMAGENET_STD = [0.229, 0.224, 0.225] 14 | 15 | INV_IMAGENET_MEAN = [-m for m in IMAGENET_MEAN] 16 | INV_IMAGENET_STD = [1.0 / s for s in IMAGENET_STD] 17 | 18 | 19 | def imagenet_preprocess(): 20 | return T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD) 21 | def generate_box_mask(boxes_gt, obj_to_img, H, W=None, threshold=0.2): 22 | if W is None: 23 | W = H 24 | bbox_mask = boxes_to_mask(boxes_gt, obj_to_img, H, W, threshold) 25 | return bbox_mask 26 | 27 | class HookTool: 28 | def __init__(self): 29 | self.extract_fea_in = None 30 | self.extract_fea = None 31 | 32 | def hook_fun(self, module, fea_in, fea_out): 33 | self.extract_fea_in = fea_in 34 | self.extract_fea = fea_out 35 | 36 | 37 | def get_linear_feas_by_hook(model): 38 | fea_hooks = [] 39 | for n, m in model.named_modules(): 40 | if isinstance(m, torch.nn.Linear): 41 | cur_hook = HookTool() 42 | m.register_forward_hook(cur_hook.hook_fun) 43 | fea_hooks.append(cur_hook) 44 | 45 | return fea_hooks 46 | 47 | def boxes_to_mask(boxes, obj_to_img, H, W=None, threshold=0.2): 48 | O = obj_to_img.size()[0] 49 | if W is None: 50 | W = H 51 | 52 | grid = _boxes_to_grid(boxes, H, W) 53 | x_rand = torch.rand([O, 1]) 54 | mask_indicator = (x_rand > threshold).float() 55 | mask_in = mask_indicator.view(O, 1, 1, 1).expand(O, 1, 8, 8) 56 | mask_in = mask_in.to(device) 57 | sampled = F.grid_sample(mask_in, grid) 58 | sampled = (sampled > 0).float().to(device) 59 | 60 | out = assign_mask_to_img(sampled, obj_to_img) 61 | out = out.to(device) 62 | mask = 1.0 - out 63 | 64 | return mask.to(device) 65 | 66 | 67 | def assign_mask_to_img(samples, obj_to_img): 68 | dtype, device = samples.dtype, samples.device 69 | O, D, H, W = samples.size() 70 | N = obj_to_img.data.max().item() + 1 71 | 72 | out = torch.zeros(N, D, H, W, dtype=dtype, device=device) 73 | idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W) 74 | out = out.scatter_add(0, idx, samples) 75 | 76 | return out 77 | 78 | def _boxes_to_grid(boxes, H, W): 79 | O = boxes.size(0) 80 | 81 | boxes = boxes.view(O, 4, 1, 1) 82 | 83 | x0, y0 = boxes[:, 0], boxes[:, 1] 84 | x1, y1 = boxes[:, 2], boxes[:, 3] 85 | ww = x1 - x0 86 | hh = y1 - y0 87 | 88 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes) 89 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes) 90 | 91 | X = (X - x0) / ww 92 | Y = (Y - y0) / hh 93 | 94 | X = X.expand(O, H, W) 95 | Y = Y.expand(O, H, W) 96 | grid = torch.stack([X, Y], dim=3) 97 | 98 | grid = grid.mul(2).sub(1) 99 | 100 | return grid 101 | 102 | def create_tensor_by_assign_samples_to_img(samples, sample_to_img, max_sample_per_img, batch_size): 103 | dtype, device = samples.dtype, samples.device 104 | N = batch_size 105 | D = samples.shape[1] 106 | assert (sample_to_img.max() + 1) == N 107 | 108 | samples_per_img = [] 109 | for i in range(N): 110 | s_idxs = (sample_to_img == i).nonzero().view(-1) 111 | sub_sample = samples[s_idxs] 112 | len_cur = sub_sample.shape[0] 113 | if len_cur > max_sample_per_img: 114 | sub_sample = sub_sample[:max_sample_per_img, :] 115 | if len_cur < max_sample_per_img: 116 | zero_vector = torch.zeros([1, D]).to(device) 117 | padding_vectors = torch.cat([copy.deepcopy(zero_vector) for _ in range(max_sample_per_img - len_cur)], dim=0) # [res, D] 118 | sub_sample = torch.cat([sub_sample, padding_vectors], dim=0) 119 | sub_sample = sub_sample.unsqueeze(0) 120 | samples_per_img.append(sub_sample) 121 | samples_per_img = torch.cat(samples_per_img, dim=0).to(device) 122 | 123 | return samples_per_img 124 | 125 | def idx_to_one_hot(idx, num_classes): 126 | result = F.one_hot(idx, num_classes) 127 | result = result.float().to(device) 128 | return result 129 | 130 | 131 | def freeze_batch_norm_2d(module, module_match={}, name=''): 132 | res = module 133 | is_match = True 134 | if module_match: 135 | is_match = name in module_match 136 | if is_match and isinstance(module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm)): 137 | res = FrozenBatchNorm2d(module.num_features) 138 | res.num_features = module.num_features 139 | res.affine = module.affine 140 | if module.affine: 141 | res.weight.data = module.weight.data.clone().detach() 142 | res.bias.data = module.bias.data.clone().detach() 143 | res.running_mean.data = module.running_mean.data 144 | res.running_var.data = module.running_var.data 145 | res.eps = module.eps 146 | else: 147 | for child_name, child in module.named_children(): 148 | full_child_name = '.'.join([name, child_name]) if name else child_name 149 | new_child = freeze_batch_norm_2d(child, module_match, full_child_name) 150 | if new_child is not child: 151 | res.add_module(child_name, new_child) 152 | return res 153 | 154 | 155 | def _ntuple(n): 156 | def parse(x): 157 | if isinstance(x, collections.abc.Iterable): 158 | return x 159 | return tuple(repeat(x, n)) 160 | return parse 161 | 162 | 163 | to_1tuple = _ntuple(1) 164 | to_2tuple = _ntuple(2) 165 | to_3tuple = _ntuple(3) 166 | to_4tuple = _ntuple(4) 167 | to_ntuple = lambda n, x: _ntuple(n)(x) 168 | 169 | 170 | def boxes_to_layout(vecs, boxes, obj_to_img, H, W=None, pooling='sum'): 171 | 172 | O, D = vecs.size() 173 | if W is None: 174 | W = H 175 | 176 | grid = _boxes_to_grid(boxes, H, W) 177 | 178 | img_in = vecs.view(O, D, 1, 1).expand(O, D, 8, 8) 179 | sampled = F.grid_sample(img_in, grid) # (O, D, H, W) 180 | 181 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 182 | 183 | return out 184 | 185 | 186 | def masks_to_layout(vecs, boxes, masks, obj_to_img, H, W=None, pooling='sum'): 187 | O, D = vecs.size() 188 | M = masks.size(1) 189 | assert masks.size() == (O, M, M) 190 | if W is None: 191 | W = H 192 | 193 | grid = _boxes_to_grid(boxes, H, W) 194 | 195 | img_in = vecs.view(O, D, 1, 1) * masks.float().view(O, 1, M, M) 196 | sampled = F.grid_sample(img_in, grid) 197 | 198 | out = _pool_samples(sampled, obj_to_img, pooling=pooling) 199 | return out 200 | 201 | 202 | def _boxes_to_grid(boxes, H, W): 203 | O = boxes.size(0) 204 | 205 | boxes = boxes.view(O, 4, 1, 1) 206 | 207 | x0, y0 = boxes[:, 0], boxes[:, 1] 208 | x1, y1 = boxes[:, 2], boxes[:, 3] 209 | ww = x1 - x0 210 | hh = y1 - y0 211 | 212 | X = torch.linspace(0, 1, steps=W).view(1, 1, W).to(boxes) 213 | Y = torch.linspace(0, 1, steps=H).view(1, H, 1).to(boxes) 214 | 215 | X = (X - x0) / ww 216 | Y = (Y - y0) / hh 217 | 218 | X = X.expand(O, H, W) 219 | Y = Y.expand(O, H, W) 220 | grid = torch.stack([X, Y], dim=3) 221 | 222 | grid = grid.mul(2).sub(1) 223 | 224 | return grid 225 | 226 | 227 | def _pool_samples(samples, obj_to_img, pooling='sum'): 228 | dtype, device = samples.dtype, samples.device 229 | O, D, H, W = samples.size() 230 | N = obj_to_img.data.max().item() + 1 231 | 232 | out = torch.zeros(N, D, H, W, dtype=dtype, device=device) 233 | idx = obj_to_img.view(O, 1, 1, 1).expand(O, D, H, W) 234 | out = out.scatter_add(0, idx, samples) 235 | 236 | if pooling == 'avg': 237 | ones = torch.ones(O, dtype=dtype, device=device) 238 | obj_counts = torch.zeros(N, dtype=dtype, device=device) 239 | obj_counts = obj_counts.scatter_add(0, obj_to_img, ones) 240 | print(obj_counts) 241 | obj_counts = obj_counts.clamp(min=1) 242 | out = out / obj_counts.view(N, 1, 1, 1) 243 | elif pooling != 'sum': 244 | raise ValueError('Invalid pooling "%s"' % pooling) 245 | 246 | return out -------------------------------------------------------------------------------- /ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /sg_image_pretraining/training/train_mim.py: -------------------------------------------------------------------------------- 1 | import json 2 | import logging 3 | import math 4 | import os 5 | import time 6 | import numpy as np 7 | import torch 8 | import torch.nn.functional as F 9 | from sgCLIP.contrastive_losses import ClipLoss 10 | from sgCLIP.generative_loss import ReconstractMaskedImageFromSceneGraphLoss 11 | from training.distributed import is_master 12 | from training.precision import get_autocast 13 | from utils import generate_box_mask 14 | from global_var import * 15 | 16 | 17 | class AverageMeter(object): 18 | def __init__(self): 19 | self.reset() 20 | 21 | def reset(self): 22 | self.val = 0 23 | self.avg = 0 24 | self.sum = 0 25 | self.count = 0 26 | 27 | def update(self, val, n=1): 28 | self.val = val 29 | self.sum += val * n 30 | self.count += n 31 | self.avg = self.sum / self.count 32 | 33 | 34 | def train_one_epoch(model, dataloader, epoch, optimizer, scaler, scheduler, args, tb_writer=None): 35 | device = torch.device(args.device) 36 | autocast = get_autocast(args.precision) 37 | 38 | model.train() 39 | clip_loss = ClipLoss( 40 | local_loss=args.local_loss, 41 | gather_with_grad=args.gather_with_grad, 42 | cache_labels=True, 43 | rank=args.rank, 44 | world_size=args.world_size) 45 | 46 | mim_loss = ReconstractMaskedImageFromSceneGraphLoss( 47 | triple_dim=1536, 48 | # image_dim=3072, 49 | image_dim=768, 50 | num_img_patches=50, 51 | num_triple=15, 52 | sg_only=False 53 | ) 54 | mim_loss = mim_loss.to(device) 55 | 56 | num_batches_per_epoch = len(dataloader) 57 | sample_digits = math.ceil(math.log(num_batches_per_epoch + 1, 10)) 58 | 59 | total_loss_m = AverageMeter() 60 | c_loss_m = AverageMeter() 61 | g_loss_m = AverageMeter() 62 | batch_time_m = AverageMeter() 63 | data_time_m = AverageMeter() 64 | end = time.time() 65 | for i, batch in enumerate(dataloader): 66 | step = num_batches_per_epoch * epoch + i 67 | scheduler(step) 68 | 69 | images, objects, boxes, triples, obj_to_img, triple_to_img = [x.to(device=device, non_blocking=True) for x in batch] 70 | graphs = [objects, boxes, triples, obj_to_img, triple_to_img] 71 | 72 | data_time_m.update(time.time() - end) 73 | optimizer.zero_grad() 74 | 75 | with autocast(): 76 | local_gt_image_feature, local_graph_features, norm_global_gt_image_features, norm_global_graph_features, logit_scale = \ 77 | model(images, graphs) 78 | 79 | batch_size, _, H, W = images.shape 80 | box_mask_for_img = generate_box_mask(boxes_gt=boxes, obj_to_img=obj_to_img, H=H, W=W, threshold=0.2) 81 | masked_images = images * box_mask_for_img.to(device) 82 | with torch.no_grad(): 83 | local_masked_image_feature, _, norm_global_masked_image_features, _, _ = model(masked_images.detach(), graphs) 84 | 85 | c_loss = clip_loss(norm_global_gt_image_features, norm_global_graph_features, logit_scale) 86 | g_loss = mim_loss(local_graph_features, local_masked_image_feature, local_gt_image_feature) 87 | total_loss = c_loss + g_loss 88 | 89 | if scaler is not None: 90 | scaler.scale(total_loss).backward() 91 | 92 | if args.norm_gradient_clip is not None: 93 | scaler.unscale_(optimizer) 94 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) 95 | scaler.step(optimizer) 96 | scaler.update() 97 | else: 98 | total_loss.backward() 99 | if args.norm_gradient_clip is not None: 100 | torch.nn.utils.clip_grad_norm_(model.parameters(), args.norm_gradient_clip, norm_type=2.0) 101 | optimizer.step() 102 | 103 | with torch.no_grad(): 104 | unwrap_model(model).logit_scale.clamp_(0, math.log(100)) 105 | 106 | batch_time_m.update(time.time() - end) 107 | end = time.time() 108 | batch_count = i + 1 109 | if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): 110 | batch_size = len(images) 111 | num_samples = batch_count * batch_size * args.world_size 112 | samples_per_epoch = dataloader.num_samples 113 | percent_complete = 100.0 * batch_count / num_batches_per_epoch 114 | 115 | total_loss_m.update(total_loss.item(), batch_size) 116 | c_loss_m.update(total_loss.item(), batch_size) 117 | g_loss_m.update(total_loss.item(), batch_size) 118 | logit_scale_scalar = logit_scale.item() 119 | logging.info( 120 | f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " 121 | f"Total Loss: {total_loss_m.val:#.5g} ({total_loss_m.avg:#.4g}) " 122 | f"Contras Loss: {c_loss_m.val:#.5g} ({c_loss_m.avg:#.4g}) " 123 | f"Gen Loss: {g_loss_m.val:#.5g} ({g_loss_m.avg:#.4g}) " 124 | f"Data (t): {data_time_m.avg:.3f} " 125 | f"Batch (t): {batch_time_m.avg:.3f}, {args.batch_size * args.world_size / batch_time_m.val:#g}/s " 126 | f"LR: {optimizer.param_groups[0]['lr']:5f} " 127 | f"Logit Scale: {logit_scale_scalar:.3f}" 128 | ) 129 | 130 | log_data = { 131 | "total loss": total_loss_m.val, 132 | "contras loss": c_loss_m.val, 133 | "gen loss": g_loss_m.val, 134 | "data_time": data_time_m.val, 135 | "batch_time": batch_time_m.val, 136 | "samples_per_scond": args.batch_size * args.world_size / batch_time_m.val, 137 | "scale": logit_scale_scalar, 138 | "lr": optimizer.param_groups[0]["lr"] 139 | } 140 | for name, val in log_data.items(): 141 | name = "train/" + name 142 | if tb_writer is not None: 143 | tb_writer.add_scalar(name, val, step) 144 | 145 | batch_time_m.reset() 146 | data_time_m.reset() 147 | 148 | 149 | def validate_one_epoch(model, dataloader, epoch, args, tb_writer=None): 150 | device = torch.device(args.device) 151 | autocast = get_autocast(args.precision) 152 | 153 | model.eval() 154 | 155 | total_acc_g2i = 0. 156 | total_acc_i2g = 0. 157 | batch_count = 0 158 | batch_size = 0 159 | for i, batch in enumerate(dataloader): 160 | batch_count += 1 161 | 162 | images, objects, boxes, triples, obj_to_img, triple_to_img = [x.to(device=device, non_blocking=True) for x in 163 | batch] 164 | graphs = [objects, boxes, triples, obj_to_img, triple_to_img] 165 | 166 | with autocast(): 167 | local_gt_image_feature, local_graph_features, norm_global_gt_image_features, norm_global_graph_features, logit_scale = \ 168 | model(images, graphs) 169 | 170 | batch_size = images.shape[0] 171 | assert batch_size > 1 172 | 173 | acc_g2i, acc_i2g = validate_acc(norm_global_gt_image_features, norm_global_graph_features, device) 174 | total_acc_g2i += acc_g2i.item() 175 | total_acc_i2g += acc_i2g.item() 176 | 177 | avg_acc_g2i = total_acc_g2i / batch_count 178 | avg_acc_i2g = total_acc_i2g / batch_count 179 | 180 | logging.info( 181 | f"Validate Epoch: {epoch} " 182 | f"Validate Batch Size: {batch_size}" 183 | f"Average accuracy of g2i: {avg_acc_g2i:.3f} @ {batch_size}" 184 | f"Average accuracy of i2g: {avg_acc_i2g:.3f} @ {batch_size}" 185 | ) 186 | 187 | log_data = { 188 | "avg_acc_g2i": avg_acc_g2i, 189 | "avg_acc_i2g": avg_acc_i2g, 190 | } 191 | for name, val in log_data.items(): 192 | name = "validate/" + name 193 | if tb_writer is not None: 194 | tb_writer.add_scalar(name, val, epoch) 195 | 196 | print( 197 | "@batch_size %d \n average accuracy of graph-to-image is %f \n average accuracy of image-to-graph is %f \n" % ( 198 | batch_size, avg_acc_g2i, avg_acc_i2g)) 199 | 200 | 201 | def validate_acc(img_emb, graph_emb, device): 202 | img_emb = img_emb.detach().to(device) 203 | graph_emb = graph_emb.detach().to(device) 204 | with torch.no_grad(): 205 | B, D = img_emb.shape 206 | sim = torch.matmul(img_emb, graph_emb.T) 207 | 208 | pred_graph_to_img = sim - torch.max(sim, dim=0, keepdim=True).values.expand(B, B) 209 | pred_img_to_graph = sim - torch.max(sim, dim=1, keepdim=True).values.expand(B, B) 210 | 211 | pred_graph_to_img = (pred_graph_to_img >= 0).int().to(device) 212 | pred_img_to_graph = (pred_img_to_graph >= 0).int().to(device) 213 | gt_label_mask = mask_correlated_samples(B).to(device) 214 | 215 | pred_graph_to_img = pred_graph_to_img * gt_label_mask 216 | pred_img_to_graph = pred_img_to_graph * gt_label_mask 217 | 218 | correct_pred_graph_to_img = torch.sum(pred_graph_to_img).to(device) 219 | correct_pred_img_to_graph = torch.sum(pred_img_to_graph).to(device) 220 | 221 | acc_graph_to_img = correct_pred_graph_to_img * 1.0 / B 222 | acc_img_to_graph = correct_pred_img_to_graph * 1.0 / B 223 | 224 | return acc_graph_to_img, acc_img_to_graph 225 | 226 | 227 | def mask_correlated_samples(batch_size): 228 | N = batch_size 229 | mask = torch.zeros((N, N), dtype=torch.int) 230 | mask = mask.fill_diagonal_(1) 231 | return mask 232 | 233 | 234 | def unwrap_model(model): 235 | if hasattr(model, 'module'): 236 | return model.module 237 | else: 238 | return model 239 | 240 | 241 | if __name__ == '__main__': 242 | img_emb = torch.randn([64, 128]) 243 | graph_emb = torch.randn([64, 128]) 244 | img_emb = F.normalize(img_emb, dim=-1) 245 | graph_emb = F.normalize(graph_emb, dim=-1) 246 | acc_g2i, acc_i2g = validate_acc(img_emb, graph_emb) 247 | print(acc_g2i, acc_i2g) --------------------------------------------------------------------------------