├── src ├── __init__.py ├── loss.py ├── model.py ├── config.yaml ├── util.py ├── datamodule.py ├── dataset.py ├── pl_module.py ├── timm_3d_decoder3.py └── metric.py ├── misc └── model.png ├── requirements.txt ├── docker-compose.yml ├── LICENSE ├── 01_create_dataset.py ├── 03_create_solution.py ├── docker └── Dockerfile ├── .gitignore ├── 11_train.py ├── README.md └── 02_create_mask.py /src/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /misc/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yu4u/kaggle-czii-4th/HEAD/misc/model.png -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | numpy 3 | pandas 4 | opencv-python 5 | scikit-learn 6 | pytorch-lightning 7 | omegaconf 8 | wandb 9 | jupyter 10 | matplotlib 11 | pydantic 12 | timm 13 | albumentations 14 | pyarrow 15 | einops 16 | joblib 17 | h5py 18 | ttach 19 | openmim 20 | torchmetrics 21 | segmentation-models-pytorch 22 | lightgbm 23 | xgboost 24 | optuna 25 | zarr 26 | ome-zarr 27 | monai 28 | -------------------------------------------------------------------------------- /docker-compose.yml: -------------------------------------------------------------------------------- 1 | version: "3.8" 2 | services: 3 | dev: 4 | shm_size: "128gb" 5 | build: 6 | context: . 7 | dockerfile: ./docker/Dockerfile 8 | args: 9 | UID: ${UID} 10 | tty: true 11 | volumes: 12 | - .:/work 13 | working_dir: /work 14 | deploy: 15 | resources: 16 | reservations: 17 | devices: 18 | - capabilities: [ gpu ] 19 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2025 Yusuke Uchida 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /01_create_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import numpy as np 4 | import zarr 5 | 6 | 7 | def get_args(): 8 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 9 | parser.add_argument("--mode", type=str, default="train") # train or test 10 | args = parser.parse_args() 11 | return args 12 | 13 | 14 | def main(): 15 | args = get_args() 16 | mode = args.mode # train or test 17 | root_dir = Path(__file__).parent.joinpath("input") 18 | output_root = Path(__file__).parent.joinpath("output") 19 | output_dir = output_root.joinpath(f"{mode}_imgs") 20 | output_dir.mkdir(exist_ok=True, parents=True) 21 | 22 | for exp_dir in root_dir.joinpath(mode, "static", "ExperimentRuns").iterdir(): 23 | if not exp_dir.is_dir(): 24 | continue 25 | 26 | print(exp_dir) 27 | zarr_path = exp_dir.joinpath("VoxelSpacing10.000", "denoised.zarr") 28 | zarr_file = zarr.open(str(zarr_path)) 29 | tomogram = zarr_file["0"][:] 30 | output_path = output_dir.joinpath(f"{exp_dir.stem}.npy") 31 | np.save(str(output_path), tomogram) 32 | 33 | 34 | if __name__ == '__main__': 35 | main() 36 | -------------------------------------------------------------------------------- /src/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from segmentation_models_pytorch.losses import FocalLoss, DiceLoss 5 | 6 | def get_loss(cfg): 7 | if cfg.model.arch == "timm3d6": 8 | return MyLossForDeepSupervision(cfg) 9 | else: 10 | return MyLoss(cfg) 11 | 12 | 13 | class MyLoss(nn.Module): 14 | def __init__(self, cfg): 15 | super().__init__() 16 | self.cfg = cfg 17 | 18 | if cfg.loss.name == "mse": 19 | self.loss = nn.MSELoss(reduction="none") 20 | elif cfg.loss.name == "focal": 21 | self.loss = FocalLoss(mode="multilabel") 22 | elif cfg.loss.name == "dice": 23 | self.loss = DiceLoss(mode="multilabel") 24 | else: 25 | raise NotImplementedError(f"loss {cfg.loss.name} not implemented") 26 | 27 | def forward(self, y_pred, y_true): 28 | return_dict = dict() 29 | loss = self.loss(y_pred, y_true) 30 | neg_weight = self.cfg.loss.neg_weight 31 | loss = (loss * (y_true + neg_weight)).mean() 32 | return_dict["loss"] = loss 33 | return return_dict 34 | 35 | 36 | def main(): 37 | pass 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /03_create_solution.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import json 3 | import pandas as pd 4 | 5 | from src.metric import particle_types 6 | 7 | 8 | def main(): 9 | root_dir = Path(__file__).parent.joinpath("input") 10 | output_dir = Path(__file__).parent.joinpath("output") 11 | output_dir.mkdir(exist_ok=True, parents=True) 12 | 13 | json_names = [particle_type + ".json" for particle_type in particle_types] 14 | rows = [] 15 | 16 | for exp_dir in root_dir.joinpath("train", "overlay", "ExperimentRuns").iterdir(): 17 | if not exp_dir.is_dir(): 18 | continue 19 | 20 | json_dir = exp_dir.joinpath("Picks") 21 | experiment = exp_dir.stem 22 | 23 | for i, json_name in enumerate(json_names): 24 | json_path = json_dir.joinpath(json_name) 25 | particle_type = json_name.split(".")[0] 26 | 27 | with open(json_path) as f: 28 | picks = json.load(f) 29 | 30 | for point in picks["points"]: 31 | x = point["location"]["x"] 32 | y = point["location"]["y"] 33 | z = point["location"]["z"] 34 | rows.append([experiment, particle_type, x, y, z]) 35 | 36 | df = pd.DataFrame(rows, columns=["experiment", "particle_type", "x", "y", "z"]) 37 | df.index.name = "id" 38 | df.to_csv(output_dir.joinpath("solution.csv")) 39 | 40 | 41 | if __name__ == '__main__': 42 | main() 43 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04 2 | # https://hub.docker.com/r/nvidia/cuda/tags 3 | 4 | ENV DEBIAN_FRONTEND=noninteractive 5 | 6 | RUN apt-get update -y && \ 7 | apt-get install -y --no-install-recommends \ 8 | tzdata \ 9 | ca-certificates \ 10 | sudo \ 11 | git \ 12 | vim \ 13 | # for pyenv \ 14 | build-essential libssl-dev zlib1g-dev \ 15 | libbz2-dev libreadline-dev libsqlite3-dev curl \ 16 | libncursesw5-dev xz-utils tk-dev libxml2-dev libxmlsec1-dev libffi-dev liblzma-dev \ 17 | # for opencv \ 18 | libgl1-mesa-dev \ 19 | && rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/* 20 | 21 | ENV TZ Asia/Tokyo 22 | 23 | ARG UID 24 | RUN useradd docker -l -u ${UID} -G sudo -s /bin/bash -m 25 | RUN echo 'Defaults visiblepw' >> /etc/sudoers 26 | RUN echo 'docker ALL=(ALL) NOPASSWD:ALL' >> /etc/sudoers 27 | 28 | USER docker 29 | 30 | ARG PYTHON_VERSION=3.11.6 31 | ENV PYENV_ROOT /home/docker/.pyenv 32 | ENV PATH $PYENV_ROOT/bin:$PYENV_ROOT/shims:$PATH 33 | 34 | RUN curl -L https://github.com/pyenv/pyenv-installer/raw/master/bin/pyenv-installer | bash && \ 35 | pyenv install ${PYTHON_VERSION} && \ 36 | pyenv global ${PYTHON_VERSION} && \ 37 | pip install --upgrade pip 38 | 39 | COPY requirements.txt /tmp 40 | RUN pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121 41 | RUN pip install -r /tmp/requirements.txt 42 | RUN pip install -U timm 43 | -------------------------------------------------------------------------------- /src/model.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import torch 4 | import torch.nn as nn 5 | 6 | from src.timm_3d_decoder3 import Timm3DDecoder3 7 | 8 | 9 | def get_model_from_cfg(cfg, resume_path=None): 10 | if cfg.model.arch == "timm3d3": 11 | model = Timm3DDecoder3(cfg) 12 | else: 13 | raise NotImplementedError 14 | 15 | if resume_path: 16 | print(f"loading model from {str(resume_path)}") 17 | checkpoint = torch.load(str(resume_path), map_location="cpu") 18 | 19 | if np.any([k.startswith("model_ema.") for k in checkpoint["state_dict"].keys()]): 20 | print(f"loading from model_ema") 21 | state_dict = {k[17:]: v for k, v in checkpoint["state_dict"].items() if k.startswith("model_ema.")} 22 | else: 23 | state_dict = {k[6:]: v for k, v in checkpoint["state_dict"].items() if k.startswith("model.")} 24 | 25 | model.load_state_dict(state_dict, strict=True) 26 | 27 | return model 28 | 29 | 30 | class EnsembleModel(nn.Module): 31 | def __init__(self, cfg): 32 | super().__init__() 33 | self.cfg = cfg 34 | 35 | if Path(cfg.model.resume_path).is_dir(): 36 | resume_paths = Path(cfg.model.resume_path).rglob("*.ckpt") 37 | else: 38 | resume_paths = [Path(cfg.model.resume_path)] 39 | 40 | self.models = nn.ModuleList() 41 | 42 | for resume_path in resume_paths: 43 | model = get_model_from_cfg(cfg, resume_path) 44 | self.models.append(model) 45 | 46 | def __call__(self, x): 47 | outputs = [model(x) for model in self.models] 48 | x = torch.mean(torch.stack(outputs), dim=0) 49 | return x 50 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ### https://raw.github.com/github/gitignore/f57304e9762876ae4c9b02867ed0cb887316387e/Python.gitignore 2 | 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | env/ 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | *.egg-info/ 27 | .installed.cfg 28 | *.egg 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *,cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | local_settings.py 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # dotenv 85 | .env 86 | 87 | # virtualenv 88 | .venv 89 | venv/ 90 | ENV/ 91 | 92 | # Spyder project settings 93 | .spyderproject 94 | 95 | # Rope project settings 96 | .ropeproject 97 | 98 | # mkdocs documentation 99 | /site 100 | 101 | /.idea/ 102 | 103 | .DS_Store 104 | 105 | -------------------------------------------------------------------------------- /src/config.yaml: -------------------------------------------------------------------------------- 1 | task: # task specific config 2 | debug: false 3 | seed: 42 4 | dirname: train_cropped_images 5 | 6 | model: 7 | arch: "unet" 8 | backbone: "tu-tf_efficientnetv2_s.in21k_ft_in1k" 9 | resume_path: null 10 | ema: false 11 | ema_decay: 0.999 12 | ema_update_after_step: 0 13 | swa: false 14 | freeze_backbone: false 15 | freeze_end_epoch: 16 16 | drop_path_rate: 0.0 17 | drop_rate: 0.0 18 | attn_drop_rate: 0.0 19 | img_size: 128 20 | img_depth: 16 21 | in_channels: 3 22 | d_model: 128 23 | num_layers: 2 24 | kernel_size: 5 25 | use_lstm: false 26 | use_attn: false 27 | depth: 50 28 | pool: "avg" # avg, gem 29 | with_pool2: true 30 | normalize_patch: false 31 | depth_flip: false 32 | stride: "pool" # pool, conv 33 | class_num: 5 34 | train_stride: 0.5 35 | use_intermediate_conv: true 36 | 37 | data: 38 | fold_num: 5 39 | fold_id: 0 40 | num_workers: 0 41 | batch_size: 2 42 | train_all: false 43 | 44 | trainer: 45 | max_epochs: 32 46 | devices: "auto" # list or str, -1 to indicate all available devices 47 | strategy: "auto" # ddp 48 | check_val_every_n_epoch: 1 49 | sync_batchnorm: false 50 | accelerator: "cpu" # cpu, gpu, tpu, ipu, hpu, mps, auto 51 | precision: 32 # 16, 32, 64, bf16 52 | gradient_clip_val: null 53 | accumulate_grad_batches: 1 54 | deterministic: true 55 | 56 | test: 57 | mode: val # test or val 58 | output_dir: preds_results 59 | tta: false 60 | target: axial 61 | dirname: null 62 | 63 | opt: 64 | opt: "AdamW" # SGD, Adam, AdamW... 65 | lr: 1e-4 66 | weight_decay: 0.01 67 | 68 | scheduler: 69 | sched: "cosine" 70 | min_lr: 0.0 71 | warmup_epochs: 0 72 | 73 | loss: 74 | name: "mse" # bce, focal 75 | alpha: 0.25 76 | mixup: 0.0 77 | cutmix: 0.0 78 | neg_weight: 0.1 79 | 80 | wandb: 81 | project: kaggle-czii 82 | name: null 83 | fast_dev_run: false 84 | -------------------------------------------------------------------------------- /src/util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class SlidingSlicer(nn.Module): 7 | def __init__(self, slice_size=3, stride=1): 8 | super(SlidingSlicer, self).__init__() 9 | 10 | # Create convolution layer to simulate the sliding slice operation 11 | self.conv = nn.Conv3d(1, slice_size, kernel_size=(slice_size, 1, 1), stride=(stride, 1, 1), 12 | bias=False, padding=(slice_size // 2, 0, 0)) 13 | 14 | # Set weights to simulate identity operation and bias to 0 15 | with torch.no_grad(): 16 | self.conv.weight.data.fill_(0) 17 | for i in range(slice_size): 18 | self.conv.weight.data[i, 0, i] = 1 19 | 20 | for param in self.conv.parameters(): 21 | param.requires_grad = False 22 | 23 | def forward(self, x): 24 | out = self.conv(x) 25 | out = out.transpose(1, 2) 26 | return out 27 | 28 | 29 | def mixup(data, targets, alpha=1.0): 30 | indices = torch.randperm(data.size(0)) 31 | shuffled_data = data[indices] 32 | shuffled_targets = targets[indices] 33 | lam = np.random.beta(alpha, alpha) 34 | data = data * lam + shuffled_data * (1 - lam) 35 | return data, targets, shuffled_targets, lam 36 | 37 | 38 | def rand_bbox(size, lam): 39 | W = size[2] 40 | H = size[3] 41 | cut_rat = np.sqrt(1. - lam) 42 | cut_w = int(W * cut_rat) 43 | cut_h = int(H * cut_rat) 44 | 45 | # uniform 46 | cx = np.random.randint(W) 47 | cy = np.random.randint(H) 48 | 49 | bbx1 = np.clip(cx - cut_w // 2, 0, W) 50 | bby1 = np.clip(cy - cut_h // 2, 0, H) 51 | bbx2 = np.clip(cx + cut_w // 2, 0, W) 52 | bby2 = np.clip(cy + cut_h // 2, 0, H) 53 | 54 | return bbx1, bby1, bbx2, bby2 55 | 56 | 57 | def cutmix(data, targets, alpha=1.0): 58 | indices = torch.randperm(data.size(0)) 59 | shuffled_data = data[indices] 60 | shuffled_targets = targets[indices] 61 | lam = np.random.beta(alpha, alpha) 62 | bbx1, bby1, bbx2, bby2 = rand_bbox(data.size(), lam) 63 | data[:, :, bbx1:bbx2, bby1:bby2] = data[indices, :, bbx1:bbx2, bby1:bby2] 64 | # adjust lambda to exactly match pixel ratio 65 | lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (data.size()[-1] * data.size()[-2])) 66 | 67 | return data, targets, shuffled_targets, lam 68 | 69 | 70 | def get_augment_policy(cfg): 71 | p_mixup = cfg.loss.mixup 72 | p_cutmix = cfg.loss.cutmix 73 | p_nothing = 1 - p_mixup - p_cutmix 74 | return np.random.choice(["nothing", "mixup", "cutmix"], p=[p_nothing, p_mixup, p_cutmix], size=1)[0] 75 | 76 | 77 | def main(): 78 | pass 79 | 80 | 81 | if __name__ == '__main__': 82 | main() -------------------------------------------------------------------------------- /11_train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pytorch_lightning import Trainer, seed_everything 3 | from pytorch_lightning.loggers import WandbLogger 4 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, StochasticWeightAveraging, Callback 5 | from omegaconf import OmegaConf 6 | 7 | from src.datamodule import MyDataModule 8 | from src.pl_module import MyModel 9 | 10 | 11 | def get_args(): 12 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 13 | parser.add_argument("--config", type=str, default="src/config.yaml") 14 | parser.add_argument("--resume", type=str, default=None) 15 | parser.add_argument("opts", default=[], nargs=argparse.REMAINDER, 16 | help="Modify config options using the command-line") 17 | args = parser.parse_args() 18 | return args 19 | 20 | 21 | class BaseModelFreezeCallback(Callback): 22 | def on_train_epoch_start(self, trainer, pl_module): 23 | current_epoch = pl_module.current_epoch 24 | freeze_end_epoch = pl_module.cfg.model.freeze_end_epoch 25 | target_model = pl_module.model.backbone 26 | 27 | if current_epoch < freeze_end_epoch: 28 | for param in target_model.parameters(): 29 | param.requires_grad = False 30 | 31 | target_model.eval() 32 | else: 33 | for param in target_model.parameters(): 34 | param.requires_grad = True 35 | 36 | target_model.train() 37 | 38 | 39 | def main(): 40 | args = get_args() 41 | cfg = OmegaConf.load(args.config) 42 | cfg = OmegaConf.merge(cfg, OmegaConf.from_cli(args.opts)) 43 | seed_everything(cfg.task.seed, workers=True) 44 | model = MyModel(cfg, mode="train") 45 | print(OmegaConf.to_yaml(cfg)) 46 | dm = MyDataModule(cfg) 47 | train_args = dict(cfg.trainer) 48 | 49 | if cfg.wandb: 50 | wandb_logger = WandbLogger(project=cfg.wandb.project, name=cfg.wandb.name, log_model=False) 51 | train_args["logger"] = wandb_logger 52 | 53 | lr_monitor = LearningRateMonitor() 54 | dirpath = "saved_models/" + cfg.wandb.name if cfg.wandb.name else None 55 | finename = f"{cfg.wandb.name}" 56 | save_on_train_epoch_end = True if cfg.data.train_all else None 57 | checkpoint_callback = ModelCheckpoint(dirpath=dirpath, monitor="total_score", save_last=False, mode="max", 58 | filename=finename + "_{epoch:03d}_{total_score:.8f}", save_weights_only=True, 59 | save_on_train_epoch_end=save_on_train_epoch_end) 60 | 61 | # if cfg.wandb.name: 62 | # checkpoint_callback.CHECKPOINT_NAME_LAST = finename 63 | 64 | callbacks = [lr_monitor, checkpoint_callback] 65 | 66 | if cfg.model.swa: 67 | eta_min = cfg.opt.lr_min if cfg.opt.lr_min else cfg.opt.lr / 10 68 | swa_callback = StochasticWeightAveraging(swa_lrs=eta_min, annealing_strategy="linear") 69 | callbacks = [swa_callback] + callbacks # follow _configure_swa_callbacks 70 | 71 | if cfg.model.freeze_backbone: 72 | callbacks.append(BaseModelFreezeCallback()) 73 | 74 | trainer = Trainer(**train_args, callbacks=callbacks) 75 | trainer.fit(model, datamodule=dm) 76 | 77 | 78 | if __name__ == '__main__': 79 | main() 80 | -------------------------------------------------------------------------------- /src/datamodule.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | from torch.utils.data import DataLoader 4 | from pytorch_lightning import LightningDataModule 5 | 6 | from .dataset import MyDataset 7 | 8 | 9 | class MyDataModule(LightningDataModule): 10 | def __init__(self, cfg): 11 | super().__init__() 12 | self.cfg = cfg 13 | self.train_dataset = None 14 | self.val_dataset = None 15 | self.test_dataset = None 16 | self.input_root = Path(__file__).parents[1].joinpath("input") 17 | self.output_root = Path(__file__).parents[1].joinpath("output") 18 | 19 | def prepare_data(self): 20 | pass 21 | 22 | def setup(self, stage=None): 23 | if stage == "fit" or (stage == "predict" and self.cfg.test.mode == "val"): 24 | img_dir = self.output_root.joinpath("train_imgs") 25 | train_imgs = [] 26 | train_masks = [] 27 | val_imgs = [] 28 | val_masks = [] 29 | 30 | for i, npy_path in enumerate(sorted(img_dir.glob("*.npy"))): 31 | img = np.load(npy_path) 32 | 33 | if self.cfg.model.normalize_patch: 34 | img = (img - img.mean()) / img.std() 35 | else: 36 | img = (img - 5.2577832e-08) / 7.199929e-06 37 | 38 | if self.cfg.model.img_depth == 16: 39 | img = np.pad(img, ((0, 0), (0, 10), (0, 10)), mode="constant") 40 | mask = np.load(str(npy_path).replace("train_imgs", "train_masks")) 41 | mask = np.pad(mask, ((0, 0), (0, 10), (0, 10), (0, 0)), mode="constant") 42 | elif self.cfg.model.img_depth == 32: 43 | img = np.pad(img, ((4, 4), (0, 10), (0, 10)), mode="constant") 44 | mask = np.load(str(npy_path).replace("train_imgs", "train_masks")) 45 | mask = np.pad(mask, ((4, 4), (0, 10), (0, 10), (0, 0)), mode="constant") 46 | else: 47 | raise ValueError(f"unknown img depth {self.cfg.model.img_depth}") 48 | 49 | mask = (mask / 255.0).astype(np.float32) 50 | 51 | if i == self.cfg.data.fold_id: 52 | val_imgs.append(img) 53 | val_masks.append(mask) 54 | else: 55 | train_imgs.append(img) 56 | train_masks.append(mask) 57 | 58 | self.train_dataset = MyDataset(self.cfg, train_imgs, train_masks, "train") 59 | self.val_dataset = MyDataset(self.cfg, val_imgs, val_masks, "val") 60 | self.test_dataset = MyDataset(self.cfg, val_imgs, val_masks, "test") 61 | else: 62 | raise ValueError(f"unknown stage {stage}") 63 | 64 | def train_dataloader(self): 65 | return DataLoader(self.train_dataset, batch_size=self.cfg.data.batch_size, collate_fn=None, 66 | shuffle=True, drop_last=True, num_workers=self.cfg.data.num_workers) 67 | 68 | def val_dataloader(self): 69 | return DataLoader(self.val_dataset, batch_size=self.cfg.data.batch_size, collate_fn=None, 70 | shuffle=False, drop_last=False, num_workers=self.cfg.data.num_workers) 71 | 72 | def test_dataloader(self): 73 | return DataLoader(self.test_dataset, batch_size=self.cfg.data.batch_size, collate_fn=None, 74 | shuffle=False, drop_last=False, num_workers=self.cfg.data.num_workers) 75 | 76 | def predict_dataloader(self): 77 | return DataLoader(self.test_dataset, batch_size=self.cfg.data.batch_size, collate_fn=None, 78 | shuffle=False, drop_last=False, num_workers=self.cfg.data.num_workers) 79 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # CZII - CryoET Object Identification - 4th Place Solution 2 | This is the implementation of the 4th place solution (yu4u's part) for [CZII - CryoET Object Identification](https://www.kaggle.com/competitions/czii-cryo-et-object-identification) at Kaggle. 3 | The source code for tattaka's model can be found [here](https://github.com/tattaka/czii-cryo-et-object-identification-public). 4 | The overall solution is described in [this discussion](https://www.kaggle.com/competitions/czii-cryo-et-object-identification/discussion/561401). 5 | 6 | Our technical paper is also [available](https://arxiv.org/abs/2502.13484). 7 | 8 | ![yu4u pipeline](misc/model.png) 9 | 10 | ## Requirements 11 | - 10GB > VRAM (trained on GeForce RTX 3090 in my case). 12 | 13 | ## Preparation 14 | - Download the competition dataset from [here](https://www.kaggle.com/competitions/rsna-2024-lumbar-spine-degenerative-classification/data) and put them in `input` directory. 15 | ```shell 16 | unzip czii-cryo-et-object-identification.zip -d input 17 | ``` 18 | - Install Docker/NVIDIA Container Toolkit. 19 | - Build the Docker image and enter the Docker container: 20 | ```shell 21 | export UID=$(id -u) 22 | docker compose up -d 23 | docker compose exec dev /bin/bash 24 | ``` 25 | - Login to wandb: 26 | ```shell 27 | wandb login 28 | ``` 29 | 30 | ## Training Process 31 | 32 | ```sh 33 | # prepare dataset 34 | python 01_create_dataset.py 35 | 36 | # create ground-truth heatmap 37 | python 02_create_mask.py --class_num 6 38 | 39 | # create ground-truth solution for validation 40 | python 03_create_solution.py 41 | 42 | # train models, required 3.5 hours per fold on GeForce RTX 3090 x 1 43 | python 11_train.py trainer.accelerator=gpu trainer.devices=[0] data.batch_size=32 trainer.precision=16 data.num_workers=8 trainer.max_epochs=64 wandb.name=timm3d3_convnext_nano_class6_fold2 opt.lr=1e-3 opt.opt=AdamW opt.weight_decay=0 scheduler.min_lr=0 model.img_depth=16 model.img_size=128 model.arch=timm3d3 model.ema=True model.ema_decay=0.999 model.in_channels=5 loss.mixup=0.5 model.backbone=convnext_nano.in12k_ft_in1k scheduler.warmup_epochs=4 model.class_num=6 trainer.deterministic=warn data.fold_id=2 44 | python 11_train.py trainer.accelerator=gpu trainer.devices=[0] data.batch_size=32 trainer.precision=16 data.num_workers=8 trainer.max_epochs=64 wandb.name=timm3d3_convnext_nano_class6_fold3 opt.lr=1e-3 opt.opt=AdamW opt.weight_decay=0 scheduler.min_lr=0 model.img_depth=16 model.img_size=128 model.arch=timm3d3 model.ema=True model.ema_decay=0.999 model.in_channels=5 loss.mixup=0.5 model.backbone=convnext_nano.in12k_ft_in1k scheduler.warmup_epochs=4 model.class_num=6 trainer.deterministic=warn data.fold_id=3 45 | python 11_train.py trainer.accelerator=gpu trainer.devices=[0] data.batch_size=32 trainer.precision=16 data.num_workers=8 trainer.max_epochs=64 wandb.name=timm3d3_convnext_nano_class6_fold4 opt.lr=1e-3 opt.opt=AdamW opt.weight_decay=0 scheduler.min_lr=0 model.img_depth=16 model.img_size=128 model.arch=timm3d3 model.ema=True model.ema_decay=0.999 model.in_channels=5 loss.mixup=0.5 model.backbone=convnext_nano.in12k_ft_in1k scheduler.warmup_epochs=4 model.class_num=6 trainer.deterministic=warn data.fold_id=4 46 | python 11_train.py trainer.accelerator=gpu trainer.devices=[0] data.batch_size=32 trainer.precision=16 data.num_workers=8 trainer.max_epochs=64 wandb.name=timm3d3_convnext_nano_class6_fold6 opt.lr=1e-3 opt.opt=AdamW opt.weight_decay=0 scheduler.min_lr=0 model.img_depth=16 model.img_size=128 model.arch=timm3d3 model.ema=True model.ema_decay=0.999 model.in_channels=5 loss.mixup=0.5 model.backbone=convnext_nano.in12k_ft_in1k scheduler.warmup_epochs=4 model.class_num=6 trainer.deterministic=warn data.fold_id=6 47 | ``` 48 | 49 | Trained models are saved in `saved_models/timm3d3_convnext_nano_class6_fold[2|3|4|6]` directories. 50 | Our checkpoint can be found [here](https://www.kaggle.com/datasets/ren4yu/czii-models8). 51 | Convert these checkpoints into TensorRT engines using [this notebook](https://www.kaggle.com/code/ren4yu/czii-tensorrt-convert-1/notebook). 52 | The outputs from this notebook can be used in [the submission notebook](https://www.kaggle.com/code/ren4yu/czii-ensemble-tensorrt-xy-stride-th?scriptVersionId=220758003). 53 | 54 | ## Citation 55 | 56 | ```bibtex 57 | @article{uchida2025unet, 58 | title={2.5D U-Net with Depth Reduction for 3D CryoET Object Identification}, 59 | author={Uchida, Yusuke and Fukui, Takaaki}, 60 | journal={arXiv preprint arXiv:2502.13484}, 61 | year={2025} 62 | } 63 | ``` 64 | -------------------------------------------------------------------------------- /02_create_mask.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | import json 4 | import numpy as np 5 | 6 | from src.metric import particle_types 7 | 8 | 9 | def get_args(): 10 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 11 | parser.add_argument("--mode", type=str, default="train") # train or test 12 | parser.add_argument("--class_num", type=int, default=5) # 5 or 6 13 | args = parser.parse_args() 14 | return args 15 | 16 | 17 | def create_gaussian_patch(patch_size, sigma): 18 | """ 19 | Create a 3D Gaussian patch with the specified size and standard deviation. 20 | 21 | :param patch_size: Size of the cubic patch (patch_size, patch_size, patch_size). 22 | :param sigma: Standard deviation of the Gaussian. 23 | :return: 3D Gaussian patch as a 3D numpy array. 24 | """ 25 | center = patch_size // 2 26 | x = np.arange(patch_size) - center 27 | y = np.arange(patch_size) - center 28 | z = np.arange(patch_size) - center 29 | x, y, z = np.meshgrid(x, y, z, indexing='ij') 30 | gaussian = np.exp(-(x ** 2 + y ** 2 + z ** 2) / (2 * sigma ** 2)) 31 | gaussian *= 255.0 / np.max(gaussian) 32 | return gaussian.astype(np.uint8) 33 | 34 | 35 | 36 | def place_patch(volume, patch, x, y, z): 37 | """ 38 | Place a 3D Gaussian patch on the volume at the specified position, taking the maximum of the existing values and the patch. 39 | 40 | :param volume: The base 3D volume where the patch will be placed. 41 | :param patch: The 3D Gaussian patch to be placed. 42 | :param x: X-coordinate where the patch will be centered. 43 | :param y: Y-coordinate where the patch will be centered. 44 | :param z: Z-coordinate where the patch will be centered. 45 | """ 46 | patch_size = patch.shape[0] 47 | half_size = patch_size // 2 48 | 49 | # Determine the region of the volume to place the patch 50 | x_start = max(0, x - half_size) 51 | x_end = min(volume.shape[2], x + half_size) 52 | y_start = max(0, y - half_size) 53 | y_end = min(volume.shape[1], y + half_size) 54 | z_start = max(0, z - half_size) 55 | z_end = min(volume.shape[0], z + half_size) 56 | 57 | # Calculate the corresponding region in the patch 58 | patch_x_start = half_size - (x - x_start) 59 | patch_x_end = half_size + (x_end - x) 60 | patch_y_start = half_size - (y - y_start) 61 | patch_y_end = half_size + (y_end - y) 62 | patch_z_start = half_size - (z - z_start) 63 | patch_z_end = half_size + (z_end - z) 64 | 65 | # Take the maximum of the existing values and the patch 66 | volume[z_start:z_end, y_start:y_end, x_start:x_end] = np.maximum( 67 | volume[z_start:z_end, y_start:y_end, x_start:x_end], 68 | patch[patch_z_start:patch_z_end, patch_y_start:patch_y_end, patch_x_start:patch_x_end] 69 | ) 70 | 71 | 72 | def main(): 73 | args = get_args() 74 | mode = args.mode # train or test 75 | class_num = args.class_num 76 | root_dir = Path(__file__).parent.joinpath("input") 77 | output_root = Path(__file__).parent.joinpath("output") 78 | output_dir = output_root.joinpath(f"{mode}_masks") 79 | output_dir.mkdir(exist_ok=True, parents=True) 80 | json_names = [particle_type + ".json" for particle_type in particle_types[:class_num]] 81 | 82 | gaussian_patch = create_gaussian_patch(45, 6) 83 | 84 | for exp_dir in root_dir.joinpath(mode, "overlay", "ExperimentRuns").iterdir(): 85 | if not exp_dir.is_dir(): 86 | continue 87 | 88 | print(exp_dir) 89 | json_dir = exp_dir.joinpath("Picks") 90 | output_mask = np.zeros((184, 630, 630, class_num), dtype=np.uint8) 91 | d, h, w, _ = output_mask.shape 92 | 93 | for i, json_name in enumerate(json_names): 94 | json_path = json_dir.joinpath(json_name) 95 | 96 | with open(json_path) as f: 97 | picks = json.load(f) 98 | 99 | # [10.012444196428572, 10.012444196428572, 10.012444537618887] 100 | 101 | for point in picks["points"]: 102 | x = int(point["location"]["x"] / 10.012444537618887 + 1) 103 | y = int(point["location"]["y"] / 10.012444196428572 + 1) 104 | z = int(point["location"]["z"] / 10.012444196428572 + 1) 105 | assert 0 <= x < w 106 | assert 0 <= y < h 107 | assert 0 <= z < d 108 | 109 | place_patch(output_mask[:, :, :, i], gaussian_patch, x, y, z) 110 | 111 | output_path = output_dir.joinpath(f"{exp_dir.stem}.npy") 112 | np.save(output_path, output_mask) 113 | 114 | 115 | if __name__ == '__main__': 116 | main() 117 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import torch 4 | from torch.utils.data import Dataset 5 | import albumentations as A 6 | from albumentations.pytorch.transforms import ToTensorV2 7 | 8 | 9 | class MyDataset(Dataset): 10 | def __init__(self, cfg, x, y=None, mode="train", img_dir=None): 11 | assert mode in ["train", "val", "test"] 12 | self.cfg = cfg 13 | self.x = x # 184, 630, 630 14 | self.y = y # 184, 630, 630, 5 15 | self.mode = mode 16 | self.indices = self.get_indices() 17 | self.transforms = get_train_transforms(cfg) if mode == "train" else get_val_transforms(cfg) 18 | 19 | @staticmethod 20 | def get_stride(img_w, img_h, img_d, tile_size_x, tile_size_y, tile_size_z, stride_scale=0.5): 21 | tmp_stride_x = tile_size_x * stride_scale 22 | tmp_stride_y = tile_size_y * stride_scale 23 | tmp_stride_z = tile_size_z * stride_scale 24 | tile_num_x = max(round((img_w - tile_size_x) / tmp_stride_x + 1), 1) 25 | tile_num_y = max(round((img_h - tile_size_y) / tmp_stride_y + 1), 1) 26 | tile_num_z = max(round((img_d - tile_size_z) / tmp_stride_z + 1), 1) 27 | stride_x = (img_w - tile_size_x) // (tile_num_x - 1) if tile_num_x - 1 > 0 else 0 28 | stride_y = (img_h - tile_size_y) // (tile_num_y - 1) if tile_num_y - 1 > 0 else 0 29 | stride_z = (img_d - tile_size_z) // (tile_num_z - 1) if tile_num_z - 1 > 0 else 0 30 | return stride_x, stride_y, stride_z, tile_num_x, tile_num_y, tile_num_z 31 | 32 | def get_indices(self): 33 | img_size = self.cfg.model.img_size 34 | img_depth = self.cfg.model.img_depth 35 | tile_size_x = img_size 36 | tile_size_y = img_size 37 | tile_size_z = img_depth 38 | stride_scale = self.cfg.model.train_stride if self.mode == "train" else 0.5 39 | indices = [] 40 | 41 | for i, x_i in enumerate(self.x): 42 | img_d, img_h, img_w = x_i.shape 43 | s_x, s_y, s_z, tile_x, tile_y, tile_z = self.get_stride(img_w, img_h, img_d, tile_size_x, tile_size_y, 44 | tile_size_z, stride_scale=stride_scale) 45 | for iz in range(tile_z): 46 | for iy in range(tile_y): 47 | for ix in range(tile_x): 48 | sx = ix * s_x 49 | sy = iy * s_y 50 | sz = iz * s_z 51 | indices.append((i, sx, sy, sz)) 52 | return np.array(indices) 53 | 54 | def __len__(self): 55 | return len(self.indices) 56 | 57 | def __getitem__(self, idx): 58 | img_size = self.cfg.model.img_size 59 | img_depth = self.cfg.model.img_depth 60 | class_num = self.cfg.model.class_num 61 | exp_id, sx, sy, sz = self.indices[idx] 62 | x = self.x[exp_id] # d, h, w 63 | 64 | if self.mode == "train": 65 | # add random offset 66 | sx += np.random.randint(-img_size // 4, img_size // 4) 67 | sy += np.random.randint(-img_size // 4, img_size // 4) 68 | sz += np.random.randint(-img_depth // 4, img_depth // 4) 69 | sx = np.clip(sx, 0, x.shape[2] - img_size) 70 | sy = np.clip(sy, 0, x.shape[1] - img_size) 71 | sz = np.clip(sz, 0, x.shape[0] - img_depth) 72 | 73 | x = x[sz:sz + img_depth, sy:sy + img_size, sx:sx + img_size] 74 | x = x.transpose(1, 2, 0) # h, w, d 75 | 76 | # y[exp_id]: 184, 630, 630, 5 77 | y = self.y[exp_id][sz:sz + img_depth, sy:sy + img_size, sx:sx + img_size] if self.y is not None else -1 78 | y = y.transpose(1, 2, 0, 3) # h, w, d, c 79 | y = y.reshape(img_size, img_size, -1) 80 | 81 | sample = self.transforms(image=x, mask=y) 82 | x = sample["image"] 83 | x = x.unsqueeze(0) 84 | y = sample["mask"] 85 | y = y.reshape(img_depth, class_num, img_size, img_size) 86 | y = y.permute(1, 0, 2, 3) 87 | # y = (y / 255.0).float() 88 | 89 | if self.mode == "train" and self.cfg.model.depth_flip: 90 | if np.random.rand() > 0.5: 91 | x = torch.flip(x, [2]) 92 | y = torch.flip(y, [2]) 93 | 94 | return x, y, (exp_id, sx, sy, sz) 95 | 96 | 97 | def get_train_transforms(cfg): 98 | return A.Compose( 99 | [ 100 | # A.Resize(height=cfg.task.img_size, width=cfg.task.img_size, p=1), 101 | # A.CenterCrop(height=cfg.task.img_size, width=cfg.task.img_size, p=1), 102 | A.ShiftScaleRotate(p=0.5, border_mode=cv2.BORDER_CONSTANT, shift_limit=0.05, scale_limit=0.1, value=0, 103 | rotate_limit=180, mask_value=0), 104 | # A.RandomScale(scale_limit=(0.8, 1.2), p=1), 105 | # A.PadIfNeeded(min_height=cfg.model.img_size, min_width=cfg.model.img_size, p=1.0, 106 | # border_mode=cv2.BORDER_CONSTANT, value=0), 107 | # A.RandomCrop(height=self.cfg.data.train_img_h, width=self.cfg.data.train_img_w, p=1.0), 108 | # A.MultiplicativeNoise(multiplier=(0.9, 1.1), elementwise=True, p=0.5), 109 | # A.RandomRotate90(p=1.0), 110 | A.HorizontalFlip(p=0.5), 111 | # A.VerticalFlip(p=0.5), 112 | # A.RandomBrightnessContrast(p=0.5, brightness_limit=0.1, contrast_limit=0.1), 113 | # A.HueSaturationValue(p=0.5), 114 | # A.ToGray(p=0.3), 115 | # A.GaussNoise(var_limit=(0.0, 0.05), p=0.5), 116 | # A.GaussianBlur(p=0.5), 117 | # normalize with imagenet statis 118 | # A.Normalize(p=1.0, mean=5.2577832e-08, std=7.199929e-06, max_pixel_value=1.0), 119 | A.RandomBrightnessContrast( 120 | brightness_limit=0.3, contrast_limit=0.3, p=0.3 121 | ), 122 | ToTensorV2(p=1.0, transpose_mask=True), 123 | ], 124 | p=1.0, 125 | ) 126 | 127 | 128 | def get_val_transforms(cfg): 129 | return A.Compose( 130 | [ 131 | # A.Resize(height=cfg.task.img_size, width=cfg.task.img_size, p=1), 132 | # A.RandomScale(scale_limit=(1.0, 1.0), p=1), 133 | # A.PadIfNeeded(min_height=cfg.model.img_size, min_width=cfg.model.img_size, p=1.0, 134 | # border_mode=cv2.BORDER_CONSTANT, value=0), 135 | # A.Crop(y_max=self.cfg.data.val_img_h, x_max=self.cfg.data.val_img_w, p=1.0), 136 | # A.Normalize(p=1.0, mean=5.2577832e-08, std=7.199929e-06, max_pixel_value=1.0), 137 | ToTensorV2(p=1.0, transpose_mask=True), 138 | ], 139 | p=1.0, 140 | ) 141 | -------------------------------------------------------------------------------- /src/pl_module.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Any 3 | import sklearn.metrics 4 | import numpy as np 5 | import torch 6 | from pytorch_lightning.core.module import LightningModule 7 | from timm.utils import ModelEmaV3 8 | from timm.optim import create_optimizer_v2 9 | from timm.scheduler import create_scheduler_v2 10 | 11 | from .model import get_model_from_cfg, EnsembleModel 12 | from .loss import get_loss 13 | from .util import mixup, get_augment_policy 14 | from .metric import get_experiment_score, particle_types, particle_to_weights 15 | 16 | 17 | def get_patch_weight(img_size, depth): 18 | s = np.linspace(0, 1, img_size) 19 | t = np.linspace(0, 1, depth) 20 | a = np.minimum(s, 1 - s).reshape(1, -1, 1) 21 | b = np.minimum(s, 1 - s).reshape(1, 1, -1) 22 | c = np.minimum(t, 1 - t).reshape(-1, 1, 1) 23 | patch_weight = np.minimum(a, b) * c 24 | patch_weight = patch_weight / patch_weight.max() 25 | return patch_weight 26 | 27 | 28 | class MyModel(LightningModule): 29 | def __init__(self, cfg, mode="train"): 30 | super().__init__() 31 | self.preds = None 32 | self.weights = None 33 | self.gts = None 34 | self.cfg = cfg 35 | 36 | img_size = self.cfg.model.img_size 37 | img_depth = self.cfg.model.img_depth 38 | self.patch_weight = get_patch_weight(img_size, img_depth) 39 | 40 | if mode == "test": 41 | self.model = EnsembleModel(cfg) 42 | else: 43 | self.model = get_model_from_cfg(cfg, cfg.model.resume_path) 44 | 45 | if mode != "test" and cfg.model.ema: 46 | self.model_ema = ModelEmaV3( 47 | self.model, 48 | decay=cfg.model.ema_decay, 49 | update_after_step=cfg.model.ema_update_after_step, 50 | ) 51 | 52 | self.loss = get_loss(cfg) 53 | 54 | def forward(self, x): 55 | return self.model(x) 56 | 57 | def training_step(self, batch, batch_idx): 58 | x, y, (exp_id, sx, sy, sz) = batch 59 | augment_policy = get_augment_policy(self.cfg) 60 | 61 | if augment_policy == "mixup": 62 | x, targets1, targets2, lam = mixup(x, y) 63 | elif augment_policy == "nothing": 64 | pass 65 | else: 66 | raise ValueError(f"unknown augment policy {augment_policy}") 67 | 68 | output = self.model(x) 69 | 70 | if augment_policy == "nothing": 71 | loss_dict = {k: v if k == "loss" else v.detach() for k, v in self.loss(output, y).items()} 72 | else: 73 | loss_dict1 = self.loss(output, targets1) 74 | loss_dict2 = self.loss(output, targets2) 75 | loss_dict = {k: lam * loss_dict1[k] + (1 - lam) * loss_dict2[k] for k in loss_dict1.keys()} 76 | loss_dict = {k: v if k == "loss" else v.detach() for k, v in loss_dict.items()} 77 | 78 | self.log_dict(loss_dict, on_epoch=True, sync_dist=True) 79 | return loss_dict 80 | 81 | def on_train_batch_end(self, out, batch, batch_idx): 82 | if self.cfg.model.ema: 83 | self.model_ema.update(self.model) 84 | 85 | def on_validation_epoch_start(self) -> None: 86 | if self.cfg.model.img_depth == 16: 87 | self.preds = np.zeros((self.cfg.model.class_num, 184, 640, 640), dtype=np.float32) 88 | self.weights = np.zeros((1, 184, 640, 640), dtype=np.float32) 89 | elif self.cfg.model.img_depth == 32: 90 | self.preds = np.zeros((self.cfg.model.class_num, 192, 640, 640), dtype=np.float32) 91 | self.weights = np.zeros((1, 192, 640, 640), dtype=np.float32) 92 | else: 93 | raise ValueError(f"unknown img_depth {self.cfg.model.img_depth}") 94 | 95 | def validation_step(self, batch, batch_idx): 96 | x, y, (exp_id, sx, sy, sz) = batch 97 | 98 | if self.cfg.model.ema: 99 | output = self.model_ema.module(x) 100 | else: 101 | output = self.model(x) 102 | 103 | loss_dict = self.loss(output, y) 104 | log_dict = {"val_" + k: v for k, v in loss_dict.items()} 105 | self.log_dict(log_dict, on_epoch=True, sync_dist=True) 106 | 107 | img_size = self.cfg.model.img_size 108 | img_depth = self.cfg.model.img_depth 109 | 110 | if self.cfg.model.arch == "timm3d6": 111 | output = output[-1] 112 | 113 | output = output.cpu().numpy() 114 | 115 | for pred, x, y, z in zip(output, sx, sy, sz): 116 | self.preds[:, z:z + img_depth, y:y + img_size, x:x + img_size] += pred * self.patch_weight 117 | self.weights[:, z:z + img_depth, y:y + img_size, x:x + img_size] += self.patch_weight 118 | 119 | def on_validation_epoch_end(self): 120 | preds = self.preds 121 | weights = np.maximum(self.weights, 1.0) 122 | preds = preds / weights 123 | 124 | if self.cfg.model.img_depth == 16: 125 | pass 126 | elif self.cfg.model.img_depth == 32: 127 | preds = preds[:, 4:-4] 128 | else: 129 | raise ValueError(f"unknown img_depth {self.cfg.model.img_depth}") 130 | 131 | fold_id = self.cfg.data.fold_id 132 | scores = get_experiment_score(preds, fold_id) 133 | type_to_scores = dict(zip(particle_types, scores)) 134 | self.log_dict(type_to_scores, on_epoch=True, sync_dist=True) 135 | 136 | total_score = 0.0 137 | total_weights = 0.0 138 | 139 | for particle_type, score in type_to_scores.items(): 140 | weight = particle_to_weights[particle_type] 141 | total_score += score * weight 142 | total_weights += weight 143 | 144 | total_score = total_score / total_weights 145 | self.log("total_score", total_score, on_epoch=True, sync_dist=True) 146 | 147 | def on_test_start(self): 148 | pass 149 | 150 | def test_step(self, batch, batch_idx): 151 | pass 152 | 153 | def on_test_epoch_end(self): 154 | pass 155 | 156 | def on_predict_start(self): 157 | if self.cfg.model.img_depth == 16: 158 | self.preds = np.zeros((self.cfg.model.class_num, 184, 640, 640), dtype=np.float32) 159 | self.weights = np.zeros((1, 184, 640, 640), dtype=np.float32) 160 | elif self.cfg.model.img_depth == 32: 161 | self.preds = np.zeros((self.cfg.model.class_num, 192, 640, 640), dtype=np.float32) 162 | self.weights = np.zeros((1, 192, 640, 640), dtype=np.float32) 163 | else: 164 | raise ValueError(f"unknown img_depth {self.cfg.model.img_depth}") 165 | 166 | 167 | def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: int = 0): 168 | x, _, (exp_id, sx, sy, sz) = batch 169 | img_size = self.cfg.model.img_size 170 | img_depth = self.cfg.model.img_depth 171 | output = self.model(x) 172 | 173 | if self.cfg.model.arch == "timm3d6": 174 | output = output[-1] 175 | 176 | output = output.cpu().numpy() 177 | 178 | for pred, x, y, z in zip(output, sx, sy, sz): 179 | self.preds[:, z:z + img_depth, y:y + img_size, x:x + img_size] += pred * self.patch_weight 180 | self.weights[:, z:z + img_depth, y:y + img_size, x:x + img_size] += self.patch_weight 181 | 182 | def configure_optimizers(self): 183 | optimizer = create_optimizer_v2(model_or_params=self.model, **self.cfg.opt) 184 | batch_size = self.cfg.data.batch_size 185 | updates_per_epoch = len(self.trainer.datamodule.train_dataset) // batch_size // self.trainer.num_devices 186 | scheduler, num_epochs = create_scheduler_v2(optimizer=optimizer, num_epochs=self.cfg.trainer.max_epochs, 187 | warmup_lr=0, **self.cfg.scheduler, 188 | step_on_epochs=False, updates_per_epoch=updates_per_epoch) 189 | lr_dict = dict( 190 | scheduler=scheduler, 191 | interval="step", 192 | frequency=1, # same as default 193 | ) 194 | return dict(optimizer=optimizer, lr_scheduler=lr_dict) 195 | 196 | def lr_scheduler_step(self, scheduler, metric): 197 | scheduler.step_update(num_updates=self.global_step) 198 | -------------------------------------------------------------------------------- /src/timm_3d_decoder3.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import timm 5 | from timm.models import FeatureListNet 6 | 7 | from .util import SlidingSlicer 8 | 9 | 10 | class Conv3dReLU(nn.Module): 11 | def __init__(self, in_channels, out_channels, kernel_size=3, padding=1, use_batchnorm=True): 12 | super().__init__() 13 | if use_batchnorm: 14 | self.block = nn.Sequential( 15 | nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=False), 16 | nn.BatchNorm3d(out_channels), 17 | nn.ReLU(inplace=True) 18 | ) 19 | else: 20 | self.block = nn.Sequential( 21 | nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding), 22 | nn.ReLU(inplace=True) 23 | ) 24 | 25 | def forward(self, x): 26 | return self.block(x) 27 | 28 | 29 | class DecoderBlock(nn.Module): 30 | def __init__( 31 | self, 32 | in_channels, 33 | skip_channels, 34 | out_channels, 35 | i, 36 | use_batchnorm=True, 37 | ): 38 | super().__init__() 39 | self.conv1 = Conv3dReLU( 40 | in_channels + skip_channels, 41 | out_channels, 42 | kernel_size=3, 43 | padding=1, 44 | use_batchnorm=use_batchnorm, 45 | ) 46 | self.conv2 = Conv3dReLU( 47 | out_channels, 48 | out_channels, 49 | kernel_size=3, 50 | padding=1, 51 | use_batchnorm=use_batchnorm, 52 | ) 53 | self.k = 1 if i == 0 else 2 54 | 55 | def forward(self, x, skip=None): 56 | x = F.interpolate(x, scale_factor=(self.k, 2, 2), mode="nearest") 57 | if skip is not None: 58 | x = torch.cat([x, skip], dim=1) 59 | x = self.conv1(x) 60 | x = self.conv2(x) 61 | return x 62 | 63 | 64 | class UnetDecoder(nn.Module): 65 | def __init__( 66 | self, 67 | encoder_channels, 68 | decoder_channels, 69 | n_blocks=5, 70 | use_batchnorm=True, 71 | ): 72 | super().__init__() 73 | 74 | if n_blocks != len(decoder_channels): 75 | raise ValueError( 76 | "Model depth is {}, but you provide `decoder_channels` for {} blocks.".format( 77 | n_blocks, len(decoder_channels) 78 | ) 79 | ) 80 | 81 | # remove first skip with same spatial resolution 82 | encoder_channels = encoder_channels[1:] 83 | # reverse channels to start from head of encoder 84 | encoder_channels = encoder_channels[::-1] 85 | 86 | # computing blocks input and output channels 87 | head_channels = encoder_channels[0] 88 | in_channels = [head_channels] + list(decoder_channels[:-1]) 89 | skip_channels = list(encoder_channels[1:]) + [0] 90 | out_channels = decoder_channels 91 | 92 | # combine decoder keyword arguments 93 | kwargs = dict(use_batchnorm=use_batchnorm) 94 | blocks = [ 95 | DecoderBlock(in_ch, skip_ch, out_ch, i, **kwargs) 96 | for i, (in_ch, skip_ch, out_ch) in enumerate(zip(in_channels, skip_channels, out_channels)) 97 | ] 98 | self.blocks = nn.ModuleList(blocks) 99 | 100 | def forward(self, *features): 101 | 102 | features = features[1:] # remove first skip with same spatial resolution 103 | features = features[::-1] # reverse channels to start from head of encoder 104 | 105 | x = features[0] 106 | skips = features[1:] 107 | 108 | for i, decoder_block in enumerate(self.blocks): 109 | skip = skips[i] if i < len(skips) else None 110 | x = decoder_block(x, skip) 111 | 112 | return x 113 | 114 | 115 | def channel_to_spatial_3d(input_tensor, upscale_factor=2): 116 | N, C, D, H, W = input_tensor.shape 117 | factor_cubed = upscale_factor ** 3 118 | new_C = C // factor_cubed 119 | output_tensor = input_tensor.view( 120 | N, new_C, upscale_factor, upscale_factor, upscale_factor, D, H, W 121 | ) 122 | output_tensor = output_tensor.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous() 123 | output_tensor = output_tensor.view( 124 | N, new_C, D * upscale_factor, H * upscale_factor, W * upscale_factor 125 | ) 126 | 127 | return output_tensor 128 | 129 | 130 | class Timm3DDecoder3(torch.nn.Module): 131 | def __init__(self, cfg): 132 | super().__init__() 133 | self.cfg = cfg 134 | pretrained = True if cfg.model.resume_path is None else False 135 | 136 | if cfg.model.in_channels > 1: 137 | self.stem = SlidingSlicer(slice_size=cfg.model.in_channels) 138 | 139 | self.out_ch = cfg.model.class_num 140 | model = timm.create_model( 141 | cfg.model.backbone, 142 | in_chans=cfg.model.in_channels, 143 | pretrained=pretrained, 144 | drop_path_rate=cfg.model.drop_path_rate, 145 | ) 146 | out_channels = [fi["num_chs"] for fi in model.feature_info] 147 | 148 | try: 149 | self.backbone = FeatureListNet(model, out_indices=tuple(range(len(out_channels))), flatten_sequential=True) 150 | except AssertionError: 151 | self.backbone = FeatureListNet(model, out_indices=tuple(range(len(out_channels))), flatten_sequential=False) 152 | 153 | self.backbone.out_channels = [cfg.model.in_channels] + out_channels 154 | 155 | self.decoder = UnetDecoder( 156 | encoder_channels=self.backbone.out_channels, 157 | decoder_channels=(1024, 768, 256), 158 | n_blocks=3, 159 | use_batchnorm=True, 160 | ) 161 | 162 | k = 3 163 | conv3ds = [ 164 | torch.nn.Sequential( 165 | Conv3dReLU(ch, ch, k, k // 2, use_batchnorm=True), 166 | Conv3dReLU(ch, ch, k, k // 2, use_batchnorm=True) 167 | ) 168 | for ch in self.backbone.out_channels[1:] 169 | ] 170 | self.conv3ds = torch.nn.ModuleList(conv3ds) 171 | self.segmentation_head = nn.Conv3d(256, 64 * self.out_ch, 1, padding=0) 172 | 173 | def _to2d(self, conv3d_block: torch.nn.Module, feature: torch.Tensor, b) -> torch.Tensor: 174 | total_batch, ch, H, W = feature.shape # b * d, ch, H, W 175 | feat_3d = feature.reshape(b, total_batch // b, ch, H, W).transpose(1, 2) 176 | feat_3d = conv3d_block(feat_3d) # b, ch, d, H, W 177 | return feat_3d 178 | 179 | def forward(self, x: torch.Tensor) -> torch.Tensor: 180 | b, _, d, h, w = x.shape 181 | 182 | if self.cfg.model.in_channels > 1: 183 | x = self.stem(x) # b, d, c, h, w 184 | 185 | x = x.reshape(b * d, self.cfg.model.in_channels, h, w) # b * d, c, h, w 186 | features = [x] 187 | pooled_cnt = 0 188 | 189 | for i, (name, module) in enumerate(self.backbone.items()): 190 | x = module(x) 191 | 192 | if name in self.backbone.return_layers: 193 | total_batch, ch, h, w = x.shape 194 | x = x.reshape(b, total_batch // b, ch, h, w).transpose(1, 2) # b, ch, d, h, w 195 | 196 | if pooled_cnt >= 4: 197 | k = 1 198 | elif self.cfg.model.img_size // w > self.cfg.model.img_depth * 2 // (total_batch // b): 199 | k = 4 200 | pooled_cnt += 2 201 | else: 202 | k = 2 203 | pooled_cnt += 1 204 | 205 | x = F.avg_pool3d(x, kernel_size=(k, 1, 1), stride=(k, 1, 1), padding=0) # b, ch, d // 2, h, w 206 | x = x.transpose(1, 2).reshape(total_batch // k, ch, h, w) # b * d // 2, ch, h, w 207 | features.append(x) 208 | 209 | 210 | features[1:] = [self._to2d(conv3d, feature, b) for conv3d, feature in zip(self.conv3ds, features[1:])] 211 | decoder_output = self.decoder(*features) 212 | masks = self.segmentation_head(decoder_output) 213 | masks = channel_to_spatial_3d(masks, upscale_factor=4) 214 | return masks 215 | 216 | def set_grad_checkpointing(self, enable: bool = True): 217 | self.backbone.encoder.model.set_grad_checkpointing(enable) 218 | -------------------------------------------------------------------------------- /src/metric.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | import numpy as np 3 | import pandas as pd 4 | from scipy.spatial import KDTree 5 | from scipy.ndimage import maximum_filter 6 | 7 | 8 | experiments = ["TS_5_4", "TS_69_2", "TS_6_4", "TS_6_6", "TS_73_6", "TS_86_3", "TS_99_9"] 9 | 10 | 11 | particle_radius = { 12 | 'apo-ferritin': 60, 13 | 'beta-amylase': 65, 14 | 'beta-galactosidase': 90, 15 | 'ribosome': 150, 16 | 'thyroglobulin': 130, 17 | 'virus-like-particle': 135, 18 | } 19 | 20 | 21 | particle_to_weights = { 22 | 'apo-ferritin': 1, 23 | 'beta-amylase': 0, 24 | 'beta-galactosidase': 2, 25 | 'ribosome': 1, 26 | 'thyroglobulin': 2, 27 | 'virus-like-particle': 1, 28 | } 29 | 30 | particle_types = [ 31 | "apo-ferritin", 32 | "beta-galactosidase", 33 | "ribosome", 34 | "thyroglobulin", 35 | "virus-like-particle", 36 | "beta-amylase", 37 | ] 38 | 39 | 40 | def get_score(reference_points, reference_radius, candidate_points): 41 | beta = 4 42 | 43 | if len(reference_points) == 0: 44 | reference_points = np.array([]) 45 | reference_radius = 1 46 | 47 | if len(candidate_points) == 0: 48 | candidate_points = np.array([]) 49 | 50 | tp, fp, fn = compute_metrics(reference_points, reference_radius, candidate_points) 51 | precision = tp / (tp + fp) if tp + fp > 0 else 0 52 | recall = tp / (tp + fn) if tp + fn > 0 else 0 53 | fbeta = (1 + beta ** 2) * (precision * recall) / (beta ** 2 * precision + recall) if ( 54 | precision + recall) > 0 else 0.0 55 | return fbeta, tp, fp, fn 56 | 57 | 58 | class ParticipantVisibleError(Exception): 59 | pass 60 | 61 | 62 | def compute_metrics(reference_points, reference_radius, candidate_points): 63 | num_reference_particles = len(reference_points) 64 | num_candidate_particles = len(candidate_points) 65 | 66 | if len(reference_points) == 0: 67 | return 0, num_candidate_particles, 0 68 | 69 | if len(candidate_points) == 0: 70 | return 0, 0, num_reference_particles 71 | 72 | ref_tree = KDTree(reference_points) 73 | candidate_tree = KDTree(candidate_points) 74 | raw_matches = candidate_tree.query_ball_tree(ref_tree, r=reference_radius) 75 | matches_within_threshold = [] 76 | for match in raw_matches: 77 | matches_within_threshold.extend(match) 78 | # Prevent submitting multiple matches per particle. 79 | # This won't be be strictly correct in the (extremely rare) case where true particles 80 | # are very close to each other. 81 | matches_within_threshold = set(matches_within_threshold) 82 | tp = int(len(matches_within_threshold)) 83 | fp = int(num_candidate_particles - tp) 84 | fn = int(num_reference_particles - tp) 85 | return tp, fp, fn 86 | 87 | 88 | def score( 89 | solution: pd.DataFrame, 90 | submission: pd.DataFrame, 91 | distance_multiplier: float, 92 | beta: int, 93 | use_weight: bool = True) -> float: 94 | ''' 95 | F_beta 96 | - a true positive occurs when 97 | - (a) the predicted location is within a threshold of the particle radius, and 98 | - (b) the correct `particle_type` is specified 99 | - raw results (TP, FP, FN) are aggregated across all experiments for each particle type 100 | - f_beta is calculated for each particle type 101 | - individual f_beta scores are weighted by particle type for final score 102 | ''' 103 | _particle_radius = {k: v * distance_multiplier for k, v in particle_radius.items()} 104 | 105 | # Filter submission to only contain experiments found in the solution split 106 | split_experiments = set(solution['experiment'].unique()) 107 | submission = submission.loc[submission['experiment'].isin(split_experiments)] 108 | 109 | # Only allow known particle types 110 | if not set(submission['particle_type'].unique()).issubset(set(particle_to_weights.keys())): 111 | raise ParticipantVisibleError('Unrecognized `particle_type`.') 112 | 113 | assert solution.duplicated(subset=['experiment', 'x', 'y', 'z']).sum() == 0 114 | assert _particle_radius.keys() == particle_to_weights.keys() 115 | 116 | results = {} 117 | for particle_type in solution['particle_type'].unique(): 118 | results[particle_type] = { 119 | 'total_tp': 0, 120 | 'total_fp': 0, 121 | 'total_fn': 0, 122 | } 123 | 124 | for experiment in split_experiments: 125 | for particle_type in solution['particle_type'].unique(): 126 | reference_radius = _particle_radius[particle_type] 127 | select = (solution['experiment'] == experiment) & (solution['particle_type'] == particle_type) 128 | reference_points = solution.loc[select, ['x', 'y', 'z']].values 129 | 130 | select = (submission['experiment'] == experiment) & (submission['particle_type'] == particle_type) 131 | candidate_points = submission.loc[select, ['x', 'y', 'z']].values 132 | 133 | if len(reference_points) == 0: 134 | reference_points = np.array([]) 135 | reference_radius = 1 136 | 137 | if len(candidate_points) == 0: 138 | candidate_points = np.array([]) 139 | 140 | tp, fp, fn = compute_metrics(reference_points, reference_radius, candidate_points) 141 | 142 | results[particle_type]['total_tp'] += tp 143 | results[particle_type]['total_fp'] += fp 144 | results[particle_type]['total_fn'] += fn 145 | # print(results) 146 | aggregate_fbeta = 0.0 147 | for particle_type, totals in results.items(): 148 | tp = totals['total_tp'] 149 | fp = totals['total_fp'] 150 | fn = totals['total_fn'] 151 | 152 | precision = tp / (tp + fp) if tp + fp > 0 else 0 153 | recall = tp / (tp + fn) if tp + fn > 0 else 0 154 | fbeta = (1 + beta**2) * (precision * recall) / (beta**2 * precision + recall) if (precision + recall) > 0 else 0.0 155 | 156 | if use_weight: 157 | aggregate_fbeta += fbeta * particle_to_weights.get(particle_type, 1.0) 158 | else: 159 | aggregate_fbeta += fbeta 160 | 161 | if use_weight: 162 | aggregate_fbeta = aggregate_fbeta / sum(particle_to_weights.values()) 163 | else: 164 | aggregate_fbeta = aggregate_fbeta / len(results) 165 | return aggregate_fbeta 166 | 167 | 168 | def find_local_maxima(arrs, threshold, a): 169 | all_coordinates = [] 170 | all_scores = [] 171 | 172 | for i in range(arrs.shape[0]): 173 | # 最大フィルタを使用して近傍の最大値を取得 174 | arr = arrs[i] 175 | max_filtered = maximum_filter(arr, size=a, mode='constant') 176 | 177 | # 入力配列が最大フィルタの出力と等しい場所(ローカル最大値)を探す 178 | local_maxima = (arr == max_filtered) & (arr >= threshold) 179 | 180 | # ローカル最大値のインデックスを取得 181 | coordinates = np.argwhere(local_maxima) 182 | all_coordinates.append(coordinates) 183 | scores = arr[local_maxima] 184 | all_scores.append(scores) 185 | 186 | return all_coordinates, all_scores 187 | 188 | 189 | def get_experiment_score(preds, fold_id): 190 | threshold = 0.1 # しきい値 191 | a = 3 # ピクセル数 192 | coordinates, scores = find_local_maxima(preds, threshold, a) 193 | best_ss = [] 194 | 195 | csv_path = Path(__file__).parents[1].joinpath("output", "solution.csv") 196 | solution = pd.read_csv(csv_path, index_col=0) 197 | experiment = experiments[fold_id] 198 | solution = solution[solution["experiment"] == experiment] 199 | solution.head() 200 | 201 | for i, coordinate in enumerate(coordinates): 202 | particle_type = particle_types[i] 203 | score_i = scores[i] 204 | 205 | best_s = 0 206 | best_th = 0 207 | best_stat = None 208 | keep = (solution["particle_type"] == particle_type) & (solution["experiment"] == experiment) 209 | reference_points = solution[keep][["x", "y", "z"]].values 210 | reference_radius = particle_radius[particle_type] * 0.5 211 | 212 | for th in np.linspace(0.05, 0.95, 19): 213 | candidate_points = [] 214 | 215 | for score_j, (z, y, x) in zip(score_i, coordinate): 216 | if score_j < th: 217 | continue 218 | 219 | point = ((x + 0.5 - 1) * 10.012444537618887, (y + 0.5 - 1) * 10.012444196428572, (z + 0.5 - 1) * 10.012444196428572) 220 | candidate_points.append(point) 221 | 222 | s, tp, fp, fn = get_score(reference_points, reference_radius, candidate_points) 223 | stat = tp, fp, fn 224 | 225 | if s > best_s: 226 | best_s = s 227 | best_th = th 228 | best_stat = stat 229 | 230 | best_ss.append(best_s) 231 | 232 | return best_ss 233 | --------------------------------------------------------------------------------