├── requirements.txt ├── cheff ├── sr │ ├── __init__.py │ ├── utils.py │ ├── sampler.py │ ├── schedule.py │ └── model.py ├── ldm │ ├── models │ │ ├── __init__.py │ │ └── diffusion │ │ │ ├── __init__.py │ │ │ └── classifier.py │ ├── modules │ │ ├── __init__.py │ │ ├── encoders │ │ │ ├── __init__.py │ │ │ └── modules.py │ │ ├── distributions │ │ │ ├── __init__.py │ │ │ └── distributions.py │ │ ├── diffusionmodules │ │ │ └── __init__.py │ │ ├── image_degradation │ │ │ ├── __init__.py │ │ │ └── utils │ │ │ │ └── test.png │ │ ├── losses │ │ │ ├── __init__.py │ │ │ ├── contperceptual.py │ │ │ └── vqperceptual.py │ │ ├── ema.py │ │ └── attention.py │ ├── __init__.py │ ├── lr_scheduler.py │ └── util.py ├── __init__.py └── machex.py ├── chexzero ├── util │ ├── __init__.py │ ├── data_utils.py │ ├── prompt_utils.py │ └── plot_grounding.py ├── components │ ├── __init__.py │ ├── pooling.py │ ├── mlp.py │ ├── bbox_prediction.py │ ├── soft_roi_pool.py │ ├── classification_losses.py │ └── bbox_losses.py ├── chexzero │ ├── bpe_simple_vocab_16e6.txt.gz │ ├── LICENSE │ ├── NOTICE.md │ ├── simple_tokenizer.py │ ├── eval.py │ ├── clip.py │ └── README.md ├── txt_encoder │ ├── __init__.py │ └── chexzero_txt_encoder.py ├── img_encoder │ ├── __init__.py │ └── chexzero_img_encoder.py └── main.py ├── .gitmodules ├── .gitignore ├── config.yaml ├── README.md ├── src ├── utils.py ├── data.py └── loss.py └── models ├── clip.py ├── classifier.py └── cheff.py /requirements.txt: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/sr/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/models/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chexzero/util/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/modules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /chexzero/components/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/models/diffusion/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/modules/encoders/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/modules/distributions/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/modules/diffusionmodules/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /cheff/ldm/modules/image_degradation/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cheff/ldm/__init__.py: -------------------------------------------------------------------------------- 1 | from cheff.ldm.util import instantiate_from_config -------------------------------------------------------------------------------- /cheff/ldm/modules/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from cheff.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "chexpert"] 2 | path = chexpert 3 | url = https://github.com/jfhealthcare/Chexpert.git 4 | -------------------------------------------------------------------------------- /chexzero/chexzero/bpe_simple_vocab_16e6.txt.gz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berkegokmen1/counterfactual-chexray-disease-editing/HEAD/chexzero/chexzero/bpe_simple_vocab_16e6.txt.gz -------------------------------------------------------------------------------- /cheff/ldm/modules/image_degradation/utils/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/berkegokmen1/counterfactual-chexray-disease-editing/HEAD/cheff/ldm/modules/image_degradation/utils/test.png -------------------------------------------------------------------------------- /cheff/__init__.py: -------------------------------------------------------------------------------- 1 | from cheff.ldm.inference import ( 2 | CheffAEModel, 3 | CheffLDM, 4 | CheffLDMT2I, 5 | CheffLDMT2IEdit, 6 | CheffLDMMaskCond, 7 | CheffLDMImageMaskCond, 8 | CheffLDMImageCond, 9 | ) 10 | from cheff.sr.sampler import CheffSRModel 11 | -------------------------------------------------------------------------------- /chexzero/txt_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import List 4 | 5 | from torch import BoolTensor, Tensor 6 | 7 | from util.data_utils import TensorDataclassMixin 8 | 9 | 10 | @dataclass 11 | class TextEncoderOutput(TensorDataclassMixin): 12 | # (N x S x d) -> already projected to model space 13 | sentence_features: Tensor 14 | # (N x S) 15 | sentence_mask: BoolTensor 16 | sentences: List[List[str]] 17 | flattened_sentences: List[str] 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | __pycache__ 3 | *.pyc 4 | *.pyo 5 | *.pyd 6 | .Python 7 | env/ 8 | venv/ 9 | ENV/ 10 | env.bak/ 11 | venv.bak/ 12 | .idea/ 13 | .vscode/ 14 | .DS_Store 15 | .ipynb_checkpoints 16 | data/ 17 | logs/ 18 | outputs/ 19 | runs/ 20 | *.sqlite3 21 | *.db 22 | *.log 23 | *.pot 24 | *.mo 25 | *.cover 26 | htmlcov/ 27 | .tox/ 28 | .nox/ 29 | .coverage 30 | .coverage.* 31 | .cache 32 | nosetests.xml 33 | coverage.xml 34 | *.coveragerc 35 | *.codecov.yml 36 | *.pylintrc 37 | *.mypy_cache/ 38 | .dmypy.json 39 | .pyre/ 40 | .pytype/ 41 | .pyright/ 42 | .pytest_cache/ 43 | .ipynb_checkpoints -------------------------------------------------------------------------------- /chexzero/components/pooling.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | from torch import BoolTensor, nn 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | class GlobalAvgPool(nn.Module): 8 | def __init__(self): 9 | super(GlobalAvgPool, self).__init__() 10 | 11 | def forward(self, x, mask: Optional[BoolTensor] = None): 12 | N, *dims, d = x.shape 13 | x = x.view(N, -1, d) 14 | 15 | if mask is not None: 16 | mask = mask.view(N, -1).bool() 17 | x = torch.masked_fill(x, ~mask[:, :, None], 0.) 18 | # (N x d) 19 | pooled = x.sum(1) / (mask.sum(1)[:, None] + 1e-7) 20 | else: 21 | pooled = torch.mean(x, dim=1) 22 | 23 | return pooled 24 | -------------------------------------------------------------------------------- /chexzero/img_encoder/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from dataclasses import dataclass 3 | from typing import Optional 4 | 5 | from torch import BoolTensor, Tensor 6 | 7 | from util.data_utils import TensorDataclassMixin 8 | 9 | 10 | @dataclass 11 | class ImageEncoderOutput(TensorDataclassMixin): 12 | # (N x H x W x d) -> already projected to model space 13 | patch_features: Tensor 14 | # (N x H x W x d) 15 | pos_embeddings: Tensor 16 | # (N x d) 17 | global_features: Optional[Tensor] = None 18 | 19 | @property 20 | def device(self): 21 | return self.patch_features.device 22 | 23 | @property 24 | def dtype(self): 25 | return self.patch_features.dtype 26 | 27 | @property 28 | def N(self): 29 | return self.patch_features.shape[0] 30 | 31 | @property 32 | def d(self): 33 | return self.patch_features.shape[-1] 34 | 35 | -------------------------------------------------------------------------------- /chexzero/chexzero/LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2022 Rajpurkar Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /config.yaml: -------------------------------------------------------------------------------- 1 | device: cuda 2 | 3 | model: 4 | 5 | clip: 6 | path: /path/to/chexzero/chex/best_64_5e-05_original_22000_0.864.pt 7 | 8 | cheff: 9 | ae_path: /path/to/chex/chexray-diffusion-ckpts/cheff_autoencoder.pt 10 | ldm_path: /path/to/chex/chexray-diffusion-ckpts/cheff_diff_uncond.pt 11 | load_external: false 12 | external_path: /path/to/chex/chexray-diffusion-ckpts/cheff_diff_uncond.pt 13 | 14 | classifier: 15 | model: resnet152 # attnresnet152 16 | restore: true 17 | path: /path/to/chex/chexpert/exp-resnet-pretrained-256/best_checkpoints/checkpoint_9.pt 18 | resize: 256 19 | lr: 0.0001 20 | pretrained: true 21 | 22 | train: 23 | method: attn 24 | 25 | data: 26 | path: /path/to/chex/CheXpert-v1.0-small 27 | prefix: CheXpert-v1.0-small/ 28 | resize: 256 29 | mini_data: 0 30 | batch_size: 4 31 | train_batch_size: 1 32 | test_batch_size: 1 33 | num_workers: 7 34 | mode: train 35 | prefetch_factor: 4 36 | 37 | finetune: 38 | experiment_name: null 39 | target: null 40 | log_dir: ./experiments/${finetune.experiment_name} 41 | num_timesteps: 50 42 | learning_rate: 1e-5 43 | num_epochs: 3 44 | lambdas: 45 | classification_step: 1.5 46 | l1_step: 0.1 47 | anchor_step: 0.01 48 | ### 49 | clip_direction_step: 0. 50 | clip_distance_step: 0. 51 | clip_step: 0 -------------------------------------------------------------------------------- /chexzero/chexzero/NOTICE.md: -------------------------------------------------------------------------------- 1 | # Third-Party Code: CheXZero 2 | Third-party code from https://github.com/rajpurkarlab/CheXzero 3 | 4 | ## Project Licenses 5 | The source code of this repository was derived from CheXzero (https://github.com/rajpurkarlab/CheXzero) and OpenAI CLIP (https://github.com/openai/CLIP). 6 | 7 | ### Open Source License / Copyright Notice 8 | ``` 9 | MIT License 10 | 11 | Copyright (c) 2021 OpenAI 12 | 13 | Permission is hereby granted, free of charge, to any person obtaining a copy 14 | of this software and associated documentation files (the "Software"), to deal 15 | in the Software without restriction, including without limitation the rights 16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 17 | copies of the Software, and to permit persons to whom the Software is 18 | furnished to do so, subject to the following conditions: 19 | 20 | The above copyright notice and this permission notice shall be included in all 21 | copies or substantial portions of the Software. 22 | 23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 29 | SOFTWARE. 30 | ``` -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Counterfactual Disease Removal and Generation in Chest X-Rays Using Diffusion Models 2 | 3 | [Paper](https://berkegokmen1.github.io/counterfactual-disease-removal-and-generation-chest-xray/static/paper.pdf) | [Project Website](https://berkegokmen1.github.io/counterfactual-disease-removal-and-generation-chest-xray/) | [BibTeX](#bibtex) 4 | 5 | ## Authors 6 | [Ahmet Berke Gökmen](https://berkegokmen1.github.io/), [Ender Konukoglu](https://people.ee.ethz.ch/~kender/) 7 | 8 | ![teaser](https://github.com/user-attachments/assets/4faf0674-66e3-45e7-bb56-c2c2caeb6ab1) 9 | 10 | ## TODO 11 | - [X] Release Website 12 | - [X] Release Code 13 | - [ ] Run Instructions 14 | 15 | ## Setup 16 | 17 | ```bash 18 | conda create -n chexray-editing python=3.10 19 | pip install -r requirements.txt [TODO] 20 | ``` 21 | 22 | ## Inference 23 | Please download chexzero, chexpert and chexray-diffusion checkpoints from their respective repositories and update the paths in `config.yaml`. 24 | 25 | In additon to the checkpoints, you'll need to download `CheXpert-v1.0-small` dataset from the official chexpert website or you may use any chest x-ray image. 26 | ```bash 27 | python finetune_sample.py --config config.yaml --target "Pleural Effusion" --mode "removal" --experiment_name "demo" 28 | ``` 29 | 30 | ## Questions 31 | 32 | You may reach me through [LinkedIn](https://www.linkedin.com/in/berkegokmen/). 33 | 34 | ## This work would not have been possible without: 35 | - https://github.com/rajpurkarlab/CheXzero 36 | - https://github.com/jfhealthcare/Chexpert 37 | - https://github.com/saiboxx/chexray-diffusion 38 | 39 | ## BibTeX 40 | ``` 41 | ``` 42 | 43 |
44 | -------------------------------------------------------------------------------- /chexzero/main.py: -------------------------------------------------------------------------------- 1 | import chexzero.clip 2 | import torch 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from chexzero.clip import tokenize 6 | from PIL import Image 7 | 8 | 9 | model, preprocess = chexzero.clip.load("ViT-B/32", device="cpu", jit=False) 10 | model.load_state_dict( 11 | torch.load("/cluster/work/cvl/agoekmen/chexzero/best_64_5e-05_original_22000_0.864.pt", map_location="cpu") 12 | ) 13 | 14 | 15 | def encode_text(text: str): 16 | tokenized_text = tokenize(text) 17 | text_features = model.encode_text(tokenized_text) 18 | return text_features 19 | 20 | 21 | def encode_image(image: torch.Tensor): 22 | image_features = model.encode_image(image) 23 | return image_features 24 | 25 | 26 | def calculate_scores(text_features: torch.Tensor, image_features: torch.Tensor): 27 | # normalized features 28 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 29 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 30 | 31 | # cosine similarity as logits 32 | logit_scale = model.logit_scale.exp() 33 | logits_per_image = logit_scale * image_features @ text_features.t() 34 | logits_per_text = logit_scale * text_features @ image_features.t() 35 | 36 | # shape = [global_batch_size, global_batch_size] 37 | return logits_per_image, logits_per_text 38 | 39 | 40 | if __name__ == "__main__": 41 | text = ["Pleural Effusion", "No Pleural Effusion"] 42 | image = Image.open( 43 | "/cluster/home/agoekmen/projects/chexray-diffusion/removal-editing/all_exp/working_experiments/__attn/original.png" 44 | ) 45 | image = preprocess(image).unsqueeze(0) 46 | text_features = encode_text(text) 47 | image_features = encode_image(image) 48 | 49 | print(text_features.shape) 50 | print(image_features.shape) 51 | 52 | logits_per_image, logits_per_text = calculate_scores(text_features, image_features) 53 | print(logits_per_image) 54 | print(logits_per_text) 55 | 56 | probs = logits_per_image.softmax(dim=1) 57 | print(probs) 58 | -------------------------------------------------------------------------------- /chexzero/components/mlp.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from util.model_utils import get_activation 7 | 8 | class MLP(nn.Module): 9 | def __init__( 10 | self, 11 | n_layers: int, 12 | d_in: int, d_out: Optional[int] = None, 13 | use_bn=False, dropout=0.0, dropout_last_layer=True, act=nn.ReLU, d_hidden_factor: int = 4, d_hidden: Optional[int] = None): 14 | """ 15 | :param num_layers: If num_hidden_layers == 0, then only use identity, for num_hidden_layers == 1, then only use linear 16 | :param d_in: 17 | :param d_hidden: 18 | :param d_out: 19 | :param use_bn 20 | """ 21 | super(MLP, self).__init__() 22 | if act is None: 23 | act = nn.ReLU 24 | act = get_activation(act) 25 | 26 | if d_out is None: 27 | d_out = d_in 28 | if d_hidden is None: 29 | d_hidden = d_hidden_factor * d_out 30 | assert n_layers >= 0 31 | if n_layers == 0: 32 | assert d_in == d_out, f'If n_layers == 0, then d_in == d_out, but got {d_in} != {d_out}' 33 | self.layers = nn.Identity() 34 | else: 35 | current_dim_in = d_in 36 | layers = [] 37 | for _ in range(n_layers - 1): 38 | layers.append(nn.Linear(current_dim_in, d_hidden, bias=not use_bn)) 39 | if use_bn: 40 | layers.append(nn.BatchNorm1d(d_hidden)) 41 | layers.append(act) 42 | if dropout > 0.0: 43 | layers.append(nn.Dropout(dropout)) 44 | current_dim_in = d_hidden 45 | layers.append(nn.Linear(current_dim_in, d_out)) 46 | if dropout_last_layer and dropout > 0.0: 47 | layers.append(nn.Dropout(dropout)) 48 | self.layers = nn.Sequential(*layers) 49 | 50 | def forward(self, x: torch.Tensor) -> torch.Tensor: 51 | *dims, d = x.shape 52 | if len(dims) > 1: 53 | x = x.reshape(-1, d) 54 | 55 | x = self.layers(x) 56 | return x.view(*dims, -1) 57 | -------------------------------------------------------------------------------- /cheff/sr/utils.py: -------------------------------------------------------------------------------- 1 | """Miscellaneous classes and functions.""" 2 | from typing import Any, Optional 3 | 4 | import numpy as np 5 | from matplotlib import animation 6 | from matplotlib import pyplot as plt 7 | from torch import Tensor 8 | from torchvision.transforms import ( 9 | Compose, 10 | Lambda, 11 | ToPILImage, 12 | ) 13 | from torchvision.utils import make_grid 14 | 15 | 16 | def transform_tensor_to_img() -> Compose: 17 | """Transform a tensor with a single element to a PIL image.""" 18 | return Compose( 19 | [ 20 | Lambda(lambda t: t.detach().cpu()), 21 | Lambda(lambda t: (t + 1) / 2), 22 | Lambda(lambda t: t.permute(1, 2, 0)), 23 | Lambda(lambda t: t * 255.0), 24 | Lambda(lambda t: t.numpy().astype(np.uint8)), 25 | ToPILImage(), 26 | ] 27 | ) 28 | 29 | 30 | def plot_image( 31 | img: Tensor, 32 | fig_size: Any = None, 33 | ncols: Optional[int] = None, 34 | show: bool = True, 35 | save_path: Optional[str] = None, 36 | ) -> None: 37 | """Plot a tensor containing image data.""" 38 | img = img.detach().cpu() 39 | 40 | # Shape of 4 implies multiple image inputs 41 | if len(img.shape) == 4: 42 | img = make_grid(img, nrow=ncols if ncols is not None else len(img)) 43 | 44 | plt.figure(figsize=fig_size) 45 | plt.imshow(img.permute(1, 2, 0)) 46 | plt.axis('off') 47 | 48 | if save_path is not None: 49 | plt.savefig(save_path, bbox_inches='tight') 50 | 51 | if show: 52 | plt.show() 53 | plt.close() 54 | 55 | 56 | def make_gif( 57 | img_arr: Tensor, 58 | save_path: str, 59 | ) -> None: 60 | """Create a GIF with the output of DiffusionController.generate().""" 61 | assert len(img_arr) == 5, 'Array has wrong shape.' 62 | 63 | img_arr = img_arr.detach().cpu() 64 | 65 | fig = plt.figure(frameon=False) 66 | ims = [] 67 | 68 | for img_t in img_arr: 69 | grid = make_grid(img_t, nrow=img_t.shape[0] // 2) 70 | im = plt.imshow(grid.permute(1, 2, 0), animated=True) 71 | plt.axis('off') 72 | plt.tight_layout() 73 | ims.append([im]) 74 | 75 | fig.tight_layout() 76 | 77 | animate = animation.ArtistAnimation( 78 | fig, ims, interval=100, blit=True, repeat_delay=2000 79 | ) 80 | animate.save(save_path) 81 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import gc 4 | import torch 5 | import numpy as np 6 | import random 7 | import torchvision.transforms as T 8 | import torch.nn.functional as F 9 | import yaml 10 | import shutil 11 | 12 | 13 | def set_seed(seed): 14 | torch.manual_seed(seed) 15 | torch.cuda.manual_seed_all(seed) 16 | np.random.seed(seed) 17 | random.seed(seed) 18 | torch.backends.cudnn.deterministic = True 19 | torch.backends.cudnn.benchmark = False 20 | 21 | 22 | def set_device(args): 23 | args.device = torch.device("cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu") 24 | print(f"Using device: {args.device}") 25 | 26 | 27 | def clear_mem(verbose=True): 28 | res = gc.collect(), torch.cuda.empty_cache() 29 | if verbose: 30 | print(f"Cleared memory: {res}") 31 | return res 32 | 33 | 34 | def create_experiment_folder(args): 35 | path = os.path.join(args.experiment.path, args.experiment.prefix + time.ctime().replace(" ", "_").replace(":", "-")) 36 | os.makedirs(path, exist_ok=True) 37 | args.out_dir = path 38 | print(f"Created experiment folder: {path}") 39 | return path 40 | 41 | 42 | def create_denormalize(): 43 | mean = [0.5330] # CheXpert mean 44 | std = [0.0349] # CheXpert std 45 | denorm = T.Normalize(mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std]) 46 | 47 | return denorm 48 | 49 | 50 | def create_normalize(): 51 | mean = [0.5330] # CheXpert mean 52 | std = [0.0349] # CheXpert std 53 | t_norm = T.Normalize(mean=mean, std=std) 54 | return t_norm 55 | 56 | 57 | def clip_normalize(): 58 | t = T.Compose([T.Resize((512, 512)), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) 59 | 60 | return t 61 | 62 | 63 | def create_blur(kernel_size=5, sigma=(2.0, 2.0)): 64 | return T.GaussianBlur(kernel_size, sigma) 65 | 66 | 67 | def create_mask_dilate(): 68 | blur = create_blur() 69 | 70 | dilate = lambda mask, kernel_size: blur( 71 | F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2) 72 | ).view(mask.shape) 73 | 74 | return dilate 75 | 76 | 77 | def save_config_yaml(args): 78 | shutil.copyfile(args.config_path, os.path.join(args.out_dir, "config.yaml")) 79 | 80 | 81 | def save_txt( 82 | path, 83 | name, 84 | text, 85 | ): 86 | if not isinstance(text, list): 87 | text = [text] 88 | 89 | with open(os.path.join(path, f"{name}.txt"), "w") as f: 90 | f.write("\n".join(text)) 91 | -------------------------------------------------------------------------------- /chexzero/components/bbox_prediction.py: -------------------------------------------------------------------------------- 1 | from typing import List 2 | import einops 3 | import torch 4 | from torchvision.ops import batched_nms, box_convert 5 | from ensemble_boxes import weighted_boxes_fusion 6 | 7 | 8 | def clip_bboxes(box_params): 9 | box_params = box_convert(box_params, 'cxcywh', 'xyxy') 10 | box_params = box_params.clamp(0., 1.) 11 | box_params = box_convert(box_params, 'xyxy', 'cxcywh') 12 | return box_params 13 | 14 | 15 | def apply_top1_filtering(boxes: List[torch.Tensor]) -> List[torch.Tensor]: 16 | filtered_boxes = [] 17 | for sample_boxes in boxes: 18 | labels = sample_boxes[:, 4] 19 | scores = sample_boxes[:, 5] 20 | unique_classes = torch.unique(labels) 21 | keep_inds = torch.stack([ 22 | (torch.where(labels == cls, 1, 0) * scores).argmax() 23 | for cls in unique_classes 24 | ]) if len(unique_classes) > 0 else torch.zeros(0, dtype=torch.long) 25 | filtered_boxes.append(sample_boxes[keep_inds]) 26 | return filtered_boxes 27 | 28 | 29 | def apply_top1_with_box_fusion(boxes: List[torch.Tensor]) -> List[torch.Tensor]: 30 | filtered_boxes = [] 31 | for sample_boxes in boxes: 32 | boxes = sample_boxes[:, :4] 33 | boxes_upper_left = boxes[:, :2] - boxes[:, 2:] / 2 34 | boxes_lower_right = boxes[:, :2] + boxes[:, 2:] / 2 35 | labels = sample_boxes[:, 4] 36 | scores = sample_boxes[:, 5] 37 | unique_classes = torch.unique(labels) 38 | filtered_sample_boxes = [] 39 | for c in unique_classes: 40 | cls_scores = scores[labels == c] 41 | if len(cls_scores) == 0: 42 | continue 43 | cls_boxes_upper_left = boxes_upper_left[labels == c] 44 | cls_boxes_lower_right = boxes_lower_right[labels == c] 45 | cls_fused_boxes = torch.cat([cls_boxes_upper_left.amin(dim=0), cls_boxes_lower_right.amax(dim=0)], dim=-1) 46 | wh = cls_fused_boxes[2:] - cls_fused_boxes[:2] 47 | cls_fused_boxes = torch.cat([cls_fused_boxes[:2] + wh / 2, wh], dim=-1) 48 | cls_top1_scores = cls_scores.amax() 49 | filtered_sample_boxes.append(torch.cat([cls_fused_boxes, c.unsqueeze(-1), cls_top1_scores.unsqueeze(-1)], dim=-1)) 50 | filtered_boxes.append(torch.stack(filtered_sample_boxes) if len(filtered_sample_boxes) > 0 else torch.zeros(0, 6)) 51 | return filtered_boxes 52 | 53 | 54 | def apply_nms(predicted_boxes: List[torch.Tensor], iou_threshold: float): 55 | predicted_boxes_after_nms = [] 56 | for sample_boxes in predicted_boxes: 57 | boxes_coords = box_convert(sample_boxes[:, 0:4], in_fmt='cxcywh', out_fmt='xyxy') 58 | cls_idxs = sample_boxes[:, 4] 59 | scores = sample_boxes[:, 5] 60 | nms_indices = batched_nms(boxes_coords, scores, cls_idxs, iou_threshold=iou_threshold) 61 | predicted_boxes_after_nms.append(sample_boxes[nms_indices, :]) 62 | return predicted_boxes_after_nms 63 | 64 | -------------------------------------------------------------------------------- /cheff/ldm/modules/ema.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class LitEma(nn.Module): 6 | def __init__(self, model, decay=0.9999, use_num_upates=True): 7 | super().__init__() 8 | if decay < 0.0 or decay > 1.0: 9 | raise ValueError('Decay must be between 0 and 1') 10 | 11 | self.m_name2s_name = {} 12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32)) 13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates 14 | else torch.tensor(-1,dtype=torch.int)) 15 | 16 | for name, p in model.named_parameters(): 17 | if p.requires_grad: 18 | #remove as '.'-character is not allowed in buffers 19 | s_name = name.replace('.','') 20 | self.m_name2s_name.update({name:s_name}) 21 | self.register_buffer(s_name,p.clone().detach().data) 22 | 23 | self.collected_params = [] 24 | 25 | def forward(self,model): 26 | decay = self.decay 27 | 28 | if self.num_updates >= 0: 29 | self.num_updates += 1 30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates)) 31 | 32 | one_minus_decay = 1.0 - decay 33 | 34 | with torch.no_grad(): 35 | m_param = dict(model.named_parameters()) 36 | shadow_params = dict(self.named_buffers()) 37 | 38 | for key in m_param: 39 | if m_param[key].requires_grad: 40 | sname = self.m_name2s_name[key] 41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) 42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key])) 43 | else: 44 | assert not key in self.m_name2s_name 45 | 46 | def copy_to(self, model): 47 | m_param = dict(model.named_parameters()) 48 | shadow_params = dict(self.named_buffers()) 49 | for key in m_param: 50 | if m_param[key].requires_grad: 51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) 52 | else: 53 | assert not key in self.m_name2s_name 54 | 55 | def store(self, parameters): 56 | """ 57 | Save the current parameters for restoring later. 58 | Args: 59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 60 | temporarily stored. 61 | """ 62 | self.collected_params = [param.clone() for param in parameters] 63 | 64 | def restore(self, parameters): 65 | """ 66 | Restore the parameters stored with the `store` method. 67 | Useful to validate the model with EMA parameters without affecting the 68 | original optimization process. Store the parameters before the 69 | `copy_to` method. After validation (or model saving), use this to 70 | restore the former parameters. 71 | Args: 72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be 73 | updated with the stored parameters. 74 | """ 75 | for c_param, param in zip(self.collected_params, parameters): 76 | param.data.copy_(c_param.data) 77 | -------------------------------------------------------------------------------- /cheff/ldm/modules/distributions/distributions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | 5 | class AbstractDistribution: 6 | def sample(self): 7 | raise NotImplementedError() 8 | 9 | def mode(self): 10 | raise NotImplementedError() 11 | 12 | 13 | class DiracDistribution(AbstractDistribution): 14 | def __init__(self, value): 15 | self.value = value 16 | 17 | def sample(self): 18 | return self.value 19 | 20 | def mode(self): 21 | return self.value 22 | 23 | 24 | class DiagonalGaussianDistribution(object): 25 | def __init__(self, parameters, deterministic=False): 26 | self.parameters = parameters 27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) 28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0) 29 | self.deterministic = deterministic 30 | self.std = torch.exp(0.5 * self.logvar) 31 | self.var = torch.exp(self.logvar) 32 | if self.deterministic: 33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) 34 | 35 | def sample(self): 36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) 37 | return x 38 | 39 | def kl(self, other=None): 40 | if self.deterministic: 41 | return torch.Tensor([0.]) 42 | else: 43 | if other is None: 44 | return 0.5 * torch.sum(torch.pow(self.mean, 2) 45 | + self.var - 1.0 - self.logvar, 46 | dim=[1, 2, 3]) 47 | else: 48 | return 0.5 * torch.sum( 49 | torch.pow(self.mean - other.mean, 2) / other.var 50 | + self.var / other.var - 1.0 - self.logvar + other.logvar, 51 | dim=[1, 2, 3]) 52 | 53 | def nll(self, sample, dims=[1,2,3]): 54 | if self.deterministic: 55 | return torch.Tensor([0.]) 56 | logtwopi = np.log(2.0 * np.pi) 57 | return 0.5 * torch.sum( 58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, 59 | dim=dims) 60 | 61 | def mode(self): 62 | return self.mean 63 | 64 | 65 | def normal_kl(mean1, logvar1, mean2, logvar2): 66 | """ 67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 68 | Compute the KL divergence between two gaussians. 69 | Shapes are automatically broadcasted, so batches can be compared to 70 | scalars, among other use cases. 71 | """ 72 | tensor = None 73 | for obj in (mean1, logvar1, mean2, logvar2): 74 | if isinstance(obj, torch.Tensor): 75 | tensor = obj 76 | break 77 | assert tensor is not None, "at least one argument must be a Tensor" 78 | 79 | # Force variances to be Tensors. Broadcasting helps convert scalars to 80 | # Tensors, but it does not work for torch.exp(). 81 | logvar1, logvar2 = [ 82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) 83 | for x in (logvar1, logvar2) 84 | ] 85 | 86 | return 0.5 * ( 87 | -1.0 88 | + logvar2 89 | - logvar1 90 | + torch.exp(logvar1 - logvar2) 91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) 92 | ) 93 | -------------------------------------------------------------------------------- /chexzero/util/data_utils.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Any, Callable, Mapping, Sequence 3 | from PIL import Image 4 | import torch 5 | 6 | def load_pil_gray(path: str) -> Image.Image: 7 | return Image.open(path).convert('L') 8 | 9 | 10 | class TensorDataclassMixin: 11 | def __init__(self): 12 | super(TensorDataclassMixin, self).__init__() 13 | assert dataclasses.is_dataclass(self), f'{type(self)} has to be a dataclass to use TensorDataclassMixin' 14 | 15 | def apply(self, tensor_fn: Callable[[torch.Tensor], torch.Tensor], ignore=None, apply_to_list=False): 16 | if ignore is None and hasattr(self, 'IGNORE_APPLY'): 17 | ignore = self.IGNORE_APPLY 18 | def apply_to_value(value): 19 | if value is None: 20 | return None 21 | elif isinstance(value, torch.Tensor): 22 | return tensor_fn(value) 23 | elif isinstance(value, list): 24 | if apply_to_list: 25 | return tensor_fn(value) 26 | else: 27 | return [apply_to_value(el) for el in value] 28 | elif isinstance(value, tuple): 29 | return tuple(apply_to_value(el) for el in value) 30 | elif isinstance(value, dict): 31 | return {key: apply_to_value(el) for key, el in value.items()} 32 | elif isinstance(value, TensorDataclassMixin): 33 | return value.apply(tensor_fn) 34 | else: 35 | return value 36 | 37 | def apply_to_field(field: dataclasses.Field): 38 | value = getattr(self, field.name) 39 | if ignore is not None and field.name in ignore: 40 | return value 41 | else: 42 | try: 43 | return apply_to_value(value) 44 | except Exception as e: 45 | raise RuntimeError(f'Error while applying {tensor_fn} to {field.name} ({value})') from e 46 | 47 | return self.__class__(**{field.name: apply_to_field(field) for field in dataclasses.fields(self)}) 48 | 49 | def to(self, device, *args, non_blocking=True, **kwargs): 50 | return self.apply(lambda x: x.to(device, *args, non_blocking=non_blocking, **kwargs)) 51 | 52 | def view(self, *args): 53 | return self.apply(lambda x: x.view(*args)) 54 | 55 | def detach(self): 56 | return self.apply(lambda x: x.detach()) 57 | 58 | def unsqueeze(self, dim): 59 | return self.apply(lambda x: x.unsqueeze(dim)) 60 | 61 | def squeeze(self, dim): 62 | return self.apply(lambda x: x.squeeze(dim)) 63 | 64 | def __getitem__(self, *args): 65 | return self.apply(lambda x: x.__getitem__(*args), apply_to_list=True) 66 | 67 | def to_dict(self): 68 | return dataclasses.asdict(self) 69 | 70 | 71 | def to_device(data: Any, device: str, non_blocking=True): 72 | if data is None: 73 | return None 74 | if isinstance(data, torch.Tensor): 75 | if device == 'cpu': 76 | non_blocking = False 77 | return data.to(device, non_blocking=non_blocking) 78 | elif isinstance(data, Mapping): 79 | return {key: to_device(data[key], device, non_blocking=non_blocking) for key in data} 80 | elif isinstance(data, Sequence) and not isinstance(data, str): 81 | return [to_device(d, device, non_blocking=non_blocking) for d in data] 82 | elif isinstance(data, str): 83 | return data 84 | elif isinstance(data, TensorDataclassMixin): 85 | if device == 'cpu': 86 | non_blocking = False 87 | return data.to(device=device, non_blocking=non_blocking) 88 | else: 89 | raise TypeError(type(data)) 90 | -------------------------------------------------------------------------------- /cheff/sr/sampler.py: -------------------------------------------------------------------------------- 1 | import os 2 | from typing import Union, Tuple 3 | 4 | from PIL import Image 5 | import torch 6 | from torch import Tensor 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision.transforms import InterpolationMode 9 | from torchvision.transforms.functional import resize, to_tensor, to_grayscale 10 | from torchvision.utils import save_image 11 | 12 | from cheff.sr.model import Unet 13 | from cheff.sr.schedule import ScheduleFactory 14 | from cheff.sr.diffusor import SR3Diffusor, SR3DDIMDiffusor 15 | 16 | 17 | class CheffSRModel: 18 | def __init__( 19 | self, 20 | model_path: str, 21 | device: Union[str, int, torch.device] = 'cuda' 22 | ) -> None: 23 | self.device = device 24 | self.model = Unet( 25 | dim=16, 26 | channels=2, 27 | out_dim=1, 28 | dim_mults=(1, 2, 4, 8, 16, 32, 32, 32), 29 | ) 30 | 31 | state_dict = torch.load(model_path, map_location='cpu') 32 | self.model.load_state_dict(state_dict['model']) 33 | self.model.to(self.device) 34 | self.model.eval() 35 | 36 | self.schedule = ScheduleFactory.get_schedule( 37 | name='cosine', timesteps=2000, device=self.device) 38 | 39 | def sample_directory( 40 | self, 41 | source_dir: str, 42 | target_dir: str, 43 | batch_size: int = 1, 44 | method: str = 'ddim', 45 | sampling_steps: int = 100, 46 | eta: float = 0. 47 | ) -> None: 48 | ds = DirectoryDataset(source_dir) 49 | loader = DataLoader(ds, batch_size=batch_size, pin_memory=True) 50 | 51 | os.makedirs(target_dir, exist_ok=True) 52 | 53 | for f_names, imgs in loader: 54 | imgs_sr = self.sample(imgs, method, sampling_steps, eta) 55 | 56 | for f_name, img_sr in zip(f_names, imgs_sr): 57 | path = os.path.join(target_dir, f_name) 58 | save_image(img_sr, path) 59 | 60 | def sample_path( 61 | self, 62 | path: str, 63 | method: str = 'ddim', 64 | sampling_steps: int = 100, 65 | eta: float = 0. 66 | ) -> Tensor: 67 | img = Image.open(path) 68 | img = to_tensor(to_grayscale(img)).unsqueeze(0) 69 | return self.sample(img, method, sampling_steps, eta) 70 | 71 | @torch.no_grad() 72 | def sample( 73 | self, 74 | img: Tensor, 75 | method: str = 'ddim', 76 | sampling_steps: int = 100, 77 | eta: float = 0. 78 | ) -> Tensor: 79 | img = img.to(self.device) 80 | img = img * 2 - 1 81 | img = resize(img, [1024, 1024], InterpolationMode.BICUBIC) 82 | 83 | if method == 'ddim': 84 | diffusor = SR3DDIMDiffusor( 85 | model=self.model, 86 | schedule=self.schedule, 87 | sampling_steps=sampling_steps, 88 | eta=eta 89 | ) 90 | else: 91 | diffusor = SR3Diffusor( 92 | model=self.model, 93 | schedule=self.schedule, 94 | ) 95 | 96 | img_sr = diffusor.p_sample_loop(sr=img) 97 | img_sr.clamp_(-1, 1) 98 | img_sr = (img_sr + 1) / 2 99 | return img_sr 100 | 101 | 102 | class DirectoryDataset(Dataset): 103 | def __init__(self, root: str) -> None: 104 | self.root = root 105 | self.files = os.listdir(root) 106 | 107 | def __len__(self): 108 | return len(self.files) 109 | 110 | def __getitem__(self, idx: int) -> Tuple[str, Tensor]: 111 | fp = os.path.join(self.root, self.files[idx]) 112 | 113 | img = Image.open(fp) 114 | img = to_tensor(to_grayscale(img)) 115 | return self.files[idx], img 116 | -------------------------------------------------------------------------------- /chexzero/util/prompt_utils.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | from itertools import chain 5 | from typing import List 6 | 7 | import torch 8 | 9 | 10 | def flatten_prompts(prompts: List[List[str]], device): 11 | flattened_prompts: List[str] = [prompt for sub_prompts in prompts if sub_prompts is not None for prompt in sub_prompts] 12 | N = len(prompts) 13 | M_is: List[int] = [len(sub_prompts) if sub_prompts is not None else 0 for sub_prompts in prompts] 14 | # (N x max(M_i)) 15 | prompt_mask = torch.zeros(N, max(M_is) if len(M_is) > 0 else 0, dtype=torch.bool, device=device) 16 | for i, M_i in enumerate(M_is): 17 | prompt_mask[i, :M_i] = True 18 | return flattened_prompts, prompt_mask 19 | 20 | def flatten_prompts_2(prompts: List[List[List[str]]], device): 21 | flattened_prompts: List[str] = [prompt for sub_prompts in prompts for sub_sub_prompts in sub_prompts for prompt in sub_sub_prompts] 22 | N = len(prompts) 23 | M_is: List[int] = [len(sub_prompts) for sub_prompts in prompts] 24 | L_is_ms: List[List[int]] = [[len(sub_sub_prompts) for sub_sub_prompts in sub_prompts] 25 | for sub_prompts in prompts] 26 | max_L = max(max(L_i_m) for L_i_m in L_is_ms) 27 | prompt_mask = torch.zeros(N, max(M_is), max_L, dtype=torch.bool, device=device) 28 | for i, L_i_ms in enumerate(L_is_ms): 29 | for m, L_i_m in enumerate(L_i_ms): 30 | prompt_mask[i, m, :L_i_m] = True 31 | 32 | return flattened_prompts, prompt_mask 33 | 34 | 35 | def apply_placeholder(prompt: str, placeholder: str, replacements: List[str]) -> List[str]: 36 | placeholder = '{' + placeholder + '}' 37 | if placeholder not in prompt: 38 | return [prompt] 39 | else: 40 | return [prompt.replace(placeholder, replacement) for replacement in replacements] 41 | 42 | 43 | def fill_prompt_templates(prompts: List[str]) -> List[str]: 44 | for placeholder, replacements in template_placeholders.items(): 45 | prompts = list(chain(*[apply_placeholder(prompt, placeholder, replacements) for prompt in prompts])) 46 | assert not any('{' in prompt for prompt in prompts), [prompt for prompt in prompts if '{' in prompt] 47 | return prompts 48 | 49 | def localized_prompt_templates(prompts: List[str], region_templates: List[str]) -> List[List[str]]: 50 | prompts = [ 51 | [reg_template.format(prompt) for prompt in prompts] 52 | for reg_template in region_templates 53 | ] 54 | assert not any('{' in prompt for templ in prompts for prompt in templ) 55 | return prompts 56 | 57 | 58 | template_placeholders = { 59 | 'normal_adj': ['normal', 'unremarkable', 'clear'], 60 | 'shape_noun': ['size', 'silhouette', 'area', 'contours'], 61 | 'shape_adj': ['round'], 62 | 'state_verb': ['appears', 'is', 'are', 'remains', 'remain', 'appear', 'exists', 'exist'], 63 | 'indication_noun': ['signs', 'evidence', 'case of', 'presence', 'findings', 'suspicious findings'], 64 | 'indication_adj': ['noticeable', 'visible', 'seen', 'appearent', 'observable'], 65 | 'indication_verb': ['indicates', 'suggests', 'suggesting', 'indicating', 'consistent with'], 66 | 'passive_indication_verb': ['can be identified', 'can be seen', 'is present', 'is noted'], 67 | 'unchanged': ['has not improved', 'unchanged', 'remains'], 68 | 'limits_noun': ['limits'], 69 | 'moderate_adj': ['mild', 'moderate', 'extensive', 'small', 'slight', 'stable', 'intact', 'mild-moderate',], 70 | 'strong_adj': [ 71 | 'large', 'significant', 'acute', 'widespread', 'relevant', 'difficult', 'apparent', 'prominent', 72 | 'convincing', 'extensive', 'severe', 'critical', 'altered', 'patchy', 'degenerative', 'substantial', 73 | 'predominant', 'massive', 'noticeable'], 74 | 'increased_adj': ['elevated', 'enlarged', 'increased', 'larger', 'large', 'widened'], 75 | 'size_noun': ['enlargement'], 76 | 'visible_adj': ['visible', 'seen', 'appearent'], 77 | 'relation_adj': ['regarding',' relating to', 'concerning', 'involving'], 78 | 'support_dev_noun': ['catheter', 'tubes', 'support device', 'monitoring device', 'wires', 'pacemaker'], 79 | 'change': ['little change', 'unchanged',], 80 | 'lung_adj': ['pulmonary','pul', 'lung', 'airspace'] 81 | } -------------------------------------------------------------------------------- /cheff/ldm/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class LambdaWarmUpCosineScheduler: 5 | """ 6 | note: use with a base_lr of 1.0 7 | """ 8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0): 9 | self.lr_warm_up_steps = warm_up_steps 10 | self.lr_start = lr_start 11 | self.lr_min = lr_min 12 | self.lr_max = lr_max 13 | self.lr_max_decay_steps = max_decay_steps 14 | self.last_lr = 0. 15 | self.verbosity_interval = verbosity_interval 16 | 17 | def schedule(self, n, **kwargs): 18 | if self.verbosity_interval > 0: 19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}") 20 | if n < self.lr_warm_up_steps: 21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start 22 | self.last_lr = lr 23 | return lr 24 | else: 25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps) 26 | t = min(t, 1.0) 27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * ( 28 | 1 + np.cos(t * np.pi)) 29 | self.last_lr = lr 30 | return lr 31 | 32 | def __call__(self, n, **kwargs): 33 | return self.schedule(n,**kwargs) 34 | 35 | 36 | class LambdaWarmUpCosineScheduler2: 37 | """ 38 | supports repeated iterations, configurable via lists 39 | note: use with a base_lr of 1.0. 40 | """ 41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0): 42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths) 43 | self.lr_warm_up_steps = warm_up_steps 44 | self.f_start = f_start 45 | self.f_min = f_min 46 | self.f_max = f_max 47 | self.cycle_lengths = cycle_lengths 48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths)) 49 | self.last_f = 0. 50 | self.verbosity_interval = verbosity_interval 51 | 52 | def find_in_interval(self, n): 53 | interval = 0 54 | for cl in self.cum_cycles[1:]: 55 | if n <= cl: 56 | return interval 57 | interval += 1 58 | 59 | def schedule(self, n, **kwargs): 60 | cycle = self.find_in_interval(n) 61 | n = n - self.cum_cycles[cycle] 62 | if self.verbosity_interval > 0: 63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 64 | f"current cycle {cycle}") 65 | if n < self.lr_warm_up_steps[cycle]: 66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 67 | self.last_f = f 68 | return f 69 | else: 70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle]) 71 | t = min(t, 1.0) 72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * ( 73 | 1 + np.cos(t * np.pi)) 74 | self.last_f = f 75 | return f 76 | 77 | def __call__(self, n, **kwargs): 78 | return self.schedule(n, **kwargs) 79 | 80 | 81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2): 82 | 83 | def schedule(self, n, **kwargs): 84 | cycle = self.find_in_interval(n) 85 | n = n - self.cum_cycles[cycle] 86 | if self.verbosity_interval > 0: 87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, " 88 | f"current cycle {cycle}") 89 | 90 | if n < self.lr_warm_up_steps[cycle]: 91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle] 92 | self.last_f = f 93 | return f 94 | else: 95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle]) 96 | self.last_f = f 97 | return f 98 | 99 | -------------------------------------------------------------------------------- /models/clip.py: -------------------------------------------------------------------------------- 1 | import chexzero.chexzero.clip 2 | from chexzero.chexzero.clip import tokenize 3 | import torch 4 | from torch.nn import functional as F 5 | from torchvision import transforms 6 | 7 | 8 | class CLIPEvaluator: 9 | """ 10 | https://huggingface.co/StanfordAIMI/XrayCLIP__vit-b-16__laion2b-s34b-b88k?library=transformers 11 | """ 12 | 13 | def __init__(self, args, model=None, preprocess=None): 14 | self.args = args 15 | self.model_path = args.model.clip.path 16 | self.device = args.device 17 | 18 | self.model = model 19 | self.preprocess = preprocess 20 | 21 | if self.model is None and self.preprocess is None: 22 | print("Loading CLIP model for CLIPEvaluator...") 23 | self.model, self.preprocess = chexzero.clip.load("ViT-B/32", device="cpu", jit=False) 24 | self.model.load_state_dict(torch.load(self.model_path, map_location=self.device)) 25 | 26 | self.model = self.model.to(self.device) 27 | self.model.eval() 28 | 29 | self.input_size = self.model.visual.input_resolution 30 | 31 | def clip_transform_for_tensor(self, tensor: torch.Tensor, target_size=None): 32 | if target_size is None: 33 | target_size = self.input_size 34 | 35 | resize = transforms.Resize((target_size, target_size), interpolation=transforms.InterpolationMode.BICUBIC) 36 | 37 | resized = resize(tensor) 38 | 39 | # mean = [0.48145466, 0.4578275, 0.40821073] 40 | # std = [0.26862954, 0.26130258, 0.27577711] 41 | 42 | normalize = transforms.Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944)) 43 | normalize_2 = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) 44 | 45 | # return resized 46 | return normalize_2(resized) 47 | 48 | def score(self, images: torch.Tensor, texts=None, preprocess_images=True): 49 | assert texts is not None and texts != [], "Texts must be provided for scoring." 50 | 51 | tokenized_text = tokenize(texts).to(self.device) 52 | text_features = self.model.encode_text(tokenized_text) 53 | 54 | if preprocess_images: 55 | images = self.clip_transform_for_tensor(images) 56 | 57 | image_features = self.model.encode_image(images) 58 | 59 | # normalized features 60 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 61 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 62 | 63 | # cosine similarity as logits 64 | logit_scale = self.model.logit_scale.exp() 65 | logits_per_image = logit_scale * image_features @ text_features.t() 66 | logits_per_text = logit_scale * text_features @ image_features.t() 67 | 68 | # shape = [global_batch_size, global_batch_size] 69 | return logits_per_image, logits_per_text 70 | 71 | def encode_text(self, texts): 72 | if isinstance(texts, str): 73 | texts = [texts] 74 | tokenized_text = tokenize(texts).to(self.device) 75 | return self.model.encode_text(tokenized_text) 76 | 77 | def score_single(self, image: torch.Tensor, text: str, preprocess_image=True, text_features=None): 78 | if text_features is None: 79 | text_features = self.encode_text(text) 80 | 81 | # Preprocess the image if necessary 82 | if preprocess_image: 83 | # image already has batch dimension 84 | image = self.clip_transform_for_tensor(image) # add batch dimension 85 | 86 | # Extract image features 87 | image_features = self.model.encode_image(image) 88 | 89 | # Normalize the features 90 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 91 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 92 | 93 | # Compute cosine similarity between the single image and text 94 | similarity = (image_features @ text_features.T).squeeze() 95 | 96 | # Rescale similarity from [-1, 1] to [0, 1] 97 | score = (similarity + 1) / 2 98 | 99 | return score.item() 100 | -------------------------------------------------------------------------------- /cheff/machex.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from typing import Dict, Optional 4 | 5 | import torch 6 | from PIL import Image 7 | from torch.utils.data import Dataset, ConcatDataset 8 | from torchvision.transforms import Compose, ToTensor 9 | 10 | 11 | class ChestXrayDataset(Dataset): 12 | """Class for handling datasets in the MaCheX composition.""" 13 | 14 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None: 15 | """Initialize ChestXrayDataset.""" 16 | self.root = root 17 | json_path = os.path.join(self.root, 'index.json') 18 | self.index_dict = ChestXrayDataset._load_json(json_path) 19 | 20 | self.keys = list(self.index_dict.keys()) 21 | 22 | if transforms is None: 23 | self.transforms = ToTensor() 24 | else: 25 | self.transforms = transforms 26 | 27 | @staticmethod 28 | def _load_json(file_path: str) -> Dict: 29 | """Load a json file as dictionary.""" 30 | with open(file_path, 'r') as f: 31 | return json.load(f) 32 | 33 | def __len__(self): 34 | """Return length of the dataset.""" 35 | return len(self.keys) 36 | 37 | def __getitem__(self, idx: int) -> Dict: 38 | """Get dataset element.""" 39 | meta = self.index_dict[self.keys[idx]] 40 | 41 | img = Image.open(meta['path']) 42 | img = self.transforms(img) 43 | 44 | return {'img': img} 45 | 46 | 47 | class MaCheXDataset(Dataset): 48 | """Massive chest X-ray dataset.""" 49 | 50 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None: 51 | """Initialize MaCheXDataset""" 52 | self.root = root 53 | sub_dataset_roots = os.listdir(self.root) 54 | datasets = [ 55 | ChestXrayDataset(root=os.path.join(root, r), transforms=transforms) 56 | for r in sub_dataset_roots 57 | ] 58 | self.ds = ConcatDataset(datasets) 59 | 60 | def __len__(self): 61 | """Return length of the dataset.""" 62 | return len(self.ds) 63 | 64 | def __getitem__(self, idx: int) -> Dict: 65 | """Get dataset element.""" 66 | return self.ds[idx] 67 | 68 | 69 | class MimicT2IDataset(ChestXrayDataset): 70 | """Mimic subset with reports.""" 71 | 72 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None: 73 | root = os.path.join(root, 'mimic') 74 | super().__init__(root, transforms) 75 | 76 | def __getitem__(self, idx: int) -> Dict: 77 | """Get dataset element.""" 78 | meta = self.index_dict[self.keys[idx]] 79 | 80 | img = Image.open(meta['path']) 81 | img = self.transforms(img) 82 | 83 | return {'img': img, 'caption': meta['report']} 84 | 85 | 86 | class LabelChestXrayDataset(ChestXrayDataset): 87 | """A Chest X-ray dataset that returns class labels.""" 88 | 89 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None: 90 | super().__init__(root, transforms) 91 | keys = [] 92 | for key in self.keys: 93 | if self.index_dict[key].get('class_label') is not None: 94 | keys.append(key) 95 | self.keys = keys 96 | 97 | def __getitem__(self, idx: int) -> Dict: 98 | """Get dataset element.""" 99 | meta = self.index_dict[self.keys[idx]] 100 | 101 | img = Image.open(meta['path']) 102 | img = self.transforms(img) 103 | 104 | return {'img': img, 'class_label': torch.tensor(meta['class_label']).float()} 105 | 106 | 107 | class CombinedLabelChestXrayDataset(Dataset): 108 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None: 109 | """Initialize MaCheXDataset""" 110 | self.root = root 111 | sub_dataset_roots = ['mimic', 'chexpert'] 112 | datasets = [ 113 | LabelChestXrayDataset(root=os.path.join(root, r), transforms=transforms) 114 | for r in sub_dataset_roots 115 | ] 116 | self.ds = ConcatDataset(datasets) 117 | 118 | def __len__(self): 119 | """Return length of the dataset.""" 120 | return len(self.ds) 121 | 122 | def __getitem__(self, idx: int) -> Dict: 123 | """Get dataset element.""" 124 | return self.ds[idx] 125 | 126 | -------------------------------------------------------------------------------- /chexzero/components/soft_roi_pool.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from typing import Tuple, Union 3 | from torch import FloatTensor, Tensor, nn 4 | import torch 5 | from torch import autocast 6 | 7 | class SoftRoiPool(nn.Module): 8 | def forward(self, x: Tensor, box_params: Tensor, beta: Union[float, FloatTensor] = 2.) -> Tuple[Tensor, Tensor]: 9 | """ 10 | :param x: Featue map of image (N x H x W x d) 11 | :param roi_params: Parameters of bounding boxes (N x Q x 4) for Q boxes. 12 | format: (x_c, y_c, w, h) each in the range of [0, 1] (relative to image size) 13 | :return roi_features, roi_maps 14 | - roi_features: Pooled features of each roi (N x Q x d) 15 | - roi_maps: Attention maps of each roi (N x Q x H x W) 16 | """ 17 | N, H, W, d = x.shape 18 | 19 | # Compute kernel on sampling grid 20 | sampled_grid = get_sample_grid(H, W, device=x.device, dtype=x.dtype) # (H x W x 2) 21 | roi_patch_map = separable_generalized_gaussian_pdf(box_params.to(dtype=x.dtype), sampled_grid, beta=beta) # (N x Q x H x W) 22 | 23 | # with autocast(device_type='cuda', enabled=False): 24 | # Batched matrix multiplication and normalize 25 | roi_features = torch.einsum('nqhw,nhwd->nqd', roi_patch_map.float(), x.float()) # (N x Q x d) 26 | roi_features = roi_features / H * W 27 | 28 | return roi_features, roi_patch_map 29 | 30 | 31 | def get_sample_grid(H: int, W: int, device, dtype) -> Tensor: 32 | # (H x W) 33 | y, x = torch.meshgrid(torch.arange(H, device=device, dtype=dtype), 34 | torch.arange(W, device=device, dtype=dtype), 35 | indexing='ij') 36 | # (H x W x 2) 37 | sampled_grid = torch.stack([x, y], dim=-1) 38 | # consider pixel centers instead of left-upper position 39 | sampled_grid += 0.5 40 | # normalize positions into range [0, 1] 41 | sampled_grid[:, :, 0] /= W 42 | sampled_grid[:, :, 1] /= H 43 | return sampled_grid 44 | 45 | 46 | def generalized_gauss_1d_log_pdf(mu: Tensor, sigma: Tensor, sampled_grid: Tensor, 47 | beta: Union[float, FloatTensor]) -> Tensor: 48 | """ 49 | :param mu: (N x K) 50 | :param sigma: (N x K) 51 | :param sampled_grid: Sampled points (P) where P is the number of sampled points of the Gaussian pdf 52 | :return (N x K x P) 53 | """ 54 | assert len(sampled_grid.shape) == 1 55 | assert len(mu.shape) == 2 56 | assert len(sigma.shape) == 2 57 | 58 | if not isinstance(beta, (float, int)): 59 | assert isinstance(beta, Tensor) 60 | if beta.numel() > 1: 61 | assert beta.shape == mu.shape 62 | beta = beta[:, :, None] 63 | 64 | # (unnormalized) log pdf = -0.5*((x-mu)/sigma)^2 65 | # log_pdf = - (1 / beta) * ( 66 | # (sampled_grid[None, None] - mu[:, :, None]) / sigma[:, :, None] 67 | # ).pow(beta) 68 | log_pdf = -( 69 | (sampled_grid[None, None] - mu[:, :, None]).abs() / sigma[:, :, None] 70 | ).pow(beta) 71 | return log_pdf 72 | 73 | 74 | def separable_generalized_gaussian_pdf(box_params: Tensor, sampled_grid: Tensor, 75 | beta: Union[float, FloatTensor]) -> Tensor: 76 | """ 77 | :param box_params: (N x Q x 4) 78 | :param sampled_grid: (... x 2) 79 | :return: (N x Q x ...) 80 | """ 81 | N, K, _ = box_params.shape 82 | *dims, _ = sampled_grid.shape 83 | sampled_grid = sampled_grid.view(-1, 2) # (... x 2) 84 | mu = box_params[:, :, :2] # (N x K x 2) 85 | sigma = box_params[:, :, 2:] # (N x K x 2) 86 | # compute x and y Gaussian pdf's independently (in log-space and non-normalized) 87 | log_scores_x = generalized_gauss_1d_log_pdf(mu[..., 0], sigma[..., 0], 88 | sampled_grid[..., 0], beta) # (N x K x ...) 89 | log_scores_y = generalized_gauss_1d_log_pdf(mu[..., 1], sigma[..., 1], 90 | sampled_grid[..., 1], beta) # (N x K x ...) 91 | # combine them in log space (multiplication in prob space) 92 | log_scores = log_scores_x + log_scores_y # (N x K x ...) 93 | 94 | # Normalize to max value = 1 95 | scores = torch.exp(log_scores) 96 | probs = scores / (scores.max(-1, keepdim=True).values + 1e-12) 97 | 98 | # Alternative: convert to probs by applying exp and then normalizing to sum equals 1 == softmax 99 | # probs = torch.softmax(log_scores, dim=-1) # (N x K x ...) 100 | 101 | return probs.view(N, K, *dims) 102 | 103 | -------------------------------------------------------------------------------- /chexzero/txt_encoder/chexzero_txt_encoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import os 3 | from typing import List 4 | from torch import Tensor 5 | from torch import nn 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | # from model.components.mlp import MLP 10 | from chexzero.components.mlp import MLP 11 | 12 | from chexzero.util.model_utils import BaseModel, BaseModelConfig, MainModelConfig 13 | import chexzero.clip 14 | 15 | 16 | @dataclass 17 | class ChexzeroTextEncoderConfig(BaseModelConfig): 18 | model_path: str = os.path.expanduser( 19 | "~/models/third_party/chexzero/CheXzero_Models/best_64_5e-05_original_22000_0.864.pt" 20 | ) 21 | frozen_language_model: bool = True 22 | 23 | # Additional projection layers (after the language model projection) 24 | # 0 = no projection, 1 = linear, 2 = one hidden layer 25 | n_projection_layers: int = 0 26 | # whether to use batch norm in the addtional projection layers 27 | projection_bn: bool = False 28 | normalize_projected: bool = False 29 | 30 | 31 | class ChexzeroTextEncoder(BaseModel): 32 | CONFIG_CLS = ChexzeroTextEncoderConfig 33 | MODIFYABLE_CONFIGS = ("frozen_backbone",) 34 | 35 | def __init__(self, config: ChexzeroTextEncoderConfig, main_config: MainModelConfig): 36 | super().__init__(config) 37 | self.config: ChexzeroTextEncoderConfig 38 | 39 | self.d = main_config.d_model 40 | 41 | model, _ = chexzero.clip.load("ViT-B/32", device="cpu", jit=False) 42 | model.load_state_dict(torch.load(self.config.model_path, map_location="cpu")) 43 | self.d = main_config.d_model 44 | 45 | self.transformer = model.transformer 46 | self.token_embedding = model.token_embedding 47 | self.positional_embedding = model.positional_embedding 48 | self.ln_final = model.ln_final 49 | self.text_projection = model.text_projection 50 | 51 | if self.config.frozen_language_model: 52 | for param in self.parameters(): 53 | param.requires_grad = False 54 | 55 | d_backbone = self.text_projection.shape[1] 56 | 57 | self.projection = MLP( 58 | self.config.n_projection_layers, 59 | d_in=d_backbone, 60 | d_out=self.d, 61 | use_bn=self.config.projection_bn, 62 | act=main_config.act, 63 | dropout=main_config.dropout, 64 | ) 65 | 66 | self.cached_sentence_embeddings = {} 67 | 68 | @property 69 | def dtype(self): 70 | return self.token_embedding.weight.dtype 71 | 72 | @property 73 | def device(self): 74 | return self.token_embedding.weight.device 75 | 76 | def forward(self, input_ids: torch.Tensor, project: bool = True, **kwargs) -> Tensor: 77 | 78 | input_ids = input_ids.to(device=self.device) 79 | 80 | # Encode image using backbone 81 | with torch.set_grad_enabled(not self.config.frozen_language_model): 82 | x = self.token_embedding(input_ids).type(self.dtype) # [batch_size, n_ctx, d_model] 83 | 84 | x = x + self.positional_embedding.type(self.dtype) 85 | x = x.permute(1, 0, 2) # NLD -> LND 86 | x = self.transformer(x) 87 | x = x.permute(1, 0, 2) # LND -> NLD 88 | x = self.ln_final(x).type(self.dtype) 89 | 90 | # x.shape = [batch_size, n_ctx, transformer.width] 91 | # take features from the eot embedding (eot_token is the highest number in each sequence) 92 | # (N_sentences x d) 93 | features = x[torch.arange(x.shape[0]), input_ids.argmax(dim=-1)] @ self.text_projection 94 | 95 | if project: 96 | features = self.projection(features) 97 | return features 98 | 99 | def encode_sentences(self, sentences: List[str], cache=False, **kwargs) -> Tensor: 100 | if cache and self.config.frozen_language_model and all(s in self.cached_sentence_embeddings for s in sentences): 101 | features = torch.stack([self.cached_sentence_embeddings[s] for s in sentences], dim=0) 102 | else: 103 | input_ids = chexzero.clip.tokenize(sentences, context_length=77) 104 | features = self(input_ids, project=False, **kwargs) 105 | 106 | if cache and self.config.frozen_language_model: 107 | for s, f in zip(sentences, features): 108 | self.cached_sentence_embeddings[s] = f.detach().float() 109 | features = features.to(dtype=self.dtype) 110 | projected_features = self.projection(features) 111 | return projected_features 112 | -------------------------------------------------------------------------------- /models/classifier.py: -------------------------------------------------------------------------------- 1 | from torchvision.models import densenet121, resnet152 2 | from chexpert.models.efficientnet import construct_model 3 | from chexpert.models.attn_aug_conv import DenseNet, ResNet, Bottleneck 4 | from chexpert.dataset import ChexpertSmall, extract_patient_ids 5 | from torchvision.models.densenet import DenseNet121_Weights 6 | 7 | import torch 8 | import torch.nn as nn 9 | import numpy as np 10 | import os 11 | 12 | 13 | def load_model(args): 14 | 15 | classifier_args = args.model.classifier 16 | model_name, restore, restore_path, lr, pretrained = ( 17 | classifier_args.model.strip().lower(), 18 | classifier_args.restore, 19 | classifier_args.path, 20 | classifier_args.lr, 21 | classifier_args.pretrained, 22 | ) 23 | 24 | print("Loading model", model_name) 25 | 26 | # load model 27 | n_classes = len(ChexpertSmall.attr_names) 28 | if model_name == "densenet121": 29 | model = densenet121(weights=(DenseNet121_Weights.DEFAULT if pretrained else None)).to(args.device) 30 | # 1. replace output layer with chexpert number of classes (pretrained loads ImageNet n_classes) 31 | model.classifier = nn.Linear(model.classifier.in_features, out_features=n_classes).to(args.device) 32 | # 2. init output layer with default torchvision init 33 | nn.init.constant_(model.classifier.bias, 0) 34 | # 3. store locations of forward and backward hooks for grad-cam 35 | grad_cam_hooks = {"forward": model.features.norm5, "backward": model.classifier} 36 | # 4. init optimizer and scheduler 37 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 38 | scheduler = None 39 | # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True) 40 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40000, 60000]) 41 | elif model_name == "aadensenet121": 42 | model = DenseNet( 43 | 32, 44 | (6, 12, 24, 16), 45 | 64, 46 | num_classes=n_classes, 47 | attn_params={"k": 0.2, "v": 0.1, "nh": 8, "relative": True, "input_dims": (320, 320)}, 48 | ).to(args.device) 49 | grad_cam_hooks = {"forward": model.features, "backward": model.classifier} 50 | attn_hooks = [model.features.transition1.conv, model.features.transition2.conv, model.features.transition3.conv] 51 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True) 52 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40000, 60000]) 53 | elif model_name == "resnet152": 54 | model = resnet152(weights=pretrained).to(args.device) 55 | model.fc = nn.Linear(model.fc.in_features, out_features=n_classes).to(args.device) 56 | grad_cam_hooks = {"forward": model.layer4, "backward": model.fc} 57 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 58 | scheduler = None 59 | elif ( 60 | model_name == "aaresnet152" 61 | ): # resnet50 layers [3,4,6,3]; resnet101 layers [3,4,23,3]; resnet 152 layers [3,8,36,3] 62 | model = ResNet( 63 | Bottleneck, 64 | [3, 8, 36, 3], 65 | num_classes=n_classes, 66 | attn_params={"k": 0.2, "v": 0.1, "nh": 8, "relative": True, "input_dims": (320, 320)}, 67 | ).to(args.device) 68 | grad_cam_hooks = {"forward": model.layer4, "backward": model.fc} 69 | attn_hooks = ( 70 | [model.layer2[i].conv2 for i in range(len(model.layer2))] 71 | + [model.layer3[i].conv2 for i in range(len(model.layer3))] 72 | + [model.layer4[i].conv2 for i in range(len(model.layer4))] 73 | ) 74 | optimizer = torch.optim.Adam(model.parameters(), lr=lr) 75 | scheduler = None 76 | elif "efficientnet" in model_name: 77 | model = construct_model(model_name, n_classes=n_classes).to(args.device) 78 | grad_cam_hooks = {"forward": model.head[1], "backward": model.head[-1]} 79 | optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, momentum=0.9, eps=0.001) 80 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, lr_decay_factor) 81 | else: 82 | raise RuntimeError("Model architecture not supported.") 83 | 84 | if restore and os.path.isfile(restore_path): 85 | print("Restoring model weights from {}".format(restore_path)) 86 | model_checkpoint = torch.load(restore_path, map_location=args.device, weights_only=False) 87 | model.load_state_dict(model_checkpoint["state_dict"]) 88 | args.model.classifier.step = model_checkpoint["global_step"] 89 | del model_checkpoint 90 | 91 | model = model.to(args.device) 92 | 93 | return model, grad_cam_hooks 94 | 95 | 96 | def define_classifier_model(args): 97 | model, _ = load_model(args) 98 | model = model.to(args.device) 99 | 100 | model.eval() 101 | 102 | classifier_name = args.model.classifier.model.lower() 103 | 104 | print("Using classifier:", classifier_name) 105 | 106 | return model 107 | -------------------------------------------------------------------------------- /src/data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | from PIL import Image 4 | from torch.utils.data import Dataset 5 | from collections import defaultdict 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | from tqdm import tqdm 9 | from torch.utils.data import random_split 10 | import pandas as pd 11 | from chexpert.dataset import ChexpertSmall 12 | import torchvision.transforms as T 13 | 14 | 15 | class ChestXRayDataset(Dataset): 16 | def __init__(self, opts, data_dir, mode="train", transform=None): # mode="train" or "valid" 17 | self.opts = opts 18 | self.mode = mode 19 | self.target = opts.finetune.target 20 | self.data_dir = data_dir 21 | self.transform = transform 22 | self.csv_path = os.path.join(data_dir, f"{mode}.csv") 23 | 24 | # Read CSV 25 | self.df = pd.read_csv(self.csv_path) 26 | self.df = self.df[ChexpertSmall.attr_names + ["Path"]] 27 | self.df = self.df.replace(-1, 0) 28 | self.df = self.df.fillna(0) 29 | 30 | # Filter out lateral images 31 | self.df = self.df[~self.df["Path"].str.contains("lateral", case=False, na=False)] 32 | 33 | self.diseased_image_paths = self.df[self.df[self.target] == 1]["Path"].tolist() 34 | self.healthy_image_paths = self.df[self.df[ChexpertSmall.attr_names].eq(0).all(axis=1)]["Path"].tolist() 35 | 36 | min_images = min(len(self.diseased_image_paths), len(self.healthy_image_paths)) 37 | print( 38 | f"Min images: {min_images}, Diseased: {len(self.diseased_image_paths)}, Healthy: {len(self.healthy_image_paths)}" 39 | ) 40 | 41 | self.diseased_image_paths = self.diseased_image_paths[:min_images] 42 | self.healthy_image_paths = self.healthy_image_paths[:min_images] 43 | 44 | def __len__(self): 45 | return len(self.diseased_image_paths) 46 | 47 | def __getitem__(self, idxs): 48 | return self.__fetch_one(idxs) 49 | 50 | def __fetch_one(self, idx): 51 | diseased_path = self.diseased_image_paths[idx] 52 | healthy_path = self.healthy_image_paths[idx] 53 | 54 | diseased_path = os.path.join(self.data_dir, diseased_path.replace(self.opts.data.prefix, "")) 55 | healthy_path = os.path.join(self.data_dir, healthy_path.replace(self.opts.data.prefix, "")) 56 | 57 | diseased_image = Image.open(diseased_path).convert("RGB") 58 | healthy_image = Image.open(healthy_path).convert("RGB") 59 | 60 | if self.transform: 61 | diseased_image = self.transform(diseased_image) 62 | healthy_image = self.transform(healthy_image) 63 | 64 | return diseased_image, healthy_image 65 | 66 | 67 | class ChestXRayDataModule(pl.LightningDataModule): 68 | def __init__(self, opts): 69 | super().__init__() 70 | self.opts = opts 71 | self.transform = T.Compose( 72 | [ 73 | T.Resize(self.opts.data.resize) if self.opts.data.resize else T.Lambda(lambda x: x), 74 | T.CenterCrop(320 if not self.opts.data.resize else self.opts.data.resize), 75 | T.ToTensor(), 76 | ] 77 | ) # expand to 3 channels 78 | 79 | self.train_dataset = None 80 | self.val_dataset = None 81 | self.test_dataset = None 82 | 83 | def get_train_dataset(self): 84 | ds = ChestXRayDataset(self.opts, self.opts.data.path, mode="train", transform=self.transform) 85 | print(f"Train dataset size: {len(ds)}") 86 | return ds 87 | 88 | def get_val_dataset(self): 89 | return self.get_test_dataset() 90 | 91 | def get_test_dataset(self): 92 | ds = ChestXRayDataset(self.opts, self.opts.data.path, mode="valid", transform=self.transform) 93 | print(f"Test dataset size: {len(ds)}") 94 | return ds 95 | 96 | def setup(self, stage=None): 97 | self.train_dataset = self.get_train_dataset() 98 | self.val_dataset = self.get_val_dataset() 99 | self.test_dataset = self.get_test_dataset() 100 | 101 | def train_dataloader(self): 102 | if self.train_dataset is None: 103 | self.train_dataset = self.get_train_dataset() 104 | 105 | return torch.utils.data.DataLoader( 106 | self.train_dataset, 107 | batch_size=self.opts.data.train_batch_size, 108 | shuffle=False, 109 | num_workers=self.opts.data.num_workers, 110 | pin_memory=True, 111 | prefetch_factor=( 112 | self.opts.data.prefetch_factor 113 | if self.opts.data.prefetch_factor and self.opts.data.num_workers > 0 114 | else None 115 | ), 116 | ) 117 | 118 | def val_dataloader(self): 119 | return self.test_dataloader() 120 | 121 | def test_dataloader(self): 122 | if self.test_dataset is None: 123 | self.test_dataset = self.get_test_dataset() 124 | 125 | return torch.utils.data.DataLoader( 126 | self.test_dataset, 127 | batch_size=self.opts.data.test_batch_size, 128 | shuffle=False, 129 | num_workers=self.opts.data.num_workers, 130 | pin_memory=True, 131 | prefetch_factor=( 132 | self.opts.data.prefetch_factor 133 | if self.opts.data.prefetch_factor and self.opts.data.num_workers > 0 134 | else None 135 | ), 136 | ) 137 | 138 | 139 | if __name__ == "__main__": 140 | 141 | from omegaconf import OmegaConf 142 | 143 | opts = OmegaConf.load( 144 | "/cluster/home/agoekmen/projects/chexray-diffusion/removal-editing/src/config/finetune_config.yaml" 145 | ) 146 | data_module = ChestXRayDataModule(opts) 147 | data_module.setup() 148 | train_loader = data_module.train_dataloader() 149 | val_loader = data_module.val_dataloader() 150 | 151 | for a, b in train_loader: 152 | print(a.shape, b.shape) 153 | break 154 | -------------------------------------------------------------------------------- /cheff/ldm/modules/losses/contperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no? 5 | 6 | 7 | class LPIPSWithDiscriminator(nn.Module): 8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0, 9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 11 | disc_loss="hinge"): 12 | 13 | super().__init__() 14 | assert disc_loss in ["hinge", "vanilla"] 15 | self.kl_weight = kl_weight 16 | self.pixel_weight = pixelloss_weight 17 | self.perceptual_loss = LPIPS().eval() 18 | self.perceptual_weight = perceptual_weight 19 | # output log variance 20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init) 21 | 22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 23 | n_layers=disc_num_layers, 24 | use_actnorm=use_actnorm 25 | ).apply(weights_init) 26 | self.discriminator_iter_start = disc_start 27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss 28 | self.disc_factor = disc_factor 29 | self.discriminator_weight = disc_weight 30 | self.disc_conditional = disc_conditional 31 | 32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 33 | if last_layer is not None: 34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 36 | else: 37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 39 | 40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 42 | d_weight = d_weight * self.discriminator_weight 43 | return d_weight 44 | 45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx, 46 | global_step, last_layer=None, cond=None, split="train", 47 | weights=None): 48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 49 | if self.perceptual_weight > 0: 50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 51 | rec_loss = rec_loss + self.perceptual_weight * p_loss 52 | 53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar 54 | weighted_nll_loss = nll_loss 55 | if weights is not None: 56 | weighted_nll_loss = weights*nll_loss 57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0] 58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 59 | kl_loss = posteriors.kl() 60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0] 61 | 62 | # now the GAN part 63 | if optimizer_idx == 0: 64 | # generator update 65 | if cond is None: 66 | assert not self.disc_conditional 67 | logits_fake = self.discriminator(reconstructions.contiguous()) 68 | else: 69 | assert self.disc_conditional 70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 71 | g_loss = -torch.mean(logits_fake) 72 | 73 | if self.disc_factor > 0.0: 74 | try: 75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 76 | except RuntimeError: 77 | assert not self.training 78 | d_weight = torch.tensor(0.0) 79 | else: 80 | d_weight = torch.tensor(0.0) 81 | 82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss 84 | 85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(), 86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(), 87 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 88 | "{}/d_weight".format(split): d_weight.detach(), 89 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 90 | "{}/g_loss".format(split): g_loss.detach().mean(), 91 | } 92 | return loss, log 93 | 94 | if optimizer_idx == 1: 95 | # second pass for discriminator update 96 | if cond is None: 97 | logits_real = self.discriminator(inputs.contiguous().detach()) 98 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 99 | else: 100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 102 | 103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 105 | 106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 107 | "{}/logits_real".format(split): logits_real.detach().mean(), 108 | "{}/logits_fake".format(split): logits_fake.detach().mean() 109 | } 110 | return d_loss, log 111 | 112 | -------------------------------------------------------------------------------- /chexzero/components/classification_losses.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | from math import prod 3 | from typing import Optional, Tuple 4 | import einops 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | def classify_features(features: torch.FloatTensor, pos_prompt_emb: torch.FloatTensor, neg_prompt_emb: torch.FloatTensor, 9 | normalized: bool = True, temp=1.0, softmax=False, 10 | threshold: Optional[float] = 0.5, return_logits=False) -> Tuple[torch.FloatTensor, torch.BoolTensor]: 11 | """ 12 | :param features: (N x ... x d) 13 | :param pos_prompt_emb: (N x ... x d) 14 | :param neg_prompt_emb: (N x ... x d) 15 | :return (N x ...) 16 | """ 17 | assert pos_prompt_emb.ndim == neg_prompt_emb.ndim 18 | assert features.ndim == pos_prompt_emb.ndim, f'{features.ndim} != {pos_prompt_emb.ndim}' 19 | features, pos_prompt_emb, neg_prompt_emb = torch.broadcast_tensors(features, pos_prompt_emb, neg_prompt_emb) 20 | 21 | if normalized: 22 | features = F.normalize(features, dim=-1) 23 | pos_prompt_emb = F.normalize(pos_prompt_emb, dim=-1) 24 | neg_prompt_emb = F.normalize(neg_prompt_emb, dim=-1) 25 | else: 26 | features = features.contiguous() 27 | pos_prompt_emb = pos_prompt_emb.contiguous() 28 | neg_prompt_emb = neg_prompt_emb.contiguous() 29 | 30 | 31 | N, *dims, d = features.shape 32 | n_dims = prod(dims) 33 | features = features.view(N, n_dims, d) 34 | pos_prompt_emb = pos_prompt_emb.view(N, n_dims, d) 35 | neg_prompt_emb = neg_prompt_emb.view(N, n_dims, d) 36 | 37 | pos_logits = torch.einsum('ijd,ijd->ij', features, pos_prompt_emb) / temp # (N x dims) 38 | neg_logits = torch.einsum('ijd,ijd->ij', features, neg_prompt_emb) / temp # (N x dims) 39 | 40 | if softmax: 41 | # (N x dims x 2) 42 | probs = torch.softmax(torch.stack([pos_logits, neg_logits], dim=-1), dim=-1) 43 | probs = probs[..., 0] # only positive probs 44 | else: 45 | probs = torch.sigmoid(pos_logits - neg_logits) 46 | probs = probs.view(N, *dims) 47 | 48 | preds = probs > threshold if threshold is not None else torch.ones_like(probs, dtype=bool) 49 | 50 | if not return_logits: 51 | return probs, preds 52 | 53 | if softmax: 54 | logits = torch.log_softmax(torch.stack([pos_logits, neg_logits], dim=-1), dim=-1) 55 | logits = logits[..., 0] # only positive probs 56 | else: 57 | logits = pos_logits - neg_logits 58 | logits = logits.view(N, *dims) 59 | 60 | return logits, probs, preds 61 | 62 | 63 | 64 | # ----------------------- Binary (and multilabel binary) losses ----------------------- # 65 | def binary_focal_loss_logits( 66 | logits: torch.Tensor, 67 | targets: torch.Tensor, 68 | alpha: float = 0.25, 69 | gamma: float = 2) -> torch.Tensor: 70 | logits, targets = torch.broadcast_tensors(logits, targets) 71 | targets = targets.to(dtype=logits.dtype) 72 | 73 | p = torch.sigmoid(logits) 74 | ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none") 75 | p_t = p * targets + (1 - p) * (1 - targets) 76 | loss = ce_loss * ((1 - p_t) ** gamma) 77 | 78 | if torch.is_tensor(alpha) or alpha >= 0: 79 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 80 | loss = alpha_t * loss 81 | return loss 82 | 83 | def binary_focal_loss_probs( 84 | probs: torch.Tensor, 85 | targets: torch.Tensor, 86 | alpha: float = 0.25, 87 | gamma: float = 2, 88 | eps=1e-7) -> torch.Tensor: 89 | probs, targets = torch.broadcast_tensors(probs, targets) 90 | targets = targets.to(dtype=probs.dtype) 91 | 92 | p = probs 93 | with torch.autocast(device_type='cuda', enabled=False): 94 | ce_loss = F.binary_cross_entropy(probs.float().clamp(min=eps, max=1-eps), targets.float(), reduction="none") 95 | p_t = p * targets + (1 - p) * (1 - targets) 96 | loss = ce_loss * ((1 - p_t) ** gamma) 97 | 98 | if alpha >= 0: 99 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets) 100 | loss = alpha_t * loss 101 | return loss 102 | 103 | 104 | def autocompute_class_weights(targets: torch.Tensor, per_class: bool = False, cls_dim: int = None) -> torch.Tensor: 105 | if per_class: 106 | assert cls_dim is not None 107 | dims_before = targets.shape[:cls_dim] 108 | C = targets.shape[cls_dim] 109 | dims_after = targets.shape[cls_dim+1:] 110 | # (... x C x ...) where C is the dim at index cls_dim 111 | targets = targets.view(prod(dims_before), C, prod(dims_after)) 112 | 113 | N_pos = targets.sum(dim=0).sum(dim=-1) # (C) 114 | # (1 x ... x 1 x C x 1 x ...) 115 | N_pos = N_pos.view((1,) * len(dims_before) + (C,) + (1,) * len(dims_after)) 116 | N = targets.numel() / C # () 117 | else: 118 | # (...) 119 | targets = targets.view(-1) 120 | N_pos = targets.sum() 121 | N = targets.numel() 122 | 123 | N_neg = N - N_pos # (C) or () 124 | 125 | weight_pos = (N + 1) / (N_pos + 1) # (C) or () 126 | weight_neg = (N + 1) / (N_neg + 1) # (C) or () 127 | 128 | return weight_pos, weight_neg 129 | 130 | 131 | def get_focal_loss(logits: bool = True, auto_weight: bool = False, **kwargs): 132 | focal_loss_fn = binary_focal_loss_logits if logits else binary_focal_loss_probs 133 | if auto_weight: 134 | def _loss_fn(preds, targets): 135 | preds, targets = torch.broadcast_tensors(preds, targets) 136 | weight_pos, weight_neg = autocompute_class_weights(targets, per_class=False) 137 | alpha = weight_pos / (weight_pos + weight_neg) 138 | return focal_loss_fn(preds, targets, alpha=alpha, **kwargs) 139 | else: 140 | _loss_fn = partial(focal_loss_fn, **kwargs) 141 | return _loss_fn 142 | -------------------------------------------------------------------------------- /cheff/sr/schedule.py: -------------------------------------------------------------------------------- 1 | """Classes und functions for variance schedules.""" 2 | 3 | from abc import ABC, abstractmethod 4 | from typing import Any, Optional 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | 11 | class BaseSchedule(ABC): 12 | """Base class for deriving schedules.""" 13 | 14 | def __init__( 15 | self, timesteps: int, device: Optional[torch.device] = None, *args, **kwargs 16 | ) -> None: 17 | """Initialize BaseSchedule.""" 18 | self.timesteps = timesteps 19 | if device is None: 20 | self.device = torch.device('cpu') 21 | else: 22 | self.device = device 23 | 24 | self.betas = self._get_betas(timesteps).to(device) 25 | self.alphas = 1.0 - self.betas 26 | 27 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) 28 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0) 29 | 30 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas) 31 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod) 32 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod) 33 | 34 | self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod) 35 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod) 36 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1) 37 | 38 | self.post_var = ( 39 | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) 40 | ) 41 | 42 | self.post_log_var_clipped = torch.log( 43 | torch.maximum(self.post_var, torch.tensor(1e-20)) 44 | ) 45 | self.post_mean_coef1 = ( 46 | self.betas 47 | * torch.sqrt(self.alphas_cumprod_prev) 48 | / (1.0 - self.alphas_cumprod) 49 | ) 50 | self.post_mean_coef2 = ( 51 | (1.0 - self.alphas_cumprod_prev) 52 | * torch.sqrt(self.alphas) 53 | / (1.0 - self.alphas_cumprod) 54 | ) 55 | 56 | @abstractmethod 57 | def _get_betas(self, timesteps: int) -> Tensor: 58 | """Get betas.""" 59 | pass 60 | 61 | 62 | class LinearSchedule(BaseSchedule): 63 | """Linear variance schedule.""" 64 | 65 | def __init__( 66 | self, 67 | timesteps: int, 68 | device: Optional[torch.device] = None, 69 | beta_start: float = 0.0001, 70 | beta_end: float = 0.02, 71 | *args, 72 | **kwargs 73 | ) -> None: 74 | """Initialize linear beta schedule.""" 75 | self.beta_start = beta_start 76 | self.beta_end = beta_end 77 | super().__init__(timesteps, device, *args, **kwargs) 78 | 79 | def _get_betas(self, timesteps: int) -> Tensor: 80 | """Get betas.""" 81 | return torch.linspace(self.beta_start, self.beta_end, timesteps) 82 | 83 | 84 | class CosineSchedule(BaseSchedule): 85 | """Cosine variance schedule.""" 86 | 87 | def __init__( 88 | self, 89 | timesteps: int, 90 | device: Optional[torch.device] = None, 91 | s: float = 0.008, 92 | *args, 93 | **kwargs 94 | ) -> None: 95 | """Initialize cosine beta schedule.""" 96 | self.s = s 97 | super().__init__(timesteps, device, *args, **kwargs) 98 | 99 | def _get_betas(self, timesteps: int) -> Tensor: 100 | """Get betas.""" 101 | steps = timesteps + 1 102 | x = torch.linspace(0, timesteps, steps) 103 | alphas_cumprod = ( 104 | torch.cos(((x / timesteps) + self.s) / (1 + self.s) * torch.pi * 0.5) ** 2 105 | ) 106 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0] 107 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) 108 | return torch.clip(betas, 0.0001, 0.9999) 109 | 110 | 111 | class QuadraticSchedule(BaseSchedule): 112 | """Quadratic variance schedule.""" 113 | 114 | def __init__( 115 | self, 116 | timesteps: int, 117 | device: Optional[torch.device] = None, 118 | beta_start: float = 0.0001, 119 | beta_end: float = 0.02, 120 | *args, 121 | **kwargs 122 | ) -> None: 123 | """Initialize quadratic beta schedule.""" 124 | self.beta_start = beta_start 125 | self.beta_end = beta_end 126 | super().__init__(timesteps, device, *args, **kwargs) 127 | 128 | def _get_betas(self, timesteps: int) -> Tensor: 129 | """Get betas.""" 130 | return ( 131 | torch.linspace(self.beta_start**0.5, self.beta_end**0.5, timesteps) ** 2 132 | ) 133 | 134 | 135 | class SigmoidSchedule(BaseSchedule): 136 | """Sigmoid variance schedule.""" 137 | 138 | def __init__( 139 | self, 140 | timesteps: int, 141 | device: Optional[torch.device] = None, 142 | beta_start: float = 0.0001, 143 | beta_end: float = 0.02, 144 | *args, 145 | **kwargs 146 | ) -> None: 147 | """Initialize sigmoid beta schedule.""" 148 | self.beta_start = beta_start 149 | self.beta_end = beta_end 150 | super().__init__(timesteps, device, *args, **kwargs) 151 | 152 | def _get_betas(self, timesteps: int) -> Tensor: 153 | """Get betas.""" 154 | betas = torch.linspace(-6, 6, timesteps) 155 | return ( 156 | torch.sigmoid(betas) * (self.beta_end - self.beta_start) + self.beta_start 157 | ) 158 | 159 | 160 | class ScheduleFactory: 161 | """Factory wrapper for variance schedules.""" 162 | 163 | @staticmethod 164 | def get_schedule(name: str, timesteps: int, *args, **kwargs) -> BaseSchedule: 165 | """Initialize a scheduler by name.""" 166 | cls: Any 167 | if name == 'linear': 168 | cls = LinearSchedule 169 | elif name == 'cosine': 170 | cls = CosineSchedule 171 | elif name == 'quadratic': 172 | cls = QuadraticSchedule 173 | elif name == 'sigmoid': 174 | cls = SigmoidSchedule 175 | else: 176 | raise ValueError( 177 | 'There is no matching schedule for name "{}".'.format(name) 178 | ) 179 | 180 | return cls(timesteps, *args, **kwargs) 181 | -------------------------------------------------------------------------------- /chexzero/chexzero/simple_tokenizer.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 OpenAI 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | """ 25 | import gzip 26 | import html 27 | import os 28 | from functools import lru_cache 29 | 30 | import ftfy 31 | import regex as re 32 | 33 | 34 | @lru_cache() 35 | def default_bpe(): 36 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") 37 | 38 | 39 | @lru_cache() 40 | def bytes_to_unicode(): 41 | """ 42 | Returns list of utf-8 byte and a corresponding list of unicode strings. 43 | The reversible bpe codes work on unicode strings. 44 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. 45 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. 46 | This is a signficant percentage of your normal, say, 32K bpe vocab. 47 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings. 48 | And avoids mapping to whitespace/control characters the bpe code barfs on. 49 | """ 50 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) 51 | cs = bs[:] 52 | n = 0 53 | for b in range(2**8): 54 | if b not in bs: 55 | bs.append(b) 56 | cs.append(2**8+n) 57 | n += 1 58 | cs = [chr(n) for n in cs] 59 | return dict(zip(bs, cs)) 60 | 61 | 62 | def get_pairs(word): 63 | """Return set of symbol pairs in a word. 64 | Word is represented as tuple of symbols (symbols being variable-length strings). 65 | """ 66 | pairs = set() 67 | prev_char = word[0] 68 | for char in word[1:]: 69 | pairs.add((prev_char, char)) 70 | prev_char = char 71 | return pairs 72 | 73 | 74 | def basic_clean(text): 75 | text = ftfy.fix_text(text) 76 | text = html.unescape(html.unescape(text)) 77 | return text.strip() 78 | 79 | 80 | def whitespace_clean(text): 81 | text = re.sub(r'\s+', ' ', text) 82 | text = text.strip() 83 | return text 84 | 85 | 86 | class SimpleTokenizer(object): 87 | def __init__(self, bpe_path: str = default_bpe()): 88 | self.byte_encoder = bytes_to_unicode() 89 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} 90 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') 91 | merges = merges[1:49152-256-2+1] 92 | merges = [tuple(merge.split()) for merge in merges] 93 | vocab = list(bytes_to_unicode().values()) 94 | vocab = vocab + [v+'' for v in vocab] 95 | for merge in merges: 96 | vocab.append(''.join(merge)) 97 | vocab.extend(['<|startoftext|>', '<|endoftext|>']) 98 | self.encoder = dict(zip(vocab, range(len(vocab)))) 99 | self.decoder = {v: k for k, v in self.encoder.items()} 100 | self.bpe_ranks = dict(zip(merges, range(len(merges)))) 101 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} 102 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) 103 | 104 | def bpe(self, token): 105 | if token in self.cache: 106 | return self.cache[token] 107 | word = tuple(token[:-1]) + ( token[-1] + '',) 108 | pairs = get_pairs(word) 109 | 110 | if not pairs: 111 | return token+'' 112 | 113 | while True: 114 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) 115 | if bigram not in self.bpe_ranks: 116 | break 117 | first, second = bigram 118 | new_word = [] 119 | i = 0 120 | while i < len(word): 121 | try: 122 | j = word.index(first, i) 123 | new_word.extend(word[i:j]) 124 | i = j 125 | except: 126 | new_word.extend(word[i:]) 127 | break 128 | 129 | if word[i] == first and i < len(word)-1 and word[i+1] == second: 130 | new_word.append(first+second) 131 | i += 2 132 | else: 133 | new_word.append(word[i]) 134 | i += 1 135 | new_word = tuple(new_word) 136 | word = new_word 137 | if len(word) == 1: 138 | break 139 | else: 140 | pairs = get_pairs(word) 141 | word = ' '.join(word) 142 | self.cache[token] = word 143 | return word 144 | 145 | def encode(self, text): 146 | bpe_tokens = [] 147 | text = whitespace_clean(basic_clean(text)).lower() 148 | for token in re.findall(self.pat, text): 149 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) 150 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) 151 | return bpe_tokens 152 | 153 | def decode(self, tokens): 154 | text = ''.join([self.decoder[token] for token in tokens]) 155 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') 156 | return text 157 | -------------------------------------------------------------------------------- /models/cheff.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import os 3 | import torch 4 | 5 | from cheff import CheffLDMImageMaskCond 6 | from cheff import CheffLDMMaskCond 7 | from cheff import CheffLDMImageCond 8 | from cheff import CheffLDM 9 | from cheff.ldm.models.autoencoder import AutoencoderKL 10 | from cheff.ldm.models.diffusion.ddpm import LatentDiffusion 11 | from cheff.ldm.models.diffusion.ddim import DDIMSampler 12 | from cheff.ldm.modules.diffusionmodules.openaimodel import AttentionBlock 13 | from cheff.ldm.modules.diffusionmodules.openaimodel import ResBlock 14 | from src.utils import save_txt 15 | 16 | 17 | def get_cheff_ldm( 18 | args, load_checkpoint=True, load_base_uncond=False, load_image_mask_cond=False, load_image_cond=False 19 | ): 20 | print("Loading cheff ldm...") 21 | cheff_args = args.model.cheff 22 | model_path, ae_path = cheff_args.ldm_path, cheff_args.ae_path 23 | 24 | if load_base_uncond: 25 | print("Loading CheffLDM...") 26 | cheff_ldm = CheffLDM(model_path=model_path, ae_path=ae_path, device=args.device) 27 | return cheff_ldm 28 | 29 | if not load_image_mask_cond and not load_image_cond: 30 | print("Loading CheffLDMMaskCond...") 31 | cheff_ldm = CheffLDMMaskCond( 32 | model_path=model_path, ae_path=ae_path, device=args.device, load_checkpoint=load_checkpoint 33 | ) 34 | elif load_image_mask_cond: 35 | print("Loading CheffLDMImageMaskCond...") 36 | cheff_ldm = CheffLDMImageMaskCond( 37 | model_path=model_path, ae_path=ae_path, device=args.device, load_checkpoint=load_checkpoint 38 | ) 39 | elif load_image_cond: 40 | print("Loading CheffLDMImageCond...") 41 | cheff_ldm = CheffLDMImageCond( 42 | model_path=model_path, ae_path=ae_path, device=args.device, load_checkpoint=load_checkpoint 43 | ) 44 | else: 45 | raise ValueError("Invalid configuration for Cheff LDM") 46 | 47 | print("Cheff LDM loaded: ", load_checkpoint) 48 | 49 | if cheff_args.load_external: 50 | print("Loading external model...") 51 | cheff_ldm.model.load_state_dict(torch.load(cheff_args.external_path, map_location=args.device)) 52 | 53 | return cheff_ldm 54 | 55 | 56 | """ 57 | The effect of applying the erasure objective (6) depends 58 | on the subset of parameters that is fine-tuned. The main 59 | distinction is between cross-attention parameters and noncross-attention parameters. Cross-attention parameters, illustrated in Figure 3a, serve as a gateway to the prompt, directly 60 | depending on the text of the prompt, while other parameters 61 | (Figure 3b) tend to contribute to a visual concept even if the 62 | concept is not mentioned in the prompt. 63 | Therefore we propose fine tuning the cross attentions, 64 | ESD-x, when the erasure is required to be controlled and 65 | specific to the prompt, such as when a named artistic style 66 | should be erased. Further, we propose fine tuning unconditional layers (non-cross-attention modules), ESD-u, when 67 | the erasure is required to be independent of the text in the 68 | prompt, such as when the global concept of NSFW nudity 69 | should be erased. We refer to cross-attention-only finetuning as ESD-x-η (where η refers to the strength of the 70 | negative guidance), and we refer to the configuration that 71 | tunes only non-cross-attention parameters as ESD-u-η. For 72 | simplicity, we write ESD-x and ESD-u when η = 1 73 | """ 74 | 75 | all_train_methods = { 76 | "attn": [AttentionBlock], 77 | "res": [ResBlock], 78 | "attnres": [AttentionBlock, ResBlock], 79 | "all": None, 80 | "notime": None, 81 | } 82 | 83 | 84 | def get_trainable_params(model, args, specific_paths=["model.diffusion_model.input_blocks.0.0"], save_params=False): 85 | parameters = [] 86 | parameter_names = [] 87 | 88 | train_method = args.model.train.method.lower().strip() 89 | 90 | assert train_method in all_train_methods.keys(), "Unsupported train method" 91 | print("Using train method:", train_method) 92 | 93 | if train_method == "notime": 94 | # Go through all the parameters and only exclude the time embedding ones 95 | for name, param in model.named_parameters(): 96 | if "time_embed" not in name and "first_stage_model" not in name: 97 | param.requires_grad = True 98 | parameters.append(param) 99 | parameter_names.append(name) 100 | 101 | elif train_method == "all": 102 | for name, param in model.named_parameters(): 103 | if "first_stage_model" not in name: 104 | param.requires_grad = True 105 | parameters.append(param) 106 | parameter_names.append(name) 107 | 108 | else: 109 | train_filters = all_train_methods[train_method] 110 | 111 | def recurse(module, current_path=""): 112 | for name, child in module.named_children(): 113 | path = f"{current_path}.{name}" if current_path else name 114 | 115 | if any(isinstance(child, klass) for klass in train_filters) or ( 116 | specific_paths and path in specific_paths 117 | ): 118 | for param_name, param in child.named_parameters(): 119 | param.requires_grad = True 120 | parameters.append(param) 121 | parameter_names.append( 122 | f"{path} -- {name}.{param_name} -- {child.__class__.__name__} -- {param_name}" 123 | ) 124 | else: 125 | recurse(child, path) 126 | 127 | recurse(model) 128 | 129 | if save_params: 130 | save_txt( 131 | args.finetune.log_dir, 132 | "trainable_params", 133 | parameter_names, 134 | ) 135 | 136 | # with open(os.path.join(args.out_dir, "non_trainable_params.txt"), "w") as f: 137 | # f.write("\n".join(non_trainable_parameter_names)) 138 | 139 | print("Trainable parameters:", len(parameters)) 140 | 141 | return parameters 142 | 143 | 144 | def clone_model_for_sample(original_model, args): 145 | cloned_model = copy.deepcopy(original_model) 146 | 147 | parameters = get_trainable_params(cloned_model.model, args, save_params=True) 148 | 149 | cloned_model.model.train() 150 | cloned_model.model.model.diffusion_model.train() 151 | return cloned_model, parameters 152 | -------------------------------------------------------------------------------- /cheff/ldm/util.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | import torch 4 | import numpy as np 5 | from collections import abc 6 | from einops import rearrange 7 | from functools import partial 8 | 9 | import multiprocessing as mp 10 | from threading import Thread 11 | from queue import Queue 12 | 13 | from inspect import isfunction 14 | from PIL import Image, ImageDraw, ImageFont 15 | 16 | 17 | def log_txt_as_img(wh, xc, size=10): 18 | # wh a tuple of (width, height) 19 | # xc a list of captions to plot 20 | b = len(xc) 21 | txts = list() 22 | for bi in range(b): 23 | txt = Image.new("RGB", wh, color="white") 24 | draw = ImageDraw.Draw(txt) 25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size) 26 | nc = int(40 * (wh[0] / 256)) 27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc)) 28 | 29 | try: 30 | draw.text((0, 0), lines, fill="black", font=font) 31 | except UnicodeEncodeError: 32 | print("Cant encode string for logging. Skipping.") 33 | 34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0 35 | txts.append(txt) 36 | txts = np.stack(txts) 37 | txts = torch.tensor(txts) 38 | return txts 39 | 40 | 41 | def ismap(x): 42 | if not isinstance(x, torch.Tensor): 43 | return False 44 | return (len(x.shape) == 4) and (x.shape[1] > 3) 45 | 46 | 47 | def isimage(x): 48 | if not isinstance(x, torch.Tensor): 49 | return False 50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1) 51 | 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | 62 | 63 | def mean_flat(tensor): 64 | """ 65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86 66 | Take the mean over all non-batch dimensions. 67 | """ 68 | return tensor.mean(dim=list(range(1, len(tensor.shape)))) 69 | 70 | 71 | def count_params(model, verbose=False): 72 | total_params = sum(p.numel() for p in model.parameters()) 73 | if verbose: 74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") 75 | return total_params 76 | 77 | 78 | def instantiate_from_config(config): 79 | if not "target" in config: 80 | if config == '__is_first_stage__': 81 | return None 82 | elif config == "__is_unconditional__": 83 | return None 84 | raise KeyError("Expected key `target` to instantiate.") 85 | return get_obj_from_str(config["target"])(**config.get("params", dict())) 86 | 87 | 88 | def get_obj_from_str(string, reload=False): 89 | module, cls = string.rsplit(".", 1) 90 | if reload: 91 | module_imp = importlib.import_module(module) 92 | importlib.reload(module_imp) 93 | return getattr(importlib.import_module(module, package=None), cls) 94 | 95 | 96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False): 97 | # create dummy dataset instance 98 | 99 | # run prefetching 100 | if idx_to_fn: 101 | res = func(data, worker_id=idx) 102 | else: 103 | res = func(data) 104 | Q.put([idx, res]) 105 | Q.put("Done") 106 | 107 | 108 | def parallel_data_prefetch( 109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False 110 | ): 111 | # if target_data_type not in ["ndarray", "list"]: 112 | # raise ValueError( 113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray." 114 | # ) 115 | if isinstance(data, np.ndarray) and target_data_type == "list": 116 | raise ValueError("list expected but function got ndarray.") 117 | elif isinstance(data, abc.Iterable): 118 | if isinstance(data, dict): 119 | print( 120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.' 121 | ) 122 | data = list(data.values()) 123 | if target_data_type == "ndarray": 124 | data = np.asarray(data) 125 | else: 126 | data = list(data) 127 | else: 128 | raise TypeError( 129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}." 130 | ) 131 | 132 | if cpu_intensive: 133 | Q = mp.Queue(1000) 134 | proc = mp.Process 135 | else: 136 | Q = Queue(1000) 137 | proc = Thread 138 | # spawn processes 139 | if target_data_type == "ndarray": 140 | arguments = [ 141 | [func, Q, part, i, use_worker_id] 142 | for i, part in enumerate(np.array_split(data, n_proc)) 143 | ] 144 | else: 145 | step = ( 146 | int(len(data) / n_proc + 1) 147 | if len(data) % n_proc != 0 148 | else int(len(data) / n_proc) 149 | ) 150 | arguments = [ 151 | [func, Q, part, i, use_worker_id] 152 | for i, part in enumerate( 153 | [data[i: i + step] for i in range(0, len(data), step)] 154 | ) 155 | ] 156 | processes = [] 157 | for i in range(n_proc): 158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i]) 159 | processes += [p] 160 | 161 | # start processes 162 | print(f"Start prefetching...") 163 | import time 164 | 165 | start = time.time() 166 | gather_res = [[] for _ in range(n_proc)] 167 | try: 168 | for p in processes: 169 | p.start() 170 | 171 | k = 0 172 | while k < n_proc: 173 | # get result 174 | res = Q.get() 175 | if res == "Done": 176 | k += 1 177 | else: 178 | gather_res[res[0]] = res[1] 179 | 180 | except Exception as e: 181 | print("Exception: ", e) 182 | for p in processes: 183 | p.terminate() 184 | 185 | raise e 186 | finally: 187 | for p in processes: 188 | p.join() 189 | print(f"Prefetching complete. [{time.time() - start} sec.]") 190 | 191 | if target_data_type == 'ndarray': 192 | if not isinstance(gather_res[0], np.ndarray): 193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0) 194 | 195 | # order outputs 196 | return np.concatenate(gather_res, axis=0) 197 | elif target_data_type == 'list': 198 | out = [] 199 | for r in gather_res: 200 | out.extend(r) 201 | return out 202 | else: 203 | return gather_res 204 | -------------------------------------------------------------------------------- /chexzero/img_encoder/chexzero_img_encoder.py: -------------------------------------------------------------------------------- 1 | from dataclasses import dataclass 2 | import logging 3 | import os 4 | from typing import Optional 5 | import einops 6 | from torch import Tensor 7 | from torch import nn 8 | import torch 9 | import torch.nn.functional as F 10 | from chexzero.model import VisualTransformer 11 | from chexzero.components.mlp import MLP 12 | from chexzero.img_encoder import ImageEncoderOutput 13 | 14 | from util.model_utils import BaseModel, BaseModelConfig, MainModelConfig 15 | import chexzero.clip 16 | 17 | log = logging.getLogger(__name__) 18 | 19 | 20 | @dataclass 21 | class ChexzeroImageEncoderConfig(BaseModelConfig): 22 | model_path: str = os.path.expanduser( 23 | "~/models/third_party/chexzero/CheXzero_Models/best_64_5e-05_original_22000_0.864.pt" 24 | ) 25 | frozen_backbone: bool = False 26 | # freeze over full training, i.e. never unfreeze 27 | freeze_backbone_layers: Optional[int] = None 28 | 29 | add_cls_features: bool = False 30 | use_pretrained_projection: bool = True 31 | # 0 = no additionl projection, 1 = linear, 2 = one hidden layer 32 | additional_projection_layers: int = 0 33 | projection_bn: bool = False 34 | normalize_projected: bool = False 35 | 36 | 37 | class ChexzeroImageEncoder(BaseModel): 38 | CONFIG_CLS = ChexzeroImageEncoderConfig 39 | MODIFYABLE_CONFIGS = ("frozen_backbone",) 40 | 41 | def __init__(self, config: ChexzeroImageEncoderConfig, main_config: MainModelConfig): 42 | super().__init__(config) 43 | self.config: ChexzeroImageEncoderConfig 44 | 45 | model, _ = chexzero.clip.load("ViT-B/32", device="cpu", jit=False) 46 | model.load_state_dict(torch.load(self.config.model_path, map_location="cpu")) 47 | self.d = main_config.d_model 48 | 49 | self.backbone: VisualTransformer = model.visual 50 | d_backbone = self.backbone.output_dim if self.config.use_pretrained_projection else self.backbone.proj.shape[0] 51 | self.n_layers = len(self.backbone.transformer.resblocks) + 2 52 | if config.frozen_backbone: 53 | for param in self.backbone.parameters(): 54 | param.requires_grad = False 55 | log.info("Freezing backbone for the whole training") 56 | self.n_frozen_layers = self.n_layers 57 | self.n_currently_frozen_layers = self.n_layers 58 | elif config.freeze_backbone_layers is not None and config.freeze_backbone_layers != 0: 59 | n_transformer_layers = len(self.backbone.transformer.resblocks) 60 | if config.freeze_backbone_layers < 0: 61 | config.freeze_backbone_layers = (n_transformer_layers + 2) + config.freeze_backbone_layers 62 | self.freeze_layers(config.freeze_backbone_layers) 63 | self.n_frozen_layers = config.freeze_backbone_layers 64 | self.n_currently_frozen_layers = config.freeze_backbone_layers 65 | log.info(f"Freezing {self.n_frozen_layers}/{self.n_layers} layers of backbone for the whole training") 66 | 67 | self.patch_projection = MLP( 68 | self.config.additional_projection_layers, 69 | d_in=d_backbone, 70 | d_out=self.d, 71 | use_bn=self.config.projection_bn, 72 | act=main_config.act, 73 | dropout=main_config.dropout, 74 | ) 75 | 76 | def freeze_layers(self, n_frozen_layers: int): 77 | # first layer (embedding layer) 78 | emb_layer_requires_grad = n_frozen_layers == 0 79 | for param in self.backbone.conv1.parameters(): 80 | param.requires_grad = emb_layer_requires_grad 81 | self.backbone.class_embedding.requires_grad = emb_layer_requires_grad 82 | self.backbone.positional_embedding.requires_grad = emb_layer_requires_grad 83 | self.backbone.ln_pre.requires_grad = emb_layer_requires_grad 84 | 85 | # transformer layers 86 | n_transformer_layers = len(self.backbone.transformer.resblocks) 87 | n_frozen_transformer_layers = n_frozen_layers - 1 88 | for i, resblock in enumerate(self.backbone.transformer.resblocks): 89 | layer_requires_grad = i >= n_frozen_transformer_layers 90 | for param in resblock.parameters(): 91 | param.requires_grad = layer_requires_grad 92 | 93 | # last layer (projection layer) 94 | proj_layer_requires_grad = n_frozen_layers > n_transformer_layers + 1 95 | for param in self.backbone.ln_post.parameters(): 96 | param.requires_grad = proj_layer_requires_grad 97 | self.backbone.proj.requires_grad = proj_layer_requires_grad 98 | 99 | def forward(self, x: Tensor, **kwargs) -> ImageEncoderOutput: 100 | 101 | if x.ndim == 3: 102 | x = einops.repeat(x, "n h w -> n c h w", c=3) 103 | device = x.device 104 | dtype = x.dtype 105 | N, _, H, W = x.shape 106 | assert H == W, "Only square images are supported" 107 | input_resolution = H 108 | assert ( 109 | self.backbone.input_resolution == input_resolution 110 | ), f"Input resolution of backbone ({self.backbone.input_resolution}) does not match input resolution of image ({input_resolution})" 111 | 112 | # Encode image using backbone 113 | with torch.set_grad_enabled(not self.config.frozen_backbone): 114 | x = self.backbone.conv1(x) # shape = [*, width, grid, grid] 115 | N, _, H, W = x.shape 116 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] 117 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] 118 | x = torch.cat( 119 | [ 120 | self.backbone.class_embedding.to(x.dtype) 121 | + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 122 | x, 123 | ], 124 | dim=1, 125 | ) # shape = [*, grid ** 2 + 1, width] 126 | pos_emb = self.backbone.positional_embedding.to(x.dtype) # shape = [grid ** 2 + 1, width] 127 | x = x + pos_emb 128 | x = self.backbone.ln_pre(x) 129 | 130 | x = x.permute(1, 0, 2) # NLD -> LND 131 | x = self.backbone.transformer(x) 132 | x = x.permute(1, 0, 2) # LND -> NLD 133 | 134 | # (N x d) 135 | cls_features = self.backbone.ln_post(x[:, 0, :]) 136 | if self.config.add_cls_features: 137 | patch_features = self.backbone.ln_post(x[:, 1:, :] + x[:, 0, :].unsqueeze(1)) 138 | else: 139 | # (N x HW x d) 140 | patch_features = self.backbone.ln_post(x[:, 1:, :]) 141 | 142 | if self.config.use_pretrained_projection: 143 | cls_features = cls_features @ self.backbone.proj 144 | patch_features = patch_features @ self.backbone.proj 145 | pos_embeddings = pos_emb[1:] @ self.backbone.proj 146 | else: 147 | pos_embeddings = pos_emb[1:] 148 | 149 | # (N, H, W, d_backbone) 150 | patch_features = einops.rearrange(patch_features, "n (h w) d -> n h w d", h=H, w=W) 151 | # (N, H, W, d_backbone) -> (N, H, W, d) 152 | pos_embeddings = einops.repeat(pos_embeddings, "(h w) d -> n h w d", h=H, w=W, n=N) 153 | 154 | # (N x H x W x d) 155 | projected_patch_features = self.patch_projection(patch_features) 156 | cls_features = self.patch_projection(cls_features.unsqueeze(1)).squeeze(1) 157 | if self.config.additional_projection_layers > 0: 158 | pos_embeddings = self.patch_projection(pos_embeddings) 159 | 160 | if self.config.normalize_projected: 161 | projected_patch_features = F.normalize(projected_patch_features, dim=-1) 162 | cls_features = F.normalize(cls_features, dim=-1) 163 | 164 | return ImageEncoderOutput( 165 | patch_features=projected_patch_features, pos_embeddings=pos_embeddings, global_features=cls_features 166 | ) 167 | -------------------------------------------------------------------------------- /chexzero/util/plot_grounding.py: -------------------------------------------------------------------------------- 1 | from matplotlib import gridspec 2 | import matplotlib 3 | import matplotlib.pyplot as plt 4 | import matplotlib.colors as mcolors 5 | import numpy as np 6 | import torch.nn.functional as F 7 | import torch 8 | import textwrap as twp 9 | 10 | import wandb 11 | 12 | from util.plot_utils import plot_img_with_bounding_boxes 13 | 14 | 15 | def plot_grounding(model_output: 'ChEX', max_samples: int = 10): 16 | N = len(model_output.sample_id) 17 | if max_samples is None or max_samples > N: 18 | max_samples = N 19 | 20 | figs = [] 21 | for i in range(max_samples): 22 | output_i = model_output[i] 23 | fig = plot_grounding_sample(output_i.x, output_i.sentences, 24 | output_i.encoded_img.patch_features, 25 | output_i.encoded_sentences.sentence_features, output_i.encoded_sentences.sentence_mask, 26 | output_i.grounding.boxes, output_i.grounding.multiboxes, 27 | output_i.grounding.box_features) 28 | figs.append(wandb.Image(fig)) 29 | return figs 30 | 31 | 32 | def plot_grounding_sample(img, sentences, 33 | patch_features, 34 | sentence_features, sentence_mask, 35 | boxes, multiboxes, box_features): 36 | n_sentences = len(sentences) 37 | assert sentence_mask.sum() == n_sentences, f'Expected {n_sentences} sentences, got {sentence_mask.sum()}' 38 | img_shape = img.shape 39 | img = img.numpy() 40 | 41 | sentence_features = sentence_features[sentence_mask] 42 | box_features = box_features[sentence_mask] 43 | if multiboxes is not None: 44 | boxes = multiboxes[sentence_mask].numpy() 45 | S, R, _ = boxes.shape 46 | else: 47 | boxes = boxes[sentence_mask].numpy() 48 | S, _ = boxes.shape 49 | 50 | # Similarities and neighbors 51 | # (S x S_r) 52 | sent_region_l2_pairwise = torch.cdist(sentence_features, box_features, p=2) 53 | # (S) 54 | sentence_region_l2 = sent_region_l2_pairwise.diagonal() 55 | # (S x S_r) 56 | sent_region_cos_pairwise = torch.nn.functional.cosine_similarity(sentence_features.unsqueeze(1), box_features.unsqueeze(0), dim=-1) 57 | # (S) 58 | sentence_region_cos = sent_region_cos_pairwise.diagonal() 59 | # (S) 60 | region_rank = (sent_region_cos_pairwise > sentence_region_cos.unsqueeze(-1)).sum(dim=-1) + 1 61 | sentence_rank = (sent_region_cos_pairwise > sentence_region_cos.unsqueeze(0)).sum(dim=0) + 1 62 | # (S) 63 | sentence_region_neighbor = sent_region_cos_pairwise.argmax(dim=-1) + 1 64 | region_sentence_neighbor = sent_region_cos_pairwise.argmax(dim=0) + 1 65 | 66 | *shape_patch, d = patch_features.shape 67 | # (S x H*W) 68 | sentence_patch_cos_pairwise = torch.nn.functional.cosine_similarity(sentence_features.unsqueeze(1), patch_features.view(-1, d).unsqueeze(0), dim=-1) 69 | # (H*W) 70 | patch_sentence_neighbor = sentence_patch_cos_pairwise.argmax(dim=0) + 1 71 | patch_sentence_neighbor = patch_sentence_neighbor.view(*shape_patch) 72 | patch_sentence_neighbor_upsampled = F.interpolate(patch_sentence_neighbor[None, None, ...].float(), size=img_shape, mode='nearest-exact').squeeze() 73 | upsample_factor = np.array(img_shape) / np.array(shape_patch) 74 | 75 | # To numpy 76 | sentence_region_l2 = sentence_region_l2.numpy() 77 | sentence_region_cos = sentence_region_cos.numpy() 78 | sentence_rank = sentence_rank.numpy() 79 | region_rank = region_rank.numpy() 80 | region_sentence_neighbor = region_sentence_neighbor.numpy() 81 | sentence_region_neighbor = sentence_region_neighbor.numpy() 82 | patch_sentence_neighbor = patch_sentence_neighbor.numpy() 83 | patch_sentence_neighbor_upsampled = patch_sentence_neighbor_upsampled.numpy() 84 | 85 | # IDs and colors, text wrap 86 | sentence_ids = list(range(1, n_sentences + 1)) 87 | sentence_colors = list(mcolors.TABLEAU_COLORS.values())[:n_sentences] if n_sentences <= 10 else matplotlib.color_sequences['tab20'][:n_sentences] 88 | sentences = [twp.fill(s, 70) for s in sentences] 89 | cmap = mcolors.LinearSegmentedColormap.from_list("cmap_name", sentence_colors, N=n_sentences) 90 | 91 | # Plotting... 92 | fig = plt.figure(figsize=(12, 6)) 93 | gs = gridspec.GridSpec(nrows=2, ncols=3, height_ratios=[1, 1]) 94 | 95 | ax_boxes = fig.add_subplot(gs[0, 0]) 96 | ax_boxes.set_xticks([]) 97 | ax_boxes.set_yticks([]) 98 | sentence_ids_array = np.array(sentence_ids) # (S) 99 | if multiboxes is not None: 100 | # (S x R) 101 | sentence_ids_array = sentence_ids_array[:, None].repeat(R, axis=1) 102 | boxes_with_ids = np.concatenate([boxes, sentence_ids_array[:, None] - 1], axis=1) if multiboxes is None \ 103 | else np.concatenate([boxes, sentence_ids_array[:, :, None] - 1], axis=2) 104 | 105 | if multiboxes is not None: 106 | boxes_with_ids = boxes_with_ids.reshape(-1, 5) 107 | plot_img_with_bounding_boxes(ax_boxes, class_names=sentence_ids, 108 | img=img, target_list=boxes_with_ids, plot_pred=False, 109 | class_cmap=sentence_colors) 110 | 111 | boxes_upper_left = (boxes_with_ids[:, :2] - boxes_with_ids[:, 2:4] / 2) * img_shape[::-1] 112 | boxes_lower_right = (boxes_with_ids[:, :2] + boxes_with_ids[:, 2:4] / 2) * img_shape[::-1] 113 | ax_boxes.set_xlim(min(0, boxes_upper_left[:, 0].min()), max(img_shape[1], boxes_lower_right[:, 0].max())) 114 | ax_boxes.set_ylim(max(img_shape[0], boxes_lower_right[:, 1].max()), min(0, boxes_upper_left[:, 1].min())) 115 | 116 | ax_patch_neighbors = fig.add_subplot(gs[0, 1]) 117 | ax_patch_neighbors.imshow(img, cmap='gray') 118 | ax_patch_neighbors.imshow(patch_sentence_neighbor_upsampled, cmap=cmap, vmin=1, vmax=n_sentences, alpha=0.6) 119 | ax_patch_neighbors.set_xticks([]) 120 | ax_patch_neighbors.set_yticks([]) 121 | for y, neighbors in enumerate(patch_sentence_neighbor): 122 | y = (y + 0.5) * upsample_factor[0] - 1 123 | for x, neighbor in enumerate(neighbors): 124 | x = (x + 0.5) * upsample_factor[1] - 1 125 | ax_patch_neighbors.text(x, y, neighbor, ha='center', va='center', color='w') 126 | 127 | ax_sent_barplots = fig.add_subplot(gs[0, -1]) 128 | sentence_ids = np.array(sentence_ids) 129 | bar_pos = np.linspace(-0.3, 0.3, n_sentences) 130 | 131 | ax_sent_barplots.bar(bar_pos, sentence_region_cos, color=sentence_colors, width=0.6 / n_sentences) 132 | ax_sent_barplots.set_xticks([0]) 133 | ax_sent_barplots.set_xticklabels(['cos']) 134 | 135 | tab_data = [ 136 | *zip(sentence_ids, sentences, sentence_rank, region_rank, sentence_region_l2.round(2), sentence_region_cos.round(2), region_sentence_neighbor, sentence_region_neighbor) 137 | ] 138 | tab_colors = [ 139 | [col, col, col, col, col, col, sentence_colors[neighbor_s - 1], sentence_colors[neighbor_r - 1]] 140 | for col, neighbor_s, neighbor_r in zip(sentence_colors, region_sentence_neighbor, sentence_region_neighbor) 141 | ] 142 | 143 | ax_table = fig.add_subplot(gs[1, :]) 144 | ax_table.axis('off') 145 | ax_table.axis('tight') 146 | table = ax_table.table( 147 | cellText=tab_data, 148 | cellLoc='center', 149 | colLoc='center', 150 | loc='center', 151 | cellColours=tab_colors, 152 | colLabels=['ID', 'Sentence', 'Rank s', 'Rank r', 'L2', 'cos', '1-NN s', '1-NN r'], 153 | colWidths=[0.04, 0.7, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05] 154 | ) 155 | table.auto_set_font_size(False) 156 | table.set_fontsize(10) 157 | table.scale(1, 1.5) 158 | n_cols = 8 159 | 160 | # Apply custom cell renderer to each cell 161 | for i in range(len(sentences) + 1): 162 | cell = table.get_celld()[(i, 1)] 163 | cell.set_text_props(horizontalalignment='left') 164 | cell.PAD = 0.01 165 | lines = len(cell.get_text().get_text().splitlines()) 166 | for j in range(n_cols): 167 | cell = table.get_celld()[(i, j)] 168 | cell.set_height(lines * 0.11) 169 | 170 | return fig 171 | 172 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torchvision import transforms 5 | 6 | from chexzero.chexzero.clip import tokenize 7 | from models.clip import CLIPEvaluator 8 | 9 | 10 | class CLIPLoss(nn.Module): 11 | def __init__(self, args, pos_prompt=None, neg_prompt=None, clip_model=None, clip_preprocess=None): 12 | super().__init__() 13 | self.clip_model = CLIPEvaluator(args, model=clip_model, preprocess=clip_preprocess) 14 | self.pos_prompt = pos_prompt 15 | self.neg_prompt = neg_prompt 16 | 17 | # print(f"CLIPLoss: Positive prompt {pos_prompt}, Negative prompt {neg_prompt}") 18 | 19 | def forward_images(self, original, edited, preprocess_images): 20 | if preprocess_images: 21 | original = self.clip_model.clip_transform_for_tensor(original) 22 | edited = self.clip_model.clip_transform_for_tensor(edited) 23 | 24 | original_features = self.clip_model.model.encode_image(original) 25 | edited_features = self.clip_model.model.encode_image(edited) 26 | 27 | return original_features, edited_features 28 | 29 | def forward_prompts(self, pos_prompt=None, neg_prompt=None): 30 | if pos_prompt is None: 31 | pos_prompt = self.pos_prompt 32 | if neg_prompt is None: 33 | neg_prompt = self.neg_prompt 34 | 35 | pos_features = self.clip_model.encode_text(pos_prompt) 36 | neg_features = self.clip_model.encode_text(neg_prompt) 37 | 38 | return pos_features, neg_features 39 | 40 | def forward(self, generated_images, pos_prompt=None, neg_prompt=None, return_logits_per_image=False): 41 | if pos_prompt is None: 42 | pos_prompt = self.pos_prompt 43 | if neg_prompt is None: 44 | neg_prompt = self.neg_prompt 45 | 46 | # Get logits 47 | logits_per_image, _ = self.clip_model.score(generated_images, [pos_prompt, neg_prompt]) 48 | 49 | if return_logits_per_image: 50 | return logits_per_image 51 | 52 | # Apply softmax to get probabilities 53 | probs = F.softmax(logits_per_image, dim=1) 54 | 55 | # Get probabilities for positive and negative prompts 56 | pos_prob = probs[:, 0] # Positive prompt is first 57 | neg_prob = probs[:, 1] # Negative prompt is second 58 | 59 | # print("Sample probabilities:") 60 | # print(probs) 61 | 62 | eps = 1e-6 63 | # neg_prob = torch.clamp(neg_prob, min=eps) 64 | # pos_prob = torch.clamp(pos_prob, min=eps) 65 | 66 | # loss = -torch.log(neg_prob + eps) + torch.log(pos_prob + eps) 67 | probs = torch.clamp(probs, min=eps) 68 | target = torch.zeros_like(probs) 69 | target[:, 1] = 1.0 70 | loss = F.binary_cross_entropy_with_logits(probs, target) 71 | 72 | # print("Sample loss:", loss, loss.mean(), loss.shape, loss.requires_grad, torch.any(torch.isnan(loss))) 73 | 74 | return loss, probs 75 | 76 | 77 | class DirectionLoss(nn.Module): 78 | 79 | def __init__(self, loss_type="mse"): 80 | super(DirectionLoss, self).__init__() 81 | 82 | self.loss_type = loss_type 83 | 84 | self.loss_func = {"mse": torch.nn.MSELoss, "cosine": torch.nn.CosineSimilarity, "mae": torch.nn.L1Loss}[ 85 | loss_type 86 | ]() 87 | 88 | def forward(self, x, y): 89 | if self.loss_type == "cosine": 90 | return 1.0 - self.loss_func(x, y) 91 | 92 | return self.loss_func(x, y) 93 | 94 | 95 | class CLIPDirectionLoss(nn.Module): 96 | 97 | def __init__(self, args, pos_prompt=None, neg_prompt=None, clip_model=None, clip_preprocess=None): 98 | super().__init__() 99 | 100 | self.clip_model = CLIPEvaluator(args, model=clip_model, preprocess=clip_preprocess) 101 | self.pos_prompt = pos_prompt 102 | self.neg_prompt = neg_prompt 103 | 104 | self.device = args.device 105 | 106 | self.loss = DirectionLoss(loss_type="cosine") 107 | 108 | self.target_direction = self.compute_text_direction() 109 | 110 | print(f"CLIPLoss: Positive prompt {pos_prompt}, Negative prompt {neg_prompt}") 111 | 112 | def encode_text(self, texts): 113 | tokenized_text = tokenize(texts).to(self.device) 114 | text_features = self.clip_model.model.encode_text(tokenized_text) 115 | 116 | return text_features 117 | 118 | def encode_image(self, images): 119 | images = self.clip_model.clip_transform_for_tensor(images) 120 | 121 | image_features = self.clip_model.model.encode_image(images) 122 | 123 | return image_features 124 | 125 | def get_text_features(self, texts): 126 | text_features = self.encode_text(texts).detach() 127 | 128 | text_features = text_features / text_features.norm(dim=-1, keepdim=True) 129 | 130 | return text_features 131 | 132 | def get_image_features(self, images): 133 | image_features = self.encode_image(images) 134 | 135 | image_features = image_features / image_features.norm(dim=-1, keepdim=True) 136 | 137 | return image_features 138 | 139 | def compute_text_direction(self, source_class=None, target_class=None) -> torch.Tensor: 140 | print("Computing text direction") 141 | if source_class is None: 142 | source_class = self.pos_prompt 143 | 144 | if target_class is None: 145 | target_class = self.neg_prompt 146 | 147 | text_features = self.get_text_features([source_class, target_class]) 148 | source_features, target_features = text_features[0], text_features[1] 149 | 150 | text_direction = target_features - source_features # .mean(axis=0, keepdim=True) 151 | text_direction /= text_direction.norm(dim=-1, keepdim=True) 152 | print("Text direction", text_direction.shape) 153 | 154 | return text_direction 155 | 156 | def forward(self, src_images, target_images, pos_prompt=None, neg_prompt=None): 157 | if pos_prompt is None: 158 | pos_prompt = self.pos_prompt 159 | if neg_prompt is None: 160 | neg_prompt = self.neg_prompt 161 | 162 | if self.target_direction is None: 163 | self.target_direction = self.compute_text_direction() 164 | 165 | src_encoding = self.get_image_features(src_images) 166 | target_encoding = self.get_image_features(target_images) 167 | 168 | edit_direction = target_encoding - src_encoding 169 | edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7 170 | 171 | # calculate the distance between src_encoding and target_encoding and it should be maximized 172 | 173 | return self.loss(edit_direction, self.target_direction).mean(), F.mse_loss(src_encoding, target_encoding) 174 | 175 | 176 | def custom_cross_entropy(pred_logits, target_logits): 177 | pred_probs = F.softmax(pred_logits, dim=-1) 178 | 179 | # Create a mask for non-inf values in target_logits 180 | mask = ~torch.isinf(target_logits) 181 | 182 | # Normalize the non-inf part of target_logits 183 | target_probs = torch.where(mask, F.softmax(target_logits, dim=-1), torch.zeros_like(target_logits)) 184 | 185 | # Compute cross-entropy loss only for non-inf positions 186 | loss = -torch.sum(target_probs * torch.log(pred_probs + 1e-8), dim=-1) 187 | 188 | # Add a large penalty for predicting non-zero probability where target is -inf 189 | inf_penalty = torch.where(~mask, pred_probs, torch.zeros_like(pred_probs)).sum(dim=-1) * 1e2 190 | 191 | return (loss + inf_penalty).mean() 192 | 193 | 194 | def kl_loss(pred_logits, target_logits): 195 | pred_log_softmax = F.log_softmax(pred_logits, dim=-1) 196 | target_softmax = F.softmax(target_logits, dim=-1) 197 | 198 | kl_loss = F.kl_div(pred_log_softmax, target_softmax, reduction="batchmean") 199 | return kl_loss 200 | 201 | 202 | def masked_mse_loss(pred, target, mask): 203 | """Take in the mask for the region to be inpainted""" 204 | if pred.shape[-1] != mask.shape[-1]: 205 | mask = F.interpolate(mask, size=(pred.shape[-1], pred.shape[-1]), mode="bilinear", align_corners=False) 206 | 207 | return torch.mean(((pred - target) * mask) ** 2) 208 | -------------------------------------------------------------------------------- /cheff/ldm/modules/encoders/modules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from functools import partial 4 | import clip 5 | from einops import repeat 6 | import kornia 7 | 8 | 9 | from cheff.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test 10 | 11 | 12 | class AbstractEncoder(nn.Module): 13 | def __init__(self): 14 | super().__init__() 15 | 16 | def encode(self, *args, **kwargs): 17 | raise NotImplementedError 18 | 19 | 20 | 21 | class ClassEmbedder(nn.Module): 22 | def __init__(self, embed_dim, n_classes=1000, key='class'): 23 | super().__init__() 24 | self.key = key 25 | self.embedding = nn.Embedding(n_classes, embed_dim) 26 | 27 | def forward(self, batch, key=None): 28 | if key is None: 29 | key = self.key 30 | # this is for use in crossattn 31 | c = batch[key][:, None] 32 | c = self.embedding(c) 33 | return c 34 | 35 | 36 | class MultiClassEmbedder(nn.Module): 37 | def __init__(self, embed_dim, n_classes=1000, key='class'): 38 | super().__init__() 39 | self.key = key 40 | self.embedding = nn.Sequential( 41 | nn.Linear(in_features=n_classes, out_features=embed_dim), 42 | nn.GELU(), 43 | nn.Linear(in_features=embed_dim, out_features=embed_dim), 44 | nn.GELU(), 45 | nn.Linear(in_features=embed_dim, out_features=embed_dim), 46 | ) 47 | 48 | def forward(self, batch, key=None): 49 | if key is None: 50 | key = self.key 51 | c = batch[key] 52 | c = self.embedding(c) 53 | c = c.unsqueeze(1) 54 | return c 55 | 56 | 57 | class TransformerEmbedder(AbstractEncoder): 58 | """Some transformer encoder layers""" 59 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=150, device="cuda"): 60 | super().__init__() 61 | self.device = device 62 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 63 | attn_layers=Encoder(dim=n_embed, depth=n_layer)) 64 | 65 | def forward(self, tokens): 66 | tokens = tokens.to(self.device) # meh 67 | z = self.transformer(tokens, return_embeddings=True) 68 | return z 69 | 70 | def encode(self, x): 71 | return self(x) 72 | 73 | 74 | class BERTTokenizer(AbstractEncoder): 75 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)""" 76 | def __init__(self, device="cuda", vq_interface=True, max_length=150): 77 | super().__init__() 78 | from transformers import BertTokenizerFast # TODO: add to reuquirements 79 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased") 80 | self.device = device 81 | self.vq_interface = vq_interface 82 | self.max_length = max_length 83 | 84 | def forward(self, text): 85 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, 86 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt") 87 | tokens = batch_encoding["input_ids"].to(self.device) 88 | return tokens 89 | 90 | @torch.no_grad() 91 | def encode(self, text): 92 | tokens = self(text) 93 | if not self.vq_interface: 94 | return tokens 95 | return None, None, [None, None, tokens] 96 | 97 | def decode(self, text): 98 | return text 99 | 100 | 101 | class BERTEmbedder(AbstractEncoder): 102 | """Uses the BERT tokenizr model and add some transformer encoder layers""" 103 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=150, 104 | device="cuda",use_tokenizer=True, embedding_dropout=0.0): 105 | super().__init__() 106 | self.use_tknz_fn = use_tokenizer 107 | if self.use_tknz_fn: 108 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len, device=device) 109 | self.device = device 110 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len, 111 | attn_layers=Encoder(dim=n_embed, depth=n_layer), 112 | emb_dropout=embedding_dropout) 113 | 114 | def forward(self, text): 115 | if self.use_tknz_fn: 116 | tokens = self.tknz_fn(text)#.to(self.device) 117 | else: 118 | tokens = text 119 | z = self.transformer(tokens, return_embeddings=True) 120 | return z 121 | 122 | def encode(self, text): 123 | # output of length 77 124 | return self(text) 125 | 126 | 127 | class SpatialRescaler(nn.Module): 128 | def __init__(self, 129 | n_stages=1, 130 | method='bilinear', 131 | multiplier=0.5, 132 | in_channels=3, 133 | out_channels=None, 134 | bias=False): 135 | super().__init__() 136 | self.n_stages = n_stages 137 | assert self.n_stages >= 0 138 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area'] 139 | self.multiplier = multiplier 140 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method) 141 | self.remap_output = out_channels is not None 142 | if self.remap_output: 143 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.') 144 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias) 145 | 146 | def forward(self,x): 147 | for stage in range(self.n_stages): 148 | x = self.interpolator(x, scale_factor=self.multiplier) 149 | 150 | 151 | if self.remap_output: 152 | x = self.channel_mapper(x) 153 | return x 154 | 155 | def encode(self, x): 156 | return self(x) 157 | 158 | 159 | class FrozenCLIPTextEmbedder(nn.Module): 160 | """ 161 | Uses the CLIP transformer encoder for text. 162 | """ 163 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True): 164 | super().__init__() 165 | self.model, _ = clip.load(version, jit=False, device="cpu") 166 | self.device = device 167 | self.max_length = max_length 168 | self.n_repeat = n_repeat 169 | self.normalize = normalize 170 | 171 | def freeze(self): 172 | self.model = self.model.eval() 173 | for param in self.parameters(): 174 | param.requires_grad = False 175 | 176 | def forward(self, text): 177 | tokens = clip.tokenize(text).to(self.device) 178 | z = self.model.encode_text(tokens) 179 | if self.normalize: 180 | z = z / torch.linalg.norm(z, dim=1, keepdim=True) 181 | return z 182 | 183 | def encode(self, text): 184 | z = self(text) 185 | if z.ndim==2: 186 | z = z[:, None, :] 187 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat) 188 | return z 189 | 190 | 191 | class FrozenClipImageEmbedder(nn.Module): 192 | """ 193 | Uses the CLIP image encoder. 194 | """ 195 | def __init__( 196 | self, 197 | model, 198 | jit=False, 199 | device='cuda' if torch.cuda.is_available() else 'cpu', 200 | antialias=False, 201 | ): 202 | super().__init__() 203 | self.model, _ = clip.load(name=model, device=device, jit=jit) 204 | 205 | self.antialias = antialias 206 | 207 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False) 208 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False) 209 | 210 | def preprocess(self, x): 211 | # normalize to [0,1] 212 | x = kornia.geometry.resize(x, (224, 224), 213 | interpolation='bicubic',align_corners=True, 214 | antialias=self.antialias) 215 | x = (x + 1.) / 2. 216 | # renormalize according to clip 217 | x = kornia.enhance.normalize(x, self.mean, self.std) 218 | return x 219 | 220 | def forward(self, x): 221 | # x is assumed to be in range [-1,1] 222 | return self.model.encode_image(self.preprocess(x)) 223 | 224 | -------------------------------------------------------------------------------- /cheff/ldm/modules/losses/vqperceptual.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | from einops import repeat 5 | 6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init 7 | from taming.modules.losses.lpips import LPIPS 8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss 9 | 10 | 11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights): 12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0] 13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3]) 14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3]) 15 | loss_real = (weights * loss_real).sum() / weights.sum() 16 | loss_fake = (weights * loss_fake).sum() / weights.sum() 17 | d_loss = 0.5 * (loss_real + loss_fake) 18 | return d_loss 19 | 20 | def adopt_weight(weight, global_step, threshold=0, value=0.): 21 | if global_step < threshold: 22 | weight = value 23 | return weight 24 | 25 | 26 | def measure_perplexity(predicted_indices, n_embed): 27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py 28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally 29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed) 30 | avg_probs = encodings.mean(0) 31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp() 32 | cluster_use = torch.sum(avg_probs > 0) 33 | return perplexity, cluster_use 34 | 35 | def l1(x, y): 36 | return torch.abs(x-y) 37 | 38 | 39 | def l2(x, y): 40 | return torch.pow((x-y), 2) 41 | 42 | 43 | class VQLPIPSWithDiscriminator(nn.Module): 44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0, 45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0, 46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False, 47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips", 48 | pixel_loss="l1"): 49 | super().__init__() 50 | assert disc_loss in ["hinge", "vanilla"] 51 | assert perceptual_loss in ["lpips", "clips", "dists"] 52 | assert pixel_loss in ["l1", "l2"] 53 | self.codebook_weight = codebook_weight 54 | self.pixel_weight = pixelloss_weight 55 | if perceptual_loss == "lpips": 56 | print(f"{self.__class__.__name__}: Running with LPIPS.") 57 | self.perceptual_loss = LPIPS().eval() 58 | else: 59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<") 60 | self.perceptual_weight = perceptual_weight 61 | 62 | if pixel_loss == "l1": 63 | self.pixel_loss = l1 64 | else: 65 | self.pixel_loss = l2 66 | 67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels, 68 | n_layers=disc_num_layers, 69 | use_actnorm=use_actnorm, 70 | ndf=disc_ndf 71 | ).apply(weights_init) 72 | self.discriminator_iter_start = disc_start 73 | if disc_loss == "hinge": 74 | self.disc_loss = hinge_d_loss 75 | elif disc_loss == "vanilla": 76 | self.disc_loss = vanilla_d_loss 77 | else: 78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.") 79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.") 80 | self.disc_factor = disc_factor 81 | self.discriminator_weight = disc_weight 82 | self.disc_conditional = disc_conditional 83 | self.n_classes = n_classes 84 | 85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None): 86 | if last_layer is not None: 87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0] 88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0] 89 | else: 90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0] 91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0] 92 | 93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4) 94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach() 95 | d_weight = d_weight * self.discriminator_weight 96 | return d_weight 97 | 98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx, 99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None): 100 | if not exists(codebook_loss): 101 | codebook_loss = torch.tensor([0.]).to(inputs.device) 102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous()) 103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous()) 104 | if self.perceptual_weight > 0: 105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous()) 106 | rec_loss = rec_loss + self.perceptual_weight * p_loss 107 | else: 108 | p_loss = torch.tensor([0.0]) 109 | 110 | nll_loss = rec_loss 111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0] 112 | nll_loss = torch.mean(nll_loss) 113 | 114 | # now the GAN part 115 | if optimizer_idx == 0: 116 | # generator update 117 | if cond is None: 118 | assert not self.disc_conditional 119 | logits_fake = self.discriminator(reconstructions.contiguous()) 120 | else: 121 | assert self.disc_conditional 122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1)) 123 | g_loss = -torch.mean(logits_fake) 124 | 125 | try: 126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer) 127 | except RuntimeError: 128 | assert not self.training 129 | d_weight = torch.tensor(0.0) 130 | 131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean() 133 | 134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), 135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(), 136 | "{}/nll_loss".format(split): nll_loss.detach().mean(), 137 | "{}/rec_loss".format(split): rec_loss.detach().mean(), 138 | "{}/p_loss".format(split): p_loss.detach().mean(), 139 | "{}/d_weight".format(split): d_weight.detach(), 140 | "{}/disc_factor".format(split): torch.tensor(disc_factor), 141 | "{}/g_loss".format(split): g_loss.detach().mean(), 142 | } 143 | if predicted_indices is not None: 144 | assert self.n_classes is not None 145 | with torch.no_grad(): 146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes) 147 | log[f"{split}/perplexity"] = perplexity 148 | log[f"{split}/cluster_usage"] = cluster_usage 149 | return loss, log 150 | 151 | if optimizer_idx == 1: 152 | # second pass for discriminator update 153 | if cond is None: 154 | logits_real = self.discriminator(inputs.contiguous().detach()) 155 | logits_fake = self.discriminator(reconstructions.contiguous().detach()) 156 | else: 157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1)) 158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1)) 159 | 160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start) 161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake) 162 | 163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(), 164 | "{}/logits_real".format(split): logits_real.detach().mean(), 165 | "{}/logits_fake".format(split): logits_fake.detach().mean() 166 | } 167 | return d_loss, log 168 | -------------------------------------------------------------------------------- /chexzero/chexzero/eval.py: -------------------------------------------------------------------------------- 1 | import subprocess 2 | import numpy as np 3 | import os 4 | import pandas as pd 5 | from PIL import Image 6 | import matplotlib.pyplot as plt 7 | from typing import List, Callable 8 | 9 | import torch 10 | from torch.utils import data 11 | from tqdm.notebook import tqdm 12 | import torch.nn as nn 13 | from torchvision.transforms import Compose, Normalize, Resize 14 | 15 | import sklearn 16 | from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report 17 | from sklearn.metrics import precision_recall_curve, f1_score 18 | from sklearn.metrics import average_precision_score 19 | from sklearn.utils import resample 20 | 21 | import scipy 22 | import scipy.stats 23 | 24 | import sys 25 | sys.path.append('../..') 26 | 27 | import chexzero.clip 28 | from chexzero.model import CLIP 29 | 30 | def compute_mean(stats, is_df=True): 31 | spec_labels = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"] 32 | if is_df: 33 | spec_df = stats[spec_labels] 34 | res = np.mean(spec_df.iloc[0]) 35 | else: 36 | # cis is df, within bootstrap 37 | vals = [stats[spec_label][0] for spec_label in spec_labels] 38 | res = np.mean(vals) 39 | return res 40 | 41 | def accuracy(output, target, topk=(1,)): 42 | pred = output.topk(max(topk), 1, True, True)[1].t() 43 | print('pred: ', pred) 44 | 45 | expand = target.expand(-1, max(topk)) 46 | print('expand: ', expand) 47 | 48 | correct = pred.eq(expand) 49 | print('correct: ', correct) 50 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk] 51 | 52 | def sigmoid(x): 53 | z = 1/(1 + np.exp(-x)) 54 | return z 55 | 56 | ''' ROC CURVE ''' 57 | def plot_roc(y_pred, y_true, roc_name, plot=False): 58 | # given the test_ground_truth, and test_predictions 59 | fpr, tpr, thresholds = roc_curve(y_true, y_pred) 60 | 61 | roc_auc = auc(fpr, tpr) 62 | 63 | if plot: 64 | plt.figure(dpi=100) 65 | plt.title(roc_name) 66 | plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc) 67 | plt.legend(loc = 'lower right') 68 | plt.plot([0, 1], [0, 1],'r--') 69 | plt.xlim([0, 1]) 70 | plt.ylim([0, 1]) 71 | plt.ylabel('True Positive Rate') 72 | plt.xlabel('False Positive Rate') 73 | plt.show() 74 | return fpr, tpr, thresholds, roc_auc 75 | 76 | # J = TP/(TP+FN) + TN/(TN+FP) - 1 = tpr - fpr 77 | def choose_operating_point(fpr, tpr, thresholds): 78 | sens = 0 79 | spec = 0 80 | J = 0 81 | for _fpr, _tpr in zip(fpr, tpr): 82 | if _tpr - _fpr > J: 83 | sens = _tpr 84 | spec = 1-_fpr 85 | J = _tpr - _fpr 86 | return sens, spec 87 | 88 | ''' PRECISION-RECALL CURVE ''' 89 | def plot_pr(y_pred, y_true, pr_name, plot=False): 90 | precision, recall, thresholds = precision_recall_curve(y_true, y_pred) 91 | pr_auc = auc(recall, precision) 92 | # plot the precision-recall curves 93 | baseline = len(y_true[y_true==1]) / len(y_true) 94 | 95 | if plot: 96 | plt.figure(dpi=20) 97 | plt.title(pr_name) 98 | plt.plot(recall, precision, 'b', label='AUC = %0.2f' % pr_auc) 99 | # axis labels 100 | plt.legend(loc = 'lower right') 101 | plt.plot([0, 1], [baseline, baseline],'r--') 102 | plt.xlim([0, 1]) 103 | plt.ylim([0, 1]) 104 | plt.xlabel('Recall') 105 | plt.ylabel('Precision') 106 | # show the plot 107 | plt.show() 108 | return precision, recall, thresholds 109 | 110 | def evaluate(y_pred, y_true, cxr_labels, 111 | roc_name='Receiver Operating Characteristic', pr_name='Precision-Recall Curve', label_idx_map=None): 112 | 113 | ''' 114 | We expect `y_pred` and `y_true` to be numpy arrays, both of shape (num_samples, num_classes) 115 | 116 | `y_pred` is a numpy array consisting of probability scores with all values in range 0-1. 117 | 118 | `y_true` is a numpy array consisting of binary values representing if a class is present in 119 | the cxr. 120 | 121 | This function provides all relevant evaluation information, ROC, AUROC, Sensitivity, Specificity, 122 | PR-Curve, Precision, Recall for each class. 123 | ''' 124 | import warnings 125 | warnings.filterwarnings('ignore') 126 | 127 | num_classes = y_pred.shape[-1] # number of total labels 128 | 129 | dataframes = [] 130 | for i in range(num_classes): 131 | # print('{}.'.format(cxr_labels[i])) 132 | 133 | if label_idx_map is None: 134 | y_pred_i = y_pred[:, i] # (num_samples,) 135 | y_true_i = y_true[:, i] # (num_samples,) 136 | 137 | else: 138 | y_pred_i = y_pred[:, i] # (num_samples,) 139 | 140 | true_index = label_idx_map[cxr_labels[i]] 141 | y_true_i = y_true[:, true_index] # (num_samples,) 142 | 143 | cxr_label = cxr_labels[i] 144 | 145 | ''' ROC CURVE ''' 146 | roc_name = cxr_label + ' ROC Curve' 147 | fpr, tpr, thresholds, roc_auc = plot_roc(y_pred_i, y_true_i, roc_name) 148 | 149 | sens, spec = choose_operating_point(fpr, tpr, thresholds) 150 | 151 | results = [[roc_auc]] 152 | df = pd.DataFrame(results, columns=[cxr_label+'_auc']) 153 | dataframes.append(df) 154 | 155 | ''' PRECISION-RECALL CURVE ''' 156 | pr_name = cxr_label + ' Precision-Recall Curve' 157 | precision, recall, thresholds = plot_pr(y_pred_i, y_true_i, pr_name) 158 | 159 | dfs = pd.concat(dataframes, axis=1) 160 | return dfs 161 | 162 | ''' Bootstrap and Confidence Intervals ''' 163 | def compute_cis(data, confidence_level=0.05): 164 | """ 165 | FUNCTION: compute_cis 166 | ------------------------------------------------------ 167 | Given a Pandas dataframe of (n, labels), return another 168 | Pandas dataframe that is (3, labels). 169 | 170 | Each row is lower bound, mean, upper bound of a confidence 171 | interval with `confidence`. 172 | 173 | Args: 174 | * data - Pandas Dataframe, of shape (num_bootstrap_samples, num_labels) 175 | * confidence_level (optional) - confidence level of interval 176 | 177 | Returns: 178 | * Pandas Dataframe, of shape (3, labels), representing mean, lower, upper 179 | """ 180 | data_columns = list(data) 181 | intervals = [] 182 | for i in data_columns: 183 | series = data[i] 184 | sorted_perfs = series.sort_values() 185 | lower_index = int(confidence_level/2 * len(sorted_perfs)) - 1 186 | upper_index = int((1 - confidence_level/2) * len(sorted_perfs)) - 1 187 | lower = sorted_perfs.iloc[lower_index].round(4) 188 | upper = sorted_perfs.iloc[upper_index].round(4) 189 | mean = round(sorted_perfs.mean(), 4) 190 | interval = pd.DataFrame({i : [mean, lower, upper]}) 191 | intervals.append(interval) 192 | intervals_df = pd.concat(intervals, axis=1) 193 | intervals_df.index = ['mean', 'lower', 'upper'] 194 | return intervals_df 195 | 196 | def bootstrap(y_pred, y_true, cxr_labels, n_samples=1000, label_idx_map=None): 197 | ''' 198 | This function will randomly sample with replacement 199 | from y_pred and y_true then evaluate `n` times 200 | and obtain AUROC scores for each. 201 | 202 | You can specify the number of samples that should be 203 | used with the `n_samples` parameter. 204 | 205 | Confidence intervals will be generated from each 206 | of the samples. 207 | 208 | Note: 209 | * n_total_labels >= n_cxr_labels 210 | `n_total_labels` is greater iff alternative labels are being tested 211 | ''' 212 | np.random.seed(97) 213 | y_pred # (500, n_total_labels) 214 | y_true # (500, n_cxr_labels) 215 | 216 | idx = np.arange(len(y_true)) 217 | 218 | boot_stats = [] 219 | for i in tqdm(range(n_samples)): 220 | sample = resample(idx, replace=True, random_state=i) 221 | y_pred_sample = y_pred[sample] 222 | y_true_sample = y_true[sample] 223 | 224 | sample_stats = evaluate(y_pred_sample, y_true_sample, cxr_labels, label_idx_map=label_idx_map) 225 | boot_stats.append(sample_stats) 226 | 227 | boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample 228 | return boot_stats, compute_cis(boot_stats) 229 | -------------------------------------------------------------------------------- /chexzero/chexzero/clip.py: -------------------------------------------------------------------------------- 1 | """ 2 | MIT License 3 | 4 | Copyright (c) 2021 OpenAI 5 | 6 | Permission is hereby granted, free of charge, to any person obtaining a copy 7 | of this software and associated documentation files (the "Software"), to deal 8 | in the Software without restriction, including without limitation the rights 9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10 | copies of the Software, and to permit persons to whom the Software is 11 | furnished to do so, subject to the following conditions: 12 | 13 | The above copyright notice and this permission notice shall be included in all 14 | copies or substantial portions of the Software. 15 | 16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 | SOFTWARE. 23 | 24 | """ 25 | 26 | import hashlib 27 | import logging 28 | import os 29 | import urllib 30 | import warnings 31 | from typing import Union, List 32 | 33 | import torch 34 | from PIL import Image 35 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize 36 | from tqdm import tqdm 37 | 38 | from chexzero.chexzero.model import build_model 39 | from chexzero.chexzero.simple_tokenizer import SimpleTokenizer as _Tokenizer 40 | 41 | 42 | __all__ = ["available_models", "load", "tokenize"] 43 | _tokenizer = _Tokenizer() 44 | 45 | _MODELS = { 46 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", 47 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", 48 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", 49 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", 50 | } 51 | 52 | log = logging.getLogger(__name__) 53 | 54 | 55 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): 56 | os.makedirs(root, exist_ok=True) 57 | filename = os.path.basename(url) 58 | 59 | expected_sha256 = url.split("/")[-2] 60 | download_target = os.path.join(root, filename) 61 | 62 | if os.path.exists(download_target) and not os.path.isfile(download_target): 63 | raise RuntimeError(f"{download_target} exists and is not a regular file") 64 | 65 | if os.path.isfile(download_target): 66 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: 67 | return download_target 68 | else: 69 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") 70 | 71 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: 72 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop: 73 | while True: 74 | buffer = source.read(8192) 75 | if not buffer: 76 | break 77 | 78 | output.write(buffer) 79 | loop.update(len(buffer)) 80 | 81 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: 82 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") 83 | 84 | return download_target 85 | 86 | 87 | def _transform(n_px): 88 | return Compose( 89 | [ 90 | Resize(n_px, interpolation=Image.BICUBIC), 91 | CenterCrop(n_px), 92 | lambda image: image.convert("RGB"), 93 | ToTensor(), 94 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 95 | ] 96 | ) 97 | 98 | 99 | def available_models() -> List[str]: 100 | """Returns the names of available CLIP models""" 101 | return list(_MODELS.keys()) 102 | 103 | 104 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): 105 | """Load a CLIP model 106 | 107 | Parameters 108 | ---------- 109 | name : str 110 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict 111 | 112 | device : Union[str, torch.device] 113 | The device to put the loaded model 114 | 115 | jit : bool 116 | Whether to load the optimized JIT model (default) or more hackable non-JIT model. 117 | 118 | Returns 119 | ------- 120 | model : torch.nn.Module 121 | The CLIP model 122 | 123 | preprocess : Callable[[PIL.Image], torch.Tensor] 124 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input 125 | """ 126 | if name in _MODELS: 127 | model_path = _download(_MODELS[name]) 128 | elif os.path.isfile(name): 129 | model_path = name 130 | else: 131 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}") 132 | 133 | try: 134 | # loading JIT archive 135 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() 136 | state_dict = None 137 | except RuntimeError: 138 | # loading saved state dict 139 | if jit: 140 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") 141 | jit = False 142 | state_dict = torch.load(model_path, map_location="cpu") 143 | 144 | if not jit: 145 | model = build_model(state_dict or model.state_dict()).to(device) 146 | if str(device) == "cpu": 147 | model.float() 148 | return model, _transform(model.visual.input_resolution) 149 | 150 | # patch the device names 151 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) 152 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] 153 | 154 | def patch_device(module): 155 | graphs = [module.graph] if hasattr(module, "graph") else [] 156 | if hasattr(module, "forward1"): 157 | graphs.append(module.forward1.graph) 158 | 159 | for graph in graphs: 160 | for node in graph.findAllNodes("prim::Constant"): 161 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): 162 | node.copyAttributes(device_node) 163 | 164 | model.apply(patch_device) 165 | patch_device(model.encode_image) 166 | patch_device(model.encode_text) 167 | 168 | # patch dtype to float32 on CPU 169 | if str(device) == "cpu": 170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) 171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] 172 | float_node = float_input.node() 173 | 174 | def patch_float(module): 175 | graphs = [module.graph] if hasattr(module, "graph") else [] 176 | if hasattr(module, "forward1"): 177 | graphs.append(module.forward1.graph) 178 | 179 | for graph in graphs: 180 | for node in graph.findAllNodes("aten::to"): 181 | inputs = list(node.inputs()) 182 | for i in [1, 2]: # dtype can be the second or third argument to aten::to() 183 | if inputs[i].node()["value"] == 5: 184 | inputs[i].node().copyAttributes(float_node) 185 | 186 | model.apply(patch_float) 187 | patch_float(model.encode_image) 188 | patch_float(model.encode_text) 189 | 190 | model.float() 191 | 192 | return model, _transform(model.input_resolution.item()) 193 | 194 | 195 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: 196 | """ 197 | Returns the tokenized representation of given input string(s) 198 | 199 | Parameters 200 | ---------- 201 | texts : Union[str, List[str]] 202 | An input string or a list of input strings to tokenize 203 | 204 | context_length : int 205 | The context length to use; all CLIP models use 77 as the context length 206 | 207 | Returns 208 | ------- 209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] 210 | """ 211 | if isinstance(texts, str): 212 | texts = [texts] 213 | 214 | sot_token = _tokenizer.encoder["<|startoftext|>"] 215 | eot_token = _tokenizer.encoder["<|endoftext|>"] 216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] 217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) 218 | 219 | for i, tokens in enumerate(all_tokens): 220 | if len(tokens) > context_length: 221 | tokens = tokens[: context_length - 1] + [tokens[-1]] 222 | log.warning(f"Input {texts[i]} is too long for context length {context_length}") 223 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") 224 | result[i, : len(tokens)] = torch.tensor(tokens) 225 | 226 | return result 227 | -------------------------------------------------------------------------------- /cheff/ldm/modules/attention.py: -------------------------------------------------------------------------------- 1 | from inspect import isfunction 2 | import math 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, einsum 6 | from einops import rearrange, repeat 7 | 8 | from cheff.ldm.modules.diffusionmodules.util import checkpoint 9 | 10 | 11 | def exists(val): 12 | return val is not None 13 | 14 | 15 | def uniq(arr): 16 | return{el: True for el in arr}.keys() 17 | 18 | 19 | def default(val, d): 20 | if exists(val): 21 | return val 22 | return d() if isfunction(d) else d 23 | 24 | 25 | def max_neg_value(t): 26 | return -torch.finfo(t.dtype).max 27 | 28 | 29 | def init_(tensor): 30 | dim = tensor.shape[-1] 31 | std = 1 / math.sqrt(dim) 32 | tensor.uniform_(-std, std) 33 | return tensor 34 | 35 | 36 | # feedforward 37 | class GEGLU(nn.Module): 38 | def __init__(self, dim_in, dim_out): 39 | super().__init__() 40 | self.proj = nn.Linear(dim_in, dim_out * 2) 41 | 42 | def forward(self, x): 43 | x, gate = self.proj(x).chunk(2, dim=-1) 44 | return x * F.gelu(gate) 45 | 46 | 47 | class FeedForward(nn.Module): 48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): 49 | super().__init__() 50 | inner_dim = int(dim * mult) 51 | dim_out = default(dim_out, dim) 52 | project_in = nn.Sequential( 53 | nn.Linear(dim, inner_dim), 54 | nn.GELU() 55 | ) if not glu else GEGLU(dim, inner_dim) 56 | 57 | self.net = nn.Sequential( 58 | project_in, 59 | nn.Dropout(dropout), 60 | nn.Linear(inner_dim, dim_out) 61 | ) 62 | 63 | def forward(self, x): 64 | return self.net(x) 65 | 66 | 67 | def zero_module(module): 68 | """ 69 | Zero out the parameters of a module and return it. 70 | """ 71 | for p in module.parameters(): 72 | p.detach().zero_() 73 | return module 74 | 75 | 76 | def Normalize(in_channels): 77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) 78 | 79 | 80 | class LinearAttention(nn.Module): 81 | def __init__(self, dim, heads=4, dim_head=32): 82 | super().__init__() 83 | self.heads = heads 84 | hidden_dim = dim_head * heads 85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) 86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 87 | 88 | def forward(self, x): 89 | b, c, h, w = x.shape 90 | qkv = self.to_qkv(x) 91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) 92 | k = k.softmax(dim=-1) 93 | context = torch.einsum('bhdn,bhen->bhde', k, v) 94 | out = torch.einsum('bhde,bhdn->bhen', context, q) 95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) 96 | return self.to_out(out) 97 | 98 | 99 | class SpatialSelfAttention(nn.Module): 100 | def __init__(self, in_channels): 101 | super().__init__() 102 | self.in_channels = in_channels 103 | 104 | self.norm = Normalize(in_channels) 105 | self.q = torch.nn.Conv2d(in_channels, 106 | in_channels, 107 | kernel_size=1, 108 | stride=1, 109 | padding=0) 110 | self.k = torch.nn.Conv2d(in_channels, 111 | in_channels, 112 | kernel_size=1, 113 | stride=1, 114 | padding=0) 115 | self.v = torch.nn.Conv2d(in_channels, 116 | in_channels, 117 | kernel_size=1, 118 | stride=1, 119 | padding=0) 120 | self.proj_out = torch.nn.Conv2d(in_channels, 121 | in_channels, 122 | kernel_size=1, 123 | stride=1, 124 | padding=0) 125 | 126 | def forward(self, x): 127 | h_ = x 128 | h_ = self.norm(h_) 129 | q = self.q(h_) 130 | k = self.k(h_) 131 | v = self.v(h_) 132 | 133 | # compute attention 134 | b,c,h,w = q.shape 135 | q = rearrange(q, 'b c h w -> b (h w) c') 136 | k = rearrange(k, 'b c h w -> b c (h w)') 137 | w_ = torch.einsum('bij,bjk->bik', q, k) 138 | 139 | w_ = w_ * (int(c)**(-0.5)) 140 | w_ = torch.nn.functional.softmax(w_, dim=2) 141 | 142 | # attend to values 143 | v = rearrange(v, 'b c h w -> b c (h w)') 144 | w_ = rearrange(w_, 'b i j -> b j i') 145 | h_ = torch.einsum('bij,bjk->bik', v, w_) 146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) 147 | h_ = self.proj_out(h_) 148 | 149 | return x+h_ 150 | 151 | 152 | class CrossAttention(nn.Module): 153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): 154 | super().__init__() 155 | inner_dim = dim_head * heads 156 | context_dim = default(context_dim, query_dim) 157 | 158 | self.scale = dim_head ** -0.5 159 | self.heads = heads 160 | 161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False) 162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False) 163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False) 164 | 165 | self.to_out = nn.Sequential( 166 | nn.Linear(inner_dim, query_dim), 167 | nn.Dropout(dropout) 168 | ) 169 | 170 | def forward(self, x, context=None, mask=None): 171 | h = self.heads 172 | 173 | q = self.to_q(x) 174 | context = default(context, x) 175 | k = self.to_k(context) 176 | v = self.to_v(context) 177 | 178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) 179 | 180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale 181 | 182 | if exists(mask): 183 | mask = rearrange(mask, 'b ... -> b (...)') 184 | max_neg_value = -torch.finfo(sim.dtype).max 185 | mask = repeat(mask, 'b j -> (b h) () j', h=h) 186 | sim.masked_fill_(~mask, max_neg_value) 187 | 188 | # attention, what we cannot get enough of 189 | attn = sim.softmax(dim=-1) 190 | 191 | out = einsum('b i j, b j d -> b i d', attn, v) 192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h) 193 | return self.to_out(out) 194 | 195 | 196 | class BasicTransformerBlock(nn.Module): 197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): 198 | super().__init__() 199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention 200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) 201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim, 202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none 203 | self.norm1 = nn.LayerNorm(dim) 204 | self.norm2 = nn.LayerNorm(dim) 205 | self.norm3 = nn.LayerNorm(dim) 206 | self.checkpoint = checkpoint 207 | 208 | def forward(self, x, context=None): 209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint) 210 | 211 | def _forward(self, x, context=None): 212 | x = self.attn1(self.norm1(x)) + x 213 | x = self.attn2(self.norm2(x), context=context) + x 214 | x = self.ff(self.norm3(x)) + x 215 | return x 216 | 217 | 218 | class SpatialTransformer(nn.Module): 219 | """ 220 | Transformer block for image-like data. 221 | First, project the input (aka embedding) 222 | and reshape to b, t, d. 223 | Then apply standard transformer action. 224 | Finally, reshape to image 225 | """ 226 | def __init__(self, in_channels, n_heads, d_head, 227 | depth=1, dropout=0., context_dim=None): 228 | super().__init__() 229 | self.in_channels = in_channels 230 | inner_dim = n_heads * d_head 231 | self.norm = Normalize(in_channels) 232 | 233 | self.proj_in = nn.Conv2d(in_channels, 234 | inner_dim, 235 | kernel_size=1, 236 | stride=1, 237 | padding=0) 238 | 239 | self.transformer_blocks = nn.ModuleList( 240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim) 241 | for d in range(depth)] 242 | ) 243 | 244 | self.proj_out = zero_module(nn.Conv2d(inner_dim, 245 | in_channels, 246 | kernel_size=1, 247 | stride=1, 248 | padding=0)) 249 | 250 | def forward(self, x, context=None): 251 | # note: if no context is given, cross-attention defaults to self-attention 252 | b, c, h, w = x.shape 253 | x_in = x 254 | x = self.norm(x) 255 | x = self.proj_in(x) 256 | x = rearrange(x, 'b c h w -> b (h w) c') 257 | for block in self.transformer_blocks: 258 | x = block(x, context=context) 259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w) 260 | x = self.proj_out(x) 261 | return x + x_in -------------------------------------------------------------------------------- /chexzero/chexzero/README.md: -------------------------------------------------------------------------------- 1 | # Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning 2 | 3 |
4 | 5 | Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning, Nat. Biomed. Eng (2022). 6 | [Paper] 7 |
Ekin Tiu, Ellie Talius, Pujan Patel, Curtis P. Langlotz, Andrew Y. Ng, Pranav Rajpurkar
8 |
9 | 10 | ```bash 11 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9 12 | ``` 13 |
14 | 15 | Screen Shot 2022-09-15 at 10 57 16 AM 16 | 17 | This repository contains code to train a self-supervised learning model on chest X-ray images that lack explicit annotations and evalute this model's performance on pathology-classification tasks. 18 | 19 |
20 | 21 | Main Findings 22 | 23 | 24 | 1. **Automatically detecting pathologies in chest x-rays without explicit annotations:** Our method learns directly from the combination of images and unstructured radiology reports, thereby avoiding time-consuming labeling efforts. Our deep learning method is capable of predicting multiple pathologies and differential diagnoses that it had not explicitly seen during training. 25 | 2. **Matching radiologist performance on different tasks on an external test set:** Our method performed on par with human performance when evaluated on an external validation set (CheXpert) of chest x-ray images labeled for the presence of 14 different conditions by multiple radiologists. 26 | 3. **Outperforming approaches that train on explicitly labeled data on an external test set:** Using no labels, we outperformed a fully supervised approach (100% of labels) on 3 out of the 8 selected pathologies on a dataset (PadChest) collected in a different country. We further demonstrated high performance (AUC > 0.9) on 14 findings and at least 0.700 on 53 findings out of 107 radiographic findings that the method had not seen during training. 27 |
28 | 29 | 30 | ## Dependencies 31 | To clone all files: 32 | 33 | ```git clone https://github.com/rajpurkarlab/CheXzero.git``` 34 | 35 | To install Python dependencies: 36 | 37 | ```pip install -r requirements.txt``` 38 | 39 | ## Data 40 | ### Training Dataset 41 | 1. Download images come from [MIMIC-CXR JPG] https://physionet.org/content/mimic-cxr-jpg/2.0.0/ and reports from [MIMIC-CXR Database](https://physionet.org/content/mimic-cxr/2.0.0/) Note: in order to gain access to the data, you must be a credentialed user as defined on [PhysioNet](https://physionet.org/settings/credentialing/). 42 | 2. Copy the dataset into the `data/` directory. 43 | 3. Run `python run_preprocess.py` 44 | 4. This should preprocess the chest x-ray images into a Hierarchical Data Format (HDF) format used for training stored at `data/cxr.h5` and extract the impressions section as text from the corresponding chest x-ray radiology report stored at `data/mimic_impressions.csv` . 45 | 46 | ### Evaluation Dataset 47 | 48 | #### CheXpert Dataset 49 | The CheXpert dataset consists of chest radiographic examinations from Stanford Hospital, performed between October 2002 50 | and July 2017 in both inpatient and outpatient centers. Population-level characteristics are unavailable for the CheXpert test 51 | dataset, as they are used for official evaluation on the CheXpert leaderboard. 52 | 53 | The main data (CheXpert data) supporting the results of this study are available at https://aimi.stanford.edu/chexpert-chest-x-rays. 54 | 55 | The CheXpert **test** dataset has recently been made public, and can be found by following the steps in the [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels) repository. 56 | 57 | #### PadChest Dataset 58 | The PadChest dataset contains chest X-rays that were interpreted by 18 radiologists at the Hospital Universitario de San Juan, 59 | Alicante, Spain, from January 2009 to December 2017. The dataset contains 109,931 image studies and 168,861 images. 60 | PadChest also contains 206,222 study reports. 61 | 62 | The [PadChest](https://arxiv.org/abs/1901.07441) is publicly available at https://bimcv.cipf.es/bimcv-projects/padchest. Those who would like to use PadChest for experimentation should request access to PadChest at the [link](https://bimcv.cipf.es/bimcv-projects/padchest). 63 | 64 | ### Model Checkpoints 65 | Model checkpoints of CheXzero pre-trained on MIMIC-CXR are publicly available at the following [link](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing). Download files and save them in the `./checkpoints/chexzero_weights` directory. 66 | 67 | ## Running Training 68 | Run the following command to perform CheXzero pretraining. 69 | ```bash 70 | python run_train.py --cxr_filepath "./data/cxr.h5" --txt_filepath "data/mimic_impressions.csv" 71 | ``` 72 | 73 | ### Arguments 74 | * `--cxr_filepath` Directory to load chest x-ray image data from. 75 | * `--txt_filepath` Directory to load radiology report impressions text from. 76 | 77 | Use `-h` flag to see all optional arguments. 78 | 79 | ## Zero-Shot Inference 80 | See the following [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) for an example of how to use CheXzero to perform zero-shot inference on a chest x-ray dataset. The example shows how to output predictions from the model ensemble and evaluate performance of the model if ground truth labels are available. 81 | 82 | ```python 83 | import zero_shot 84 | 85 | # computes predictions for a set of images stored as a np array of probabilities for each pathology 86 | predictions, y_pred_avg = zero_shot.ensemble_models( 87 | model_paths=model_paths, 88 | cxr_filepath=cxr_filepath, 89 | cxr_labels=cxr_labels, 90 | cxr_pair_template=cxr_pair_template, 91 | cache_dir=cache_dir, 92 | ) 93 | ``` 94 | ### Arguments 95 | * `model_paths: List[str]`: List of paths to all checkpoints to be used in the ensemble. To run on a single model, input a list containing a single path. 96 | * `cxr_filepath: str`: Path to images `.h5` file 97 | * `cxr_labels: List[str]`: List of pathologies to query in each image 98 | * `cxr_pair_templates: Tuple[str, str]`: constrasting templates used to query model (see Figure 1 in article for visual explanation). 99 | * `cache_dir: str`: Directory to cache predictions of each checkpoint, use to avoid recomputing predictions. 100 | 101 | In order to use CheXzero for zero-shot inference, ensure the following requirements are met: 102 | * All input *`images`* must be stored in a single `.h5` (Hierarchical Data Format). See the [`img_to_h5`](https://github.com/rajpurkarlab/CheXzero/blob/main/preprocess_padchest.py#L156) function in [preprocess_padchest.py](https://github.com/rajpurkarlab/internal-chexzero/blob/cleanversion/preprocess_padchest.py) for an example of how to convert a list of paths to `.png` files into a valid `.h5` file. 103 | * The *ground truth `labels`* must be in a `.csv` dataframe where rows represent each image sample, and each column represents the binary labels for a particular pathology on each sample. 104 | * Ensure all [model checkpoints](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing) are stored in `checkpoints/chexzero_weights/`, or the `model_dir` that is specified in the notebook. 105 | 106 | ## Evaluation 107 | Given a numpy array of predictions (obtained from zero-shot inference), and a numpy array of ground truth labels, one can evaluate the performance of the model using the following code: 108 | ```python 109 | import zero_shot 110 | import eval 111 | 112 | # loads in ground truth labels into memory 113 | test_pred = y_pred_avg 114 | test_true = zero_shot.make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels) 115 | 116 | # evaluate model, no bootstrap 117 | cxr_results: pd.DataFrame = eval.evaluate(test_pred, test_true, cxr_labels) # eval on full test datset 118 | 119 | # boostrap evaluations for 95% confidence intervals 120 | bootstrap_results: Tuple[pd.DataFrame, pd.DataFrame] = eval.bootstrap(test_pred, test_true, cxr_labels) # (df of results for each bootstrap, df of CI) 121 | 122 | # print results with confidence intervals 123 | print(bootstrap_results[1]) 124 | ``` 125 | The results are represented as a `pd.DataFrame` which can be saved as a `.csv`. 126 | 127 | ### CheXpert Test Dataset 128 | In order to replicate the results in the paper, zero-shot inference and evaluation can be performed on the now publicly available CheXpert test dataset. 129 | 1) Download labels at [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels/blob/main/groundtruth.csv) and image files from [Stanford AIMI](https://stanfordaimi.azurewebsites.net/datasets/23c56a0d-15de-405b-87c8-99c30138950c) and save in the `./data` directory in `CheXzero/`. The test dataset images should have the following directory structure: 130 | ``` 131 | data/ 132 | ├─ CheXpert/ 133 | │ ├─ test/ 134 | │ │ ├─ patient64741/ 135 | │ │ │ ├─ study1/ 136 | │ │ │ │ ├─ view1_frontal.jpg 137 | │ │ ├─ .../ 138 | ``` 139 | 140 | 2) Run `run_preprocess.py` script with the following arguments: 141 | ```bash 142 | python run_preprocess.py --dataset_type "chexpert-test" --cxr_out_path "./data/chexpert_test.h5" --chest_x_ray_path "./data/CheXpert/test/" 143 | ``` 144 | This should save a `.h5` version of the test dataset images which can be used for evaluation. 145 | 146 | 3) Open sample zero-shot [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) and run all cells. If the directory structure is set up correctly, then all cells should run without errors. 147 | 148 | ## Issues 149 | Please open new issue threads specifying the issue with the codebase or report issues directly to ekintiu@stanford.edu. 150 | 151 | ## Citation 152 | ```bash 153 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9 154 | ``` 155 | 156 | ## License 157 | The source code for the site is licensed under the MIT license, which you can find in the `LICENSE` file. Also see `NOTICE.md` for attributions to third-party sources. 158 | -------------------------------------------------------------------------------- /cheff/sr/model.py: -------------------------------------------------------------------------------- 1 | """Classes and functions for neural networks.""" 2 | import math 3 | from functools import partial 4 | from typing import Optional, Tuple 5 | 6 | import torch 7 | from einops import rearrange 8 | from torch import ( 9 | Tensor, 10 | einsum, 11 | nn, 12 | ) 13 | 14 | 15 | class Residual(nn.Module): 16 | """Wrapper for residual connection of a function.""" 17 | 18 | def __init__(self, fn: nn.Module) -> None: 19 | """Initialize residual connection.""" 20 | super().__init__() 21 | self.fn = fn 22 | 23 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor: 24 | """Pass a tensor through the module.""" 25 | return self.fn(x, *args, **kwargs) + x 26 | 27 | 28 | def get_upsample_conv(dim: int) -> nn.ConvTranspose2d: 29 | """Initialize transposed convolution layer.""" 30 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1) 31 | 32 | 33 | def get_downsample_conv(dim: int) -> nn.Conv2d: 34 | """Initialize convolution layer.""" 35 | return nn.Conv2d(dim, dim, 4, 2, 1) 36 | 37 | 38 | class SinusoidalPositionEmbeddings(nn.Module): 39 | """Class for sinusoidal embeddings.""" 40 | 41 | def __init__(self, dim: int) -> None: 42 | """Initialize SinusoidalPositionEmbeddings.""" 43 | super().__init__() 44 | self.dim = dim 45 | 46 | def forward(self, time): 47 | """Pass a tensor through the module.""" 48 | device = time.device 49 | half_dim = self.dim // 2 50 | embeddings = math.log(10000) / (half_dim - 1) 51 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) 52 | embeddings = time[:, None] * embeddings[None, :] 53 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) 54 | return embeddings 55 | 56 | 57 | class Block(nn.Module): 58 | """Neural block with convolutions, norm and activations.""" 59 | 60 | def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None: 61 | """Initialize block.""" 62 | super().__init__() 63 | self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) 64 | self.norm = nn.GroupNorm(groups, dim_out) 65 | self.act = nn.SiLU() 66 | 67 | def forward(self, x: Tensor) -> Tensor: 68 | """Pass a tensor through the module.""" 69 | x = self.proj(x) 70 | x = self.norm(x) 71 | x = self.act(x) 72 | return x 73 | 74 | 75 | class ResnetBlock(nn.Module): 76 | """Residual block.""" 77 | 78 | def __init__( 79 | self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8 80 | ) -> None: 81 | """Initialize a residual block.""" 82 | super().__init__() 83 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out)) 84 | 85 | self.block1 = Block(dim, dim_out, groups=groups) 86 | self.block2 = Block(dim_out, dim_out, groups=groups) 87 | 88 | if dim != dim_out: 89 | self.res_conv = nn.Conv2d(dim, dim_out, 1) 90 | else: 91 | self.res_conv = nn.Identity() 92 | 93 | def forward(self, x: Tensor, t: Tensor) -> Tensor: 94 | """Pass a tensor through the module.""" 95 | h = self.block1(x) 96 | 97 | time_emb = self.mlp(t) 98 | h += rearrange(time_emb, 'b c -> b c 1 1') 99 | 100 | h = self.block2(h) 101 | 102 | return h + self.res_conv(x) 103 | 104 | 105 | class Attention(nn.Module): 106 | """Attention module.""" 107 | 108 | def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: 109 | """Initialize Attention.""" 110 | super().__init__() 111 | self.scale = dim_head**-0.5 112 | self.heads = heads 113 | hidden_dim = dim_head * heads 114 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 115 | self.to_out = nn.Conv2d(hidden_dim, dim, 1) 116 | 117 | def forward(self, x: Tensor) -> Tensor: 118 | """Pass a tensor through the module.""" 119 | b, c, h, w = x.shape 120 | qkv = self.to_qkv(x).chunk(3, dim=1) 121 | q, k, v = map( 122 | lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv 123 | ) 124 | q = q * self.scale 125 | 126 | sim = einsum('b h d i, b h d j -> b h i j', q, k) 127 | sim = sim - sim.amax(dim=-1, keepdim=True).detach() 128 | attn = sim.softmax(dim=-1) 129 | 130 | out = einsum('b h i j, b h d j -> b h i d', attn, v) 131 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w) 132 | return self.to_out(out) 133 | 134 | 135 | class LinearAttention(nn.Module): 136 | """Linear attention module.""" 137 | 138 | def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None: 139 | """Initialize linear attention module.""" 140 | super().__init__() 141 | self.scale = dim_head**-0.5 142 | self.heads = heads 143 | hidden_dim = dim_head * heads 144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) 145 | 146 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim)) 147 | 148 | def forward(self, x: Tensor) -> Tensor: 149 | """Pass a tensor through the module.""" 150 | b, c, h, w = x.shape 151 | qkv = self.to_qkv(x).chunk(3, dim=1) 152 | q, k, v = map( 153 | lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv 154 | ) 155 | 156 | q = q.softmax(dim=-2) 157 | k = k.softmax(dim=-1) 158 | 159 | q = q * self.scale 160 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v) 161 | 162 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q) 163 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w) 164 | return self.to_out(out) 165 | 166 | 167 | class PreNorm(nn.Module): 168 | """PreNorm Module.""" 169 | 170 | def __init__(self, dim: int, fn: nn.Module) -> None: 171 | """Initialize PreNorm.""" 172 | super().__init__() 173 | self.fn = fn 174 | self.norm = nn.GroupNorm(1, dim) 175 | 176 | def forward(self, x: Tensor) -> Tensor: 177 | """Pass a tensor through the module.""" 178 | x = self.norm(x) 179 | return self.fn(x) 180 | 181 | 182 | class Unet(nn.Module): 183 | """U-net architecture.""" 184 | 185 | def __init__( 186 | self, 187 | dim: int, 188 | init_dim: Optional[int] = None, 189 | out_dim: Optional[int] = None, 190 | dim_mults: Tuple = (1, 2, 4, 8), 191 | num_attention_layer: int = 5, 192 | channels: int = 3, 193 | block_groups: int = 8, 194 | ) -> None: 195 | """Initialize U-net.""" 196 | super().__init__() 197 | 198 | self.channels = channels 199 | 200 | init_dim = init_dim if init_dim is not None else dim // 3 * 2 201 | self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) 202 | 203 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)] 204 | in_out = list(zip(dims[:-1], dims[1:])) 205 | 206 | block_klass = partial(ResnetBlock, groups=block_groups) 207 | 208 | time_dim = dim * 4 209 | self.time_mlp = nn.Sequential( 210 | SinusoidalPositionEmbeddings(dim), 211 | nn.Linear(dim, time_dim), 212 | nn.GELU(), 213 | nn.Linear(time_dim, time_dim), 214 | ) 215 | 216 | # layers 217 | self.downs = nn.ModuleList([]) 218 | self.ups = nn.ModuleList([]) 219 | num_resolutions = len(in_out) 220 | 221 | for ind, (dim_in, dim_out) in enumerate(in_out): 222 | is_last = ind >= (num_resolutions - 1) 223 | has_att = ind >= (num_resolutions - num_attention_layer + 1) 224 | 225 | self.downs.append( 226 | nn.ModuleList( 227 | [ 228 | block_klass(dim_in, dim_out, time_emb_dim=time_dim), 229 | block_klass(dim_out, dim_out, time_emb_dim=time_dim), 230 | Residual(PreNorm(dim_out, LinearAttention(dim_out))) 231 | if has_att 232 | else nn.Identity(), 233 | get_downsample_conv(dim_out) if not is_last else nn.Identity(), 234 | ] 235 | ) 236 | ) 237 | 238 | mid_dim = dims[-1] 239 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 240 | if num_attention_layer >= 1: 241 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) 242 | else: 243 | self.mid_attn = nn.Identity() 244 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) 245 | 246 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): 247 | is_last = ind >= (num_resolutions - 1) 248 | has_att = ind >= (num_resolutions - num_attention_layer + 1) 249 | 250 | self.ups.append( 251 | nn.ModuleList( 252 | [ 253 | block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), 254 | block_klass(dim_in, dim_in, time_emb_dim=time_dim), 255 | Residual(PreNorm(dim_in, LinearAttention(dim_in))) 256 | if has_att 257 | else nn.Identity(), 258 | get_upsample_conv(dim_in) if not is_last else nn.Identity(), 259 | ] 260 | ) 261 | ) 262 | 263 | out_dim = out_dim if out_dim is not None else channels 264 | self.final_conv = nn.Sequential( 265 | Residual( 266 | nn.Sequential( 267 | Block(dim, dim, groups=block_groups), 268 | Block(dim, dim, groups=block_groups), 269 | ) 270 | ), 271 | nn.Conv2d(dim, out_dim, 1), 272 | ) 273 | 274 | def forward(self, x: Tensor, t: Tensor = None) -> Tensor: 275 | """Pass a tensor through the module.""" 276 | x = self.init_conv(x) 277 | t = self.time_mlp(t) 278 | 279 | h = [] 280 | 281 | # Downsampling 282 | for block1, block2, attn, downsample in self.downs: 283 | x = block1(x, t) 284 | x = block2(x, t) 285 | x = attn(x) 286 | h.append(x) 287 | x = downsample(x) 288 | 289 | # Bottleneck 290 | x = self.mid_block1(x, t) 291 | x = self.mid_attn(x) 292 | x = self.mid_block2(x, t) 293 | 294 | # Upsampling 295 | for block1, block2, attn, upsample in self.ups: 296 | x = torch.cat((x, h.pop()), dim=1) 297 | x = block1(x, t) 298 | x = block2(x, t) 299 | x = attn(x) 300 | x = upsample(x) 301 | 302 | return self.final_conv(x) 303 | -------------------------------------------------------------------------------- /chexzero/components/bbox_losses.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | from math import prod 4 | from typing import Tuple 5 | import einops 6 | from torch import Tensor 7 | import torch 8 | import torch.nn.functional as F 9 | from scipy.optimize import linear_sum_assignment 10 | 11 | 12 | def bbox_l1_ccost(boxes1, boxes2): 13 | """ 14 | :param boxes1: (... x R_s x 4) in (xc, yc, w, h) format 15 | :param boxes2: (... x R_t x 4) in (xc, yc, w, h) format 16 | :return: (... x R_s x R_t) 17 | """ 18 | assert boxes1.ndim == boxes2.ndim 19 | assert boxes1.ndim >= 3 20 | assert boxes1.shape[-1] >= 4 and boxes2.shape[-1] >= 4 21 | *batch_dims1, R_s, _ = boxes1.shape 22 | *batch_dims2, R_t, _ = boxes2.shape 23 | dims = prod(batch_dims1) 24 | assert batch_dims1 == batch_dims2 25 | # (... x R_s x 4) 26 | boxes1 = boxes1[..., :4].reshape(dims, R_s, 4) 27 | # (... x R_t x 4) 28 | boxes2 = boxes2[..., :4].reshape(dims, R_t, 4).to(dtype=boxes1.dtype) 29 | # (... x R_s x R_t) 30 | l1_cost = torch.cdist(boxes1, boxes2, p=1) 31 | return l1_cost.view(*batch_dims1, R_s, R_t) 32 | 33 | 34 | def bbox_l1_pcost(boxes1, boxes2): 35 | """ 36 | :param boxes1: (... x 4) in (xc, yc, w, h) format 37 | :param boxes2: (... x 4) in (xc, yc, w, h) format 38 | :return: (...) 39 | """ 40 | boxes1 = boxes1[..., :4] 41 | boxes2 = boxes2[..., :4] 42 | assert boxes1.shape == boxes2.shape 43 | return F.l1_loss(boxes1, boxes2, reduction='none').sum(dim=-1) 44 | 45 | 46 | def bbox_giou_ccost(boxes1, boxes2): 47 | """ 48 | :param boxes1: (... x R_s x 4) in (xc, yc, w, h) format 49 | :param boxes2: (... x R_t x 4) in (xc, yc, w, h) format 50 | :return: (... x R_s x R_t) 51 | """ 52 | assert boxes1.ndim == boxes2.ndim 53 | assert boxes1.ndim >= 3 54 | assert boxes1.shape[-1] >= 4 and boxes2.shape[-1] >= 4 55 | *batch_dims1, R_s, _ = boxes1.shape 56 | *batch_dims2, R_t, _ = boxes2.shape 57 | assert batch_dims1 == batch_dims2 58 | dims = prod(batch_dims1) 59 | # (... x R_s x 4) 60 | boxes1 = center_to_corners_format(boxes1[..., :4].reshape(dims, R_s, 4)) 61 | boxes2 = center_to_corners_format(boxes2[..., :4].reshape(dims, R_t, 4).to(dtype=boxes1.dtype)) 62 | # (... x R_s x R_t) 63 | giou_cost = 1. - box_giou(boxes1, boxes2) 64 | 65 | return giou_cost.view(*batch_dims1, R_s, R_t) 66 | 67 | def bbox_giou_pcost(boxes1, boxes2): 68 | """ 69 | :param boxes1: (... x 4) in (xc, yc, w, h) format 70 | :param boxes2: (... x 4) in (xc, yc, w, h) format 71 | :return: (...) 72 | """ 73 | boxes1 = center_to_corners_format(boxes1[..., :4]) 74 | boxes2 = center_to_corners_format(boxes2[..., :4]) 75 | assert boxes1.shape == boxes2.shape 76 | *batch_dims, _ = boxes1.shape 77 | boxes1 = boxes1.view(-1, 1, 1, 4) 78 | boxes2 = boxes2.view(-1, 1, 1, 4) 79 | # (..., 1, 1) 80 | giou_cost = 1. - box_giou(boxes1, boxes2) 81 | return giou_cost.view(*batch_dims) 82 | 83 | def box_giou(boxes1, boxes2): 84 | """ 85 | https://giou.stanford.edu/ 86 | :param boxes1: (... x N x M x 4) in (x0, y0, x1, y1) format 87 | :param boxes2: (... x N x M x 4) in (x0, y0, x1, y1) format 88 | :return: (... x N x M) giou 89 | """ 90 | # degenerate boxes gives inf / nan results 91 | # so do an early check 92 | if not (boxes1[..., 2:] >= boxes1[..., :2]).all(): 93 | raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}") 94 | if not (boxes2[..., 2:] >= boxes2[..., :2]).all(): 95 | raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}") 96 | 97 | # (... x N x M) 98 | iou, union = box_iou(boxes1, boxes2) 99 | 100 | top_left = torch.min(boxes1[..., :, None, :2], boxes2[..., None, :, :2]) 101 | bottom_right = torch.max(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:]) 102 | 103 | width_height = (bottom_right - top_left).clamp(min=0) # [...,N,M,2] 104 | area = width_height[..., :, :, 0] * width_height[..., :, :, 1] 105 | return iou - (area - union) / area.clamp(min=1e-7) 106 | 107 | def box_iou(boxes1, boxes2): 108 | """ 109 | :param boxes1: (... x N x 4) in (x0, y0, x1, y1) format 110 | :param boxes2: (... x M x 4) in (x0, y0, x1, y1) format 111 | :return: (... x N x M) iou and union 112 | """ 113 | area1 = box_area(boxes1) 114 | area2 = box_area(boxes2) 115 | 116 | left_top = torch.max(boxes1[..., :, None, :2], boxes2[..., None, :, :2]) # [...,N,M,2] 117 | right_bottom = torch.min(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2] 118 | 119 | width_height = (right_bottom - left_top).clamp(min=0) # [...,N,M,2] 120 | inter = width_height[..., :, :, 0] * width_height[..., :, :, 1] # [...,N,M] 121 | # [...,N,M] 122 | union = area1[..., :, None] + area2[..., None, :] - inter 123 | 124 | iou = inter / union.clamp(min=1e-7) 125 | return iou, union 126 | 127 | 128 | def box_area(boxes: Tensor) -> Tensor: 129 | """ 130 | :param boxes: (..., 4) in (x0, y0, x1, y1) format 131 | :return: (...) 132 | """ 133 | boxes = _upcast(boxes) 134 | return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1]) 135 | 136 | 137 | @torch.no_grad() 138 | def match_multiregions(cost, mask, non_matched_region_mode: str, greedy_match=False): 139 | """ 140 | :param cost: (N x C x R_s x R_t) 141 | :param mask: (N x C x R_t) 142 | """ 143 | N, C, R_s, R_t = cost.shape 144 | cost = cost.flatten(0, 1) # (N*C x R_s x R_t) 145 | mask = mask.flatten(0, 1) # (N*C x R_t) 146 | # List (N*C) of (R_s, R_t') tensors where R_t' is the number of target boxes for the class and sample 147 | matches_list, assign_mask_list = zip(*[compute_matching(c, m, non_matched_region_mode, greedy_match=greedy_match) for c, m in zip(cost, mask)]) 148 | 149 | # (N x C x R_s) 150 | matches = einops.rearrange(torch.stack(matches_list), '(n c) r_s -> n c r_s', n=N, c=C) 151 | # (N x C x R_s x R_t) 152 | assign_mask = einops.rearrange(torch.stack(assign_mask_list), '(n c) r_s r_t -> n c r_s r_t', n=N, c=C) 153 | 154 | return matches, assign_mask 155 | 156 | 157 | def compute_matching(cost: torch.FloatTensor, mask: torch.BoolTensor, non_matched_region_mode: str, greedy_match=False): 158 | device = cost.device 159 | R_t_all = mask.shape[-1] 160 | cost = cost[:, mask] 161 | if not greedy_match: 162 | cost = cost.cpu() 163 | 164 | R_s, R_t = cost.shape 165 | if R_t == 0: 166 | return torch.zeros((R_s,), dtype=torch.bool, device=device), torch.zeros((R_s, R_t_all), dtype=torch.bool, device=device) 167 | 168 | if non_matched_region_mode == 'balanced_match': 169 | cost = extend_cost_for_balanced_match(cost) 170 | else: 171 | assert non_matched_region_mode == 'ignore', f'Unknown non_matched_region_mode {non_matched_region_mode}' 172 | 173 | if greedy_match: 174 | indices_s, indices_t = _do_matching_greedy(cost) 175 | else: 176 | indices_s, indices_t = _do_matching(cost) 177 | # (N_match) 178 | matches = indices_s[indices_t < R_t] 179 | # (N_match) 180 | indices_t = indices_t % R_t # for balanced match 181 | 182 | # (R_s) 183 | matches: torch.BoolTensor = torch.zeros((R_s,), dtype=torch.bool, device=device if greedy_match else 'cpu').scatter_(0, matches, True) 184 | # (R_s x R_t) 185 | # based on both, indices_s and indices_t 186 | assign_mask = torch.zeros((R_s, R_t), dtype=torch.bool, device=device if greedy_match else 'cpu') 187 | assign_mask[indices_s, indices_t] = True 188 | 189 | if not greedy_match: 190 | matches = matches.to(device, non_blocking=True) 191 | assign_mask = assign_mask.to(device, non_blocking=True) 192 | 193 | assign_mask_all = torch.zeros((R_s, R_t_all), dtype=torch.bool, device=device) 194 | assign_mask_all[:, mask] = assign_mask 195 | 196 | return matches, assign_mask_all 197 | 198 | 199 | def _do_matching(cost: torch.FloatTensor) -> Tuple[torch.LongTensor, torch.LongTensor]: 200 | indices_s, indices_t = linear_sum_assignment(cost.numpy()) 201 | indices_s = torch.from_numpy(indices_s) 202 | indices_t = torch.from_numpy(indices_t) 203 | return indices_s, indices_t 204 | 205 | 206 | def _do_matching_greedy(cost: torch.FloatTensor) -> Tuple[torch.LongTensor, torch.LongTensor]: 207 | """ 208 | :param R_s, R_t 209 | :return: (N_match), (N_match) 210 | """ 211 | indices_s, indices_t = [], [] 212 | R_s, R_t = cost.shape 213 | cost = cost.clone() 214 | 215 | for _ in range(min(R_s, R_t)): 216 | mins_R_t, argmins_R_t = cost.min(dim=1) 217 | s = mins_R_t.argmin() 218 | t = argmins_R_t[s] 219 | indices_s.append(s) 220 | indices_t.append(t) 221 | cost[s, :] = float('inf') 222 | cost[:, t] = float('inf') 223 | 224 | indices_s = torch.stack(indices_s) 225 | indices_t = torch.stack(indices_t) 226 | return indices_s, indices_t 227 | 228 | 229 | def extend_cost_for_balanced_match(cost): 230 | # cost: (R_s x R_t') 231 | R_s, R_t = cost.shape 232 | max_cost_diff = cost.max() - cost.min() 233 | 234 | if R_s > R_t: 235 | n_copies = R_s // R_t + (1 if R_s % R_t != 0 else 0) 236 | R_t_total = R_t * n_copies 237 | assert R_t_total >= R_s 238 | 239 | # (n_copies - 1) 240 | extra_costs = torch.full((n_copies - 1,), max_cost_diff + 1., 241 | dtype=torch.float, device=cost.device).cumsum(0) 242 | # (R_t_total) 243 | extra_costs = torch.cat([ 244 | torch.zeros((R_t,), dtype=torch.float, device=cost.device), 245 | extra_costs.repeat_interleave(R_t) 246 | ]) 247 | # (R_s x R_t_total) 248 | cost = cost.repeat(1, n_copies) + extra_costs[None, :] 249 | 250 | return cost 251 | 252 | 253 | # ----------------- Conversion utils ----------------- 254 | def center_to_corners_format(x): 255 | """ 256 | Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format 257 | (x_0, y_0, x_1, y_1). 258 | """ 259 | x_c, y_c, w, h, *args = x.unbind(-1) 260 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h), *args] 261 | return torch.stack(b, dim=-1) 262 | 263 | def corners_to_center_format(x): 264 | """ 265 | Converts a PyTorch tensor of bounding boxes of corners format (x_0, y_0, x_1, y_1) to center format 266 | (center_x, center_y, width, height). 267 | """ 268 | x_0, y_0, x_1, y_1, *args = x.unbind(-1) 269 | b = [(x_0 + x_1) / 2, (y_0 + y_1) / 2, (x_1 - x_0), (y_1 - y_0), *args] 270 | return torch.stack(b, dim=-1) 271 | 272 | def _upcast(t: Tensor) -> Tensor: 273 | # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type 274 | if t.is_floating_point(): 275 | return t if t.dtype in (torch.float32, torch.float64) else t.float() 276 | else: 277 | return t if t.dtype in (torch.int32, torch.int64) else t.int() 278 | -------------------------------------------------------------------------------- /cheff/ldm/models/diffusion/classifier.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import pytorch_lightning as pl 4 | from omegaconf import OmegaConf 5 | from torch.nn import functional as F 6 | from torch.optim import AdamW 7 | from torch.optim.lr_scheduler import LambdaLR 8 | from copy import deepcopy 9 | from einops import rearrange 10 | from glob import glob 11 | from natsort import natsorted 12 | 13 | from cheff.ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel 14 | from cheff.ldm.util import log_txt_as_img, default, ismap, instantiate_from_config 15 | 16 | __models__ = { 17 | 'class_label': EncoderUNetModel, 18 | 'segmentation': UNetModel 19 | } 20 | 21 | 22 | def disabled_train(self, mode=True): 23 | """Overwrite model.train with this function to make sure train/eval mode 24 | does not change anymore.""" 25 | return self 26 | 27 | 28 | class NoisyLatentImageClassifier(pl.LightningModule): 29 | 30 | def __init__(self, 31 | diffusion_path, 32 | num_classes, 33 | ckpt_path=None, 34 | pool='attention', 35 | label_key=None, 36 | diffusion_ckpt_path=None, 37 | scheduler_config=None, 38 | weight_decay=1.e-2, 39 | log_steps=10, 40 | monitor='val/loss', 41 | *args, 42 | **kwargs): 43 | super().__init__(*args, **kwargs) 44 | self.num_classes = num_classes 45 | # get latest config of sr model 46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1] 47 | self.diffusion_config = OmegaConf.load(diffusion_config).model 48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path 49 | self.load_diffusion() 50 | 51 | self.monitor = monitor 52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1 53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps 54 | self.log_steps = log_steps 55 | 56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \ 57 | else self.diffusion_model.cond_stage_key 58 | 59 | assert self.label_key is not None, 'label_key neither in sr model nor in model.params' 60 | 61 | if self.label_key not in __models__: 62 | raise NotImplementedError() 63 | 64 | self.load_classifier(ckpt_path, pool) 65 | 66 | self.scheduler_config = scheduler_config 67 | self.use_scheduler = self.scheduler_config is not None 68 | self.weight_decay = weight_decay 69 | 70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False): 71 | sd = torch.load(path, map_location="cpu") 72 | if "state_dict" in list(sd.keys()): 73 | sd = sd["state_dict"] 74 | keys = list(sd.keys()) 75 | for k in keys: 76 | for ik in ignore_keys: 77 | if k.startswith(ik): 78 | print("Deleting key {} from state_dict.".format(k)) 79 | del sd[k] 80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict( 81 | sd, strict=False) 82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") 83 | if len(missing) > 0: 84 | print(f"Missing Keys: {missing}") 85 | if len(unexpected) > 0: 86 | print(f"Unexpected Keys: {unexpected}") 87 | 88 | def load_diffusion(self): 89 | model = instantiate_from_config(self.diffusion_config) 90 | self.diffusion_model = model.eval() 91 | self.diffusion_model.train = disabled_train 92 | for param in self.diffusion_model.parameters(): 93 | param.requires_grad = False 94 | 95 | def load_classifier(self, ckpt_path, pool): 96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params) 97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels 98 | model_config.out_channels = self.num_classes 99 | if self.label_key == 'class_label': 100 | model_config.pool = pool 101 | 102 | self.model = __models__[self.label_key](**model_config) 103 | if ckpt_path is not None: 104 | print('#####################################################################') 105 | print(f'load from ckpt "{ckpt_path}"') 106 | print('#####################################################################') 107 | self.init_from_ckpt(ckpt_path) 108 | 109 | @torch.no_grad() 110 | def get_x_noisy(self, x, t, noise=None): 111 | noise = default(noise, lambda: torch.randn_like(x)) 112 | continuous_sqrt_alpha_cumprod = None 113 | if self.diffusion_model.use_continuous_noise: 114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1) 115 | # todo: make sure t+1 is correct here 116 | 117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise, 118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod) 119 | 120 | def forward(self, x_noisy, t, *args, **kwargs): 121 | return self.model(x_noisy, t) 122 | 123 | @torch.no_grad() 124 | def get_input(self, batch, k): 125 | x = batch[k] 126 | if len(x.shape) == 3: 127 | x = x[..., None] 128 | x = rearrange(x, 'b h w c -> b c h w') 129 | x = x.to(memory_format=torch.contiguous_format).float() 130 | return x 131 | 132 | @torch.no_grad() 133 | def get_conditioning(self, batch, k=None): 134 | if k is None: 135 | k = self.label_key 136 | assert k is not None, 'Needs to provide label key' 137 | 138 | targets = batch[k].to(self.device) 139 | 140 | if self.label_key == 'segmentation': 141 | targets = rearrange(targets, 'b h w c -> b c h w') 142 | for down in range(self.numd): 143 | h, w = targets.shape[-2:] 144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest') 145 | 146 | # targets = rearrange(targets,'b c h w -> b h w c') 147 | 148 | return targets 149 | 150 | def compute_top_k(self, logits, labels, k, reduction="mean"): 151 | _, top_ks = torch.topk(logits, k, dim=1) 152 | if reduction == "mean": 153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item() 154 | elif reduction == "none": 155 | return (top_ks == labels[:, None]).float().sum(dim=-1) 156 | 157 | def on_train_epoch_start(self): 158 | # save some memory 159 | self.diffusion_model.model.to('cpu') 160 | 161 | @torch.no_grad() 162 | def write_logs(self, loss, logits, targets): 163 | log_prefix = 'train' if self.training else 'val' 164 | log = {} 165 | log[f"{log_prefix}/loss"] = loss.mean() 166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k( 167 | logits, targets, k=1, reduction="mean" 168 | ) 169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k( 170 | logits, targets, k=5, reduction="mean" 171 | ) 172 | 173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True) 174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False) 175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True) 176 | lr = self.optimizers().param_groups[0]['lr'] 177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True) 178 | 179 | def shared_step(self, batch, t=None): 180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key) 181 | targets = self.get_conditioning(batch) 182 | if targets.dim() == 4: 183 | targets = targets.argmax(dim=1) 184 | if t is None: 185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long() 186 | else: 187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long() 188 | x_noisy = self.get_x_noisy(x, t) 189 | logits = self(x_noisy, t) 190 | 191 | loss = F.cross_entropy(logits, targets, reduction='none') 192 | 193 | self.write_logs(loss.detach(), logits.detach(), targets.detach()) 194 | 195 | loss = loss.mean() 196 | return loss, logits, x_noisy, targets 197 | 198 | def training_step(self, batch, batch_idx): 199 | loss, *_ = self.shared_step(batch) 200 | return loss 201 | 202 | def reset_noise_accs(self): 203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in 204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)} 205 | 206 | def on_validation_start(self): 207 | self.reset_noise_accs() 208 | 209 | @torch.no_grad() 210 | def validation_step(self, batch, batch_idx): 211 | loss, *_ = self.shared_step(batch) 212 | 213 | for t in self.noisy_acc: 214 | _, logits, _, targets = self.shared_step(batch, t) 215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean')) 216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean')) 217 | 218 | return loss 219 | 220 | def configure_optimizers(self): 221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay) 222 | 223 | if self.use_scheduler: 224 | scheduler = instantiate_from_config(self.scheduler_config) 225 | 226 | print("Setting up LambdaLR scheduler...") 227 | scheduler = [ 228 | { 229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule), 230 | 'interval': 'step', 231 | 'frequency': 1 232 | }] 233 | return [optimizer], scheduler 234 | 235 | return optimizer 236 | 237 | @torch.no_grad() 238 | def log_images(self, batch, N=8, *args, **kwargs): 239 | log = dict() 240 | x = self.get_input(batch, self.diffusion_model.first_stage_key) 241 | log['inputs'] = x 242 | 243 | y = self.get_conditioning(batch) 244 | 245 | if self.label_key == 'class_label': 246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"]) 247 | log['labels'] = y 248 | 249 | if ismap(y): 250 | log['labels'] = self.diffusion_model.to_rgb(y) 251 | 252 | for step in range(self.log_steps): 253 | current_time = step * self.log_time_interval 254 | 255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time) 256 | 257 | log[f'inputs@t{current_time}'] = x_noisy 258 | 259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes) 260 | pred = rearrange(pred, 'b h w c -> b c h w') 261 | 262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred) 263 | 264 | for key in log: 265 | log[key] = log[key][:N] 266 | 267 | return log 268 | --------------------------------------------------------------------------------