├── requirements.txt
├── cheff
├── sr
│ ├── __init__.py
│ ├── utils.py
│ ├── sampler.py
│ ├── schedule.py
│ └── model.py
├── ldm
│ ├── models
│ │ ├── __init__.py
│ │ └── diffusion
│ │ │ ├── __init__.py
│ │ │ └── classifier.py
│ ├── modules
│ │ ├── __init__.py
│ │ ├── encoders
│ │ │ ├── __init__.py
│ │ │ └── modules.py
│ │ ├── distributions
│ │ │ ├── __init__.py
│ │ │ └── distributions.py
│ │ ├── diffusionmodules
│ │ │ └── __init__.py
│ │ ├── image_degradation
│ │ │ ├── __init__.py
│ │ │ └── utils
│ │ │ │ └── test.png
│ │ ├── losses
│ │ │ ├── __init__.py
│ │ │ ├── contperceptual.py
│ │ │ └── vqperceptual.py
│ │ ├── ema.py
│ │ └── attention.py
│ ├── __init__.py
│ ├── lr_scheduler.py
│ └── util.py
├── __init__.py
└── machex.py
├── chexzero
├── util
│ ├── __init__.py
│ ├── data_utils.py
│ ├── prompt_utils.py
│ └── plot_grounding.py
├── components
│ ├── __init__.py
│ ├── pooling.py
│ ├── mlp.py
│ ├── bbox_prediction.py
│ ├── soft_roi_pool.py
│ ├── classification_losses.py
│ └── bbox_losses.py
├── chexzero
│ ├── bpe_simple_vocab_16e6.txt.gz
│ ├── LICENSE
│ ├── NOTICE.md
│ ├── simple_tokenizer.py
│ ├── eval.py
│ ├── clip.py
│ └── README.md
├── txt_encoder
│ ├── __init__.py
│ └── chexzero_txt_encoder.py
├── img_encoder
│ ├── __init__.py
│ └── chexzero_img_encoder.py
└── main.py
├── .gitmodules
├── .gitignore
├── config.yaml
├── README.md
├── src
├── utils.py
├── data.py
└── loss.py
└── models
├── clip.py
├── classifier.py
└── cheff.py
/requirements.txt:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/sr/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/models/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chexzero/util/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/chexzero/components/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/models/diffusion/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/encoders/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/distributions/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/diffusionmodules/__init__.py:
--------------------------------------------------------------------------------
1 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/image_degradation/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/cheff/ldm/__init__.py:
--------------------------------------------------------------------------------
1 | from cheff.ldm.util import instantiate_from_config
--------------------------------------------------------------------------------
/cheff/ldm/modules/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from cheff.ldm.modules.losses.contperceptual import LPIPSWithDiscriminator
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "chexpert"]
2 | path = chexpert
3 | url = https://github.com/jfhealthcare/Chexpert.git
4 |
--------------------------------------------------------------------------------
/chexzero/chexzero/bpe_simple_vocab_16e6.txt.gz:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/berkegokmen1/counterfactual-chexray-disease-editing/HEAD/chexzero/chexzero/bpe_simple_vocab_16e6.txt.gz
--------------------------------------------------------------------------------
/cheff/ldm/modules/image_degradation/utils/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/berkegokmen1/counterfactual-chexray-disease-editing/HEAD/cheff/ldm/modules/image_degradation/utils/test.png
--------------------------------------------------------------------------------
/cheff/__init__.py:
--------------------------------------------------------------------------------
1 | from cheff.ldm.inference import (
2 | CheffAEModel,
3 | CheffLDM,
4 | CheffLDMT2I,
5 | CheffLDMT2IEdit,
6 | CheffLDMMaskCond,
7 | CheffLDMImageMaskCond,
8 | CheffLDMImageCond,
9 | )
10 | from cheff.sr.sampler import CheffSRModel
11 |
--------------------------------------------------------------------------------
/chexzero/txt_encoder/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from dataclasses import dataclass
3 | from typing import List
4 |
5 | from torch import BoolTensor, Tensor
6 |
7 | from util.data_utils import TensorDataclassMixin
8 |
9 |
10 | @dataclass
11 | class TextEncoderOutput(TensorDataclassMixin):
12 | # (N x S x d) -> already projected to model space
13 | sentence_features: Tensor
14 | # (N x S)
15 | sentence_mask: BoolTensor
16 | sentences: List[List[str]]
17 | flattened_sentences: List[str]
18 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | __pycache__
3 | *.pyc
4 | *.pyo
5 | *.pyd
6 | .Python
7 | env/
8 | venv/
9 | ENV/
10 | env.bak/
11 | venv.bak/
12 | .idea/
13 | .vscode/
14 | .DS_Store
15 | .ipynb_checkpoints
16 | data/
17 | logs/
18 | outputs/
19 | runs/
20 | *.sqlite3
21 | *.db
22 | *.log
23 | *.pot
24 | *.mo
25 | *.cover
26 | htmlcov/
27 | .tox/
28 | .nox/
29 | .coverage
30 | .coverage.*
31 | .cache
32 | nosetests.xml
33 | coverage.xml
34 | *.coveragerc
35 | *.codecov.yml
36 | *.pylintrc
37 | *.mypy_cache/
38 | .dmypy.json
39 | .pyre/
40 | .pytype/
41 | .pyright/
42 | .pytest_cache/
43 | .ipynb_checkpoints
--------------------------------------------------------------------------------
/chexzero/components/pooling.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | from torch import BoolTensor, nn
3 | import torch
4 | import torch.nn.functional as F
5 |
6 |
7 | class GlobalAvgPool(nn.Module):
8 | def __init__(self):
9 | super(GlobalAvgPool, self).__init__()
10 |
11 | def forward(self, x, mask: Optional[BoolTensor] = None):
12 | N, *dims, d = x.shape
13 | x = x.view(N, -1, d)
14 |
15 | if mask is not None:
16 | mask = mask.view(N, -1).bool()
17 | x = torch.masked_fill(x, ~mask[:, :, None], 0.)
18 | # (N x d)
19 | pooled = x.sum(1) / (mask.sum(1)[:, None] + 1e-7)
20 | else:
21 | pooled = torch.mean(x, dim=1)
22 |
23 | return pooled
24 |
--------------------------------------------------------------------------------
/chexzero/img_encoder/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from dataclasses import dataclass
3 | from typing import Optional
4 |
5 | from torch import BoolTensor, Tensor
6 |
7 | from util.data_utils import TensorDataclassMixin
8 |
9 |
10 | @dataclass
11 | class ImageEncoderOutput(TensorDataclassMixin):
12 | # (N x H x W x d) -> already projected to model space
13 | patch_features: Tensor
14 | # (N x H x W x d)
15 | pos_embeddings: Tensor
16 | # (N x d)
17 | global_features: Optional[Tensor] = None
18 |
19 | @property
20 | def device(self):
21 | return self.patch_features.device
22 |
23 | @property
24 | def dtype(self):
25 | return self.patch_features.dtype
26 |
27 | @property
28 | def N(self):
29 | return self.patch_features.shape[0]
30 |
31 | @property
32 | def d(self):
33 | return self.patch_features.shape[-1]
34 |
35 |
--------------------------------------------------------------------------------
/chexzero/chexzero/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 Rajpurkar Lab
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/config.yaml:
--------------------------------------------------------------------------------
1 | device: cuda
2 |
3 | model:
4 |
5 | clip:
6 | path: /path/to/chexzero/chex/best_64_5e-05_original_22000_0.864.pt
7 |
8 | cheff:
9 | ae_path: /path/to/chex/chexray-diffusion-ckpts/cheff_autoencoder.pt
10 | ldm_path: /path/to/chex/chexray-diffusion-ckpts/cheff_diff_uncond.pt
11 | load_external: false
12 | external_path: /path/to/chex/chexray-diffusion-ckpts/cheff_diff_uncond.pt
13 |
14 | classifier:
15 | model: resnet152 # attnresnet152
16 | restore: true
17 | path: /path/to/chex/chexpert/exp-resnet-pretrained-256/best_checkpoints/checkpoint_9.pt
18 | resize: 256
19 | lr: 0.0001
20 | pretrained: true
21 |
22 | train:
23 | method: attn
24 |
25 | data:
26 | path: /path/to/chex/CheXpert-v1.0-small
27 | prefix: CheXpert-v1.0-small/
28 | resize: 256
29 | mini_data: 0
30 | batch_size: 4
31 | train_batch_size: 1
32 | test_batch_size: 1
33 | num_workers: 7
34 | mode: train
35 | prefetch_factor: 4
36 |
37 | finetune:
38 | experiment_name: null
39 | target: null
40 | log_dir: ./experiments/${finetune.experiment_name}
41 | num_timesteps: 50
42 | learning_rate: 1e-5
43 | num_epochs: 3
44 | lambdas:
45 | classification_step: 1.5
46 | l1_step: 0.1
47 | anchor_step: 0.01
48 | ###
49 | clip_direction_step: 0.
50 | clip_distance_step: 0.
51 | clip_step: 0
--------------------------------------------------------------------------------
/chexzero/chexzero/NOTICE.md:
--------------------------------------------------------------------------------
1 | # Third-Party Code: CheXZero
2 | Third-party code from https://github.com/rajpurkarlab/CheXzero
3 |
4 | ## Project Licenses
5 | The source code of this repository was derived from CheXzero (https://github.com/rajpurkarlab/CheXzero) and OpenAI CLIP (https://github.com/openai/CLIP).
6 |
7 | ### Open Source License / Copyright Notice
8 | ```
9 | MIT License
10 |
11 | Copyright (c) 2021 OpenAI
12 |
13 | Permission is hereby granted, free of charge, to any person obtaining a copy
14 | of this software and associated documentation files (the "Software"), to deal
15 | in the Software without restriction, including without limitation the rights
16 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
17 | copies of the Software, and to permit persons to whom the Software is
18 | furnished to do so, subject to the following conditions:
19 |
20 | The above copyright notice and this permission notice shall be included in all
21 | copies or substantial portions of the Software.
22 |
23 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
24 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
25 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
26 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
27 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
28 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
29 | SOFTWARE.
30 | ```
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Counterfactual Disease Removal and Generation in Chest X-Rays Using Diffusion Models
2 |
3 | [Paper](https://berkegokmen1.github.io/counterfactual-disease-removal-and-generation-chest-xray/static/paper.pdf) | [Project Website](https://berkegokmen1.github.io/counterfactual-disease-removal-and-generation-chest-xray/) | [BibTeX](#bibtex)
4 |
5 | ## Authors
6 | [Ahmet Berke Gökmen](https://berkegokmen1.github.io/), [Ender Konukoglu](https://people.ee.ethz.ch/~kender/)
7 |
8 | 
9 |
10 | ## TODO
11 | - [X] Release Website
12 | - [X] Release Code
13 | - [ ] Run Instructions
14 |
15 | ## Setup
16 |
17 | ```bash
18 | conda create -n chexray-editing python=3.10
19 | pip install -r requirements.txt [TODO]
20 | ```
21 |
22 | ## Inference
23 | Please download chexzero, chexpert and chexray-diffusion checkpoints from their respective repositories and update the paths in `config.yaml`.
24 |
25 | In additon to the checkpoints, you'll need to download `CheXpert-v1.0-small` dataset from the official chexpert website or you may use any chest x-ray image.
26 | ```bash
27 | python finetune_sample.py --config config.yaml --target "Pleural Effusion" --mode "removal" --experiment_name "demo"
28 | ```
29 |
30 | ## Questions
31 |
32 | You may reach me through [LinkedIn](https://www.linkedin.com/in/berkegokmen/).
33 |
34 | ## This work would not have been possible without:
35 | - https://github.com/rajpurkarlab/CheXzero
36 | - https://github.com/jfhealthcare/Chexpert
37 | - https://github.com/saiboxx/chexray-diffusion
38 |
39 | ## BibTeX
40 | ```
41 | ```
42 |
43 |
44 |
--------------------------------------------------------------------------------
/chexzero/main.py:
--------------------------------------------------------------------------------
1 | import chexzero.clip
2 | import torch
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from chexzero.clip import tokenize
6 | from PIL import Image
7 |
8 |
9 | model, preprocess = chexzero.clip.load("ViT-B/32", device="cpu", jit=False)
10 | model.load_state_dict(
11 | torch.load("/cluster/work/cvl/agoekmen/chexzero/best_64_5e-05_original_22000_0.864.pt", map_location="cpu")
12 | )
13 |
14 |
15 | def encode_text(text: str):
16 | tokenized_text = tokenize(text)
17 | text_features = model.encode_text(tokenized_text)
18 | return text_features
19 |
20 |
21 | def encode_image(image: torch.Tensor):
22 | image_features = model.encode_image(image)
23 | return image_features
24 |
25 |
26 | def calculate_scores(text_features: torch.Tensor, image_features: torch.Tensor):
27 | # normalized features
28 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
29 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
30 |
31 | # cosine similarity as logits
32 | logit_scale = model.logit_scale.exp()
33 | logits_per_image = logit_scale * image_features @ text_features.t()
34 | logits_per_text = logit_scale * text_features @ image_features.t()
35 |
36 | # shape = [global_batch_size, global_batch_size]
37 | return logits_per_image, logits_per_text
38 |
39 |
40 | if __name__ == "__main__":
41 | text = ["Pleural Effusion", "No Pleural Effusion"]
42 | image = Image.open(
43 | "/cluster/home/agoekmen/projects/chexray-diffusion/removal-editing/all_exp/working_experiments/__attn/original.png"
44 | )
45 | image = preprocess(image).unsqueeze(0)
46 | text_features = encode_text(text)
47 | image_features = encode_image(image)
48 |
49 | print(text_features.shape)
50 | print(image_features.shape)
51 |
52 | logits_per_image, logits_per_text = calculate_scores(text_features, image_features)
53 | print(logits_per_image)
54 | print(logits_per_text)
55 |
56 | probs = logits_per_image.softmax(dim=1)
57 | print(probs)
58 |
--------------------------------------------------------------------------------
/chexzero/components/mlp.py:
--------------------------------------------------------------------------------
1 | from typing import Optional
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from util.model_utils import get_activation
7 |
8 | class MLP(nn.Module):
9 | def __init__(
10 | self,
11 | n_layers: int,
12 | d_in: int, d_out: Optional[int] = None,
13 | use_bn=False, dropout=0.0, dropout_last_layer=True, act=nn.ReLU, d_hidden_factor: int = 4, d_hidden: Optional[int] = None):
14 | """
15 | :param num_layers: If num_hidden_layers == 0, then only use identity, for num_hidden_layers == 1, then only use linear
16 | :param d_in:
17 | :param d_hidden:
18 | :param d_out:
19 | :param use_bn
20 | """
21 | super(MLP, self).__init__()
22 | if act is None:
23 | act = nn.ReLU
24 | act = get_activation(act)
25 |
26 | if d_out is None:
27 | d_out = d_in
28 | if d_hidden is None:
29 | d_hidden = d_hidden_factor * d_out
30 | assert n_layers >= 0
31 | if n_layers == 0:
32 | assert d_in == d_out, f'If n_layers == 0, then d_in == d_out, but got {d_in} != {d_out}'
33 | self.layers = nn.Identity()
34 | else:
35 | current_dim_in = d_in
36 | layers = []
37 | for _ in range(n_layers - 1):
38 | layers.append(nn.Linear(current_dim_in, d_hidden, bias=not use_bn))
39 | if use_bn:
40 | layers.append(nn.BatchNorm1d(d_hidden))
41 | layers.append(act)
42 | if dropout > 0.0:
43 | layers.append(nn.Dropout(dropout))
44 | current_dim_in = d_hidden
45 | layers.append(nn.Linear(current_dim_in, d_out))
46 | if dropout_last_layer and dropout > 0.0:
47 | layers.append(nn.Dropout(dropout))
48 | self.layers = nn.Sequential(*layers)
49 |
50 | def forward(self, x: torch.Tensor) -> torch.Tensor:
51 | *dims, d = x.shape
52 | if len(dims) > 1:
53 | x = x.reshape(-1, d)
54 |
55 | x = self.layers(x)
56 | return x.view(*dims, -1)
57 |
--------------------------------------------------------------------------------
/cheff/sr/utils.py:
--------------------------------------------------------------------------------
1 | """Miscellaneous classes and functions."""
2 | from typing import Any, Optional
3 |
4 | import numpy as np
5 | from matplotlib import animation
6 | from matplotlib import pyplot as plt
7 | from torch import Tensor
8 | from torchvision.transforms import (
9 | Compose,
10 | Lambda,
11 | ToPILImage,
12 | )
13 | from torchvision.utils import make_grid
14 |
15 |
16 | def transform_tensor_to_img() -> Compose:
17 | """Transform a tensor with a single element to a PIL image."""
18 | return Compose(
19 | [
20 | Lambda(lambda t: t.detach().cpu()),
21 | Lambda(lambda t: (t + 1) / 2),
22 | Lambda(lambda t: t.permute(1, 2, 0)),
23 | Lambda(lambda t: t * 255.0),
24 | Lambda(lambda t: t.numpy().astype(np.uint8)),
25 | ToPILImage(),
26 | ]
27 | )
28 |
29 |
30 | def plot_image(
31 | img: Tensor,
32 | fig_size: Any = None,
33 | ncols: Optional[int] = None,
34 | show: bool = True,
35 | save_path: Optional[str] = None,
36 | ) -> None:
37 | """Plot a tensor containing image data."""
38 | img = img.detach().cpu()
39 |
40 | # Shape of 4 implies multiple image inputs
41 | if len(img.shape) == 4:
42 | img = make_grid(img, nrow=ncols if ncols is not None else len(img))
43 |
44 | plt.figure(figsize=fig_size)
45 | plt.imshow(img.permute(1, 2, 0))
46 | plt.axis('off')
47 |
48 | if save_path is not None:
49 | plt.savefig(save_path, bbox_inches='tight')
50 |
51 | if show:
52 | plt.show()
53 | plt.close()
54 |
55 |
56 | def make_gif(
57 | img_arr: Tensor,
58 | save_path: str,
59 | ) -> None:
60 | """Create a GIF with the output of DiffusionController.generate()."""
61 | assert len(img_arr) == 5, 'Array has wrong shape.'
62 |
63 | img_arr = img_arr.detach().cpu()
64 |
65 | fig = plt.figure(frameon=False)
66 | ims = []
67 |
68 | for img_t in img_arr:
69 | grid = make_grid(img_t, nrow=img_t.shape[0] // 2)
70 | im = plt.imshow(grid.permute(1, 2, 0), animated=True)
71 | plt.axis('off')
72 | plt.tight_layout()
73 | ims.append([im])
74 |
75 | fig.tight_layout()
76 |
77 | animate = animation.ArtistAnimation(
78 | fig, ims, interval=100, blit=True, repeat_delay=2000
79 | )
80 | animate.save(save_path)
81 |
--------------------------------------------------------------------------------
/src/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import gc
4 | import torch
5 | import numpy as np
6 | import random
7 | import torchvision.transforms as T
8 | import torch.nn.functional as F
9 | import yaml
10 | import shutil
11 |
12 |
13 | def set_seed(seed):
14 | torch.manual_seed(seed)
15 | torch.cuda.manual_seed_all(seed)
16 | np.random.seed(seed)
17 | random.seed(seed)
18 | torch.backends.cudnn.deterministic = True
19 | torch.backends.cudnn.benchmark = False
20 |
21 |
22 | def set_device(args):
23 | args.device = torch.device("cuda" if torch.cuda.is_available() and args.device == "cuda" else "cpu")
24 | print(f"Using device: {args.device}")
25 |
26 |
27 | def clear_mem(verbose=True):
28 | res = gc.collect(), torch.cuda.empty_cache()
29 | if verbose:
30 | print(f"Cleared memory: {res}")
31 | return res
32 |
33 |
34 | def create_experiment_folder(args):
35 | path = os.path.join(args.experiment.path, args.experiment.prefix + time.ctime().replace(" ", "_").replace(":", "-"))
36 | os.makedirs(path, exist_ok=True)
37 | args.out_dir = path
38 | print(f"Created experiment folder: {path}")
39 | return path
40 |
41 |
42 | def create_denormalize():
43 | mean = [0.5330] # CheXpert mean
44 | std = [0.0349] # CheXpert std
45 | denorm = T.Normalize(mean=[-m / s for m, s in zip(mean, std)], std=[1 / s for s in std])
46 |
47 | return denorm
48 |
49 |
50 | def create_normalize():
51 | mean = [0.5330] # CheXpert mean
52 | std = [0.0349] # CheXpert std
53 | t_norm = T.Normalize(mean=mean, std=std)
54 | return t_norm
55 |
56 |
57 | def clip_normalize():
58 | t = T.Compose([T.Resize((512, 512)), T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
59 |
60 | return t
61 |
62 |
63 | def create_blur(kernel_size=5, sigma=(2.0, 2.0)):
64 | return T.GaussianBlur(kernel_size, sigma)
65 |
66 |
67 | def create_mask_dilate():
68 | blur = create_blur()
69 |
70 | dilate = lambda mask, kernel_size: blur(
71 | F.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) // 2)
72 | ).view(mask.shape)
73 |
74 | return dilate
75 |
76 |
77 | def save_config_yaml(args):
78 | shutil.copyfile(args.config_path, os.path.join(args.out_dir, "config.yaml"))
79 |
80 |
81 | def save_txt(
82 | path,
83 | name,
84 | text,
85 | ):
86 | if not isinstance(text, list):
87 | text = [text]
88 |
89 | with open(os.path.join(path, f"{name}.txt"), "w") as f:
90 | f.write("\n".join(text))
91 |
--------------------------------------------------------------------------------
/chexzero/components/bbox_prediction.py:
--------------------------------------------------------------------------------
1 | from typing import List
2 | import einops
3 | import torch
4 | from torchvision.ops import batched_nms, box_convert
5 | from ensemble_boxes import weighted_boxes_fusion
6 |
7 |
8 | def clip_bboxes(box_params):
9 | box_params = box_convert(box_params, 'cxcywh', 'xyxy')
10 | box_params = box_params.clamp(0., 1.)
11 | box_params = box_convert(box_params, 'xyxy', 'cxcywh')
12 | return box_params
13 |
14 |
15 | def apply_top1_filtering(boxes: List[torch.Tensor]) -> List[torch.Tensor]:
16 | filtered_boxes = []
17 | for sample_boxes in boxes:
18 | labels = sample_boxes[:, 4]
19 | scores = sample_boxes[:, 5]
20 | unique_classes = torch.unique(labels)
21 | keep_inds = torch.stack([
22 | (torch.where(labels == cls, 1, 0) * scores).argmax()
23 | for cls in unique_classes
24 | ]) if len(unique_classes) > 0 else torch.zeros(0, dtype=torch.long)
25 | filtered_boxes.append(sample_boxes[keep_inds])
26 | return filtered_boxes
27 |
28 |
29 | def apply_top1_with_box_fusion(boxes: List[torch.Tensor]) -> List[torch.Tensor]:
30 | filtered_boxes = []
31 | for sample_boxes in boxes:
32 | boxes = sample_boxes[:, :4]
33 | boxes_upper_left = boxes[:, :2] - boxes[:, 2:] / 2
34 | boxes_lower_right = boxes[:, :2] + boxes[:, 2:] / 2
35 | labels = sample_boxes[:, 4]
36 | scores = sample_boxes[:, 5]
37 | unique_classes = torch.unique(labels)
38 | filtered_sample_boxes = []
39 | for c in unique_classes:
40 | cls_scores = scores[labels == c]
41 | if len(cls_scores) == 0:
42 | continue
43 | cls_boxes_upper_left = boxes_upper_left[labels == c]
44 | cls_boxes_lower_right = boxes_lower_right[labels == c]
45 | cls_fused_boxes = torch.cat([cls_boxes_upper_left.amin(dim=0), cls_boxes_lower_right.amax(dim=0)], dim=-1)
46 | wh = cls_fused_boxes[2:] - cls_fused_boxes[:2]
47 | cls_fused_boxes = torch.cat([cls_fused_boxes[:2] + wh / 2, wh], dim=-1)
48 | cls_top1_scores = cls_scores.amax()
49 | filtered_sample_boxes.append(torch.cat([cls_fused_boxes, c.unsqueeze(-1), cls_top1_scores.unsqueeze(-1)], dim=-1))
50 | filtered_boxes.append(torch.stack(filtered_sample_boxes) if len(filtered_sample_boxes) > 0 else torch.zeros(0, 6))
51 | return filtered_boxes
52 |
53 |
54 | def apply_nms(predicted_boxes: List[torch.Tensor], iou_threshold: float):
55 | predicted_boxes_after_nms = []
56 | for sample_boxes in predicted_boxes:
57 | boxes_coords = box_convert(sample_boxes[:, 0:4], in_fmt='cxcywh', out_fmt='xyxy')
58 | cls_idxs = sample_boxes[:, 4]
59 | scores = sample_boxes[:, 5]
60 | nms_indices = batched_nms(boxes_coords, scores, cls_idxs, iou_threshold=iou_threshold)
61 | predicted_boxes_after_nms.append(sample_boxes[nms_indices, :])
62 | return predicted_boxes_after_nms
63 |
64 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/ema.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class LitEma(nn.Module):
6 | def __init__(self, model, decay=0.9999, use_num_upates=True):
7 | super().__init__()
8 | if decay < 0.0 or decay > 1.0:
9 | raise ValueError('Decay must be between 0 and 1')
10 |
11 | self.m_name2s_name = {}
12 | self.register_buffer('decay', torch.tensor(decay, dtype=torch.float32))
13 | self.register_buffer('num_updates', torch.tensor(0,dtype=torch.int) if use_num_upates
14 | else torch.tensor(-1,dtype=torch.int))
15 |
16 | for name, p in model.named_parameters():
17 | if p.requires_grad:
18 | #remove as '.'-character is not allowed in buffers
19 | s_name = name.replace('.','')
20 | self.m_name2s_name.update({name:s_name})
21 | self.register_buffer(s_name,p.clone().detach().data)
22 |
23 | self.collected_params = []
24 |
25 | def forward(self,model):
26 | decay = self.decay
27 |
28 | if self.num_updates >= 0:
29 | self.num_updates += 1
30 | decay = min(self.decay,(1 + self.num_updates) / (10 + self.num_updates))
31 |
32 | one_minus_decay = 1.0 - decay
33 |
34 | with torch.no_grad():
35 | m_param = dict(model.named_parameters())
36 | shadow_params = dict(self.named_buffers())
37 |
38 | for key in m_param:
39 | if m_param[key].requires_grad:
40 | sname = self.m_name2s_name[key]
41 | shadow_params[sname] = shadow_params[sname].type_as(m_param[key])
42 | shadow_params[sname].sub_(one_minus_decay * (shadow_params[sname] - m_param[key]))
43 | else:
44 | assert not key in self.m_name2s_name
45 |
46 | def copy_to(self, model):
47 | m_param = dict(model.named_parameters())
48 | shadow_params = dict(self.named_buffers())
49 | for key in m_param:
50 | if m_param[key].requires_grad:
51 | m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data)
52 | else:
53 | assert not key in self.m_name2s_name
54 |
55 | def store(self, parameters):
56 | """
57 | Save the current parameters for restoring later.
58 | Args:
59 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
60 | temporarily stored.
61 | """
62 | self.collected_params = [param.clone() for param in parameters]
63 |
64 | def restore(self, parameters):
65 | """
66 | Restore the parameters stored with the `store` method.
67 | Useful to validate the model with EMA parameters without affecting the
68 | original optimization process. Store the parameters before the
69 | `copy_to` method. After validation (or model saving), use this to
70 | restore the former parameters.
71 | Args:
72 | parameters: Iterable of `torch.nn.Parameter`; the parameters to be
73 | updated with the stored parameters.
74 | """
75 | for c_param, param in zip(self.collected_params, parameters):
76 | param.data.copy_(c_param.data)
77 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/distributions/distributions.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 |
4 |
5 | class AbstractDistribution:
6 | def sample(self):
7 | raise NotImplementedError()
8 |
9 | def mode(self):
10 | raise NotImplementedError()
11 |
12 |
13 | class DiracDistribution(AbstractDistribution):
14 | def __init__(self, value):
15 | self.value = value
16 |
17 | def sample(self):
18 | return self.value
19 |
20 | def mode(self):
21 | return self.value
22 |
23 |
24 | class DiagonalGaussianDistribution(object):
25 | def __init__(self, parameters, deterministic=False):
26 | self.parameters = parameters
27 | self.mean, self.logvar = torch.chunk(parameters, 2, dim=1)
28 | self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
29 | self.deterministic = deterministic
30 | self.std = torch.exp(0.5 * self.logvar)
31 | self.var = torch.exp(self.logvar)
32 | if self.deterministic:
33 | self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
34 |
35 | def sample(self):
36 | x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
37 | return x
38 |
39 | def kl(self, other=None):
40 | if self.deterministic:
41 | return torch.Tensor([0.])
42 | else:
43 | if other is None:
44 | return 0.5 * torch.sum(torch.pow(self.mean, 2)
45 | + self.var - 1.0 - self.logvar,
46 | dim=[1, 2, 3])
47 | else:
48 | return 0.5 * torch.sum(
49 | torch.pow(self.mean - other.mean, 2) / other.var
50 | + self.var / other.var - 1.0 - self.logvar + other.logvar,
51 | dim=[1, 2, 3])
52 |
53 | def nll(self, sample, dims=[1,2,3]):
54 | if self.deterministic:
55 | return torch.Tensor([0.])
56 | logtwopi = np.log(2.0 * np.pi)
57 | return 0.5 * torch.sum(
58 | logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
59 | dim=dims)
60 |
61 | def mode(self):
62 | return self.mean
63 |
64 |
65 | def normal_kl(mean1, logvar1, mean2, logvar2):
66 | """
67 | source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12
68 | Compute the KL divergence between two gaussians.
69 | Shapes are automatically broadcasted, so batches can be compared to
70 | scalars, among other use cases.
71 | """
72 | tensor = None
73 | for obj in (mean1, logvar1, mean2, logvar2):
74 | if isinstance(obj, torch.Tensor):
75 | tensor = obj
76 | break
77 | assert tensor is not None, "at least one argument must be a Tensor"
78 |
79 | # Force variances to be Tensors. Broadcasting helps convert scalars to
80 | # Tensors, but it does not work for torch.exp().
81 | logvar1, logvar2 = [
82 | x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor)
83 | for x in (logvar1, logvar2)
84 | ]
85 |
86 | return 0.5 * (
87 | -1.0
88 | + logvar2
89 | - logvar1
90 | + torch.exp(logvar1 - logvar2)
91 | + ((mean1 - mean2) ** 2) * torch.exp(-logvar2)
92 | )
93 |
--------------------------------------------------------------------------------
/chexzero/util/data_utils.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Any, Callable, Mapping, Sequence
3 | from PIL import Image
4 | import torch
5 |
6 | def load_pil_gray(path: str) -> Image.Image:
7 | return Image.open(path).convert('L')
8 |
9 |
10 | class TensorDataclassMixin:
11 | def __init__(self):
12 | super(TensorDataclassMixin, self).__init__()
13 | assert dataclasses.is_dataclass(self), f'{type(self)} has to be a dataclass to use TensorDataclassMixin'
14 |
15 | def apply(self, tensor_fn: Callable[[torch.Tensor], torch.Tensor], ignore=None, apply_to_list=False):
16 | if ignore is None and hasattr(self, 'IGNORE_APPLY'):
17 | ignore = self.IGNORE_APPLY
18 | def apply_to_value(value):
19 | if value is None:
20 | return None
21 | elif isinstance(value, torch.Tensor):
22 | return tensor_fn(value)
23 | elif isinstance(value, list):
24 | if apply_to_list:
25 | return tensor_fn(value)
26 | else:
27 | return [apply_to_value(el) for el in value]
28 | elif isinstance(value, tuple):
29 | return tuple(apply_to_value(el) for el in value)
30 | elif isinstance(value, dict):
31 | return {key: apply_to_value(el) for key, el in value.items()}
32 | elif isinstance(value, TensorDataclassMixin):
33 | return value.apply(tensor_fn)
34 | else:
35 | return value
36 |
37 | def apply_to_field(field: dataclasses.Field):
38 | value = getattr(self, field.name)
39 | if ignore is not None and field.name in ignore:
40 | return value
41 | else:
42 | try:
43 | return apply_to_value(value)
44 | except Exception as e:
45 | raise RuntimeError(f'Error while applying {tensor_fn} to {field.name} ({value})') from e
46 |
47 | return self.__class__(**{field.name: apply_to_field(field) for field in dataclasses.fields(self)})
48 |
49 | def to(self, device, *args, non_blocking=True, **kwargs):
50 | return self.apply(lambda x: x.to(device, *args, non_blocking=non_blocking, **kwargs))
51 |
52 | def view(self, *args):
53 | return self.apply(lambda x: x.view(*args))
54 |
55 | def detach(self):
56 | return self.apply(lambda x: x.detach())
57 |
58 | def unsqueeze(self, dim):
59 | return self.apply(lambda x: x.unsqueeze(dim))
60 |
61 | def squeeze(self, dim):
62 | return self.apply(lambda x: x.squeeze(dim))
63 |
64 | def __getitem__(self, *args):
65 | return self.apply(lambda x: x.__getitem__(*args), apply_to_list=True)
66 |
67 | def to_dict(self):
68 | return dataclasses.asdict(self)
69 |
70 |
71 | def to_device(data: Any, device: str, non_blocking=True):
72 | if data is None:
73 | return None
74 | if isinstance(data, torch.Tensor):
75 | if device == 'cpu':
76 | non_blocking = False
77 | return data.to(device, non_blocking=non_blocking)
78 | elif isinstance(data, Mapping):
79 | return {key: to_device(data[key], device, non_blocking=non_blocking) for key in data}
80 | elif isinstance(data, Sequence) and not isinstance(data, str):
81 | return [to_device(d, device, non_blocking=non_blocking) for d in data]
82 | elif isinstance(data, str):
83 | return data
84 | elif isinstance(data, TensorDataclassMixin):
85 | if device == 'cpu':
86 | non_blocking = False
87 | return data.to(device=device, non_blocking=non_blocking)
88 | else:
89 | raise TypeError(type(data))
90 |
--------------------------------------------------------------------------------
/cheff/sr/sampler.py:
--------------------------------------------------------------------------------
1 | import os
2 | from typing import Union, Tuple
3 |
4 | from PIL import Image
5 | import torch
6 | from torch import Tensor
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision.transforms import InterpolationMode
9 | from torchvision.transforms.functional import resize, to_tensor, to_grayscale
10 | from torchvision.utils import save_image
11 |
12 | from cheff.sr.model import Unet
13 | from cheff.sr.schedule import ScheduleFactory
14 | from cheff.sr.diffusor import SR3Diffusor, SR3DDIMDiffusor
15 |
16 |
17 | class CheffSRModel:
18 | def __init__(
19 | self,
20 | model_path: str,
21 | device: Union[str, int, torch.device] = 'cuda'
22 | ) -> None:
23 | self.device = device
24 | self.model = Unet(
25 | dim=16,
26 | channels=2,
27 | out_dim=1,
28 | dim_mults=(1, 2, 4, 8, 16, 32, 32, 32),
29 | )
30 |
31 | state_dict = torch.load(model_path, map_location='cpu')
32 | self.model.load_state_dict(state_dict['model'])
33 | self.model.to(self.device)
34 | self.model.eval()
35 |
36 | self.schedule = ScheduleFactory.get_schedule(
37 | name='cosine', timesteps=2000, device=self.device)
38 |
39 | def sample_directory(
40 | self,
41 | source_dir: str,
42 | target_dir: str,
43 | batch_size: int = 1,
44 | method: str = 'ddim',
45 | sampling_steps: int = 100,
46 | eta: float = 0.
47 | ) -> None:
48 | ds = DirectoryDataset(source_dir)
49 | loader = DataLoader(ds, batch_size=batch_size, pin_memory=True)
50 |
51 | os.makedirs(target_dir, exist_ok=True)
52 |
53 | for f_names, imgs in loader:
54 | imgs_sr = self.sample(imgs, method, sampling_steps, eta)
55 |
56 | for f_name, img_sr in zip(f_names, imgs_sr):
57 | path = os.path.join(target_dir, f_name)
58 | save_image(img_sr, path)
59 |
60 | def sample_path(
61 | self,
62 | path: str,
63 | method: str = 'ddim',
64 | sampling_steps: int = 100,
65 | eta: float = 0.
66 | ) -> Tensor:
67 | img = Image.open(path)
68 | img = to_tensor(to_grayscale(img)).unsqueeze(0)
69 | return self.sample(img, method, sampling_steps, eta)
70 |
71 | @torch.no_grad()
72 | def sample(
73 | self,
74 | img: Tensor,
75 | method: str = 'ddim',
76 | sampling_steps: int = 100,
77 | eta: float = 0.
78 | ) -> Tensor:
79 | img = img.to(self.device)
80 | img = img * 2 - 1
81 | img = resize(img, [1024, 1024], InterpolationMode.BICUBIC)
82 |
83 | if method == 'ddim':
84 | diffusor = SR3DDIMDiffusor(
85 | model=self.model,
86 | schedule=self.schedule,
87 | sampling_steps=sampling_steps,
88 | eta=eta
89 | )
90 | else:
91 | diffusor = SR3Diffusor(
92 | model=self.model,
93 | schedule=self.schedule,
94 | )
95 |
96 | img_sr = diffusor.p_sample_loop(sr=img)
97 | img_sr.clamp_(-1, 1)
98 | img_sr = (img_sr + 1) / 2
99 | return img_sr
100 |
101 |
102 | class DirectoryDataset(Dataset):
103 | def __init__(self, root: str) -> None:
104 | self.root = root
105 | self.files = os.listdir(root)
106 |
107 | def __len__(self):
108 | return len(self.files)
109 |
110 | def __getitem__(self, idx: int) -> Tuple[str, Tensor]:
111 | fp = os.path.join(self.root, self.files[idx])
112 |
113 | img = Image.open(fp)
114 | img = to_tensor(to_grayscale(img))
115 | return self.files[idx], img
116 |
--------------------------------------------------------------------------------
/chexzero/util/prompt_utils.py:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | from itertools import chain
5 | from typing import List
6 |
7 | import torch
8 |
9 |
10 | def flatten_prompts(prompts: List[List[str]], device):
11 | flattened_prompts: List[str] = [prompt for sub_prompts in prompts if sub_prompts is not None for prompt in sub_prompts]
12 | N = len(prompts)
13 | M_is: List[int] = [len(sub_prompts) if sub_prompts is not None else 0 for sub_prompts in prompts]
14 | # (N x max(M_i))
15 | prompt_mask = torch.zeros(N, max(M_is) if len(M_is) > 0 else 0, dtype=torch.bool, device=device)
16 | for i, M_i in enumerate(M_is):
17 | prompt_mask[i, :M_i] = True
18 | return flattened_prompts, prompt_mask
19 |
20 | def flatten_prompts_2(prompts: List[List[List[str]]], device):
21 | flattened_prompts: List[str] = [prompt for sub_prompts in prompts for sub_sub_prompts in sub_prompts for prompt in sub_sub_prompts]
22 | N = len(prompts)
23 | M_is: List[int] = [len(sub_prompts) for sub_prompts in prompts]
24 | L_is_ms: List[List[int]] = [[len(sub_sub_prompts) for sub_sub_prompts in sub_prompts]
25 | for sub_prompts in prompts]
26 | max_L = max(max(L_i_m) for L_i_m in L_is_ms)
27 | prompt_mask = torch.zeros(N, max(M_is), max_L, dtype=torch.bool, device=device)
28 | for i, L_i_ms in enumerate(L_is_ms):
29 | for m, L_i_m in enumerate(L_i_ms):
30 | prompt_mask[i, m, :L_i_m] = True
31 |
32 | return flattened_prompts, prompt_mask
33 |
34 |
35 | def apply_placeholder(prompt: str, placeholder: str, replacements: List[str]) -> List[str]:
36 | placeholder = '{' + placeholder + '}'
37 | if placeholder not in prompt:
38 | return [prompt]
39 | else:
40 | return [prompt.replace(placeholder, replacement) for replacement in replacements]
41 |
42 |
43 | def fill_prompt_templates(prompts: List[str]) -> List[str]:
44 | for placeholder, replacements in template_placeholders.items():
45 | prompts = list(chain(*[apply_placeholder(prompt, placeholder, replacements) for prompt in prompts]))
46 | assert not any('{' in prompt for prompt in prompts), [prompt for prompt in prompts if '{' in prompt]
47 | return prompts
48 |
49 | def localized_prompt_templates(prompts: List[str], region_templates: List[str]) -> List[List[str]]:
50 | prompts = [
51 | [reg_template.format(prompt) for prompt in prompts]
52 | for reg_template in region_templates
53 | ]
54 | assert not any('{' in prompt for templ in prompts for prompt in templ)
55 | return prompts
56 |
57 |
58 | template_placeholders = {
59 | 'normal_adj': ['normal', 'unremarkable', 'clear'],
60 | 'shape_noun': ['size', 'silhouette', 'area', 'contours'],
61 | 'shape_adj': ['round'],
62 | 'state_verb': ['appears', 'is', 'are', 'remains', 'remain', 'appear', 'exists', 'exist'],
63 | 'indication_noun': ['signs', 'evidence', 'case of', 'presence', 'findings', 'suspicious findings'],
64 | 'indication_adj': ['noticeable', 'visible', 'seen', 'appearent', 'observable'],
65 | 'indication_verb': ['indicates', 'suggests', 'suggesting', 'indicating', 'consistent with'],
66 | 'passive_indication_verb': ['can be identified', 'can be seen', 'is present', 'is noted'],
67 | 'unchanged': ['has not improved', 'unchanged', 'remains'],
68 | 'limits_noun': ['limits'],
69 | 'moderate_adj': ['mild', 'moderate', 'extensive', 'small', 'slight', 'stable', 'intact', 'mild-moderate',],
70 | 'strong_adj': [
71 | 'large', 'significant', 'acute', 'widespread', 'relevant', 'difficult', 'apparent', 'prominent',
72 | 'convincing', 'extensive', 'severe', 'critical', 'altered', 'patchy', 'degenerative', 'substantial',
73 | 'predominant', 'massive', 'noticeable'],
74 | 'increased_adj': ['elevated', 'enlarged', 'increased', 'larger', 'large', 'widened'],
75 | 'size_noun': ['enlargement'],
76 | 'visible_adj': ['visible', 'seen', 'appearent'],
77 | 'relation_adj': ['regarding',' relating to', 'concerning', 'involving'],
78 | 'support_dev_noun': ['catheter', 'tubes', 'support device', 'monitoring device', 'wires', 'pacemaker'],
79 | 'change': ['little change', 'unchanged',],
80 | 'lung_adj': ['pulmonary','pul', 'lung', 'airspace']
81 | }
--------------------------------------------------------------------------------
/cheff/ldm/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | class LambdaWarmUpCosineScheduler:
5 | """
6 | note: use with a base_lr of 1.0
7 | """
8 | def __init__(self, warm_up_steps, lr_min, lr_max, lr_start, max_decay_steps, verbosity_interval=0):
9 | self.lr_warm_up_steps = warm_up_steps
10 | self.lr_start = lr_start
11 | self.lr_min = lr_min
12 | self.lr_max = lr_max
13 | self.lr_max_decay_steps = max_decay_steps
14 | self.last_lr = 0.
15 | self.verbosity_interval = verbosity_interval
16 |
17 | def schedule(self, n, **kwargs):
18 | if self.verbosity_interval > 0:
19 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_lr}")
20 | if n < self.lr_warm_up_steps:
21 | lr = (self.lr_max - self.lr_start) / self.lr_warm_up_steps * n + self.lr_start
22 | self.last_lr = lr
23 | return lr
24 | else:
25 | t = (n - self.lr_warm_up_steps) / (self.lr_max_decay_steps - self.lr_warm_up_steps)
26 | t = min(t, 1.0)
27 | lr = self.lr_min + 0.5 * (self.lr_max - self.lr_min) * (
28 | 1 + np.cos(t * np.pi))
29 | self.last_lr = lr
30 | return lr
31 |
32 | def __call__(self, n, **kwargs):
33 | return self.schedule(n,**kwargs)
34 |
35 |
36 | class LambdaWarmUpCosineScheduler2:
37 | """
38 | supports repeated iterations, configurable via lists
39 | note: use with a base_lr of 1.0.
40 | """
41 | def __init__(self, warm_up_steps, f_min, f_max, f_start, cycle_lengths, verbosity_interval=0):
42 | assert len(warm_up_steps) == len(f_min) == len(f_max) == len(f_start) == len(cycle_lengths)
43 | self.lr_warm_up_steps = warm_up_steps
44 | self.f_start = f_start
45 | self.f_min = f_min
46 | self.f_max = f_max
47 | self.cycle_lengths = cycle_lengths
48 | self.cum_cycles = np.cumsum([0] + list(self.cycle_lengths))
49 | self.last_f = 0.
50 | self.verbosity_interval = verbosity_interval
51 |
52 | def find_in_interval(self, n):
53 | interval = 0
54 | for cl in self.cum_cycles[1:]:
55 | if n <= cl:
56 | return interval
57 | interval += 1
58 |
59 | def schedule(self, n, **kwargs):
60 | cycle = self.find_in_interval(n)
61 | n = n - self.cum_cycles[cycle]
62 | if self.verbosity_interval > 0:
63 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
64 | f"current cycle {cycle}")
65 | if n < self.lr_warm_up_steps[cycle]:
66 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
67 | self.last_f = f
68 | return f
69 | else:
70 | t = (n - self.lr_warm_up_steps[cycle]) / (self.cycle_lengths[cycle] - self.lr_warm_up_steps[cycle])
71 | t = min(t, 1.0)
72 | f = self.f_min[cycle] + 0.5 * (self.f_max[cycle] - self.f_min[cycle]) * (
73 | 1 + np.cos(t * np.pi))
74 | self.last_f = f
75 | return f
76 |
77 | def __call__(self, n, **kwargs):
78 | return self.schedule(n, **kwargs)
79 |
80 |
81 | class LambdaLinearScheduler(LambdaWarmUpCosineScheduler2):
82 |
83 | def schedule(self, n, **kwargs):
84 | cycle = self.find_in_interval(n)
85 | n = n - self.cum_cycles[cycle]
86 | if self.verbosity_interval > 0:
87 | if n % self.verbosity_interval == 0: print(f"current step: {n}, recent lr-multiplier: {self.last_f}, "
88 | f"current cycle {cycle}")
89 |
90 | if n < self.lr_warm_up_steps[cycle]:
91 | f = (self.f_max[cycle] - self.f_start[cycle]) / self.lr_warm_up_steps[cycle] * n + self.f_start[cycle]
92 | self.last_f = f
93 | return f
94 | else:
95 | f = self.f_min[cycle] + (self.f_max[cycle] - self.f_min[cycle]) * (self.cycle_lengths[cycle] - n) / (self.cycle_lengths[cycle])
96 | self.last_f = f
97 | return f
98 |
99 |
--------------------------------------------------------------------------------
/models/clip.py:
--------------------------------------------------------------------------------
1 | import chexzero.chexzero.clip
2 | from chexzero.chexzero.clip import tokenize
3 | import torch
4 | from torch.nn import functional as F
5 | from torchvision import transforms
6 |
7 |
8 | class CLIPEvaluator:
9 | """
10 | https://huggingface.co/StanfordAIMI/XrayCLIP__vit-b-16__laion2b-s34b-b88k?library=transformers
11 | """
12 |
13 | def __init__(self, args, model=None, preprocess=None):
14 | self.args = args
15 | self.model_path = args.model.clip.path
16 | self.device = args.device
17 |
18 | self.model = model
19 | self.preprocess = preprocess
20 |
21 | if self.model is None and self.preprocess is None:
22 | print("Loading CLIP model for CLIPEvaluator...")
23 | self.model, self.preprocess = chexzero.clip.load("ViT-B/32", device="cpu", jit=False)
24 | self.model.load_state_dict(torch.load(self.model_path, map_location=self.device))
25 |
26 | self.model = self.model.to(self.device)
27 | self.model.eval()
28 |
29 | self.input_size = self.model.visual.input_resolution
30 |
31 | def clip_transform_for_tensor(self, tensor: torch.Tensor, target_size=None):
32 | if target_size is None:
33 | target_size = self.input_size
34 |
35 | resize = transforms.Resize((target_size, target_size), interpolation=transforms.InterpolationMode.BICUBIC)
36 |
37 | resized = resize(tensor)
38 |
39 | # mean = [0.48145466, 0.4578275, 0.40821073]
40 | # std = [0.26862954, 0.26130258, 0.27577711]
41 |
42 | normalize = transforms.Normalize((101.48761, 101.48761, 101.48761), (83.43944, 83.43944, 83.43944))
43 | normalize_2 = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
44 |
45 | # return resized
46 | return normalize_2(resized)
47 |
48 | def score(self, images: torch.Tensor, texts=None, preprocess_images=True):
49 | assert texts is not None and texts != [], "Texts must be provided for scoring."
50 |
51 | tokenized_text = tokenize(texts).to(self.device)
52 | text_features = self.model.encode_text(tokenized_text)
53 |
54 | if preprocess_images:
55 | images = self.clip_transform_for_tensor(images)
56 |
57 | image_features = self.model.encode_image(images)
58 |
59 | # normalized features
60 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
61 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
62 |
63 | # cosine similarity as logits
64 | logit_scale = self.model.logit_scale.exp()
65 | logits_per_image = logit_scale * image_features @ text_features.t()
66 | logits_per_text = logit_scale * text_features @ image_features.t()
67 |
68 | # shape = [global_batch_size, global_batch_size]
69 | return logits_per_image, logits_per_text
70 |
71 | def encode_text(self, texts):
72 | if isinstance(texts, str):
73 | texts = [texts]
74 | tokenized_text = tokenize(texts).to(self.device)
75 | return self.model.encode_text(tokenized_text)
76 |
77 | def score_single(self, image: torch.Tensor, text: str, preprocess_image=True, text_features=None):
78 | if text_features is None:
79 | text_features = self.encode_text(text)
80 |
81 | # Preprocess the image if necessary
82 | if preprocess_image:
83 | # image already has batch dimension
84 | image = self.clip_transform_for_tensor(image) # add batch dimension
85 |
86 | # Extract image features
87 | image_features = self.model.encode_image(image)
88 |
89 | # Normalize the features
90 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
91 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
92 |
93 | # Compute cosine similarity between the single image and text
94 | similarity = (image_features @ text_features.T).squeeze()
95 |
96 | # Rescale similarity from [-1, 1] to [0, 1]
97 | score = (similarity + 1) / 2
98 |
99 | return score.item()
100 |
--------------------------------------------------------------------------------
/cheff/machex.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | from typing import Dict, Optional
4 |
5 | import torch
6 | from PIL import Image
7 | from torch.utils.data import Dataset, ConcatDataset
8 | from torchvision.transforms import Compose, ToTensor
9 |
10 |
11 | class ChestXrayDataset(Dataset):
12 | """Class for handling datasets in the MaCheX composition."""
13 |
14 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
15 | """Initialize ChestXrayDataset."""
16 | self.root = root
17 | json_path = os.path.join(self.root, 'index.json')
18 | self.index_dict = ChestXrayDataset._load_json(json_path)
19 |
20 | self.keys = list(self.index_dict.keys())
21 |
22 | if transforms is None:
23 | self.transforms = ToTensor()
24 | else:
25 | self.transforms = transforms
26 |
27 | @staticmethod
28 | def _load_json(file_path: str) -> Dict:
29 | """Load a json file as dictionary."""
30 | with open(file_path, 'r') as f:
31 | return json.load(f)
32 |
33 | def __len__(self):
34 | """Return length of the dataset."""
35 | return len(self.keys)
36 |
37 | def __getitem__(self, idx: int) -> Dict:
38 | """Get dataset element."""
39 | meta = self.index_dict[self.keys[idx]]
40 |
41 | img = Image.open(meta['path'])
42 | img = self.transforms(img)
43 |
44 | return {'img': img}
45 |
46 |
47 | class MaCheXDataset(Dataset):
48 | """Massive chest X-ray dataset."""
49 |
50 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
51 | """Initialize MaCheXDataset"""
52 | self.root = root
53 | sub_dataset_roots = os.listdir(self.root)
54 | datasets = [
55 | ChestXrayDataset(root=os.path.join(root, r), transforms=transforms)
56 | for r in sub_dataset_roots
57 | ]
58 | self.ds = ConcatDataset(datasets)
59 |
60 | def __len__(self):
61 | """Return length of the dataset."""
62 | return len(self.ds)
63 |
64 | def __getitem__(self, idx: int) -> Dict:
65 | """Get dataset element."""
66 | return self.ds[idx]
67 |
68 |
69 | class MimicT2IDataset(ChestXrayDataset):
70 | """Mimic subset with reports."""
71 |
72 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
73 | root = os.path.join(root, 'mimic')
74 | super().__init__(root, transforms)
75 |
76 | def __getitem__(self, idx: int) -> Dict:
77 | """Get dataset element."""
78 | meta = self.index_dict[self.keys[idx]]
79 |
80 | img = Image.open(meta['path'])
81 | img = self.transforms(img)
82 |
83 | return {'img': img, 'caption': meta['report']}
84 |
85 |
86 | class LabelChestXrayDataset(ChestXrayDataset):
87 | """A Chest X-ray dataset that returns class labels."""
88 |
89 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
90 | super().__init__(root, transforms)
91 | keys = []
92 | for key in self.keys:
93 | if self.index_dict[key].get('class_label') is not None:
94 | keys.append(key)
95 | self.keys = keys
96 |
97 | def __getitem__(self, idx: int) -> Dict:
98 | """Get dataset element."""
99 | meta = self.index_dict[self.keys[idx]]
100 |
101 | img = Image.open(meta['path'])
102 | img = self.transforms(img)
103 |
104 | return {'img': img, 'class_label': torch.tensor(meta['class_label']).float()}
105 |
106 |
107 | class CombinedLabelChestXrayDataset(Dataset):
108 | def __init__(self, root: str, transforms: Optional[Compose] = None) -> None:
109 | """Initialize MaCheXDataset"""
110 | self.root = root
111 | sub_dataset_roots = ['mimic', 'chexpert']
112 | datasets = [
113 | LabelChestXrayDataset(root=os.path.join(root, r), transforms=transforms)
114 | for r in sub_dataset_roots
115 | ]
116 | self.ds = ConcatDataset(datasets)
117 |
118 | def __len__(self):
119 | """Return length of the dataset."""
120 | return len(self.ds)
121 |
122 | def __getitem__(self, idx: int) -> Dict:
123 | """Get dataset element."""
124 | return self.ds[idx]
125 |
126 |
--------------------------------------------------------------------------------
/chexzero/components/soft_roi_pool.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from typing import Tuple, Union
3 | from torch import FloatTensor, Tensor, nn
4 | import torch
5 | from torch import autocast
6 |
7 | class SoftRoiPool(nn.Module):
8 | def forward(self, x: Tensor, box_params: Tensor, beta: Union[float, FloatTensor] = 2.) -> Tuple[Tensor, Tensor]:
9 | """
10 | :param x: Featue map of image (N x H x W x d)
11 | :param roi_params: Parameters of bounding boxes (N x Q x 4) for Q boxes.
12 | format: (x_c, y_c, w, h) each in the range of [0, 1] (relative to image size)
13 | :return roi_features, roi_maps
14 | - roi_features: Pooled features of each roi (N x Q x d)
15 | - roi_maps: Attention maps of each roi (N x Q x H x W)
16 | """
17 | N, H, W, d = x.shape
18 |
19 | # Compute kernel on sampling grid
20 | sampled_grid = get_sample_grid(H, W, device=x.device, dtype=x.dtype) # (H x W x 2)
21 | roi_patch_map = separable_generalized_gaussian_pdf(box_params.to(dtype=x.dtype), sampled_grid, beta=beta) # (N x Q x H x W)
22 |
23 | # with autocast(device_type='cuda', enabled=False):
24 | # Batched matrix multiplication and normalize
25 | roi_features = torch.einsum('nqhw,nhwd->nqd', roi_patch_map.float(), x.float()) # (N x Q x d)
26 | roi_features = roi_features / H * W
27 |
28 | return roi_features, roi_patch_map
29 |
30 |
31 | def get_sample_grid(H: int, W: int, device, dtype) -> Tensor:
32 | # (H x W)
33 | y, x = torch.meshgrid(torch.arange(H, device=device, dtype=dtype),
34 | torch.arange(W, device=device, dtype=dtype),
35 | indexing='ij')
36 | # (H x W x 2)
37 | sampled_grid = torch.stack([x, y], dim=-1)
38 | # consider pixel centers instead of left-upper position
39 | sampled_grid += 0.5
40 | # normalize positions into range [0, 1]
41 | sampled_grid[:, :, 0] /= W
42 | sampled_grid[:, :, 1] /= H
43 | return sampled_grid
44 |
45 |
46 | def generalized_gauss_1d_log_pdf(mu: Tensor, sigma: Tensor, sampled_grid: Tensor,
47 | beta: Union[float, FloatTensor]) -> Tensor:
48 | """
49 | :param mu: (N x K)
50 | :param sigma: (N x K)
51 | :param sampled_grid: Sampled points (P) where P is the number of sampled points of the Gaussian pdf
52 | :return (N x K x P)
53 | """
54 | assert len(sampled_grid.shape) == 1
55 | assert len(mu.shape) == 2
56 | assert len(sigma.shape) == 2
57 |
58 | if not isinstance(beta, (float, int)):
59 | assert isinstance(beta, Tensor)
60 | if beta.numel() > 1:
61 | assert beta.shape == mu.shape
62 | beta = beta[:, :, None]
63 |
64 | # (unnormalized) log pdf = -0.5*((x-mu)/sigma)^2
65 | # log_pdf = - (1 / beta) * (
66 | # (sampled_grid[None, None] - mu[:, :, None]) / sigma[:, :, None]
67 | # ).pow(beta)
68 | log_pdf = -(
69 | (sampled_grid[None, None] - mu[:, :, None]).abs() / sigma[:, :, None]
70 | ).pow(beta)
71 | return log_pdf
72 |
73 |
74 | def separable_generalized_gaussian_pdf(box_params: Tensor, sampled_grid: Tensor,
75 | beta: Union[float, FloatTensor]) -> Tensor:
76 | """
77 | :param box_params: (N x Q x 4)
78 | :param sampled_grid: (... x 2)
79 | :return: (N x Q x ...)
80 | """
81 | N, K, _ = box_params.shape
82 | *dims, _ = sampled_grid.shape
83 | sampled_grid = sampled_grid.view(-1, 2) # (... x 2)
84 | mu = box_params[:, :, :2] # (N x K x 2)
85 | sigma = box_params[:, :, 2:] # (N x K x 2)
86 | # compute x and y Gaussian pdf's independently (in log-space and non-normalized)
87 | log_scores_x = generalized_gauss_1d_log_pdf(mu[..., 0], sigma[..., 0],
88 | sampled_grid[..., 0], beta) # (N x K x ...)
89 | log_scores_y = generalized_gauss_1d_log_pdf(mu[..., 1], sigma[..., 1],
90 | sampled_grid[..., 1], beta) # (N x K x ...)
91 | # combine them in log space (multiplication in prob space)
92 | log_scores = log_scores_x + log_scores_y # (N x K x ...)
93 |
94 | # Normalize to max value = 1
95 | scores = torch.exp(log_scores)
96 | probs = scores / (scores.max(-1, keepdim=True).values + 1e-12)
97 |
98 | # Alternative: convert to probs by applying exp and then normalizing to sum equals 1 == softmax
99 | # probs = torch.softmax(log_scores, dim=-1) # (N x K x ...)
100 |
101 | return probs.view(N, K, *dims)
102 |
103 |
--------------------------------------------------------------------------------
/chexzero/txt_encoder/chexzero_txt_encoder.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import os
3 | from typing import List
4 | from torch import Tensor
5 | from torch import nn
6 | import torch
7 | import torch.nn.functional as F
8 |
9 | # from model.components.mlp import MLP
10 | from chexzero.components.mlp import MLP
11 |
12 | from chexzero.util.model_utils import BaseModel, BaseModelConfig, MainModelConfig
13 | import chexzero.clip
14 |
15 |
16 | @dataclass
17 | class ChexzeroTextEncoderConfig(BaseModelConfig):
18 | model_path: str = os.path.expanduser(
19 | "~/models/third_party/chexzero/CheXzero_Models/best_64_5e-05_original_22000_0.864.pt"
20 | )
21 | frozen_language_model: bool = True
22 |
23 | # Additional projection layers (after the language model projection)
24 | # 0 = no projection, 1 = linear, 2 = one hidden layer
25 | n_projection_layers: int = 0
26 | # whether to use batch norm in the addtional projection layers
27 | projection_bn: bool = False
28 | normalize_projected: bool = False
29 |
30 |
31 | class ChexzeroTextEncoder(BaseModel):
32 | CONFIG_CLS = ChexzeroTextEncoderConfig
33 | MODIFYABLE_CONFIGS = ("frozen_backbone",)
34 |
35 | def __init__(self, config: ChexzeroTextEncoderConfig, main_config: MainModelConfig):
36 | super().__init__(config)
37 | self.config: ChexzeroTextEncoderConfig
38 |
39 | self.d = main_config.d_model
40 |
41 | model, _ = chexzero.clip.load("ViT-B/32", device="cpu", jit=False)
42 | model.load_state_dict(torch.load(self.config.model_path, map_location="cpu"))
43 | self.d = main_config.d_model
44 |
45 | self.transformer = model.transformer
46 | self.token_embedding = model.token_embedding
47 | self.positional_embedding = model.positional_embedding
48 | self.ln_final = model.ln_final
49 | self.text_projection = model.text_projection
50 |
51 | if self.config.frozen_language_model:
52 | for param in self.parameters():
53 | param.requires_grad = False
54 |
55 | d_backbone = self.text_projection.shape[1]
56 |
57 | self.projection = MLP(
58 | self.config.n_projection_layers,
59 | d_in=d_backbone,
60 | d_out=self.d,
61 | use_bn=self.config.projection_bn,
62 | act=main_config.act,
63 | dropout=main_config.dropout,
64 | )
65 |
66 | self.cached_sentence_embeddings = {}
67 |
68 | @property
69 | def dtype(self):
70 | return self.token_embedding.weight.dtype
71 |
72 | @property
73 | def device(self):
74 | return self.token_embedding.weight.device
75 |
76 | def forward(self, input_ids: torch.Tensor, project: bool = True, **kwargs) -> Tensor:
77 |
78 | input_ids = input_ids.to(device=self.device)
79 |
80 | # Encode image using backbone
81 | with torch.set_grad_enabled(not self.config.frozen_language_model):
82 | x = self.token_embedding(input_ids).type(self.dtype) # [batch_size, n_ctx, d_model]
83 |
84 | x = x + self.positional_embedding.type(self.dtype)
85 | x = x.permute(1, 0, 2) # NLD -> LND
86 | x = self.transformer(x)
87 | x = x.permute(1, 0, 2) # LND -> NLD
88 | x = self.ln_final(x).type(self.dtype)
89 |
90 | # x.shape = [batch_size, n_ctx, transformer.width]
91 | # take features from the eot embedding (eot_token is the highest number in each sequence)
92 | # (N_sentences x d)
93 | features = x[torch.arange(x.shape[0]), input_ids.argmax(dim=-1)] @ self.text_projection
94 |
95 | if project:
96 | features = self.projection(features)
97 | return features
98 |
99 | def encode_sentences(self, sentences: List[str], cache=False, **kwargs) -> Tensor:
100 | if cache and self.config.frozen_language_model and all(s in self.cached_sentence_embeddings for s in sentences):
101 | features = torch.stack([self.cached_sentence_embeddings[s] for s in sentences], dim=0)
102 | else:
103 | input_ids = chexzero.clip.tokenize(sentences, context_length=77)
104 | features = self(input_ids, project=False, **kwargs)
105 |
106 | if cache and self.config.frozen_language_model:
107 | for s, f in zip(sentences, features):
108 | self.cached_sentence_embeddings[s] = f.detach().float()
109 | features = features.to(dtype=self.dtype)
110 | projected_features = self.projection(features)
111 | return projected_features
112 |
--------------------------------------------------------------------------------
/models/classifier.py:
--------------------------------------------------------------------------------
1 | from torchvision.models import densenet121, resnet152
2 | from chexpert.models.efficientnet import construct_model
3 | from chexpert.models.attn_aug_conv import DenseNet, ResNet, Bottleneck
4 | from chexpert.dataset import ChexpertSmall, extract_patient_ids
5 | from torchvision.models.densenet import DenseNet121_Weights
6 |
7 | import torch
8 | import torch.nn as nn
9 | import numpy as np
10 | import os
11 |
12 |
13 | def load_model(args):
14 |
15 | classifier_args = args.model.classifier
16 | model_name, restore, restore_path, lr, pretrained = (
17 | classifier_args.model.strip().lower(),
18 | classifier_args.restore,
19 | classifier_args.path,
20 | classifier_args.lr,
21 | classifier_args.pretrained,
22 | )
23 |
24 | print("Loading model", model_name)
25 |
26 | # load model
27 | n_classes = len(ChexpertSmall.attr_names)
28 | if model_name == "densenet121":
29 | model = densenet121(weights=(DenseNet121_Weights.DEFAULT if pretrained else None)).to(args.device)
30 | # 1. replace output layer with chexpert number of classes (pretrained loads ImageNet n_classes)
31 | model.classifier = nn.Linear(model.classifier.in_features, out_features=n_classes).to(args.device)
32 | # 2. init output layer with default torchvision init
33 | nn.init.constant_(model.classifier.bias, 0)
34 | # 3. store locations of forward and backward hooks for grad-cam
35 | grad_cam_hooks = {"forward": model.features.norm5, "backward": model.classifier}
36 | # 4. init optimizer and scheduler
37 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
38 | scheduler = None
39 | # optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
40 | # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40000, 60000])
41 | elif model_name == "aadensenet121":
42 | model = DenseNet(
43 | 32,
44 | (6, 12, 24, 16),
45 | 64,
46 | num_classes=n_classes,
47 | attn_params={"k": 0.2, "v": 0.1, "nh": 8, "relative": True, "input_dims": (320, 320)},
48 | ).to(args.device)
49 | grad_cam_hooks = {"forward": model.features, "backward": model.classifier}
50 | attn_hooks = [model.features.transition1.conv, model.features.transition2.conv, model.features.transition3.conv]
51 | optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True)
52 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [40000, 60000])
53 | elif model_name == "resnet152":
54 | model = resnet152(weights=pretrained).to(args.device)
55 | model.fc = nn.Linear(model.fc.in_features, out_features=n_classes).to(args.device)
56 | grad_cam_hooks = {"forward": model.layer4, "backward": model.fc}
57 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
58 | scheduler = None
59 | elif (
60 | model_name == "aaresnet152"
61 | ): # resnet50 layers [3,4,6,3]; resnet101 layers [3,4,23,3]; resnet 152 layers [3,8,36,3]
62 | model = ResNet(
63 | Bottleneck,
64 | [3, 8, 36, 3],
65 | num_classes=n_classes,
66 | attn_params={"k": 0.2, "v": 0.1, "nh": 8, "relative": True, "input_dims": (320, 320)},
67 | ).to(args.device)
68 | grad_cam_hooks = {"forward": model.layer4, "backward": model.fc}
69 | attn_hooks = (
70 | [model.layer2[i].conv2 for i in range(len(model.layer2))]
71 | + [model.layer3[i].conv2 for i in range(len(model.layer3))]
72 | + [model.layer4[i].conv2 for i in range(len(model.layer4))]
73 | )
74 | optimizer = torch.optim.Adam(model.parameters(), lr=lr)
75 | scheduler = None
76 | elif "efficientnet" in model_name:
77 | model = construct_model(model_name, n_classes=n_classes).to(args.device)
78 | grad_cam_hooks = {"forward": model.head[1], "backward": model.head[-1]}
79 | optimizer = torch.optim.RMSprop(model.parameters(), lr=lr, momentum=0.9, eps=0.001)
80 | scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, lr_decay_factor)
81 | else:
82 | raise RuntimeError("Model architecture not supported.")
83 |
84 | if restore and os.path.isfile(restore_path):
85 | print("Restoring model weights from {}".format(restore_path))
86 | model_checkpoint = torch.load(restore_path, map_location=args.device, weights_only=False)
87 | model.load_state_dict(model_checkpoint["state_dict"])
88 | args.model.classifier.step = model_checkpoint["global_step"]
89 | del model_checkpoint
90 |
91 | model = model.to(args.device)
92 |
93 | return model, grad_cam_hooks
94 |
95 |
96 | def define_classifier_model(args):
97 | model, _ = load_model(args)
98 | model = model.to(args.device)
99 |
100 | model.eval()
101 |
102 | classifier_name = args.model.classifier.model.lower()
103 |
104 | print("Using classifier:", classifier_name)
105 |
106 | return model
107 |
--------------------------------------------------------------------------------
/src/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | from PIL import Image
4 | from torch.utils.data import Dataset
5 | from collections import defaultdict
6 | import numpy as np
7 | import pytorch_lightning as pl
8 | from tqdm import tqdm
9 | from torch.utils.data import random_split
10 | import pandas as pd
11 | from chexpert.dataset import ChexpertSmall
12 | import torchvision.transforms as T
13 |
14 |
15 | class ChestXRayDataset(Dataset):
16 | def __init__(self, opts, data_dir, mode="train", transform=None): # mode="train" or "valid"
17 | self.opts = opts
18 | self.mode = mode
19 | self.target = opts.finetune.target
20 | self.data_dir = data_dir
21 | self.transform = transform
22 | self.csv_path = os.path.join(data_dir, f"{mode}.csv")
23 |
24 | # Read CSV
25 | self.df = pd.read_csv(self.csv_path)
26 | self.df = self.df[ChexpertSmall.attr_names + ["Path"]]
27 | self.df = self.df.replace(-1, 0)
28 | self.df = self.df.fillna(0)
29 |
30 | # Filter out lateral images
31 | self.df = self.df[~self.df["Path"].str.contains("lateral", case=False, na=False)]
32 |
33 | self.diseased_image_paths = self.df[self.df[self.target] == 1]["Path"].tolist()
34 | self.healthy_image_paths = self.df[self.df[ChexpertSmall.attr_names].eq(0).all(axis=1)]["Path"].tolist()
35 |
36 | min_images = min(len(self.diseased_image_paths), len(self.healthy_image_paths))
37 | print(
38 | f"Min images: {min_images}, Diseased: {len(self.diseased_image_paths)}, Healthy: {len(self.healthy_image_paths)}"
39 | )
40 |
41 | self.diseased_image_paths = self.diseased_image_paths[:min_images]
42 | self.healthy_image_paths = self.healthy_image_paths[:min_images]
43 |
44 | def __len__(self):
45 | return len(self.diseased_image_paths)
46 |
47 | def __getitem__(self, idxs):
48 | return self.__fetch_one(idxs)
49 |
50 | def __fetch_one(self, idx):
51 | diseased_path = self.diseased_image_paths[idx]
52 | healthy_path = self.healthy_image_paths[idx]
53 |
54 | diseased_path = os.path.join(self.data_dir, diseased_path.replace(self.opts.data.prefix, ""))
55 | healthy_path = os.path.join(self.data_dir, healthy_path.replace(self.opts.data.prefix, ""))
56 |
57 | diseased_image = Image.open(diseased_path).convert("RGB")
58 | healthy_image = Image.open(healthy_path).convert("RGB")
59 |
60 | if self.transform:
61 | diseased_image = self.transform(diseased_image)
62 | healthy_image = self.transform(healthy_image)
63 |
64 | return diseased_image, healthy_image
65 |
66 |
67 | class ChestXRayDataModule(pl.LightningDataModule):
68 | def __init__(self, opts):
69 | super().__init__()
70 | self.opts = opts
71 | self.transform = T.Compose(
72 | [
73 | T.Resize(self.opts.data.resize) if self.opts.data.resize else T.Lambda(lambda x: x),
74 | T.CenterCrop(320 if not self.opts.data.resize else self.opts.data.resize),
75 | T.ToTensor(),
76 | ]
77 | ) # expand to 3 channels
78 |
79 | self.train_dataset = None
80 | self.val_dataset = None
81 | self.test_dataset = None
82 |
83 | def get_train_dataset(self):
84 | ds = ChestXRayDataset(self.opts, self.opts.data.path, mode="train", transform=self.transform)
85 | print(f"Train dataset size: {len(ds)}")
86 | return ds
87 |
88 | def get_val_dataset(self):
89 | return self.get_test_dataset()
90 |
91 | def get_test_dataset(self):
92 | ds = ChestXRayDataset(self.opts, self.opts.data.path, mode="valid", transform=self.transform)
93 | print(f"Test dataset size: {len(ds)}")
94 | return ds
95 |
96 | def setup(self, stage=None):
97 | self.train_dataset = self.get_train_dataset()
98 | self.val_dataset = self.get_val_dataset()
99 | self.test_dataset = self.get_test_dataset()
100 |
101 | def train_dataloader(self):
102 | if self.train_dataset is None:
103 | self.train_dataset = self.get_train_dataset()
104 |
105 | return torch.utils.data.DataLoader(
106 | self.train_dataset,
107 | batch_size=self.opts.data.train_batch_size,
108 | shuffle=False,
109 | num_workers=self.opts.data.num_workers,
110 | pin_memory=True,
111 | prefetch_factor=(
112 | self.opts.data.prefetch_factor
113 | if self.opts.data.prefetch_factor and self.opts.data.num_workers > 0
114 | else None
115 | ),
116 | )
117 |
118 | def val_dataloader(self):
119 | return self.test_dataloader()
120 |
121 | def test_dataloader(self):
122 | if self.test_dataset is None:
123 | self.test_dataset = self.get_test_dataset()
124 |
125 | return torch.utils.data.DataLoader(
126 | self.test_dataset,
127 | batch_size=self.opts.data.test_batch_size,
128 | shuffle=False,
129 | num_workers=self.opts.data.num_workers,
130 | pin_memory=True,
131 | prefetch_factor=(
132 | self.opts.data.prefetch_factor
133 | if self.opts.data.prefetch_factor and self.opts.data.num_workers > 0
134 | else None
135 | ),
136 | )
137 |
138 |
139 | if __name__ == "__main__":
140 |
141 | from omegaconf import OmegaConf
142 |
143 | opts = OmegaConf.load(
144 | "/cluster/home/agoekmen/projects/chexray-diffusion/removal-editing/src/config/finetune_config.yaml"
145 | )
146 | data_module = ChestXRayDataModule(opts)
147 | data_module.setup()
148 | train_loader = data_module.train_dataloader()
149 | val_loader = data_module.val_dataloader()
150 |
151 | for a, b in train_loader:
152 | print(a.shape, b.shape)
153 | break
154 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/losses/contperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | from taming.modules.losses.vqperceptual import * # TODO: taming dependency yes/no?
5 |
6 |
7 | class LPIPSWithDiscriminator(nn.Module):
8 | def __init__(self, disc_start, logvar_init=0.0, kl_weight=1.0, pixelloss_weight=1.0,
9 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
10 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
11 | disc_loss="hinge"):
12 |
13 | super().__init__()
14 | assert disc_loss in ["hinge", "vanilla"]
15 | self.kl_weight = kl_weight
16 | self.pixel_weight = pixelloss_weight
17 | self.perceptual_loss = LPIPS().eval()
18 | self.perceptual_weight = perceptual_weight
19 | # output log variance
20 | self.logvar = nn.Parameter(torch.ones(size=()) * logvar_init)
21 |
22 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
23 | n_layers=disc_num_layers,
24 | use_actnorm=use_actnorm
25 | ).apply(weights_init)
26 | self.discriminator_iter_start = disc_start
27 | self.disc_loss = hinge_d_loss if disc_loss == "hinge" else vanilla_d_loss
28 | self.disc_factor = disc_factor
29 | self.discriminator_weight = disc_weight
30 | self.disc_conditional = disc_conditional
31 |
32 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
33 | if last_layer is not None:
34 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
35 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
36 | else:
37 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
38 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
39 |
40 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
41 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
42 | d_weight = d_weight * self.discriminator_weight
43 | return d_weight
44 |
45 | def forward(self, inputs, reconstructions, posteriors, optimizer_idx,
46 | global_step, last_layer=None, cond=None, split="train",
47 | weights=None):
48 | rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
49 | if self.perceptual_weight > 0:
50 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
51 | rec_loss = rec_loss + self.perceptual_weight * p_loss
52 |
53 | nll_loss = rec_loss / torch.exp(self.logvar) + self.logvar
54 | weighted_nll_loss = nll_loss
55 | if weights is not None:
56 | weighted_nll_loss = weights*nll_loss
57 | weighted_nll_loss = torch.sum(weighted_nll_loss) / weighted_nll_loss.shape[0]
58 | nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
59 | kl_loss = posteriors.kl()
60 | kl_loss = torch.sum(kl_loss) / kl_loss.shape[0]
61 |
62 | # now the GAN part
63 | if optimizer_idx == 0:
64 | # generator update
65 | if cond is None:
66 | assert not self.disc_conditional
67 | logits_fake = self.discriminator(reconstructions.contiguous())
68 | else:
69 | assert self.disc_conditional
70 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
71 | g_loss = -torch.mean(logits_fake)
72 |
73 | if self.disc_factor > 0.0:
74 | try:
75 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
76 | except RuntimeError:
77 | assert not self.training
78 | d_weight = torch.tensor(0.0)
79 | else:
80 | d_weight = torch.tensor(0.0)
81 |
82 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
83 | loss = weighted_nll_loss + self.kl_weight * kl_loss + d_weight * disc_factor * g_loss
84 |
85 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(), "{}/logvar".format(split): self.logvar.detach(),
86 | "{}/kl_loss".format(split): kl_loss.detach().mean(), "{}/nll_loss".format(split): nll_loss.detach().mean(),
87 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
88 | "{}/d_weight".format(split): d_weight.detach(),
89 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
90 | "{}/g_loss".format(split): g_loss.detach().mean(),
91 | }
92 | return loss, log
93 |
94 | if optimizer_idx == 1:
95 | # second pass for discriminator update
96 | if cond is None:
97 | logits_real = self.discriminator(inputs.contiguous().detach())
98 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
99 | else:
100 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
101 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
102 |
103 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
104 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
105 |
106 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
107 | "{}/logits_real".format(split): logits_real.detach().mean(),
108 | "{}/logits_fake".format(split): logits_fake.detach().mean()
109 | }
110 | return d_loss, log
111 |
112 |
--------------------------------------------------------------------------------
/chexzero/components/classification_losses.py:
--------------------------------------------------------------------------------
1 | from functools import partial
2 | from math import prod
3 | from typing import Optional, Tuple
4 | import einops
5 | import torch
6 | import torch.nn.functional as F
7 |
8 | def classify_features(features: torch.FloatTensor, pos_prompt_emb: torch.FloatTensor, neg_prompt_emb: torch.FloatTensor,
9 | normalized: bool = True, temp=1.0, softmax=False,
10 | threshold: Optional[float] = 0.5, return_logits=False) -> Tuple[torch.FloatTensor, torch.BoolTensor]:
11 | """
12 | :param features: (N x ... x d)
13 | :param pos_prompt_emb: (N x ... x d)
14 | :param neg_prompt_emb: (N x ... x d)
15 | :return (N x ...)
16 | """
17 | assert pos_prompt_emb.ndim == neg_prompt_emb.ndim
18 | assert features.ndim == pos_prompt_emb.ndim, f'{features.ndim} != {pos_prompt_emb.ndim}'
19 | features, pos_prompt_emb, neg_prompt_emb = torch.broadcast_tensors(features, pos_prompt_emb, neg_prompt_emb)
20 |
21 | if normalized:
22 | features = F.normalize(features, dim=-1)
23 | pos_prompt_emb = F.normalize(pos_prompt_emb, dim=-1)
24 | neg_prompt_emb = F.normalize(neg_prompt_emb, dim=-1)
25 | else:
26 | features = features.contiguous()
27 | pos_prompt_emb = pos_prompt_emb.contiguous()
28 | neg_prompt_emb = neg_prompt_emb.contiguous()
29 |
30 |
31 | N, *dims, d = features.shape
32 | n_dims = prod(dims)
33 | features = features.view(N, n_dims, d)
34 | pos_prompt_emb = pos_prompt_emb.view(N, n_dims, d)
35 | neg_prompt_emb = neg_prompt_emb.view(N, n_dims, d)
36 |
37 | pos_logits = torch.einsum('ijd,ijd->ij', features, pos_prompt_emb) / temp # (N x dims)
38 | neg_logits = torch.einsum('ijd,ijd->ij', features, neg_prompt_emb) / temp # (N x dims)
39 |
40 | if softmax:
41 | # (N x dims x 2)
42 | probs = torch.softmax(torch.stack([pos_logits, neg_logits], dim=-1), dim=-1)
43 | probs = probs[..., 0] # only positive probs
44 | else:
45 | probs = torch.sigmoid(pos_logits - neg_logits)
46 | probs = probs.view(N, *dims)
47 |
48 | preds = probs > threshold if threshold is not None else torch.ones_like(probs, dtype=bool)
49 |
50 | if not return_logits:
51 | return probs, preds
52 |
53 | if softmax:
54 | logits = torch.log_softmax(torch.stack([pos_logits, neg_logits], dim=-1), dim=-1)
55 | logits = logits[..., 0] # only positive probs
56 | else:
57 | logits = pos_logits - neg_logits
58 | logits = logits.view(N, *dims)
59 |
60 | return logits, probs, preds
61 |
62 |
63 |
64 | # ----------------------- Binary (and multilabel binary) losses ----------------------- #
65 | def binary_focal_loss_logits(
66 | logits: torch.Tensor,
67 | targets: torch.Tensor,
68 | alpha: float = 0.25,
69 | gamma: float = 2) -> torch.Tensor:
70 | logits, targets = torch.broadcast_tensors(logits, targets)
71 | targets = targets.to(dtype=logits.dtype)
72 |
73 | p = torch.sigmoid(logits)
74 | ce_loss = F.binary_cross_entropy_with_logits(logits, targets, reduction="none")
75 | p_t = p * targets + (1 - p) * (1 - targets)
76 | loss = ce_loss * ((1 - p_t) ** gamma)
77 |
78 | if torch.is_tensor(alpha) or alpha >= 0:
79 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
80 | loss = alpha_t * loss
81 | return loss
82 |
83 | def binary_focal_loss_probs(
84 | probs: torch.Tensor,
85 | targets: torch.Tensor,
86 | alpha: float = 0.25,
87 | gamma: float = 2,
88 | eps=1e-7) -> torch.Tensor:
89 | probs, targets = torch.broadcast_tensors(probs, targets)
90 | targets = targets.to(dtype=probs.dtype)
91 |
92 | p = probs
93 | with torch.autocast(device_type='cuda', enabled=False):
94 | ce_loss = F.binary_cross_entropy(probs.float().clamp(min=eps, max=1-eps), targets.float(), reduction="none")
95 | p_t = p * targets + (1 - p) * (1 - targets)
96 | loss = ce_loss * ((1 - p_t) ** gamma)
97 |
98 | if alpha >= 0:
99 | alpha_t = alpha * targets + (1 - alpha) * (1 - targets)
100 | loss = alpha_t * loss
101 | return loss
102 |
103 |
104 | def autocompute_class_weights(targets: torch.Tensor, per_class: bool = False, cls_dim: int = None) -> torch.Tensor:
105 | if per_class:
106 | assert cls_dim is not None
107 | dims_before = targets.shape[:cls_dim]
108 | C = targets.shape[cls_dim]
109 | dims_after = targets.shape[cls_dim+1:]
110 | # (... x C x ...) where C is the dim at index cls_dim
111 | targets = targets.view(prod(dims_before), C, prod(dims_after))
112 |
113 | N_pos = targets.sum(dim=0).sum(dim=-1) # (C)
114 | # (1 x ... x 1 x C x 1 x ...)
115 | N_pos = N_pos.view((1,) * len(dims_before) + (C,) + (1,) * len(dims_after))
116 | N = targets.numel() / C # ()
117 | else:
118 | # (...)
119 | targets = targets.view(-1)
120 | N_pos = targets.sum()
121 | N = targets.numel()
122 |
123 | N_neg = N - N_pos # (C) or ()
124 |
125 | weight_pos = (N + 1) / (N_pos + 1) # (C) or ()
126 | weight_neg = (N + 1) / (N_neg + 1) # (C) or ()
127 |
128 | return weight_pos, weight_neg
129 |
130 |
131 | def get_focal_loss(logits: bool = True, auto_weight: bool = False, **kwargs):
132 | focal_loss_fn = binary_focal_loss_logits if logits else binary_focal_loss_probs
133 | if auto_weight:
134 | def _loss_fn(preds, targets):
135 | preds, targets = torch.broadcast_tensors(preds, targets)
136 | weight_pos, weight_neg = autocompute_class_weights(targets, per_class=False)
137 | alpha = weight_pos / (weight_pos + weight_neg)
138 | return focal_loss_fn(preds, targets, alpha=alpha, **kwargs)
139 | else:
140 | _loss_fn = partial(focal_loss_fn, **kwargs)
141 | return _loss_fn
142 |
--------------------------------------------------------------------------------
/cheff/sr/schedule.py:
--------------------------------------------------------------------------------
1 | """Classes und functions for variance schedules."""
2 |
3 | from abc import ABC, abstractmethod
4 | from typing import Any, Optional
5 |
6 | import torch
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 |
10 |
11 | class BaseSchedule(ABC):
12 | """Base class for deriving schedules."""
13 |
14 | def __init__(
15 | self, timesteps: int, device: Optional[torch.device] = None, *args, **kwargs
16 | ) -> None:
17 | """Initialize BaseSchedule."""
18 | self.timesteps = timesteps
19 | if device is None:
20 | self.device = torch.device('cpu')
21 | else:
22 | self.device = device
23 |
24 | self.betas = self._get_betas(timesteps).to(device)
25 | self.alphas = 1.0 - self.betas
26 |
27 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
28 | self.alphas_cumprod_prev = F.pad(self.alphas_cumprod[:-1], (1, 0), value=1.0)
29 |
30 | self.sqrt_recip_alphas = torch.sqrt(1.0 / self.alphas)
31 | self.sqrt_alphas_cumprod = torch.sqrt(self.alphas_cumprod)
32 | self.sqrt_one_minus_alphas_cumprod = torch.sqrt(1.0 - self.alphas_cumprod)
33 |
34 | self.log_one_minus_alphas_cumprod = torch.log(1.0 - self.alphas_cumprod)
35 | self.sqrt_recip_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod)
36 | self.sqrt_recipm1_alphas_cumprod = torch.sqrt(1.0 / self.alphas_cumprod - 1)
37 |
38 | self.post_var = (
39 | self.betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
40 | )
41 |
42 | self.post_log_var_clipped = torch.log(
43 | torch.maximum(self.post_var, torch.tensor(1e-20))
44 | )
45 | self.post_mean_coef1 = (
46 | self.betas
47 | * torch.sqrt(self.alphas_cumprod_prev)
48 | / (1.0 - self.alphas_cumprod)
49 | )
50 | self.post_mean_coef2 = (
51 | (1.0 - self.alphas_cumprod_prev)
52 | * torch.sqrt(self.alphas)
53 | / (1.0 - self.alphas_cumprod)
54 | )
55 |
56 | @abstractmethod
57 | def _get_betas(self, timesteps: int) -> Tensor:
58 | """Get betas."""
59 | pass
60 |
61 |
62 | class LinearSchedule(BaseSchedule):
63 | """Linear variance schedule."""
64 |
65 | def __init__(
66 | self,
67 | timesteps: int,
68 | device: Optional[torch.device] = None,
69 | beta_start: float = 0.0001,
70 | beta_end: float = 0.02,
71 | *args,
72 | **kwargs
73 | ) -> None:
74 | """Initialize linear beta schedule."""
75 | self.beta_start = beta_start
76 | self.beta_end = beta_end
77 | super().__init__(timesteps, device, *args, **kwargs)
78 |
79 | def _get_betas(self, timesteps: int) -> Tensor:
80 | """Get betas."""
81 | return torch.linspace(self.beta_start, self.beta_end, timesteps)
82 |
83 |
84 | class CosineSchedule(BaseSchedule):
85 | """Cosine variance schedule."""
86 |
87 | def __init__(
88 | self,
89 | timesteps: int,
90 | device: Optional[torch.device] = None,
91 | s: float = 0.008,
92 | *args,
93 | **kwargs
94 | ) -> None:
95 | """Initialize cosine beta schedule."""
96 | self.s = s
97 | super().__init__(timesteps, device, *args, **kwargs)
98 |
99 | def _get_betas(self, timesteps: int) -> Tensor:
100 | """Get betas."""
101 | steps = timesteps + 1
102 | x = torch.linspace(0, timesteps, steps)
103 | alphas_cumprod = (
104 | torch.cos(((x / timesteps) + self.s) / (1 + self.s) * torch.pi * 0.5) ** 2
105 | )
106 | alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
107 | betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
108 | return torch.clip(betas, 0.0001, 0.9999)
109 |
110 |
111 | class QuadraticSchedule(BaseSchedule):
112 | """Quadratic variance schedule."""
113 |
114 | def __init__(
115 | self,
116 | timesteps: int,
117 | device: Optional[torch.device] = None,
118 | beta_start: float = 0.0001,
119 | beta_end: float = 0.02,
120 | *args,
121 | **kwargs
122 | ) -> None:
123 | """Initialize quadratic beta schedule."""
124 | self.beta_start = beta_start
125 | self.beta_end = beta_end
126 | super().__init__(timesteps, device, *args, **kwargs)
127 |
128 | def _get_betas(self, timesteps: int) -> Tensor:
129 | """Get betas."""
130 | return (
131 | torch.linspace(self.beta_start**0.5, self.beta_end**0.5, timesteps) ** 2
132 | )
133 |
134 |
135 | class SigmoidSchedule(BaseSchedule):
136 | """Sigmoid variance schedule."""
137 |
138 | def __init__(
139 | self,
140 | timesteps: int,
141 | device: Optional[torch.device] = None,
142 | beta_start: float = 0.0001,
143 | beta_end: float = 0.02,
144 | *args,
145 | **kwargs
146 | ) -> None:
147 | """Initialize sigmoid beta schedule."""
148 | self.beta_start = beta_start
149 | self.beta_end = beta_end
150 | super().__init__(timesteps, device, *args, **kwargs)
151 |
152 | def _get_betas(self, timesteps: int) -> Tensor:
153 | """Get betas."""
154 | betas = torch.linspace(-6, 6, timesteps)
155 | return (
156 | torch.sigmoid(betas) * (self.beta_end - self.beta_start) + self.beta_start
157 | )
158 |
159 |
160 | class ScheduleFactory:
161 | """Factory wrapper for variance schedules."""
162 |
163 | @staticmethod
164 | def get_schedule(name: str, timesteps: int, *args, **kwargs) -> BaseSchedule:
165 | """Initialize a scheduler by name."""
166 | cls: Any
167 | if name == 'linear':
168 | cls = LinearSchedule
169 | elif name == 'cosine':
170 | cls = CosineSchedule
171 | elif name == 'quadratic':
172 | cls = QuadraticSchedule
173 | elif name == 'sigmoid':
174 | cls = SigmoidSchedule
175 | else:
176 | raise ValueError(
177 | 'There is no matching schedule for name "{}".'.format(name)
178 | )
179 |
180 | return cls(timesteps, *args, **kwargs)
181 |
--------------------------------------------------------------------------------
/chexzero/chexzero/simple_tokenizer.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 OpenAI
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | """
25 | import gzip
26 | import html
27 | import os
28 | from functools import lru_cache
29 |
30 | import ftfy
31 | import regex as re
32 |
33 |
34 | @lru_cache()
35 | def default_bpe():
36 | return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
37 |
38 |
39 | @lru_cache()
40 | def bytes_to_unicode():
41 | """
42 | Returns list of utf-8 byte and a corresponding list of unicode strings.
43 | The reversible bpe codes work on unicode strings.
44 | This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
45 | When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
46 | This is a signficant percentage of your normal, say, 32K bpe vocab.
47 | To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
48 | And avoids mapping to whitespace/control characters the bpe code barfs on.
49 | """
50 | bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
51 | cs = bs[:]
52 | n = 0
53 | for b in range(2**8):
54 | if b not in bs:
55 | bs.append(b)
56 | cs.append(2**8+n)
57 | n += 1
58 | cs = [chr(n) for n in cs]
59 | return dict(zip(bs, cs))
60 |
61 |
62 | def get_pairs(word):
63 | """Return set of symbol pairs in a word.
64 | Word is represented as tuple of symbols (symbols being variable-length strings).
65 | """
66 | pairs = set()
67 | prev_char = word[0]
68 | for char in word[1:]:
69 | pairs.add((prev_char, char))
70 | prev_char = char
71 | return pairs
72 |
73 |
74 | def basic_clean(text):
75 | text = ftfy.fix_text(text)
76 | text = html.unescape(html.unescape(text))
77 | return text.strip()
78 |
79 |
80 | def whitespace_clean(text):
81 | text = re.sub(r'\s+', ' ', text)
82 | text = text.strip()
83 | return text
84 |
85 |
86 | class SimpleTokenizer(object):
87 | def __init__(self, bpe_path: str = default_bpe()):
88 | self.byte_encoder = bytes_to_unicode()
89 | self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
90 | merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
91 | merges = merges[1:49152-256-2+1]
92 | merges = [tuple(merge.split()) for merge in merges]
93 | vocab = list(bytes_to_unicode().values())
94 | vocab = vocab + [v+'' for v in vocab]
95 | for merge in merges:
96 | vocab.append(''.join(merge))
97 | vocab.extend(['<|startoftext|>', '<|endoftext|>'])
98 | self.encoder = dict(zip(vocab, range(len(vocab))))
99 | self.decoder = {v: k for k, v in self.encoder.items()}
100 | self.bpe_ranks = dict(zip(merges, range(len(merges))))
101 | self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
102 | self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
103 |
104 | def bpe(self, token):
105 | if token in self.cache:
106 | return self.cache[token]
107 | word = tuple(token[:-1]) + ( token[-1] + '',)
108 | pairs = get_pairs(word)
109 |
110 | if not pairs:
111 | return token+''
112 |
113 | while True:
114 | bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
115 | if bigram not in self.bpe_ranks:
116 | break
117 | first, second = bigram
118 | new_word = []
119 | i = 0
120 | while i < len(word):
121 | try:
122 | j = word.index(first, i)
123 | new_word.extend(word[i:j])
124 | i = j
125 | except:
126 | new_word.extend(word[i:])
127 | break
128 |
129 | if word[i] == first and i < len(word)-1 and word[i+1] == second:
130 | new_word.append(first+second)
131 | i += 2
132 | else:
133 | new_word.append(word[i])
134 | i += 1
135 | new_word = tuple(new_word)
136 | word = new_word
137 | if len(word) == 1:
138 | break
139 | else:
140 | pairs = get_pairs(word)
141 | word = ' '.join(word)
142 | self.cache[token] = word
143 | return word
144 |
145 | def encode(self, text):
146 | bpe_tokens = []
147 | text = whitespace_clean(basic_clean(text)).lower()
148 | for token in re.findall(self.pat, text):
149 | token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
150 | bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
151 | return bpe_tokens
152 |
153 | def decode(self, tokens):
154 | text = ''.join([self.decoder[token] for token in tokens])
155 | text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
156 | return text
157 |
--------------------------------------------------------------------------------
/models/cheff.py:
--------------------------------------------------------------------------------
1 | import copy
2 | import os
3 | import torch
4 |
5 | from cheff import CheffLDMImageMaskCond
6 | from cheff import CheffLDMMaskCond
7 | from cheff import CheffLDMImageCond
8 | from cheff import CheffLDM
9 | from cheff.ldm.models.autoencoder import AutoencoderKL
10 | from cheff.ldm.models.diffusion.ddpm import LatentDiffusion
11 | from cheff.ldm.models.diffusion.ddim import DDIMSampler
12 | from cheff.ldm.modules.diffusionmodules.openaimodel import AttentionBlock
13 | from cheff.ldm.modules.diffusionmodules.openaimodel import ResBlock
14 | from src.utils import save_txt
15 |
16 |
17 | def get_cheff_ldm(
18 | args, load_checkpoint=True, load_base_uncond=False, load_image_mask_cond=False, load_image_cond=False
19 | ):
20 | print("Loading cheff ldm...")
21 | cheff_args = args.model.cheff
22 | model_path, ae_path = cheff_args.ldm_path, cheff_args.ae_path
23 |
24 | if load_base_uncond:
25 | print("Loading CheffLDM...")
26 | cheff_ldm = CheffLDM(model_path=model_path, ae_path=ae_path, device=args.device)
27 | return cheff_ldm
28 |
29 | if not load_image_mask_cond and not load_image_cond:
30 | print("Loading CheffLDMMaskCond...")
31 | cheff_ldm = CheffLDMMaskCond(
32 | model_path=model_path, ae_path=ae_path, device=args.device, load_checkpoint=load_checkpoint
33 | )
34 | elif load_image_mask_cond:
35 | print("Loading CheffLDMImageMaskCond...")
36 | cheff_ldm = CheffLDMImageMaskCond(
37 | model_path=model_path, ae_path=ae_path, device=args.device, load_checkpoint=load_checkpoint
38 | )
39 | elif load_image_cond:
40 | print("Loading CheffLDMImageCond...")
41 | cheff_ldm = CheffLDMImageCond(
42 | model_path=model_path, ae_path=ae_path, device=args.device, load_checkpoint=load_checkpoint
43 | )
44 | else:
45 | raise ValueError("Invalid configuration for Cheff LDM")
46 |
47 | print("Cheff LDM loaded: ", load_checkpoint)
48 |
49 | if cheff_args.load_external:
50 | print("Loading external model...")
51 | cheff_ldm.model.load_state_dict(torch.load(cheff_args.external_path, map_location=args.device))
52 |
53 | return cheff_ldm
54 |
55 |
56 | """
57 | The effect of applying the erasure objective (6) depends
58 | on the subset of parameters that is fine-tuned. The main
59 | distinction is between cross-attention parameters and noncross-attention parameters. Cross-attention parameters, illustrated in Figure 3a, serve as a gateway to the prompt, directly
60 | depending on the text of the prompt, while other parameters
61 | (Figure 3b) tend to contribute to a visual concept even if the
62 | concept is not mentioned in the prompt.
63 | Therefore we propose fine tuning the cross attentions,
64 | ESD-x, when the erasure is required to be controlled and
65 | specific to the prompt, such as when a named artistic style
66 | should be erased. Further, we propose fine tuning unconditional layers (non-cross-attention modules), ESD-u, when
67 | the erasure is required to be independent of the text in the
68 | prompt, such as when the global concept of NSFW nudity
69 | should be erased. We refer to cross-attention-only finetuning as ESD-x-η (where η refers to the strength of the
70 | negative guidance), and we refer to the configuration that
71 | tunes only non-cross-attention parameters as ESD-u-η. For
72 | simplicity, we write ESD-x and ESD-u when η = 1
73 | """
74 |
75 | all_train_methods = {
76 | "attn": [AttentionBlock],
77 | "res": [ResBlock],
78 | "attnres": [AttentionBlock, ResBlock],
79 | "all": None,
80 | "notime": None,
81 | }
82 |
83 |
84 | def get_trainable_params(model, args, specific_paths=["model.diffusion_model.input_blocks.0.0"], save_params=False):
85 | parameters = []
86 | parameter_names = []
87 |
88 | train_method = args.model.train.method.lower().strip()
89 |
90 | assert train_method in all_train_methods.keys(), "Unsupported train method"
91 | print("Using train method:", train_method)
92 |
93 | if train_method == "notime":
94 | # Go through all the parameters and only exclude the time embedding ones
95 | for name, param in model.named_parameters():
96 | if "time_embed" not in name and "first_stage_model" not in name:
97 | param.requires_grad = True
98 | parameters.append(param)
99 | parameter_names.append(name)
100 |
101 | elif train_method == "all":
102 | for name, param in model.named_parameters():
103 | if "first_stage_model" not in name:
104 | param.requires_grad = True
105 | parameters.append(param)
106 | parameter_names.append(name)
107 |
108 | else:
109 | train_filters = all_train_methods[train_method]
110 |
111 | def recurse(module, current_path=""):
112 | for name, child in module.named_children():
113 | path = f"{current_path}.{name}" if current_path else name
114 |
115 | if any(isinstance(child, klass) for klass in train_filters) or (
116 | specific_paths and path in specific_paths
117 | ):
118 | for param_name, param in child.named_parameters():
119 | param.requires_grad = True
120 | parameters.append(param)
121 | parameter_names.append(
122 | f"{path} -- {name}.{param_name} -- {child.__class__.__name__} -- {param_name}"
123 | )
124 | else:
125 | recurse(child, path)
126 |
127 | recurse(model)
128 |
129 | if save_params:
130 | save_txt(
131 | args.finetune.log_dir,
132 | "trainable_params",
133 | parameter_names,
134 | )
135 |
136 | # with open(os.path.join(args.out_dir, "non_trainable_params.txt"), "w") as f:
137 | # f.write("\n".join(non_trainable_parameter_names))
138 |
139 | print("Trainable parameters:", len(parameters))
140 |
141 | return parameters
142 |
143 |
144 | def clone_model_for_sample(original_model, args):
145 | cloned_model = copy.deepcopy(original_model)
146 |
147 | parameters = get_trainable_params(cloned_model.model, args, save_params=True)
148 |
149 | cloned_model.model.train()
150 | cloned_model.model.model.diffusion_model.train()
151 | return cloned_model, parameters
152 |
--------------------------------------------------------------------------------
/cheff/ldm/util.py:
--------------------------------------------------------------------------------
1 | import importlib
2 |
3 | import torch
4 | import numpy as np
5 | from collections import abc
6 | from einops import rearrange
7 | from functools import partial
8 |
9 | import multiprocessing as mp
10 | from threading import Thread
11 | from queue import Queue
12 |
13 | from inspect import isfunction
14 | from PIL import Image, ImageDraw, ImageFont
15 |
16 |
17 | def log_txt_as_img(wh, xc, size=10):
18 | # wh a tuple of (width, height)
19 | # xc a list of captions to plot
20 | b = len(xc)
21 | txts = list()
22 | for bi in range(b):
23 | txt = Image.new("RGB", wh, color="white")
24 | draw = ImageDraw.Draw(txt)
25 | font = ImageFont.truetype('data/DejaVuSans.ttf', size=size)
26 | nc = int(40 * (wh[0] / 256))
27 | lines = "\n".join(xc[bi][start:start + nc] for start in range(0, len(xc[bi]), nc))
28 |
29 | try:
30 | draw.text((0, 0), lines, fill="black", font=font)
31 | except UnicodeEncodeError:
32 | print("Cant encode string for logging. Skipping.")
33 |
34 | txt = np.array(txt).transpose(2, 0, 1) / 127.5 - 1.0
35 | txts.append(txt)
36 | txts = np.stack(txts)
37 | txts = torch.tensor(txts)
38 | return txts
39 |
40 |
41 | def ismap(x):
42 | if not isinstance(x, torch.Tensor):
43 | return False
44 | return (len(x.shape) == 4) and (x.shape[1] > 3)
45 |
46 |
47 | def isimage(x):
48 | if not isinstance(x, torch.Tensor):
49 | return False
50 | return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
51 |
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
62 |
63 | def mean_flat(tensor):
64 | """
65 | https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/nn.py#L86
66 | Take the mean over all non-batch dimensions.
67 | """
68 | return tensor.mean(dim=list(range(1, len(tensor.shape))))
69 |
70 |
71 | def count_params(model, verbose=False):
72 | total_params = sum(p.numel() for p in model.parameters())
73 | if verbose:
74 | print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
75 | return total_params
76 |
77 |
78 | def instantiate_from_config(config):
79 | if not "target" in config:
80 | if config == '__is_first_stage__':
81 | return None
82 | elif config == "__is_unconditional__":
83 | return None
84 | raise KeyError("Expected key `target` to instantiate.")
85 | return get_obj_from_str(config["target"])(**config.get("params", dict()))
86 |
87 |
88 | def get_obj_from_str(string, reload=False):
89 | module, cls = string.rsplit(".", 1)
90 | if reload:
91 | module_imp = importlib.import_module(module)
92 | importlib.reload(module_imp)
93 | return getattr(importlib.import_module(module, package=None), cls)
94 |
95 |
96 | def _do_parallel_data_prefetch(func, Q, data, idx, idx_to_fn=False):
97 | # create dummy dataset instance
98 |
99 | # run prefetching
100 | if idx_to_fn:
101 | res = func(data, worker_id=idx)
102 | else:
103 | res = func(data)
104 | Q.put([idx, res])
105 | Q.put("Done")
106 |
107 |
108 | def parallel_data_prefetch(
109 | func: callable, data, n_proc, target_data_type="ndarray", cpu_intensive=True, use_worker_id=False
110 | ):
111 | # if target_data_type not in ["ndarray", "list"]:
112 | # raise ValueError(
113 | # "Data, which is passed to parallel_data_prefetch has to be either of type list or ndarray."
114 | # )
115 | if isinstance(data, np.ndarray) and target_data_type == "list":
116 | raise ValueError("list expected but function got ndarray.")
117 | elif isinstance(data, abc.Iterable):
118 | if isinstance(data, dict):
119 | print(
120 | f'WARNING:"data" argument passed to parallel_data_prefetch is a dict: Using only its values and disregarding keys.'
121 | )
122 | data = list(data.values())
123 | if target_data_type == "ndarray":
124 | data = np.asarray(data)
125 | else:
126 | data = list(data)
127 | else:
128 | raise TypeError(
129 | f"The data, that shall be processed parallel has to be either an np.ndarray or an Iterable, but is actually {type(data)}."
130 | )
131 |
132 | if cpu_intensive:
133 | Q = mp.Queue(1000)
134 | proc = mp.Process
135 | else:
136 | Q = Queue(1000)
137 | proc = Thread
138 | # spawn processes
139 | if target_data_type == "ndarray":
140 | arguments = [
141 | [func, Q, part, i, use_worker_id]
142 | for i, part in enumerate(np.array_split(data, n_proc))
143 | ]
144 | else:
145 | step = (
146 | int(len(data) / n_proc + 1)
147 | if len(data) % n_proc != 0
148 | else int(len(data) / n_proc)
149 | )
150 | arguments = [
151 | [func, Q, part, i, use_worker_id]
152 | for i, part in enumerate(
153 | [data[i: i + step] for i in range(0, len(data), step)]
154 | )
155 | ]
156 | processes = []
157 | for i in range(n_proc):
158 | p = proc(target=_do_parallel_data_prefetch, args=arguments[i])
159 | processes += [p]
160 |
161 | # start processes
162 | print(f"Start prefetching...")
163 | import time
164 |
165 | start = time.time()
166 | gather_res = [[] for _ in range(n_proc)]
167 | try:
168 | for p in processes:
169 | p.start()
170 |
171 | k = 0
172 | while k < n_proc:
173 | # get result
174 | res = Q.get()
175 | if res == "Done":
176 | k += 1
177 | else:
178 | gather_res[res[0]] = res[1]
179 |
180 | except Exception as e:
181 | print("Exception: ", e)
182 | for p in processes:
183 | p.terminate()
184 |
185 | raise e
186 | finally:
187 | for p in processes:
188 | p.join()
189 | print(f"Prefetching complete. [{time.time() - start} sec.]")
190 |
191 | if target_data_type == 'ndarray':
192 | if not isinstance(gather_res[0], np.ndarray):
193 | return np.concatenate([np.asarray(r) for r in gather_res], axis=0)
194 |
195 | # order outputs
196 | return np.concatenate(gather_res, axis=0)
197 | elif target_data_type == 'list':
198 | out = []
199 | for r in gather_res:
200 | out.extend(r)
201 | return out
202 | else:
203 | return gather_res
204 |
--------------------------------------------------------------------------------
/chexzero/img_encoder/chexzero_img_encoder.py:
--------------------------------------------------------------------------------
1 | from dataclasses import dataclass
2 | import logging
3 | import os
4 | from typing import Optional
5 | import einops
6 | from torch import Tensor
7 | from torch import nn
8 | import torch
9 | import torch.nn.functional as F
10 | from chexzero.model import VisualTransformer
11 | from chexzero.components.mlp import MLP
12 | from chexzero.img_encoder import ImageEncoderOutput
13 |
14 | from util.model_utils import BaseModel, BaseModelConfig, MainModelConfig
15 | import chexzero.clip
16 |
17 | log = logging.getLogger(__name__)
18 |
19 |
20 | @dataclass
21 | class ChexzeroImageEncoderConfig(BaseModelConfig):
22 | model_path: str = os.path.expanduser(
23 | "~/models/third_party/chexzero/CheXzero_Models/best_64_5e-05_original_22000_0.864.pt"
24 | )
25 | frozen_backbone: bool = False
26 | # freeze over full training, i.e. never unfreeze
27 | freeze_backbone_layers: Optional[int] = None
28 |
29 | add_cls_features: bool = False
30 | use_pretrained_projection: bool = True
31 | # 0 = no additionl projection, 1 = linear, 2 = one hidden layer
32 | additional_projection_layers: int = 0
33 | projection_bn: bool = False
34 | normalize_projected: bool = False
35 |
36 |
37 | class ChexzeroImageEncoder(BaseModel):
38 | CONFIG_CLS = ChexzeroImageEncoderConfig
39 | MODIFYABLE_CONFIGS = ("frozen_backbone",)
40 |
41 | def __init__(self, config: ChexzeroImageEncoderConfig, main_config: MainModelConfig):
42 | super().__init__(config)
43 | self.config: ChexzeroImageEncoderConfig
44 |
45 | model, _ = chexzero.clip.load("ViT-B/32", device="cpu", jit=False)
46 | model.load_state_dict(torch.load(self.config.model_path, map_location="cpu"))
47 | self.d = main_config.d_model
48 |
49 | self.backbone: VisualTransformer = model.visual
50 | d_backbone = self.backbone.output_dim if self.config.use_pretrained_projection else self.backbone.proj.shape[0]
51 | self.n_layers = len(self.backbone.transformer.resblocks) + 2
52 | if config.frozen_backbone:
53 | for param in self.backbone.parameters():
54 | param.requires_grad = False
55 | log.info("Freezing backbone for the whole training")
56 | self.n_frozen_layers = self.n_layers
57 | self.n_currently_frozen_layers = self.n_layers
58 | elif config.freeze_backbone_layers is not None and config.freeze_backbone_layers != 0:
59 | n_transformer_layers = len(self.backbone.transformer.resblocks)
60 | if config.freeze_backbone_layers < 0:
61 | config.freeze_backbone_layers = (n_transformer_layers + 2) + config.freeze_backbone_layers
62 | self.freeze_layers(config.freeze_backbone_layers)
63 | self.n_frozen_layers = config.freeze_backbone_layers
64 | self.n_currently_frozen_layers = config.freeze_backbone_layers
65 | log.info(f"Freezing {self.n_frozen_layers}/{self.n_layers} layers of backbone for the whole training")
66 |
67 | self.patch_projection = MLP(
68 | self.config.additional_projection_layers,
69 | d_in=d_backbone,
70 | d_out=self.d,
71 | use_bn=self.config.projection_bn,
72 | act=main_config.act,
73 | dropout=main_config.dropout,
74 | )
75 |
76 | def freeze_layers(self, n_frozen_layers: int):
77 | # first layer (embedding layer)
78 | emb_layer_requires_grad = n_frozen_layers == 0
79 | for param in self.backbone.conv1.parameters():
80 | param.requires_grad = emb_layer_requires_grad
81 | self.backbone.class_embedding.requires_grad = emb_layer_requires_grad
82 | self.backbone.positional_embedding.requires_grad = emb_layer_requires_grad
83 | self.backbone.ln_pre.requires_grad = emb_layer_requires_grad
84 |
85 | # transformer layers
86 | n_transformer_layers = len(self.backbone.transformer.resblocks)
87 | n_frozen_transformer_layers = n_frozen_layers - 1
88 | for i, resblock in enumerate(self.backbone.transformer.resblocks):
89 | layer_requires_grad = i >= n_frozen_transformer_layers
90 | for param in resblock.parameters():
91 | param.requires_grad = layer_requires_grad
92 |
93 | # last layer (projection layer)
94 | proj_layer_requires_grad = n_frozen_layers > n_transformer_layers + 1
95 | for param in self.backbone.ln_post.parameters():
96 | param.requires_grad = proj_layer_requires_grad
97 | self.backbone.proj.requires_grad = proj_layer_requires_grad
98 |
99 | def forward(self, x: Tensor, **kwargs) -> ImageEncoderOutput:
100 |
101 | if x.ndim == 3:
102 | x = einops.repeat(x, "n h w -> n c h w", c=3)
103 | device = x.device
104 | dtype = x.dtype
105 | N, _, H, W = x.shape
106 | assert H == W, "Only square images are supported"
107 | input_resolution = H
108 | assert (
109 | self.backbone.input_resolution == input_resolution
110 | ), f"Input resolution of backbone ({self.backbone.input_resolution}) does not match input resolution of image ({input_resolution})"
111 |
112 | # Encode image using backbone
113 | with torch.set_grad_enabled(not self.config.frozen_backbone):
114 | x = self.backbone.conv1(x) # shape = [*, width, grid, grid]
115 | N, _, H, W = x.shape
116 | x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
117 | x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
118 | x = torch.cat(
119 | [
120 | self.backbone.class_embedding.to(x.dtype)
121 | + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
122 | x,
123 | ],
124 | dim=1,
125 | ) # shape = [*, grid ** 2 + 1, width]
126 | pos_emb = self.backbone.positional_embedding.to(x.dtype) # shape = [grid ** 2 + 1, width]
127 | x = x + pos_emb
128 | x = self.backbone.ln_pre(x)
129 |
130 | x = x.permute(1, 0, 2) # NLD -> LND
131 | x = self.backbone.transformer(x)
132 | x = x.permute(1, 0, 2) # LND -> NLD
133 |
134 | # (N x d)
135 | cls_features = self.backbone.ln_post(x[:, 0, :])
136 | if self.config.add_cls_features:
137 | patch_features = self.backbone.ln_post(x[:, 1:, :] + x[:, 0, :].unsqueeze(1))
138 | else:
139 | # (N x HW x d)
140 | patch_features = self.backbone.ln_post(x[:, 1:, :])
141 |
142 | if self.config.use_pretrained_projection:
143 | cls_features = cls_features @ self.backbone.proj
144 | patch_features = patch_features @ self.backbone.proj
145 | pos_embeddings = pos_emb[1:] @ self.backbone.proj
146 | else:
147 | pos_embeddings = pos_emb[1:]
148 |
149 | # (N, H, W, d_backbone)
150 | patch_features = einops.rearrange(patch_features, "n (h w) d -> n h w d", h=H, w=W)
151 | # (N, H, W, d_backbone) -> (N, H, W, d)
152 | pos_embeddings = einops.repeat(pos_embeddings, "(h w) d -> n h w d", h=H, w=W, n=N)
153 |
154 | # (N x H x W x d)
155 | projected_patch_features = self.patch_projection(patch_features)
156 | cls_features = self.patch_projection(cls_features.unsqueeze(1)).squeeze(1)
157 | if self.config.additional_projection_layers > 0:
158 | pos_embeddings = self.patch_projection(pos_embeddings)
159 |
160 | if self.config.normalize_projected:
161 | projected_patch_features = F.normalize(projected_patch_features, dim=-1)
162 | cls_features = F.normalize(cls_features, dim=-1)
163 |
164 | return ImageEncoderOutput(
165 | patch_features=projected_patch_features, pos_embeddings=pos_embeddings, global_features=cls_features
166 | )
167 |
--------------------------------------------------------------------------------
/chexzero/util/plot_grounding.py:
--------------------------------------------------------------------------------
1 | from matplotlib import gridspec
2 | import matplotlib
3 | import matplotlib.pyplot as plt
4 | import matplotlib.colors as mcolors
5 | import numpy as np
6 | import torch.nn.functional as F
7 | import torch
8 | import textwrap as twp
9 |
10 | import wandb
11 |
12 | from util.plot_utils import plot_img_with_bounding_boxes
13 |
14 |
15 | def plot_grounding(model_output: 'ChEX', max_samples: int = 10):
16 | N = len(model_output.sample_id)
17 | if max_samples is None or max_samples > N:
18 | max_samples = N
19 |
20 | figs = []
21 | for i in range(max_samples):
22 | output_i = model_output[i]
23 | fig = plot_grounding_sample(output_i.x, output_i.sentences,
24 | output_i.encoded_img.patch_features,
25 | output_i.encoded_sentences.sentence_features, output_i.encoded_sentences.sentence_mask,
26 | output_i.grounding.boxes, output_i.grounding.multiboxes,
27 | output_i.grounding.box_features)
28 | figs.append(wandb.Image(fig))
29 | return figs
30 |
31 |
32 | def plot_grounding_sample(img, sentences,
33 | patch_features,
34 | sentence_features, sentence_mask,
35 | boxes, multiboxes, box_features):
36 | n_sentences = len(sentences)
37 | assert sentence_mask.sum() == n_sentences, f'Expected {n_sentences} sentences, got {sentence_mask.sum()}'
38 | img_shape = img.shape
39 | img = img.numpy()
40 |
41 | sentence_features = sentence_features[sentence_mask]
42 | box_features = box_features[sentence_mask]
43 | if multiboxes is not None:
44 | boxes = multiboxes[sentence_mask].numpy()
45 | S, R, _ = boxes.shape
46 | else:
47 | boxes = boxes[sentence_mask].numpy()
48 | S, _ = boxes.shape
49 |
50 | # Similarities and neighbors
51 | # (S x S_r)
52 | sent_region_l2_pairwise = torch.cdist(sentence_features, box_features, p=2)
53 | # (S)
54 | sentence_region_l2 = sent_region_l2_pairwise.diagonal()
55 | # (S x S_r)
56 | sent_region_cos_pairwise = torch.nn.functional.cosine_similarity(sentence_features.unsqueeze(1), box_features.unsqueeze(0), dim=-1)
57 | # (S)
58 | sentence_region_cos = sent_region_cos_pairwise.diagonal()
59 | # (S)
60 | region_rank = (sent_region_cos_pairwise > sentence_region_cos.unsqueeze(-1)).sum(dim=-1) + 1
61 | sentence_rank = (sent_region_cos_pairwise > sentence_region_cos.unsqueeze(0)).sum(dim=0) + 1
62 | # (S)
63 | sentence_region_neighbor = sent_region_cos_pairwise.argmax(dim=-1) + 1
64 | region_sentence_neighbor = sent_region_cos_pairwise.argmax(dim=0) + 1
65 |
66 | *shape_patch, d = patch_features.shape
67 | # (S x H*W)
68 | sentence_patch_cos_pairwise = torch.nn.functional.cosine_similarity(sentence_features.unsqueeze(1), patch_features.view(-1, d).unsqueeze(0), dim=-1)
69 | # (H*W)
70 | patch_sentence_neighbor = sentence_patch_cos_pairwise.argmax(dim=0) + 1
71 | patch_sentence_neighbor = patch_sentence_neighbor.view(*shape_patch)
72 | patch_sentence_neighbor_upsampled = F.interpolate(patch_sentence_neighbor[None, None, ...].float(), size=img_shape, mode='nearest-exact').squeeze()
73 | upsample_factor = np.array(img_shape) / np.array(shape_patch)
74 |
75 | # To numpy
76 | sentence_region_l2 = sentence_region_l2.numpy()
77 | sentence_region_cos = sentence_region_cos.numpy()
78 | sentence_rank = sentence_rank.numpy()
79 | region_rank = region_rank.numpy()
80 | region_sentence_neighbor = region_sentence_neighbor.numpy()
81 | sentence_region_neighbor = sentence_region_neighbor.numpy()
82 | patch_sentence_neighbor = patch_sentence_neighbor.numpy()
83 | patch_sentence_neighbor_upsampled = patch_sentence_neighbor_upsampled.numpy()
84 |
85 | # IDs and colors, text wrap
86 | sentence_ids = list(range(1, n_sentences + 1))
87 | sentence_colors = list(mcolors.TABLEAU_COLORS.values())[:n_sentences] if n_sentences <= 10 else matplotlib.color_sequences['tab20'][:n_sentences]
88 | sentences = [twp.fill(s, 70) for s in sentences]
89 | cmap = mcolors.LinearSegmentedColormap.from_list("cmap_name", sentence_colors, N=n_sentences)
90 |
91 | # Plotting...
92 | fig = plt.figure(figsize=(12, 6))
93 | gs = gridspec.GridSpec(nrows=2, ncols=3, height_ratios=[1, 1])
94 |
95 | ax_boxes = fig.add_subplot(gs[0, 0])
96 | ax_boxes.set_xticks([])
97 | ax_boxes.set_yticks([])
98 | sentence_ids_array = np.array(sentence_ids) # (S)
99 | if multiboxes is not None:
100 | # (S x R)
101 | sentence_ids_array = sentence_ids_array[:, None].repeat(R, axis=1)
102 | boxes_with_ids = np.concatenate([boxes, sentence_ids_array[:, None] - 1], axis=1) if multiboxes is None \
103 | else np.concatenate([boxes, sentence_ids_array[:, :, None] - 1], axis=2)
104 |
105 | if multiboxes is not None:
106 | boxes_with_ids = boxes_with_ids.reshape(-1, 5)
107 | plot_img_with_bounding_boxes(ax_boxes, class_names=sentence_ids,
108 | img=img, target_list=boxes_with_ids, plot_pred=False,
109 | class_cmap=sentence_colors)
110 |
111 | boxes_upper_left = (boxes_with_ids[:, :2] - boxes_with_ids[:, 2:4] / 2) * img_shape[::-1]
112 | boxes_lower_right = (boxes_with_ids[:, :2] + boxes_with_ids[:, 2:4] / 2) * img_shape[::-1]
113 | ax_boxes.set_xlim(min(0, boxes_upper_left[:, 0].min()), max(img_shape[1], boxes_lower_right[:, 0].max()))
114 | ax_boxes.set_ylim(max(img_shape[0], boxes_lower_right[:, 1].max()), min(0, boxes_upper_left[:, 1].min()))
115 |
116 | ax_patch_neighbors = fig.add_subplot(gs[0, 1])
117 | ax_patch_neighbors.imshow(img, cmap='gray')
118 | ax_patch_neighbors.imshow(patch_sentence_neighbor_upsampled, cmap=cmap, vmin=1, vmax=n_sentences, alpha=0.6)
119 | ax_patch_neighbors.set_xticks([])
120 | ax_patch_neighbors.set_yticks([])
121 | for y, neighbors in enumerate(patch_sentence_neighbor):
122 | y = (y + 0.5) * upsample_factor[0] - 1
123 | for x, neighbor in enumerate(neighbors):
124 | x = (x + 0.5) * upsample_factor[1] - 1
125 | ax_patch_neighbors.text(x, y, neighbor, ha='center', va='center', color='w')
126 |
127 | ax_sent_barplots = fig.add_subplot(gs[0, -1])
128 | sentence_ids = np.array(sentence_ids)
129 | bar_pos = np.linspace(-0.3, 0.3, n_sentences)
130 |
131 | ax_sent_barplots.bar(bar_pos, sentence_region_cos, color=sentence_colors, width=0.6 / n_sentences)
132 | ax_sent_barplots.set_xticks([0])
133 | ax_sent_barplots.set_xticklabels(['cos'])
134 |
135 | tab_data = [
136 | *zip(sentence_ids, sentences, sentence_rank, region_rank, sentence_region_l2.round(2), sentence_region_cos.round(2), region_sentence_neighbor, sentence_region_neighbor)
137 | ]
138 | tab_colors = [
139 | [col, col, col, col, col, col, sentence_colors[neighbor_s - 1], sentence_colors[neighbor_r - 1]]
140 | for col, neighbor_s, neighbor_r in zip(sentence_colors, region_sentence_neighbor, sentence_region_neighbor)
141 | ]
142 |
143 | ax_table = fig.add_subplot(gs[1, :])
144 | ax_table.axis('off')
145 | ax_table.axis('tight')
146 | table = ax_table.table(
147 | cellText=tab_data,
148 | cellLoc='center',
149 | colLoc='center',
150 | loc='center',
151 | cellColours=tab_colors,
152 | colLabels=['ID', 'Sentence', 'Rank s', 'Rank r', 'L2', 'cos', '1-NN s', '1-NN r'],
153 | colWidths=[0.04, 0.7, 0.05, 0.05, 0.05, 0.05, 0.05, 0.05]
154 | )
155 | table.auto_set_font_size(False)
156 | table.set_fontsize(10)
157 | table.scale(1, 1.5)
158 | n_cols = 8
159 |
160 | # Apply custom cell renderer to each cell
161 | for i in range(len(sentences) + 1):
162 | cell = table.get_celld()[(i, 1)]
163 | cell.set_text_props(horizontalalignment='left')
164 | cell.PAD = 0.01
165 | lines = len(cell.get_text().get_text().splitlines())
166 | for j in range(n_cols):
167 | cell = table.get_celld()[(i, j)]
168 | cell.set_height(lines * 0.11)
169 |
170 | return fig
171 |
172 |
--------------------------------------------------------------------------------
/src/loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from torchvision import transforms
5 |
6 | from chexzero.chexzero.clip import tokenize
7 | from models.clip import CLIPEvaluator
8 |
9 |
10 | class CLIPLoss(nn.Module):
11 | def __init__(self, args, pos_prompt=None, neg_prompt=None, clip_model=None, clip_preprocess=None):
12 | super().__init__()
13 | self.clip_model = CLIPEvaluator(args, model=clip_model, preprocess=clip_preprocess)
14 | self.pos_prompt = pos_prompt
15 | self.neg_prompt = neg_prompt
16 |
17 | # print(f"CLIPLoss: Positive prompt {pos_prompt}, Negative prompt {neg_prompt}")
18 |
19 | def forward_images(self, original, edited, preprocess_images):
20 | if preprocess_images:
21 | original = self.clip_model.clip_transform_for_tensor(original)
22 | edited = self.clip_model.clip_transform_for_tensor(edited)
23 |
24 | original_features = self.clip_model.model.encode_image(original)
25 | edited_features = self.clip_model.model.encode_image(edited)
26 |
27 | return original_features, edited_features
28 |
29 | def forward_prompts(self, pos_prompt=None, neg_prompt=None):
30 | if pos_prompt is None:
31 | pos_prompt = self.pos_prompt
32 | if neg_prompt is None:
33 | neg_prompt = self.neg_prompt
34 |
35 | pos_features = self.clip_model.encode_text(pos_prompt)
36 | neg_features = self.clip_model.encode_text(neg_prompt)
37 |
38 | return pos_features, neg_features
39 |
40 | def forward(self, generated_images, pos_prompt=None, neg_prompt=None, return_logits_per_image=False):
41 | if pos_prompt is None:
42 | pos_prompt = self.pos_prompt
43 | if neg_prompt is None:
44 | neg_prompt = self.neg_prompt
45 |
46 | # Get logits
47 | logits_per_image, _ = self.clip_model.score(generated_images, [pos_prompt, neg_prompt])
48 |
49 | if return_logits_per_image:
50 | return logits_per_image
51 |
52 | # Apply softmax to get probabilities
53 | probs = F.softmax(logits_per_image, dim=1)
54 |
55 | # Get probabilities for positive and negative prompts
56 | pos_prob = probs[:, 0] # Positive prompt is first
57 | neg_prob = probs[:, 1] # Negative prompt is second
58 |
59 | # print("Sample probabilities:")
60 | # print(probs)
61 |
62 | eps = 1e-6
63 | # neg_prob = torch.clamp(neg_prob, min=eps)
64 | # pos_prob = torch.clamp(pos_prob, min=eps)
65 |
66 | # loss = -torch.log(neg_prob + eps) + torch.log(pos_prob + eps)
67 | probs = torch.clamp(probs, min=eps)
68 | target = torch.zeros_like(probs)
69 | target[:, 1] = 1.0
70 | loss = F.binary_cross_entropy_with_logits(probs, target)
71 |
72 | # print("Sample loss:", loss, loss.mean(), loss.shape, loss.requires_grad, torch.any(torch.isnan(loss)))
73 |
74 | return loss, probs
75 |
76 |
77 | class DirectionLoss(nn.Module):
78 |
79 | def __init__(self, loss_type="mse"):
80 | super(DirectionLoss, self).__init__()
81 |
82 | self.loss_type = loss_type
83 |
84 | self.loss_func = {"mse": torch.nn.MSELoss, "cosine": torch.nn.CosineSimilarity, "mae": torch.nn.L1Loss}[
85 | loss_type
86 | ]()
87 |
88 | def forward(self, x, y):
89 | if self.loss_type == "cosine":
90 | return 1.0 - self.loss_func(x, y)
91 |
92 | return self.loss_func(x, y)
93 |
94 |
95 | class CLIPDirectionLoss(nn.Module):
96 |
97 | def __init__(self, args, pos_prompt=None, neg_prompt=None, clip_model=None, clip_preprocess=None):
98 | super().__init__()
99 |
100 | self.clip_model = CLIPEvaluator(args, model=clip_model, preprocess=clip_preprocess)
101 | self.pos_prompt = pos_prompt
102 | self.neg_prompt = neg_prompt
103 |
104 | self.device = args.device
105 |
106 | self.loss = DirectionLoss(loss_type="cosine")
107 |
108 | self.target_direction = self.compute_text_direction()
109 |
110 | print(f"CLIPLoss: Positive prompt {pos_prompt}, Negative prompt {neg_prompt}")
111 |
112 | def encode_text(self, texts):
113 | tokenized_text = tokenize(texts).to(self.device)
114 | text_features = self.clip_model.model.encode_text(tokenized_text)
115 |
116 | return text_features
117 |
118 | def encode_image(self, images):
119 | images = self.clip_model.clip_transform_for_tensor(images)
120 |
121 | image_features = self.clip_model.model.encode_image(images)
122 |
123 | return image_features
124 |
125 | def get_text_features(self, texts):
126 | text_features = self.encode_text(texts).detach()
127 |
128 | text_features = text_features / text_features.norm(dim=-1, keepdim=True)
129 |
130 | return text_features
131 |
132 | def get_image_features(self, images):
133 | image_features = self.encode_image(images)
134 |
135 | image_features = image_features / image_features.norm(dim=-1, keepdim=True)
136 |
137 | return image_features
138 |
139 | def compute_text_direction(self, source_class=None, target_class=None) -> torch.Tensor:
140 | print("Computing text direction")
141 | if source_class is None:
142 | source_class = self.pos_prompt
143 |
144 | if target_class is None:
145 | target_class = self.neg_prompt
146 |
147 | text_features = self.get_text_features([source_class, target_class])
148 | source_features, target_features = text_features[0], text_features[1]
149 |
150 | text_direction = target_features - source_features # .mean(axis=0, keepdim=True)
151 | text_direction /= text_direction.norm(dim=-1, keepdim=True)
152 | print("Text direction", text_direction.shape)
153 |
154 | return text_direction
155 |
156 | def forward(self, src_images, target_images, pos_prompt=None, neg_prompt=None):
157 | if pos_prompt is None:
158 | pos_prompt = self.pos_prompt
159 | if neg_prompt is None:
160 | neg_prompt = self.neg_prompt
161 |
162 | if self.target_direction is None:
163 | self.target_direction = self.compute_text_direction()
164 |
165 | src_encoding = self.get_image_features(src_images)
166 | target_encoding = self.get_image_features(target_images)
167 |
168 | edit_direction = target_encoding - src_encoding
169 | edit_direction /= edit_direction.clone().norm(dim=-1, keepdim=True) + 1e-7
170 |
171 | # calculate the distance between src_encoding and target_encoding and it should be maximized
172 |
173 | return self.loss(edit_direction, self.target_direction).mean(), F.mse_loss(src_encoding, target_encoding)
174 |
175 |
176 | def custom_cross_entropy(pred_logits, target_logits):
177 | pred_probs = F.softmax(pred_logits, dim=-1)
178 |
179 | # Create a mask for non-inf values in target_logits
180 | mask = ~torch.isinf(target_logits)
181 |
182 | # Normalize the non-inf part of target_logits
183 | target_probs = torch.where(mask, F.softmax(target_logits, dim=-1), torch.zeros_like(target_logits))
184 |
185 | # Compute cross-entropy loss only for non-inf positions
186 | loss = -torch.sum(target_probs * torch.log(pred_probs + 1e-8), dim=-1)
187 |
188 | # Add a large penalty for predicting non-zero probability where target is -inf
189 | inf_penalty = torch.where(~mask, pred_probs, torch.zeros_like(pred_probs)).sum(dim=-1) * 1e2
190 |
191 | return (loss + inf_penalty).mean()
192 |
193 |
194 | def kl_loss(pred_logits, target_logits):
195 | pred_log_softmax = F.log_softmax(pred_logits, dim=-1)
196 | target_softmax = F.softmax(target_logits, dim=-1)
197 |
198 | kl_loss = F.kl_div(pred_log_softmax, target_softmax, reduction="batchmean")
199 | return kl_loss
200 |
201 |
202 | def masked_mse_loss(pred, target, mask):
203 | """Take in the mask for the region to be inpainted"""
204 | if pred.shape[-1] != mask.shape[-1]:
205 | mask = F.interpolate(mask, size=(pred.shape[-1], pred.shape[-1]), mode="bilinear", align_corners=False)
206 |
207 | return torch.mean(((pred - target) * mask) ** 2)
208 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/encoders/modules.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from functools import partial
4 | import clip
5 | from einops import repeat
6 | import kornia
7 |
8 |
9 | from cheff.ldm.modules.x_transformer import Encoder, TransformerWrapper # TODO: can we directly rely on lucidrains code and simply add this as a reuirement? --> test
10 |
11 |
12 | class AbstractEncoder(nn.Module):
13 | def __init__(self):
14 | super().__init__()
15 |
16 | def encode(self, *args, **kwargs):
17 | raise NotImplementedError
18 |
19 |
20 |
21 | class ClassEmbedder(nn.Module):
22 | def __init__(self, embed_dim, n_classes=1000, key='class'):
23 | super().__init__()
24 | self.key = key
25 | self.embedding = nn.Embedding(n_classes, embed_dim)
26 |
27 | def forward(self, batch, key=None):
28 | if key is None:
29 | key = self.key
30 | # this is for use in crossattn
31 | c = batch[key][:, None]
32 | c = self.embedding(c)
33 | return c
34 |
35 |
36 | class MultiClassEmbedder(nn.Module):
37 | def __init__(self, embed_dim, n_classes=1000, key='class'):
38 | super().__init__()
39 | self.key = key
40 | self.embedding = nn.Sequential(
41 | nn.Linear(in_features=n_classes, out_features=embed_dim),
42 | nn.GELU(),
43 | nn.Linear(in_features=embed_dim, out_features=embed_dim),
44 | nn.GELU(),
45 | nn.Linear(in_features=embed_dim, out_features=embed_dim),
46 | )
47 |
48 | def forward(self, batch, key=None):
49 | if key is None:
50 | key = self.key
51 | c = batch[key]
52 | c = self.embedding(c)
53 | c = c.unsqueeze(1)
54 | return c
55 |
56 |
57 | class TransformerEmbedder(AbstractEncoder):
58 | """Some transformer encoder layers"""
59 | def __init__(self, n_embed, n_layer, vocab_size, max_seq_len=150, device="cuda"):
60 | super().__init__()
61 | self.device = device
62 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
63 | attn_layers=Encoder(dim=n_embed, depth=n_layer))
64 |
65 | def forward(self, tokens):
66 | tokens = tokens.to(self.device) # meh
67 | z = self.transformer(tokens, return_embeddings=True)
68 | return z
69 |
70 | def encode(self, x):
71 | return self(x)
72 |
73 |
74 | class BERTTokenizer(AbstractEncoder):
75 | """ Uses a pretrained BERT tokenizer by huggingface. Vocab size: 30522 (?)"""
76 | def __init__(self, device="cuda", vq_interface=True, max_length=150):
77 | super().__init__()
78 | from transformers import BertTokenizerFast # TODO: add to reuquirements
79 | self.tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
80 | self.device = device
81 | self.vq_interface = vq_interface
82 | self.max_length = max_length
83 |
84 | def forward(self, text):
85 | batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
86 | return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
87 | tokens = batch_encoding["input_ids"].to(self.device)
88 | return tokens
89 |
90 | @torch.no_grad()
91 | def encode(self, text):
92 | tokens = self(text)
93 | if not self.vq_interface:
94 | return tokens
95 | return None, None, [None, None, tokens]
96 |
97 | def decode(self, text):
98 | return text
99 |
100 |
101 | class BERTEmbedder(AbstractEncoder):
102 | """Uses the BERT tokenizr model and add some transformer encoder layers"""
103 | def __init__(self, n_embed, n_layer, vocab_size=30522, max_seq_len=150,
104 | device="cuda",use_tokenizer=True, embedding_dropout=0.0):
105 | super().__init__()
106 | self.use_tknz_fn = use_tokenizer
107 | if self.use_tknz_fn:
108 | self.tknz_fn = BERTTokenizer(vq_interface=False, max_length=max_seq_len, device=device)
109 | self.device = device
110 | self.transformer = TransformerWrapper(num_tokens=vocab_size, max_seq_len=max_seq_len,
111 | attn_layers=Encoder(dim=n_embed, depth=n_layer),
112 | emb_dropout=embedding_dropout)
113 |
114 | def forward(self, text):
115 | if self.use_tknz_fn:
116 | tokens = self.tknz_fn(text)#.to(self.device)
117 | else:
118 | tokens = text
119 | z = self.transformer(tokens, return_embeddings=True)
120 | return z
121 |
122 | def encode(self, text):
123 | # output of length 77
124 | return self(text)
125 |
126 |
127 | class SpatialRescaler(nn.Module):
128 | def __init__(self,
129 | n_stages=1,
130 | method='bilinear',
131 | multiplier=0.5,
132 | in_channels=3,
133 | out_channels=None,
134 | bias=False):
135 | super().__init__()
136 | self.n_stages = n_stages
137 | assert self.n_stages >= 0
138 | assert method in ['nearest','linear','bilinear','trilinear','bicubic','area']
139 | self.multiplier = multiplier
140 | self.interpolator = partial(torch.nn.functional.interpolate, mode=method)
141 | self.remap_output = out_channels is not None
142 | if self.remap_output:
143 | print(f'Spatial Rescaler mapping from {in_channels} to {out_channels} channels after resizing.')
144 | self.channel_mapper = nn.Conv2d(in_channels,out_channels,1,bias=bias)
145 |
146 | def forward(self,x):
147 | for stage in range(self.n_stages):
148 | x = self.interpolator(x, scale_factor=self.multiplier)
149 |
150 |
151 | if self.remap_output:
152 | x = self.channel_mapper(x)
153 | return x
154 |
155 | def encode(self, x):
156 | return self(x)
157 |
158 |
159 | class FrozenCLIPTextEmbedder(nn.Module):
160 | """
161 | Uses the CLIP transformer encoder for text.
162 | """
163 | def __init__(self, version='ViT-L/14', device="cuda", max_length=77, n_repeat=1, normalize=True):
164 | super().__init__()
165 | self.model, _ = clip.load(version, jit=False, device="cpu")
166 | self.device = device
167 | self.max_length = max_length
168 | self.n_repeat = n_repeat
169 | self.normalize = normalize
170 |
171 | def freeze(self):
172 | self.model = self.model.eval()
173 | for param in self.parameters():
174 | param.requires_grad = False
175 |
176 | def forward(self, text):
177 | tokens = clip.tokenize(text).to(self.device)
178 | z = self.model.encode_text(tokens)
179 | if self.normalize:
180 | z = z / torch.linalg.norm(z, dim=1, keepdim=True)
181 | return z
182 |
183 | def encode(self, text):
184 | z = self(text)
185 | if z.ndim==2:
186 | z = z[:, None, :]
187 | z = repeat(z, 'b 1 d -> b k d', k=self.n_repeat)
188 | return z
189 |
190 |
191 | class FrozenClipImageEmbedder(nn.Module):
192 | """
193 | Uses the CLIP image encoder.
194 | """
195 | def __init__(
196 | self,
197 | model,
198 | jit=False,
199 | device='cuda' if torch.cuda.is_available() else 'cpu',
200 | antialias=False,
201 | ):
202 | super().__init__()
203 | self.model, _ = clip.load(name=model, device=device, jit=jit)
204 |
205 | self.antialias = antialias
206 |
207 | self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
208 | self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
209 |
210 | def preprocess(self, x):
211 | # normalize to [0,1]
212 | x = kornia.geometry.resize(x, (224, 224),
213 | interpolation='bicubic',align_corners=True,
214 | antialias=self.antialias)
215 | x = (x + 1.) / 2.
216 | # renormalize according to clip
217 | x = kornia.enhance.normalize(x, self.mean, self.std)
218 | return x
219 |
220 | def forward(self, x):
221 | # x is assumed to be in range [-1,1]
222 | return self.model.encode_image(self.preprocess(x))
223 |
224 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/losses/vqperceptual.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | import torch.nn.functional as F
4 | from einops import repeat
5 |
6 | from taming.modules.discriminator.model import NLayerDiscriminator, weights_init
7 | from taming.modules.losses.lpips import LPIPS
8 | from taming.modules.losses.vqperceptual import hinge_d_loss, vanilla_d_loss
9 |
10 |
11 | def hinge_d_loss_with_exemplar_weights(logits_real, logits_fake, weights):
12 | assert weights.shape[0] == logits_real.shape[0] == logits_fake.shape[0]
13 | loss_real = torch.mean(F.relu(1. - logits_real), dim=[1,2,3])
14 | loss_fake = torch.mean(F.relu(1. + logits_fake), dim=[1,2,3])
15 | loss_real = (weights * loss_real).sum() / weights.sum()
16 | loss_fake = (weights * loss_fake).sum() / weights.sum()
17 | d_loss = 0.5 * (loss_real + loss_fake)
18 | return d_loss
19 |
20 | def adopt_weight(weight, global_step, threshold=0, value=0.):
21 | if global_step < threshold:
22 | weight = value
23 | return weight
24 |
25 |
26 | def measure_perplexity(predicted_indices, n_embed):
27 | # src: https://github.com/karpathy/deep-vector-quantization/blob/main/model.py
28 | # eval cluster perplexity. when perplexity == num_embeddings then all clusters are used exactly equally
29 | encodings = F.one_hot(predicted_indices, n_embed).float().reshape(-1, n_embed)
30 | avg_probs = encodings.mean(0)
31 | perplexity = (-(avg_probs * torch.log(avg_probs + 1e-10)).sum()).exp()
32 | cluster_use = torch.sum(avg_probs > 0)
33 | return perplexity, cluster_use
34 |
35 | def l1(x, y):
36 | return torch.abs(x-y)
37 |
38 |
39 | def l2(x, y):
40 | return torch.pow((x-y), 2)
41 |
42 |
43 | class VQLPIPSWithDiscriminator(nn.Module):
44 | def __init__(self, disc_start, codebook_weight=1.0, pixelloss_weight=1.0,
45 | disc_num_layers=3, disc_in_channels=3, disc_factor=1.0, disc_weight=1.0,
46 | perceptual_weight=1.0, use_actnorm=False, disc_conditional=False,
47 | disc_ndf=64, disc_loss="hinge", n_classes=None, perceptual_loss="lpips",
48 | pixel_loss="l1"):
49 | super().__init__()
50 | assert disc_loss in ["hinge", "vanilla"]
51 | assert perceptual_loss in ["lpips", "clips", "dists"]
52 | assert pixel_loss in ["l1", "l2"]
53 | self.codebook_weight = codebook_weight
54 | self.pixel_weight = pixelloss_weight
55 | if perceptual_loss == "lpips":
56 | print(f"{self.__class__.__name__}: Running with LPIPS.")
57 | self.perceptual_loss = LPIPS().eval()
58 | else:
59 | raise ValueError(f"Unknown perceptual loss: >> {perceptual_loss} <<")
60 | self.perceptual_weight = perceptual_weight
61 |
62 | if pixel_loss == "l1":
63 | self.pixel_loss = l1
64 | else:
65 | self.pixel_loss = l2
66 |
67 | self.discriminator = NLayerDiscriminator(input_nc=disc_in_channels,
68 | n_layers=disc_num_layers,
69 | use_actnorm=use_actnorm,
70 | ndf=disc_ndf
71 | ).apply(weights_init)
72 | self.discriminator_iter_start = disc_start
73 | if disc_loss == "hinge":
74 | self.disc_loss = hinge_d_loss
75 | elif disc_loss == "vanilla":
76 | self.disc_loss = vanilla_d_loss
77 | else:
78 | raise ValueError(f"Unknown GAN loss '{disc_loss}'.")
79 | print(f"VQLPIPSWithDiscriminator running with {disc_loss} loss.")
80 | self.disc_factor = disc_factor
81 | self.discriminator_weight = disc_weight
82 | self.disc_conditional = disc_conditional
83 | self.n_classes = n_classes
84 |
85 | def calculate_adaptive_weight(self, nll_loss, g_loss, last_layer=None):
86 | if last_layer is not None:
87 | nll_grads = torch.autograd.grad(nll_loss, last_layer, retain_graph=True)[0]
88 | g_grads = torch.autograd.grad(g_loss, last_layer, retain_graph=True)[0]
89 | else:
90 | nll_grads = torch.autograd.grad(nll_loss, self.last_layer[0], retain_graph=True)[0]
91 | g_grads = torch.autograd.grad(g_loss, self.last_layer[0], retain_graph=True)[0]
92 |
93 | d_weight = torch.norm(nll_grads) / (torch.norm(g_grads) + 1e-4)
94 | d_weight = torch.clamp(d_weight, 0.0, 1e4).detach()
95 | d_weight = d_weight * self.discriminator_weight
96 | return d_weight
97 |
98 | def forward(self, codebook_loss, inputs, reconstructions, optimizer_idx,
99 | global_step, last_layer=None, cond=None, split="train", predicted_indices=None):
100 | if not exists(codebook_loss):
101 | codebook_loss = torch.tensor([0.]).to(inputs.device)
102 | #rec_loss = torch.abs(inputs.contiguous() - reconstructions.contiguous())
103 | rec_loss = self.pixel_loss(inputs.contiguous(), reconstructions.contiguous())
104 | if self.perceptual_weight > 0:
105 | p_loss = self.perceptual_loss(inputs.contiguous(), reconstructions.contiguous())
106 | rec_loss = rec_loss + self.perceptual_weight * p_loss
107 | else:
108 | p_loss = torch.tensor([0.0])
109 |
110 | nll_loss = rec_loss
111 | #nll_loss = torch.sum(nll_loss) / nll_loss.shape[0]
112 | nll_loss = torch.mean(nll_loss)
113 |
114 | # now the GAN part
115 | if optimizer_idx == 0:
116 | # generator update
117 | if cond is None:
118 | assert not self.disc_conditional
119 | logits_fake = self.discriminator(reconstructions.contiguous())
120 | else:
121 | assert self.disc_conditional
122 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous(), cond), dim=1))
123 | g_loss = -torch.mean(logits_fake)
124 |
125 | try:
126 | d_weight = self.calculate_adaptive_weight(nll_loss, g_loss, last_layer=last_layer)
127 | except RuntimeError:
128 | assert not self.training
129 | d_weight = torch.tensor(0.0)
130 |
131 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
132 | loss = nll_loss + d_weight * disc_factor * g_loss + self.codebook_weight * codebook_loss.mean()
133 |
134 | log = {"{}/total_loss".format(split): loss.clone().detach().mean(),
135 | "{}/quant_loss".format(split): codebook_loss.detach().mean(),
136 | "{}/nll_loss".format(split): nll_loss.detach().mean(),
137 | "{}/rec_loss".format(split): rec_loss.detach().mean(),
138 | "{}/p_loss".format(split): p_loss.detach().mean(),
139 | "{}/d_weight".format(split): d_weight.detach(),
140 | "{}/disc_factor".format(split): torch.tensor(disc_factor),
141 | "{}/g_loss".format(split): g_loss.detach().mean(),
142 | }
143 | if predicted_indices is not None:
144 | assert self.n_classes is not None
145 | with torch.no_grad():
146 | perplexity, cluster_usage = measure_perplexity(predicted_indices, self.n_classes)
147 | log[f"{split}/perplexity"] = perplexity
148 | log[f"{split}/cluster_usage"] = cluster_usage
149 | return loss, log
150 |
151 | if optimizer_idx == 1:
152 | # second pass for discriminator update
153 | if cond is None:
154 | logits_real = self.discriminator(inputs.contiguous().detach())
155 | logits_fake = self.discriminator(reconstructions.contiguous().detach())
156 | else:
157 | logits_real = self.discriminator(torch.cat((inputs.contiguous().detach(), cond), dim=1))
158 | logits_fake = self.discriminator(torch.cat((reconstructions.contiguous().detach(), cond), dim=1))
159 |
160 | disc_factor = adopt_weight(self.disc_factor, global_step, threshold=self.discriminator_iter_start)
161 | d_loss = disc_factor * self.disc_loss(logits_real, logits_fake)
162 |
163 | log = {"{}/disc_loss".format(split): d_loss.clone().detach().mean(),
164 | "{}/logits_real".format(split): logits_real.detach().mean(),
165 | "{}/logits_fake".format(split): logits_fake.detach().mean()
166 | }
167 | return d_loss, log
168 |
--------------------------------------------------------------------------------
/chexzero/chexzero/eval.py:
--------------------------------------------------------------------------------
1 | import subprocess
2 | import numpy as np
3 | import os
4 | import pandas as pd
5 | from PIL import Image
6 | import matplotlib.pyplot as plt
7 | from typing import List, Callable
8 |
9 | import torch
10 | from torch.utils import data
11 | from tqdm.notebook import tqdm
12 | import torch.nn as nn
13 | from torchvision.transforms import Compose, Normalize, Resize
14 |
15 | import sklearn
16 | from sklearn.metrics import matthews_corrcoef, confusion_matrix, accuracy_score, auc, roc_auc_score, roc_curve, classification_report
17 | from sklearn.metrics import precision_recall_curve, f1_score
18 | from sklearn.metrics import average_precision_score
19 | from sklearn.utils import resample
20 |
21 | import scipy
22 | import scipy.stats
23 |
24 | import sys
25 | sys.path.append('../..')
26 |
27 | import chexzero.clip
28 | from chexzero.model import CLIP
29 |
30 | def compute_mean(stats, is_df=True):
31 | spec_labels = ["Atelectasis", "Cardiomegaly", "Consolidation", "Edema", "Pleural Effusion"]
32 | if is_df:
33 | spec_df = stats[spec_labels]
34 | res = np.mean(spec_df.iloc[0])
35 | else:
36 | # cis is df, within bootstrap
37 | vals = [stats[spec_label][0] for spec_label in spec_labels]
38 | res = np.mean(vals)
39 | return res
40 |
41 | def accuracy(output, target, topk=(1,)):
42 | pred = output.topk(max(topk), 1, True, True)[1].t()
43 | print('pred: ', pred)
44 |
45 | expand = target.expand(-1, max(topk))
46 | print('expand: ', expand)
47 |
48 | correct = pred.eq(expand)
49 | print('correct: ', correct)
50 | return [float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) for k in topk]
51 |
52 | def sigmoid(x):
53 | z = 1/(1 + np.exp(-x))
54 | return z
55 |
56 | ''' ROC CURVE '''
57 | def plot_roc(y_pred, y_true, roc_name, plot=False):
58 | # given the test_ground_truth, and test_predictions
59 | fpr, tpr, thresholds = roc_curve(y_true, y_pred)
60 |
61 | roc_auc = auc(fpr, tpr)
62 |
63 | if plot:
64 | plt.figure(dpi=100)
65 | plt.title(roc_name)
66 | plt.plot(fpr, tpr, 'b', label = 'AUC = %0.2f' % roc_auc)
67 | plt.legend(loc = 'lower right')
68 | plt.plot([0, 1], [0, 1],'r--')
69 | plt.xlim([0, 1])
70 | plt.ylim([0, 1])
71 | plt.ylabel('True Positive Rate')
72 | plt.xlabel('False Positive Rate')
73 | plt.show()
74 | return fpr, tpr, thresholds, roc_auc
75 |
76 | # J = TP/(TP+FN) + TN/(TN+FP) - 1 = tpr - fpr
77 | def choose_operating_point(fpr, tpr, thresholds):
78 | sens = 0
79 | spec = 0
80 | J = 0
81 | for _fpr, _tpr in zip(fpr, tpr):
82 | if _tpr - _fpr > J:
83 | sens = _tpr
84 | spec = 1-_fpr
85 | J = _tpr - _fpr
86 | return sens, spec
87 |
88 | ''' PRECISION-RECALL CURVE '''
89 | def plot_pr(y_pred, y_true, pr_name, plot=False):
90 | precision, recall, thresholds = precision_recall_curve(y_true, y_pred)
91 | pr_auc = auc(recall, precision)
92 | # plot the precision-recall curves
93 | baseline = len(y_true[y_true==1]) / len(y_true)
94 |
95 | if plot:
96 | plt.figure(dpi=20)
97 | plt.title(pr_name)
98 | plt.plot(recall, precision, 'b', label='AUC = %0.2f' % pr_auc)
99 | # axis labels
100 | plt.legend(loc = 'lower right')
101 | plt.plot([0, 1], [baseline, baseline],'r--')
102 | plt.xlim([0, 1])
103 | plt.ylim([0, 1])
104 | plt.xlabel('Recall')
105 | plt.ylabel('Precision')
106 | # show the plot
107 | plt.show()
108 | return precision, recall, thresholds
109 |
110 | def evaluate(y_pred, y_true, cxr_labels,
111 | roc_name='Receiver Operating Characteristic', pr_name='Precision-Recall Curve', label_idx_map=None):
112 |
113 | '''
114 | We expect `y_pred` and `y_true` to be numpy arrays, both of shape (num_samples, num_classes)
115 |
116 | `y_pred` is a numpy array consisting of probability scores with all values in range 0-1.
117 |
118 | `y_true` is a numpy array consisting of binary values representing if a class is present in
119 | the cxr.
120 |
121 | This function provides all relevant evaluation information, ROC, AUROC, Sensitivity, Specificity,
122 | PR-Curve, Precision, Recall for each class.
123 | '''
124 | import warnings
125 | warnings.filterwarnings('ignore')
126 |
127 | num_classes = y_pred.shape[-1] # number of total labels
128 |
129 | dataframes = []
130 | for i in range(num_classes):
131 | # print('{}.'.format(cxr_labels[i]))
132 |
133 | if label_idx_map is None:
134 | y_pred_i = y_pred[:, i] # (num_samples,)
135 | y_true_i = y_true[:, i] # (num_samples,)
136 |
137 | else:
138 | y_pred_i = y_pred[:, i] # (num_samples,)
139 |
140 | true_index = label_idx_map[cxr_labels[i]]
141 | y_true_i = y_true[:, true_index] # (num_samples,)
142 |
143 | cxr_label = cxr_labels[i]
144 |
145 | ''' ROC CURVE '''
146 | roc_name = cxr_label + ' ROC Curve'
147 | fpr, tpr, thresholds, roc_auc = plot_roc(y_pred_i, y_true_i, roc_name)
148 |
149 | sens, spec = choose_operating_point(fpr, tpr, thresholds)
150 |
151 | results = [[roc_auc]]
152 | df = pd.DataFrame(results, columns=[cxr_label+'_auc'])
153 | dataframes.append(df)
154 |
155 | ''' PRECISION-RECALL CURVE '''
156 | pr_name = cxr_label + ' Precision-Recall Curve'
157 | precision, recall, thresholds = plot_pr(y_pred_i, y_true_i, pr_name)
158 |
159 | dfs = pd.concat(dataframes, axis=1)
160 | return dfs
161 |
162 | ''' Bootstrap and Confidence Intervals '''
163 | def compute_cis(data, confidence_level=0.05):
164 | """
165 | FUNCTION: compute_cis
166 | ------------------------------------------------------
167 | Given a Pandas dataframe of (n, labels), return another
168 | Pandas dataframe that is (3, labels).
169 |
170 | Each row is lower bound, mean, upper bound of a confidence
171 | interval with `confidence`.
172 |
173 | Args:
174 | * data - Pandas Dataframe, of shape (num_bootstrap_samples, num_labels)
175 | * confidence_level (optional) - confidence level of interval
176 |
177 | Returns:
178 | * Pandas Dataframe, of shape (3, labels), representing mean, lower, upper
179 | """
180 | data_columns = list(data)
181 | intervals = []
182 | for i in data_columns:
183 | series = data[i]
184 | sorted_perfs = series.sort_values()
185 | lower_index = int(confidence_level/2 * len(sorted_perfs)) - 1
186 | upper_index = int((1 - confidence_level/2) * len(sorted_perfs)) - 1
187 | lower = sorted_perfs.iloc[lower_index].round(4)
188 | upper = sorted_perfs.iloc[upper_index].round(4)
189 | mean = round(sorted_perfs.mean(), 4)
190 | interval = pd.DataFrame({i : [mean, lower, upper]})
191 | intervals.append(interval)
192 | intervals_df = pd.concat(intervals, axis=1)
193 | intervals_df.index = ['mean', 'lower', 'upper']
194 | return intervals_df
195 |
196 | def bootstrap(y_pred, y_true, cxr_labels, n_samples=1000, label_idx_map=None):
197 | '''
198 | This function will randomly sample with replacement
199 | from y_pred and y_true then evaluate `n` times
200 | and obtain AUROC scores for each.
201 |
202 | You can specify the number of samples that should be
203 | used with the `n_samples` parameter.
204 |
205 | Confidence intervals will be generated from each
206 | of the samples.
207 |
208 | Note:
209 | * n_total_labels >= n_cxr_labels
210 | `n_total_labels` is greater iff alternative labels are being tested
211 | '''
212 | np.random.seed(97)
213 | y_pred # (500, n_total_labels)
214 | y_true # (500, n_cxr_labels)
215 |
216 | idx = np.arange(len(y_true))
217 |
218 | boot_stats = []
219 | for i in tqdm(range(n_samples)):
220 | sample = resample(idx, replace=True, random_state=i)
221 | y_pred_sample = y_pred[sample]
222 | y_true_sample = y_true[sample]
223 |
224 | sample_stats = evaluate(y_pred_sample, y_true_sample, cxr_labels, label_idx_map=label_idx_map)
225 | boot_stats.append(sample_stats)
226 |
227 | boot_stats = pd.concat(boot_stats) # pandas array of evaluations for each sample
228 | return boot_stats, compute_cis(boot_stats)
229 |
--------------------------------------------------------------------------------
/chexzero/chexzero/clip.py:
--------------------------------------------------------------------------------
1 | """
2 | MIT License
3 |
4 | Copyright (c) 2021 OpenAI
5 |
6 | Permission is hereby granted, free of charge, to any person obtaining a copy
7 | of this software and associated documentation files (the "Software"), to deal
8 | in the Software without restriction, including without limitation the rights
9 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10 | copies of the Software, and to permit persons to whom the Software is
11 | furnished to do so, subject to the following conditions:
12 |
13 | The above copyright notice and this permission notice shall be included in all
14 | copies or substantial portions of the Software.
15 |
16 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 | SOFTWARE.
23 |
24 | """
25 |
26 | import hashlib
27 | import logging
28 | import os
29 | import urllib
30 | import warnings
31 | from typing import Union, List
32 |
33 | import torch
34 | from PIL import Image
35 | from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
36 | from tqdm import tqdm
37 |
38 | from chexzero.chexzero.model import build_model
39 | from chexzero.chexzero.simple_tokenizer import SimpleTokenizer as _Tokenizer
40 |
41 |
42 | __all__ = ["available_models", "load", "tokenize"]
43 | _tokenizer = _Tokenizer()
44 |
45 | _MODELS = {
46 | "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
47 | "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
48 | "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
49 | "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
50 | }
51 |
52 | log = logging.getLogger(__name__)
53 |
54 |
55 | def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
56 | os.makedirs(root, exist_ok=True)
57 | filename = os.path.basename(url)
58 |
59 | expected_sha256 = url.split("/")[-2]
60 | download_target = os.path.join(root, filename)
61 |
62 | if os.path.exists(download_target) and not os.path.isfile(download_target):
63 | raise RuntimeError(f"{download_target} exists and is not a regular file")
64 |
65 | if os.path.isfile(download_target):
66 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
67 | return download_target
68 | else:
69 | warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
70 |
71 | with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
72 | with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True) as loop:
73 | while True:
74 | buffer = source.read(8192)
75 | if not buffer:
76 | break
77 |
78 | output.write(buffer)
79 | loop.update(len(buffer))
80 |
81 | if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
82 | raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
83 |
84 | return download_target
85 |
86 |
87 | def _transform(n_px):
88 | return Compose(
89 | [
90 | Resize(n_px, interpolation=Image.BICUBIC),
91 | CenterCrop(n_px),
92 | lambda image: image.convert("RGB"),
93 | ToTensor(),
94 | Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
95 | ]
96 | )
97 |
98 |
99 | def available_models() -> List[str]:
100 | """Returns the names of available CLIP models"""
101 | return list(_MODELS.keys())
102 |
103 |
104 | def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
105 | """Load a CLIP model
106 |
107 | Parameters
108 | ----------
109 | name : str
110 | A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
111 |
112 | device : Union[str, torch.device]
113 | The device to put the loaded model
114 |
115 | jit : bool
116 | Whether to load the optimized JIT model (default) or more hackable non-JIT model.
117 |
118 | Returns
119 | -------
120 | model : torch.nn.Module
121 | The CLIP model
122 |
123 | preprocess : Callable[[PIL.Image], torch.Tensor]
124 | A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
125 | """
126 | if name in _MODELS:
127 | model_path = _download(_MODELS[name])
128 | elif os.path.isfile(name):
129 | model_path = name
130 | else:
131 | raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
132 |
133 | try:
134 | # loading JIT archive
135 | model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
136 | state_dict = None
137 | except RuntimeError:
138 | # loading saved state dict
139 | if jit:
140 | warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
141 | jit = False
142 | state_dict = torch.load(model_path, map_location="cpu")
143 |
144 | if not jit:
145 | model = build_model(state_dict or model.state_dict()).to(device)
146 | if str(device) == "cpu":
147 | model.float()
148 | return model, _transform(model.visual.input_resolution)
149 |
150 | # patch the device names
151 | device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
152 | device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
153 |
154 | def patch_device(module):
155 | graphs = [module.graph] if hasattr(module, "graph") else []
156 | if hasattr(module, "forward1"):
157 | graphs.append(module.forward1.graph)
158 |
159 | for graph in graphs:
160 | for node in graph.findAllNodes("prim::Constant"):
161 | if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
162 | node.copyAttributes(device_node)
163 |
164 | model.apply(patch_device)
165 | patch_device(model.encode_image)
166 | patch_device(model.encode_text)
167 |
168 | # patch dtype to float32 on CPU
169 | if str(device) == "cpu":
170 | float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
171 | float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
172 | float_node = float_input.node()
173 |
174 | def patch_float(module):
175 | graphs = [module.graph] if hasattr(module, "graph") else []
176 | if hasattr(module, "forward1"):
177 | graphs.append(module.forward1.graph)
178 |
179 | for graph in graphs:
180 | for node in graph.findAllNodes("aten::to"):
181 | inputs = list(node.inputs())
182 | for i in [1, 2]: # dtype can be the second or third argument to aten::to()
183 | if inputs[i].node()["value"] == 5:
184 | inputs[i].node().copyAttributes(float_node)
185 |
186 | model.apply(patch_float)
187 | patch_float(model.encode_image)
188 | patch_float(model.encode_text)
189 |
190 | model.float()
191 |
192 | return model, _transform(model.input_resolution.item())
193 |
194 |
195 | def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor:
196 | """
197 | Returns the tokenized representation of given input string(s)
198 |
199 | Parameters
200 | ----------
201 | texts : Union[str, List[str]]
202 | An input string or a list of input strings to tokenize
203 |
204 | context_length : int
205 | The context length to use; all CLIP models use 77 as the context length
206 |
207 | Returns
208 | -------
209 | A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length]
210 | """
211 | if isinstance(texts, str):
212 | texts = [texts]
213 |
214 | sot_token = _tokenizer.encoder["<|startoftext|>"]
215 | eot_token = _tokenizer.encoder["<|endoftext|>"]
216 | all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts]
217 | result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
218 |
219 | for i, tokens in enumerate(all_tokens):
220 | if len(tokens) > context_length:
221 | tokens = tokens[: context_length - 1] + [tokens[-1]]
222 | log.warning(f"Input {texts[i]} is too long for context length {context_length}")
223 | # raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
224 | result[i, : len(tokens)] = torch.tensor(tokens)
225 |
226 | return result
227 |
--------------------------------------------------------------------------------
/cheff/ldm/modules/attention.py:
--------------------------------------------------------------------------------
1 | from inspect import isfunction
2 | import math
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, einsum
6 | from einops import rearrange, repeat
7 |
8 | from cheff.ldm.modules.diffusionmodules.util import checkpoint
9 |
10 |
11 | def exists(val):
12 | return val is not None
13 |
14 |
15 | def uniq(arr):
16 | return{el: True for el in arr}.keys()
17 |
18 |
19 | def default(val, d):
20 | if exists(val):
21 | return val
22 | return d() if isfunction(d) else d
23 |
24 |
25 | def max_neg_value(t):
26 | return -torch.finfo(t.dtype).max
27 |
28 |
29 | def init_(tensor):
30 | dim = tensor.shape[-1]
31 | std = 1 / math.sqrt(dim)
32 | tensor.uniform_(-std, std)
33 | return tensor
34 |
35 |
36 | # feedforward
37 | class GEGLU(nn.Module):
38 | def __init__(self, dim_in, dim_out):
39 | super().__init__()
40 | self.proj = nn.Linear(dim_in, dim_out * 2)
41 |
42 | def forward(self, x):
43 | x, gate = self.proj(x).chunk(2, dim=-1)
44 | return x * F.gelu(gate)
45 |
46 |
47 | class FeedForward(nn.Module):
48 | def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.):
49 | super().__init__()
50 | inner_dim = int(dim * mult)
51 | dim_out = default(dim_out, dim)
52 | project_in = nn.Sequential(
53 | nn.Linear(dim, inner_dim),
54 | nn.GELU()
55 | ) if not glu else GEGLU(dim, inner_dim)
56 |
57 | self.net = nn.Sequential(
58 | project_in,
59 | nn.Dropout(dropout),
60 | nn.Linear(inner_dim, dim_out)
61 | )
62 |
63 | def forward(self, x):
64 | return self.net(x)
65 |
66 |
67 | def zero_module(module):
68 | """
69 | Zero out the parameters of a module and return it.
70 | """
71 | for p in module.parameters():
72 | p.detach().zero_()
73 | return module
74 |
75 |
76 | def Normalize(in_channels):
77 | return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
78 |
79 |
80 | class LinearAttention(nn.Module):
81 | def __init__(self, dim, heads=4, dim_head=32):
82 | super().__init__()
83 | self.heads = heads
84 | hidden_dim = dim_head * heads
85 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False)
86 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
87 |
88 | def forward(self, x):
89 | b, c, h, w = x.shape
90 | qkv = self.to_qkv(x)
91 | q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
92 | k = k.softmax(dim=-1)
93 | context = torch.einsum('bhdn,bhen->bhde', k, v)
94 | out = torch.einsum('bhde,bhdn->bhen', context, q)
95 | out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
96 | return self.to_out(out)
97 |
98 |
99 | class SpatialSelfAttention(nn.Module):
100 | def __init__(self, in_channels):
101 | super().__init__()
102 | self.in_channels = in_channels
103 |
104 | self.norm = Normalize(in_channels)
105 | self.q = torch.nn.Conv2d(in_channels,
106 | in_channels,
107 | kernel_size=1,
108 | stride=1,
109 | padding=0)
110 | self.k = torch.nn.Conv2d(in_channels,
111 | in_channels,
112 | kernel_size=1,
113 | stride=1,
114 | padding=0)
115 | self.v = torch.nn.Conv2d(in_channels,
116 | in_channels,
117 | kernel_size=1,
118 | stride=1,
119 | padding=0)
120 | self.proj_out = torch.nn.Conv2d(in_channels,
121 | in_channels,
122 | kernel_size=1,
123 | stride=1,
124 | padding=0)
125 |
126 | def forward(self, x):
127 | h_ = x
128 | h_ = self.norm(h_)
129 | q = self.q(h_)
130 | k = self.k(h_)
131 | v = self.v(h_)
132 |
133 | # compute attention
134 | b,c,h,w = q.shape
135 | q = rearrange(q, 'b c h w -> b (h w) c')
136 | k = rearrange(k, 'b c h w -> b c (h w)')
137 | w_ = torch.einsum('bij,bjk->bik', q, k)
138 |
139 | w_ = w_ * (int(c)**(-0.5))
140 | w_ = torch.nn.functional.softmax(w_, dim=2)
141 |
142 | # attend to values
143 | v = rearrange(v, 'b c h w -> b c (h w)')
144 | w_ = rearrange(w_, 'b i j -> b j i')
145 | h_ = torch.einsum('bij,bjk->bik', v, w_)
146 | h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h)
147 | h_ = self.proj_out(h_)
148 |
149 | return x+h_
150 |
151 |
152 | class CrossAttention(nn.Module):
153 | def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
154 | super().__init__()
155 | inner_dim = dim_head * heads
156 | context_dim = default(context_dim, query_dim)
157 |
158 | self.scale = dim_head ** -0.5
159 | self.heads = heads
160 |
161 | self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
162 | self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
163 | self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
164 |
165 | self.to_out = nn.Sequential(
166 | nn.Linear(inner_dim, query_dim),
167 | nn.Dropout(dropout)
168 | )
169 |
170 | def forward(self, x, context=None, mask=None):
171 | h = self.heads
172 |
173 | q = self.to_q(x)
174 | context = default(context, x)
175 | k = self.to_k(context)
176 | v = self.to_v(context)
177 |
178 | q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
179 |
180 | sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
181 |
182 | if exists(mask):
183 | mask = rearrange(mask, 'b ... -> b (...)')
184 | max_neg_value = -torch.finfo(sim.dtype).max
185 | mask = repeat(mask, 'b j -> (b h) () j', h=h)
186 | sim.masked_fill_(~mask, max_neg_value)
187 |
188 | # attention, what we cannot get enough of
189 | attn = sim.softmax(dim=-1)
190 |
191 | out = einsum('b i j, b j d -> b i d', attn, v)
192 | out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
193 | return self.to_out(out)
194 |
195 |
196 | class BasicTransformerBlock(nn.Module):
197 | def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True):
198 | super().__init__()
199 | self.attn1 = CrossAttention(query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout) # is a self-attention
200 | self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
201 | self.attn2 = CrossAttention(query_dim=dim, context_dim=context_dim,
202 | heads=n_heads, dim_head=d_head, dropout=dropout) # is self-attn if context is none
203 | self.norm1 = nn.LayerNorm(dim)
204 | self.norm2 = nn.LayerNorm(dim)
205 | self.norm3 = nn.LayerNorm(dim)
206 | self.checkpoint = checkpoint
207 |
208 | def forward(self, x, context=None):
209 | return checkpoint(self._forward, (x, context), self.parameters(), self.checkpoint)
210 |
211 | def _forward(self, x, context=None):
212 | x = self.attn1(self.norm1(x)) + x
213 | x = self.attn2(self.norm2(x), context=context) + x
214 | x = self.ff(self.norm3(x)) + x
215 | return x
216 |
217 |
218 | class SpatialTransformer(nn.Module):
219 | """
220 | Transformer block for image-like data.
221 | First, project the input (aka embedding)
222 | and reshape to b, t, d.
223 | Then apply standard transformer action.
224 | Finally, reshape to image
225 | """
226 | def __init__(self, in_channels, n_heads, d_head,
227 | depth=1, dropout=0., context_dim=None):
228 | super().__init__()
229 | self.in_channels = in_channels
230 | inner_dim = n_heads * d_head
231 | self.norm = Normalize(in_channels)
232 |
233 | self.proj_in = nn.Conv2d(in_channels,
234 | inner_dim,
235 | kernel_size=1,
236 | stride=1,
237 | padding=0)
238 |
239 | self.transformer_blocks = nn.ModuleList(
240 | [BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
241 | for d in range(depth)]
242 | )
243 |
244 | self.proj_out = zero_module(nn.Conv2d(inner_dim,
245 | in_channels,
246 | kernel_size=1,
247 | stride=1,
248 | padding=0))
249 |
250 | def forward(self, x, context=None):
251 | # note: if no context is given, cross-attention defaults to self-attention
252 | b, c, h, w = x.shape
253 | x_in = x
254 | x = self.norm(x)
255 | x = self.proj_in(x)
256 | x = rearrange(x, 'b c h w -> b (h w) c')
257 | for block in self.transformer_blocks:
258 | x = block(x, context=context)
259 | x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w)
260 | x = self.proj_out(x)
261 | return x + x_in
--------------------------------------------------------------------------------
/chexzero/chexzero/README.md:
--------------------------------------------------------------------------------
1 | # Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning
2 |
3 |
4 |
5 | Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning, Nat. Biomed. Eng (2022).
6 | [Paper]
7 |
Ekin Tiu, Ellie Talius, Pujan Patel, Curtis P. Langlotz, Andrew Y. Ng, Pranav Rajpurkar
8 |
9 |
10 | ```bash
11 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9
12 | ```
13 |
14 |
15 |
16 |
17 | This repository contains code to train a self-supervised learning model on chest X-ray images that lack explicit annotations and evalute this model's performance on pathology-classification tasks.
18 |
19 |
20 |
21 | Main Findings
22 |
23 |
24 | 1. **Automatically detecting pathologies in chest x-rays without explicit annotations:** Our method learns directly from the combination of images and unstructured radiology reports, thereby avoiding time-consuming labeling efforts. Our deep learning method is capable of predicting multiple pathologies and differential diagnoses that it had not explicitly seen during training.
25 | 2. **Matching radiologist performance on different tasks on an external test set:** Our method performed on par with human performance when evaluated on an external validation set (CheXpert) of chest x-ray images labeled for the presence of 14 different conditions by multiple radiologists.
26 | 3. **Outperforming approaches that train on explicitly labeled data on an external test set:** Using no labels, we outperformed a fully supervised approach (100% of labels) on 3 out of the 8 selected pathologies on a dataset (PadChest) collected in a different country. We further demonstrated high performance (AUC > 0.9) on 14 findings and at least 0.700 on 53 findings out of 107 radiographic findings that the method had not seen during training.
27 |
28 |
29 |
30 | ## Dependencies
31 | To clone all files:
32 |
33 | ```git clone https://github.com/rajpurkarlab/CheXzero.git```
34 |
35 | To install Python dependencies:
36 |
37 | ```pip install -r requirements.txt```
38 |
39 | ## Data
40 | ### Training Dataset
41 | 1. Download images come from [MIMIC-CXR JPG] https://physionet.org/content/mimic-cxr-jpg/2.0.0/ and reports from [MIMIC-CXR Database](https://physionet.org/content/mimic-cxr/2.0.0/) Note: in order to gain access to the data, you must be a credentialed user as defined on [PhysioNet](https://physionet.org/settings/credentialing/).
42 | 2. Copy the dataset into the `data/` directory.
43 | 3. Run `python run_preprocess.py`
44 | 4. This should preprocess the chest x-ray images into a Hierarchical Data Format (HDF) format used for training stored at `data/cxr.h5` and extract the impressions section as text from the corresponding chest x-ray radiology report stored at `data/mimic_impressions.csv` .
45 |
46 | ### Evaluation Dataset
47 |
48 | #### CheXpert Dataset
49 | The CheXpert dataset consists of chest radiographic examinations from Stanford Hospital, performed between October 2002
50 | and July 2017 in both inpatient and outpatient centers. Population-level characteristics are unavailable for the CheXpert test
51 | dataset, as they are used for official evaluation on the CheXpert leaderboard.
52 |
53 | The main data (CheXpert data) supporting the results of this study are available at https://aimi.stanford.edu/chexpert-chest-x-rays.
54 |
55 | The CheXpert **test** dataset has recently been made public, and can be found by following the steps in the [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels) repository.
56 |
57 | #### PadChest Dataset
58 | The PadChest dataset contains chest X-rays that were interpreted by 18 radiologists at the Hospital Universitario de San Juan,
59 | Alicante, Spain, from January 2009 to December 2017. The dataset contains 109,931 image studies and 168,861 images.
60 | PadChest also contains 206,222 study reports.
61 |
62 | The [PadChest](https://arxiv.org/abs/1901.07441) is publicly available at https://bimcv.cipf.es/bimcv-projects/padchest. Those who would like to use PadChest for experimentation should request access to PadChest at the [link](https://bimcv.cipf.es/bimcv-projects/padchest).
63 |
64 | ### Model Checkpoints
65 | Model checkpoints of CheXzero pre-trained on MIMIC-CXR are publicly available at the following [link](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing). Download files and save them in the `./checkpoints/chexzero_weights` directory.
66 |
67 | ## Running Training
68 | Run the following command to perform CheXzero pretraining.
69 | ```bash
70 | python run_train.py --cxr_filepath "./data/cxr.h5" --txt_filepath "data/mimic_impressions.csv"
71 | ```
72 |
73 | ### Arguments
74 | * `--cxr_filepath` Directory to load chest x-ray image data from.
75 | * `--txt_filepath` Directory to load radiology report impressions text from.
76 |
77 | Use `-h` flag to see all optional arguments.
78 |
79 | ## Zero-Shot Inference
80 | See the following [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) for an example of how to use CheXzero to perform zero-shot inference on a chest x-ray dataset. The example shows how to output predictions from the model ensemble and evaluate performance of the model if ground truth labels are available.
81 |
82 | ```python
83 | import zero_shot
84 |
85 | # computes predictions for a set of images stored as a np array of probabilities for each pathology
86 | predictions, y_pred_avg = zero_shot.ensemble_models(
87 | model_paths=model_paths,
88 | cxr_filepath=cxr_filepath,
89 | cxr_labels=cxr_labels,
90 | cxr_pair_template=cxr_pair_template,
91 | cache_dir=cache_dir,
92 | )
93 | ```
94 | ### Arguments
95 | * `model_paths: List[str]`: List of paths to all checkpoints to be used in the ensemble. To run on a single model, input a list containing a single path.
96 | * `cxr_filepath: str`: Path to images `.h5` file
97 | * `cxr_labels: List[str]`: List of pathologies to query in each image
98 | * `cxr_pair_templates: Tuple[str, str]`: constrasting templates used to query model (see Figure 1 in article for visual explanation).
99 | * `cache_dir: str`: Directory to cache predictions of each checkpoint, use to avoid recomputing predictions.
100 |
101 | In order to use CheXzero for zero-shot inference, ensure the following requirements are met:
102 | * All input *`images`* must be stored in a single `.h5` (Hierarchical Data Format). See the [`img_to_h5`](https://github.com/rajpurkarlab/CheXzero/blob/main/preprocess_padchest.py#L156) function in [preprocess_padchest.py](https://github.com/rajpurkarlab/internal-chexzero/blob/cleanversion/preprocess_padchest.py) for an example of how to convert a list of paths to `.png` files into a valid `.h5` file.
103 | * The *ground truth `labels`* must be in a `.csv` dataframe where rows represent each image sample, and each column represents the binary labels for a particular pathology on each sample.
104 | * Ensure all [model checkpoints](https://drive.google.com/drive/folders/1makFLiEMbSleYltaRxw81aBhEDMpVwno?usp=sharing) are stored in `checkpoints/chexzero_weights/`, or the `model_dir` that is specified in the notebook.
105 |
106 | ## Evaluation
107 | Given a numpy array of predictions (obtained from zero-shot inference), and a numpy array of ground truth labels, one can evaluate the performance of the model using the following code:
108 | ```python
109 | import zero_shot
110 | import eval
111 |
112 | # loads in ground truth labels into memory
113 | test_pred = y_pred_avg
114 | test_true = zero_shot.make_true_labels(cxr_true_labels_path=cxr_true_labels_path, cxr_labels=cxr_labels)
115 |
116 | # evaluate model, no bootstrap
117 | cxr_results: pd.DataFrame = eval.evaluate(test_pred, test_true, cxr_labels) # eval on full test datset
118 |
119 | # boostrap evaluations for 95% confidence intervals
120 | bootstrap_results: Tuple[pd.DataFrame, pd.DataFrame] = eval.bootstrap(test_pred, test_true, cxr_labels) # (df of results for each bootstrap, df of CI)
121 |
122 | # print results with confidence intervals
123 | print(bootstrap_results[1])
124 | ```
125 | The results are represented as a `pd.DataFrame` which can be saved as a `.csv`.
126 |
127 | ### CheXpert Test Dataset
128 | In order to replicate the results in the paper, zero-shot inference and evaluation can be performed on the now publicly available CheXpert test dataset.
129 | 1) Download labels at [cheXpert-test-set-labels](https://github.com/rajpurkarlab/cheXpert-test-set-labels/blob/main/groundtruth.csv) and image files from [Stanford AIMI](https://stanfordaimi.azurewebsites.net/datasets/23c56a0d-15de-405b-87c8-99c30138950c) and save in the `./data` directory in `CheXzero/`. The test dataset images should have the following directory structure:
130 | ```
131 | data/
132 | ├─ CheXpert/
133 | │ ├─ test/
134 | │ │ ├─ patient64741/
135 | │ │ │ ├─ study1/
136 | │ │ │ │ ├─ view1_frontal.jpg
137 | │ │ ├─ .../
138 | ```
139 |
140 | 2) Run `run_preprocess.py` script with the following arguments:
141 | ```bash
142 | python run_preprocess.py --dataset_type "chexpert-test" --cxr_out_path "./data/chexpert_test.h5" --chest_x_ray_path "./data/CheXpert/test/"
143 | ```
144 | This should save a `.h5` version of the test dataset images which can be used for evaluation.
145 |
146 | 3) Open sample zero-shot [notebook](https://github.com/rajpurkarlab/CheXzero/blob/main/notebooks/zero_shot.ipynb) and run all cells. If the directory structure is set up correctly, then all cells should run without errors.
147 |
148 | ## Issues
149 | Please open new issue threads specifying the issue with the codebase or report issues directly to ekintiu@stanford.edu.
150 |
151 | ## Citation
152 | ```bash
153 | Tiu, E., Talius, E., Patel, P. et al. Expert-level detection of pathologies from unannotated chest X-ray images via self-supervised learning. Nat. Biomed. Eng (2022). https://doi.org/10.1038/s41551-022-00936-9
154 | ```
155 |
156 | ## License
157 | The source code for the site is licensed under the MIT license, which you can find in the `LICENSE` file. Also see `NOTICE.md` for attributions to third-party sources.
158 |
--------------------------------------------------------------------------------
/cheff/sr/model.py:
--------------------------------------------------------------------------------
1 | """Classes and functions for neural networks."""
2 | import math
3 | from functools import partial
4 | from typing import Optional, Tuple
5 |
6 | import torch
7 | from einops import rearrange
8 | from torch import (
9 | Tensor,
10 | einsum,
11 | nn,
12 | )
13 |
14 |
15 | class Residual(nn.Module):
16 | """Wrapper for residual connection of a function."""
17 |
18 | def __init__(self, fn: nn.Module) -> None:
19 | """Initialize residual connection."""
20 | super().__init__()
21 | self.fn = fn
22 |
23 | def forward(self, x: Tensor, *args, **kwargs) -> Tensor:
24 | """Pass a tensor through the module."""
25 | return self.fn(x, *args, **kwargs) + x
26 |
27 |
28 | def get_upsample_conv(dim: int) -> nn.ConvTranspose2d:
29 | """Initialize transposed convolution layer."""
30 | return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
31 |
32 |
33 | def get_downsample_conv(dim: int) -> nn.Conv2d:
34 | """Initialize convolution layer."""
35 | return nn.Conv2d(dim, dim, 4, 2, 1)
36 |
37 |
38 | class SinusoidalPositionEmbeddings(nn.Module):
39 | """Class for sinusoidal embeddings."""
40 |
41 | def __init__(self, dim: int) -> None:
42 | """Initialize SinusoidalPositionEmbeddings."""
43 | super().__init__()
44 | self.dim = dim
45 |
46 | def forward(self, time):
47 | """Pass a tensor through the module."""
48 | device = time.device
49 | half_dim = self.dim // 2
50 | embeddings = math.log(10000) / (half_dim - 1)
51 | embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings)
52 | embeddings = time[:, None] * embeddings[None, :]
53 | embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1)
54 | return embeddings
55 |
56 |
57 | class Block(nn.Module):
58 | """Neural block with convolutions, norm and activations."""
59 |
60 | def __init__(self, dim: int, dim_out: int, groups: int = 8) -> None:
61 | """Initialize block."""
62 | super().__init__()
63 | self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
64 | self.norm = nn.GroupNorm(groups, dim_out)
65 | self.act = nn.SiLU()
66 |
67 | def forward(self, x: Tensor) -> Tensor:
68 | """Pass a tensor through the module."""
69 | x = self.proj(x)
70 | x = self.norm(x)
71 | x = self.act(x)
72 | return x
73 |
74 |
75 | class ResnetBlock(nn.Module):
76 | """Residual block."""
77 |
78 | def __init__(
79 | self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8
80 | ) -> None:
81 | """Initialize a residual block."""
82 | super().__init__()
83 | self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out))
84 |
85 | self.block1 = Block(dim, dim_out, groups=groups)
86 | self.block2 = Block(dim_out, dim_out, groups=groups)
87 |
88 | if dim != dim_out:
89 | self.res_conv = nn.Conv2d(dim, dim_out, 1)
90 | else:
91 | self.res_conv = nn.Identity()
92 |
93 | def forward(self, x: Tensor, t: Tensor) -> Tensor:
94 | """Pass a tensor through the module."""
95 | h = self.block1(x)
96 |
97 | time_emb = self.mlp(t)
98 | h += rearrange(time_emb, 'b c -> b c 1 1')
99 |
100 | h = self.block2(h)
101 |
102 | return h + self.res_conv(x)
103 |
104 |
105 | class Attention(nn.Module):
106 | """Attention module."""
107 |
108 | def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None:
109 | """Initialize Attention."""
110 | super().__init__()
111 | self.scale = dim_head**-0.5
112 | self.heads = heads
113 | hidden_dim = dim_head * heads
114 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
115 | self.to_out = nn.Conv2d(hidden_dim, dim, 1)
116 |
117 | def forward(self, x: Tensor) -> Tensor:
118 | """Pass a tensor through the module."""
119 | b, c, h, w = x.shape
120 | qkv = self.to_qkv(x).chunk(3, dim=1)
121 | q, k, v = map(
122 | lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv
123 | )
124 | q = q * self.scale
125 |
126 | sim = einsum('b h d i, b h d j -> b h i j', q, k)
127 | sim = sim - sim.amax(dim=-1, keepdim=True).detach()
128 | attn = sim.softmax(dim=-1)
129 |
130 | out = einsum('b h i j, b h d j -> b h i d', attn, v)
131 | out = rearrange(out, 'b h (x y) d -> b (h d) x y', x=h, y=w)
132 | return self.to_out(out)
133 |
134 |
135 | class LinearAttention(nn.Module):
136 | """Linear attention module."""
137 |
138 | def __init__(self, dim: int, heads: int = 4, dim_head: int = 32) -> None:
139 | """Initialize linear attention module."""
140 | super().__init__()
141 | self.scale = dim_head**-0.5
142 | self.heads = heads
143 | hidden_dim = dim_head * heads
144 | self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
145 |
146 | self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), nn.GroupNorm(1, dim))
147 |
148 | def forward(self, x: Tensor) -> Tensor:
149 | """Pass a tensor through the module."""
150 | b, c, h, w = x.shape
151 | qkv = self.to_qkv(x).chunk(3, dim=1)
152 | q, k, v = map(
153 | lambda t: rearrange(t, 'b (h c) x y -> b h c (x y)', h=self.heads), qkv
154 | )
155 |
156 | q = q.softmax(dim=-2)
157 | k = k.softmax(dim=-1)
158 |
159 | q = q * self.scale
160 | context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
161 |
162 | out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
163 | out = rearrange(out, 'b h c (x y) -> b (h c) x y', h=self.heads, x=h, y=w)
164 | return self.to_out(out)
165 |
166 |
167 | class PreNorm(nn.Module):
168 | """PreNorm Module."""
169 |
170 | def __init__(self, dim: int, fn: nn.Module) -> None:
171 | """Initialize PreNorm."""
172 | super().__init__()
173 | self.fn = fn
174 | self.norm = nn.GroupNorm(1, dim)
175 |
176 | def forward(self, x: Tensor) -> Tensor:
177 | """Pass a tensor through the module."""
178 | x = self.norm(x)
179 | return self.fn(x)
180 |
181 |
182 | class Unet(nn.Module):
183 | """U-net architecture."""
184 |
185 | def __init__(
186 | self,
187 | dim: int,
188 | init_dim: Optional[int] = None,
189 | out_dim: Optional[int] = None,
190 | dim_mults: Tuple = (1, 2, 4, 8),
191 | num_attention_layer: int = 5,
192 | channels: int = 3,
193 | block_groups: int = 8,
194 | ) -> None:
195 | """Initialize U-net."""
196 | super().__init__()
197 |
198 | self.channels = channels
199 |
200 | init_dim = init_dim if init_dim is not None else dim // 3 * 2
201 | self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
202 |
203 | dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
204 | in_out = list(zip(dims[:-1], dims[1:]))
205 |
206 | block_klass = partial(ResnetBlock, groups=block_groups)
207 |
208 | time_dim = dim * 4
209 | self.time_mlp = nn.Sequential(
210 | SinusoidalPositionEmbeddings(dim),
211 | nn.Linear(dim, time_dim),
212 | nn.GELU(),
213 | nn.Linear(time_dim, time_dim),
214 | )
215 |
216 | # layers
217 | self.downs = nn.ModuleList([])
218 | self.ups = nn.ModuleList([])
219 | num_resolutions = len(in_out)
220 |
221 | for ind, (dim_in, dim_out) in enumerate(in_out):
222 | is_last = ind >= (num_resolutions - 1)
223 | has_att = ind >= (num_resolutions - num_attention_layer + 1)
224 |
225 | self.downs.append(
226 | nn.ModuleList(
227 | [
228 | block_klass(dim_in, dim_out, time_emb_dim=time_dim),
229 | block_klass(dim_out, dim_out, time_emb_dim=time_dim),
230 | Residual(PreNorm(dim_out, LinearAttention(dim_out)))
231 | if has_att
232 | else nn.Identity(),
233 | get_downsample_conv(dim_out) if not is_last else nn.Identity(),
234 | ]
235 | )
236 | )
237 |
238 | mid_dim = dims[-1]
239 | self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
240 | if num_attention_layer >= 1:
241 | self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
242 | else:
243 | self.mid_attn = nn.Identity()
244 | self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
245 |
246 | for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
247 | is_last = ind >= (num_resolutions - 1)
248 | has_att = ind >= (num_resolutions - num_attention_layer + 1)
249 |
250 | self.ups.append(
251 | nn.ModuleList(
252 | [
253 | block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
254 | block_klass(dim_in, dim_in, time_emb_dim=time_dim),
255 | Residual(PreNorm(dim_in, LinearAttention(dim_in)))
256 | if has_att
257 | else nn.Identity(),
258 | get_upsample_conv(dim_in) if not is_last else nn.Identity(),
259 | ]
260 | )
261 | )
262 |
263 | out_dim = out_dim if out_dim is not None else channels
264 | self.final_conv = nn.Sequential(
265 | Residual(
266 | nn.Sequential(
267 | Block(dim, dim, groups=block_groups),
268 | Block(dim, dim, groups=block_groups),
269 | )
270 | ),
271 | nn.Conv2d(dim, out_dim, 1),
272 | )
273 |
274 | def forward(self, x: Tensor, t: Tensor = None) -> Tensor:
275 | """Pass a tensor through the module."""
276 | x = self.init_conv(x)
277 | t = self.time_mlp(t)
278 |
279 | h = []
280 |
281 | # Downsampling
282 | for block1, block2, attn, downsample in self.downs:
283 | x = block1(x, t)
284 | x = block2(x, t)
285 | x = attn(x)
286 | h.append(x)
287 | x = downsample(x)
288 |
289 | # Bottleneck
290 | x = self.mid_block1(x, t)
291 | x = self.mid_attn(x)
292 | x = self.mid_block2(x, t)
293 |
294 | # Upsampling
295 | for block1, block2, attn, upsample in self.ups:
296 | x = torch.cat((x, h.pop()), dim=1)
297 | x = block1(x, t)
298 | x = block2(x, t)
299 | x = attn(x)
300 | x = upsample(x)
301 |
302 | return self.final_conv(x)
303 |
--------------------------------------------------------------------------------
/chexzero/components/bbox_losses.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | from math import prod
4 | from typing import Tuple
5 | import einops
6 | from torch import Tensor
7 | import torch
8 | import torch.nn.functional as F
9 | from scipy.optimize import linear_sum_assignment
10 |
11 |
12 | def bbox_l1_ccost(boxes1, boxes2):
13 | """
14 | :param boxes1: (... x R_s x 4) in (xc, yc, w, h) format
15 | :param boxes2: (... x R_t x 4) in (xc, yc, w, h) format
16 | :return: (... x R_s x R_t)
17 | """
18 | assert boxes1.ndim == boxes2.ndim
19 | assert boxes1.ndim >= 3
20 | assert boxes1.shape[-1] >= 4 and boxes2.shape[-1] >= 4
21 | *batch_dims1, R_s, _ = boxes1.shape
22 | *batch_dims2, R_t, _ = boxes2.shape
23 | dims = prod(batch_dims1)
24 | assert batch_dims1 == batch_dims2
25 | # (... x R_s x 4)
26 | boxes1 = boxes1[..., :4].reshape(dims, R_s, 4)
27 | # (... x R_t x 4)
28 | boxes2 = boxes2[..., :4].reshape(dims, R_t, 4).to(dtype=boxes1.dtype)
29 | # (... x R_s x R_t)
30 | l1_cost = torch.cdist(boxes1, boxes2, p=1)
31 | return l1_cost.view(*batch_dims1, R_s, R_t)
32 |
33 |
34 | def bbox_l1_pcost(boxes1, boxes2):
35 | """
36 | :param boxes1: (... x 4) in (xc, yc, w, h) format
37 | :param boxes2: (... x 4) in (xc, yc, w, h) format
38 | :return: (...)
39 | """
40 | boxes1 = boxes1[..., :4]
41 | boxes2 = boxes2[..., :4]
42 | assert boxes1.shape == boxes2.shape
43 | return F.l1_loss(boxes1, boxes2, reduction='none').sum(dim=-1)
44 |
45 |
46 | def bbox_giou_ccost(boxes1, boxes2):
47 | """
48 | :param boxes1: (... x R_s x 4) in (xc, yc, w, h) format
49 | :param boxes2: (... x R_t x 4) in (xc, yc, w, h) format
50 | :return: (... x R_s x R_t)
51 | """
52 | assert boxes1.ndim == boxes2.ndim
53 | assert boxes1.ndim >= 3
54 | assert boxes1.shape[-1] >= 4 and boxes2.shape[-1] >= 4
55 | *batch_dims1, R_s, _ = boxes1.shape
56 | *batch_dims2, R_t, _ = boxes2.shape
57 | assert batch_dims1 == batch_dims2
58 | dims = prod(batch_dims1)
59 | # (... x R_s x 4)
60 | boxes1 = center_to_corners_format(boxes1[..., :4].reshape(dims, R_s, 4))
61 | boxes2 = center_to_corners_format(boxes2[..., :4].reshape(dims, R_t, 4).to(dtype=boxes1.dtype))
62 | # (... x R_s x R_t)
63 | giou_cost = 1. - box_giou(boxes1, boxes2)
64 |
65 | return giou_cost.view(*batch_dims1, R_s, R_t)
66 |
67 | def bbox_giou_pcost(boxes1, boxes2):
68 | """
69 | :param boxes1: (... x 4) in (xc, yc, w, h) format
70 | :param boxes2: (... x 4) in (xc, yc, w, h) format
71 | :return: (...)
72 | """
73 | boxes1 = center_to_corners_format(boxes1[..., :4])
74 | boxes2 = center_to_corners_format(boxes2[..., :4])
75 | assert boxes1.shape == boxes2.shape
76 | *batch_dims, _ = boxes1.shape
77 | boxes1 = boxes1.view(-1, 1, 1, 4)
78 | boxes2 = boxes2.view(-1, 1, 1, 4)
79 | # (..., 1, 1)
80 | giou_cost = 1. - box_giou(boxes1, boxes2)
81 | return giou_cost.view(*batch_dims)
82 |
83 | def box_giou(boxes1, boxes2):
84 | """
85 | https://giou.stanford.edu/
86 | :param boxes1: (... x N x M x 4) in (x0, y0, x1, y1) format
87 | :param boxes2: (... x N x M x 4) in (x0, y0, x1, y1) format
88 | :return: (... x N x M) giou
89 | """
90 | # degenerate boxes gives inf / nan results
91 | # so do an early check
92 | if not (boxes1[..., 2:] >= boxes1[..., :2]).all():
93 | raise ValueError(f"boxes1 must be in [x0, y0, x1, y1] (corner) format, but got {boxes1}")
94 | if not (boxes2[..., 2:] >= boxes2[..., :2]).all():
95 | raise ValueError(f"boxes2 must be in [x0, y0, x1, y1] (corner) format, but got {boxes2}")
96 |
97 | # (... x N x M)
98 | iou, union = box_iou(boxes1, boxes2)
99 |
100 | top_left = torch.min(boxes1[..., :, None, :2], boxes2[..., None, :, :2])
101 | bottom_right = torch.max(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:])
102 |
103 | width_height = (bottom_right - top_left).clamp(min=0) # [...,N,M,2]
104 | area = width_height[..., :, :, 0] * width_height[..., :, :, 1]
105 | return iou - (area - union) / area.clamp(min=1e-7)
106 |
107 | def box_iou(boxes1, boxes2):
108 | """
109 | :param boxes1: (... x N x 4) in (x0, y0, x1, y1) format
110 | :param boxes2: (... x M x 4) in (x0, y0, x1, y1) format
111 | :return: (... x N x M) iou and union
112 | """
113 | area1 = box_area(boxes1)
114 | area2 = box_area(boxes2)
115 |
116 | left_top = torch.max(boxes1[..., :, None, :2], boxes2[..., None, :, :2]) # [...,N,M,2]
117 | right_bottom = torch.min(boxes1[..., :, None, 2:], boxes2[..., None, :, 2:]) # [...,N,M,2]
118 |
119 | width_height = (right_bottom - left_top).clamp(min=0) # [...,N,M,2]
120 | inter = width_height[..., :, :, 0] * width_height[..., :, :, 1] # [...,N,M]
121 | # [...,N,M]
122 | union = area1[..., :, None] + area2[..., None, :] - inter
123 |
124 | iou = inter / union.clamp(min=1e-7)
125 | return iou, union
126 |
127 |
128 | def box_area(boxes: Tensor) -> Tensor:
129 | """
130 | :param boxes: (..., 4) in (x0, y0, x1, y1) format
131 | :return: (...)
132 | """
133 | boxes = _upcast(boxes)
134 | return (boxes[..., 2] - boxes[..., 0]) * (boxes[..., 3] - boxes[..., 1])
135 |
136 |
137 | @torch.no_grad()
138 | def match_multiregions(cost, mask, non_matched_region_mode: str, greedy_match=False):
139 | """
140 | :param cost: (N x C x R_s x R_t)
141 | :param mask: (N x C x R_t)
142 | """
143 | N, C, R_s, R_t = cost.shape
144 | cost = cost.flatten(0, 1) # (N*C x R_s x R_t)
145 | mask = mask.flatten(0, 1) # (N*C x R_t)
146 | # List (N*C) of (R_s, R_t') tensors where R_t' is the number of target boxes for the class and sample
147 | matches_list, assign_mask_list = zip(*[compute_matching(c, m, non_matched_region_mode, greedy_match=greedy_match) for c, m in zip(cost, mask)])
148 |
149 | # (N x C x R_s)
150 | matches = einops.rearrange(torch.stack(matches_list), '(n c) r_s -> n c r_s', n=N, c=C)
151 | # (N x C x R_s x R_t)
152 | assign_mask = einops.rearrange(torch.stack(assign_mask_list), '(n c) r_s r_t -> n c r_s r_t', n=N, c=C)
153 |
154 | return matches, assign_mask
155 |
156 |
157 | def compute_matching(cost: torch.FloatTensor, mask: torch.BoolTensor, non_matched_region_mode: str, greedy_match=False):
158 | device = cost.device
159 | R_t_all = mask.shape[-1]
160 | cost = cost[:, mask]
161 | if not greedy_match:
162 | cost = cost.cpu()
163 |
164 | R_s, R_t = cost.shape
165 | if R_t == 0:
166 | return torch.zeros((R_s,), dtype=torch.bool, device=device), torch.zeros((R_s, R_t_all), dtype=torch.bool, device=device)
167 |
168 | if non_matched_region_mode == 'balanced_match':
169 | cost = extend_cost_for_balanced_match(cost)
170 | else:
171 | assert non_matched_region_mode == 'ignore', f'Unknown non_matched_region_mode {non_matched_region_mode}'
172 |
173 | if greedy_match:
174 | indices_s, indices_t = _do_matching_greedy(cost)
175 | else:
176 | indices_s, indices_t = _do_matching(cost)
177 | # (N_match)
178 | matches = indices_s[indices_t < R_t]
179 | # (N_match)
180 | indices_t = indices_t % R_t # for balanced match
181 |
182 | # (R_s)
183 | matches: torch.BoolTensor = torch.zeros((R_s,), dtype=torch.bool, device=device if greedy_match else 'cpu').scatter_(0, matches, True)
184 | # (R_s x R_t)
185 | # based on both, indices_s and indices_t
186 | assign_mask = torch.zeros((R_s, R_t), dtype=torch.bool, device=device if greedy_match else 'cpu')
187 | assign_mask[indices_s, indices_t] = True
188 |
189 | if not greedy_match:
190 | matches = matches.to(device, non_blocking=True)
191 | assign_mask = assign_mask.to(device, non_blocking=True)
192 |
193 | assign_mask_all = torch.zeros((R_s, R_t_all), dtype=torch.bool, device=device)
194 | assign_mask_all[:, mask] = assign_mask
195 |
196 | return matches, assign_mask_all
197 |
198 |
199 | def _do_matching(cost: torch.FloatTensor) -> Tuple[torch.LongTensor, torch.LongTensor]:
200 | indices_s, indices_t = linear_sum_assignment(cost.numpy())
201 | indices_s = torch.from_numpy(indices_s)
202 | indices_t = torch.from_numpy(indices_t)
203 | return indices_s, indices_t
204 |
205 |
206 | def _do_matching_greedy(cost: torch.FloatTensor) -> Tuple[torch.LongTensor, torch.LongTensor]:
207 | """
208 | :param R_s, R_t
209 | :return: (N_match), (N_match)
210 | """
211 | indices_s, indices_t = [], []
212 | R_s, R_t = cost.shape
213 | cost = cost.clone()
214 |
215 | for _ in range(min(R_s, R_t)):
216 | mins_R_t, argmins_R_t = cost.min(dim=1)
217 | s = mins_R_t.argmin()
218 | t = argmins_R_t[s]
219 | indices_s.append(s)
220 | indices_t.append(t)
221 | cost[s, :] = float('inf')
222 | cost[:, t] = float('inf')
223 |
224 | indices_s = torch.stack(indices_s)
225 | indices_t = torch.stack(indices_t)
226 | return indices_s, indices_t
227 |
228 |
229 | def extend_cost_for_balanced_match(cost):
230 | # cost: (R_s x R_t')
231 | R_s, R_t = cost.shape
232 | max_cost_diff = cost.max() - cost.min()
233 |
234 | if R_s > R_t:
235 | n_copies = R_s // R_t + (1 if R_s % R_t != 0 else 0)
236 | R_t_total = R_t * n_copies
237 | assert R_t_total >= R_s
238 |
239 | # (n_copies - 1)
240 | extra_costs = torch.full((n_copies - 1,), max_cost_diff + 1.,
241 | dtype=torch.float, device=cost.device).cumsum(0)
242 | # (R_t_total)
243 | extra_costs = torch.cat([
244 | torch.zeros((R_t,), dtype=torch.float, device=cost.device),
245 | extra_costs.repeat_interleave(R_t)
246 | ])
247 | # (R_s x R_t_total)
248 | cost = cost.repeat(1, n_copies) + extra_costs[None, :]
249 |
250 | return cost
251 |
252 |
253 | # ----------------- Conversion utils -----------------
254 | def center_to_corners_format(x):
255 | """
256 | Converts a PyTorch tensor of bounding boxes of center format (center_x, center_y, width, height) to corners format
257 | (x_0, y_0, x_1, y_1).
258 | """
259 | x_c, y_c, w, h, *args = x.unbind(-1)
260 | b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h), *args]
261 | return torch.stack(b, dim=-1)
262 |
263 | def corners_to_center_format(x):
264 | """
265 | Converts a PyTorch tensor of bounding boxes of corners format (x_0, y_0, x_1, y_1) to center format
266 | (center_x, center_y, width, height).
267 | """
268 | x_0, y_0, x_1, y_1, *args = x.unbind(-1)
269 | b = [(x_0 + x_1) / 2, (y_0 + y_1) / 2, (x_1 - x_0), (y_1 - y_0), *args]
270 | return torch.stack(b, dim=-1)
271 |
272 | def _upcast(t: Tensor) -> Tensor:
273 | # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type
274 | if t.is_floating_point():
275 | return t if t.dtype in (torch.float32, torch.float64) else t.float()
276 | else:
277 | return t if t.dtype in (torch.int32, torch.int64) else t.int()
278 |
--------------------------------------------------------------------------------
/cheff/ldm/models/diffusion/classifier.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import pytorch_lightning as pl
4 | from omegaconf import OmegaConf
5 | from torch.nn import functional as F
6 | from torch.optim import AdamW
7 | from torch.optim.lr_scheduler import LambdaLR
8 | from copy import deepcopy
9 | from einops import rearrange
10 | from glob import glob
11 | from natsort import natsorted
12 |
13 | from cheff.ldm.modules.diffusionmodules.openaimodel import EncoderUNetModel, UNetModel
14 | from cheff.ldm.util import log_txt_as_img, default, ismap, instantiate_from_config
15 |
16 | __models__ = {
17 | 'class_label': EncoderUNetModel,
18 | 'segmentation': UNetModel
19 | }
20 |
21 |
22 | def disabled_train(self, mode=True):
23 | """Overwrite model.train with this function to make sure train/eval mode
24 | does not change anymore."""
25 | return self
26 |
27 |
28 | class NoisyLatentImageClassifier(pl.LightningModule):
29 |
30 | def __init__(self,
31 | diffusion_path,
32 | num_classes,
33 | ckpt_path=None,
34 | pool='attention',
35 | label_key=None,
36 | diffusion_ckpt_path=None,
37 | scheduler_config=None,
38 | weight_decay=1.e-2,
39 | log_steps=10,
40 | monitor='val/loss',
41 | *args,
42 | **kwargs):
43 | super().__init__(*args, **kwargs)
44 | self.num_classes = num_classes
45 | # get latest config of sr model
46 | diffusion_config = natsorted(glob(os.path.join(diffusion_path, 'configs', '*-project.yaml')))[-1]
47 | self.diffusion_config = OmegaConf.load(diffusion_config).model
48 | self.diffusion_config.params.ckpt_path = diffusion_ckpt_path
49 | self.load_diffusion()
50 |
51 | self.monitor = monitor
52 | self.numd = self.diffusion_model.first_stage_model.encoder.num_resolutions - 1
53 | self.log_time_interval = self.diffusion_model.num_timesteps // log_steps
54 | self.log_steps = log_steps
55 |
56 | self.label_key = label_key if not hasattr(self.diffusion_model, 'cond_stage_key') \
57 | else self.diffusion_model.cond_stage_key
58 |
59 | assert self.label_key is not None, 'label_key neither in sr model nor in model.params'
60 |
61 | if self.label_key not in __models__:
62 | raise NotImplementedError()
63 |
64 | self.load_classifier(ckpt_path, pool)
65 |
66 | self.scheduler_config = scheduler_config
67 | self.use_scheduler = self.scheduler_config is not None
68 | self.weight_decay = weight_decay
69 |
70 | def init_from_ckpt(self, path, ignore_keys=list(), only_model=False):
71 | sd = torch.load(path, map_location="cpu")
72 | if "state_dict" in list(sd.keys()):
73 | sd = sd["state_dict"]
74 | keys = list(sd.keys())
75 | for k in keys:
76 | for ik in ignore_keys:
77 | if k.startswith(ik):
78 | print("Deleting key {} from state_dict.".format(k))
79 | del sd[k]
80 | missing, unexpected = self.load_state_dict(sd, strict=False) if not only_model else self.model.load_state_dict(
81 | sd, strict=False)
82 | print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys")
83 | if len(missing) > 0:
84 | print(f"Missing Keys: {missing}")
85 | if len(unexpected) > 0:
86 | print(f"Unexpected Keys: {unexpected}")
87 |
88 | def load_diffusion(self):
89 | model = instantiate_from_config(self.diffusion_config)
90 | self.diffusion_model = model.eval()
91 | self.diffusion_model.train = disabled_train
92 | for param in self.diffusion_model.parameters():
93 | param.requires_grad = False
94 |
95 | def load_classifier(self, ckpt_path, pool):
96 | model_config = deepcopy(self.diffusion_config.params.unet_config.params)
97 | model_config.in_channels = self.diffusion_config.params.unet_config.params.out_channels
98 | model_config.out_channels = self.num_classes
99 | if self.label_key == 'class_label':
100 | model_config.pool = pool
101 |
102 | self.model = __models__[self.label_key](**model_config)
103 | if ckpt_path is not None:
104 | print('#####################################################################')
105 | print(f'load from ckpt "{ckpt_path}"')
106 | print('#####################################################################')
107 | self.init_from_ckpt(ckpt_path)
108 |
109 | @torch.no_grad()
110 | def get_x_noisy(self, x, t, noise=None):
111 | noise = default(noise, lambda: torch.randn_like(x))
112 | continuous_sqrt_alpha_cumprod = None
113 | if self.diffusion_model.use_continuous_noise:
114 | continuous_sqrt_alpha_cumprod = self.diffusion_model.sample_continuous_noise_level(x.shape[0], t + 1)
115 | # todo: make sure t+1 is correct here
116 |
117 | return self.diffusion_model.q_sample(x_start=x, t=t, noise=noise,
118 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod)
119 |
120 | def forward(self, x_noisy, t, *args, **kwargs):
121 | return self.model(x_noisy, t)
122 |
123 | @torch.no_grad()
124 | def get_input(self, batch, k):
125 | x = batch[k]
126 | if len(x.shape) == 3:
127 | x = x[..., None]
128 | x = rearrange(x, 'b h w c -> b c h w')
129 | x = x.to(memory_format=torch.contiguous_format).float()
130 | return x
131 |
132 | @torch.no_grad()
133 | def get_conditioning(self, batch, k=None):
134 | if k is None:
135 | k = self.label_key
136 | assert k is not None, 'Needs to provide label key'
137 |
138 | targets = batch[k].to(self.device)
139 |
140 | if self.label_key == 'segmentation':
141 | targets = rearrange(targets, 'b h w c -> b c h w')
142 | for down in range(self.numd):
143 | h, w = targets.shape[-2:]
144 | targets = F.interpolate(targets, size=(h // 2, w // 2), mode='nearest')
145 |
146 | # targets = rearrange(targets,'b c h w -> b h w c')
147 |
148 | return targets
149 |
150 | def compute_top_k(self, logits, labels, k, reduction="mean"):
151 | _, top_ks = torch.topk(logits, k, dim=1)
152 | if reduction == "mean":
153 | return (top_ks == labels[:, None]).float().sum(dim=-1).mean().item()
154 | elif reduction == "none":
155 | return (top_ks == labels[:, None]).float().sum(dim=-1)
156 |
157 | def on_train_epoch_start(self):
158 | # save some memory
159 | self.diffusion_model.model.to('cpu')
160 |
161 | @torch.no_grad()
162 | def write_logs(self, loss, logits, targets):
163 | log_prefix = 'train' if self.training else 'val'
164 | log = {}
165 | log[f"{log_prefix}/loss"] = loss.mean()
166 | log[f"{log_prefix}/acc@1"] = self.compute_top_k(
167 | logits, targets, k=1, reduction="mean"
168 | )
169 | log[f"{log_prefix}/acc@5"] = self.compute_top_k(
170 | logits, targets, k=5, reduction="mean"
171 | )
172 |
173 | self.log_dict(log, prog_bar=False, logger=True, on_step=self.training, on_epoch=True)
174 | self.log('loss', log[f"{log_prefix}/loss"], prog_bar=True, logger=False)
175 | self.log('global_step', self.global_step, logger=False, on_epoch=False, prog_bar=True)
176 | lr = self.optimizers().param_groups[0]['lr']
177 | self.log('lr_abs', lr, on_step=True, logger=True, on_epoch=False, prog_bar=True)
178 |
179 | def shared_step(self, batch, t=None):
180 | x, *_ = self.diffusion_model.get_input(batch, k=self.diffusion_model.first_stage_key)
181 | targets = self.get_conditioning(batch)
182 | if targets.dim() == 4:
183 | targets = targets.argmax(dim=1)
184 | if t is None:
185 | t = torch.randint(0, self.diffusion_model.num_timesteps, (x.shape[0],), device=self.device).long()
186 | else:
187 | t = torch.full(size=(x.shape[0],), fill_value=t, device=self.device).long()
188 | x_noisy = self.get_x_noisy(x, t)
189 | logits = self(x_noisy, t)
190 |
191 | loss = F.cross_entropy(logits, targets, reduction='none')
192 |
193 | self.write_logs(loss.detach(), logits.detach(), targets.detach())
194 |
195 | loss = loss.mean()
196 | return loss, logits, x_noisy, targets
197 |
198 | def training_step(self, batch, batch_idx):
199 | loss, *_ = self.shared_step(batch)
200 | return loss
201 |
202 | def reset_noise_accs(self):
203 | self.noisy_acc = {t: {'acc@1': [], 'acc@5': []} for t in
204 | range(0, self.diffusion_model.num_timesteps, self.diffusion_model.log_every_t)}
205 |
206 | def on_validation_start(self):
207 | self.reset_noise_accs()
208 |
209 | @torch.no_grad()
210 | def validation_step(self, batch, batch_idx):
211 | loss, *_ = self.shared_step(batch)
212 |
213 | for t in self.noisy_acc:
214 | _, logits, _, targets = self.shared_step(batch, t)
215 | self.noisy_acc[t]['acc@1'].append(self.compute_top_k(logits, targets, k=1, reduction='mean'))
216 | self.noisy_acc[t]['acc@5'].append(self.compute_top_k(logits, targets, k=5, reduction='mean'))
217 |
218 | return loss
219 |
220 | def configure_optimizers(self):
221 | optimizer = AdamW(self.model.parameters(), lr=self.learning_rate, weight_decay=self.weight_decay)
222 |
223 | if self.use_scheduler:
224 | scheduler = instantiate_from_config(self.scheduler_config)
225 |
226 | print("Setting up LambdaLR scheduler...")
227 | scheduler = [
228 | {
229 | 'scheduler': LambdaLR(optimizer, lr_lambda=scheduler.schedule),
230 | 'interval': 'step',
231 | 'frequency': 1
232 | }]
233 | return [optimizer], scheduler
234 |
235 | return optimizer
236 |
237 | @torch.no_grad()
238 | def log_images(self, batch, N=8, *args, **kwargs):
239 | log = dict()
240 | x = self.get_input(batch, self.diffusion_model.first_stage_key)
241 | log['inputs'] = x
242 |
243 | y = self.get_conditioning(batch)
244 |
245 | if self.label_key == 'class_label':
246 | y = log_txt_as_img((x.shape[2], x.shape[3]), batch["human_label"])
247 | log['labels'] = y
248 |
249 | if ismap(y):
250 | log['labels'] = self.diffusion_model.to_rgb(y)
251 |
252 | for step in range(self.log_steps):
253 | current_time = step * self.log_time_interval
254 |
255 | _, logits, x_noisy, _ = self.shared_step(batch, t=current_time)
256 |
257 | log[f'inputs@t{current_time}'] = x_noisy
258 |
259 | pred = F.one_hot(logits.argmax(dim=1), num_classes=self.num_classes)
260 | pred = rearrange(pred, 'b h w c -> b c h w')
261 |
262 | log[f'pred@t{current_time}'] = self.diffusion_model.to_rgb(pred)
263 |
264 | for key in log:
265 | log[key] = log[key][:N]
266 |
267 | return log
268 |
--------------------------------------------------------------------------------