├── pyproject.toml ├── input ├── species.npy ├── individual_id.npy └── README.md ├── Dockerfile ├── requirements.txt ├── config ├── debug.yaml ├── efficientnet_b5.yaml ├── efficientnet_b6.yaml ├── efficientnet_b7.yaml ├── efficientnet_v2l.yaml ├── efficientnet_v2m.yaml ├── config.py └── default.yaml ├── src ├── utils.py ├── metric_learning.py ├── tune.py ├── ensemble.py ├── dataset.py └── train.py ├── README.md └── .gitignore /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 120 -------------------------------------------------------------------------------- /input/species.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knshnb/kaggle-happywhale-1st-place/HEAD/input/species.npy -------------------------------------------------------------------------------- /input/individual_id.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/knshnb/kaggle-happywhale-1st-place/HEAD/input/individual_id.npy -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvcr.io/nvidia/pytorch:22.02-py3 2 | 3 | COPY requirements.txt . 4 | RUN pip install -r requirements.txt 5 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | albumentations==1.1.0 2 | opencv-python-headless==4.5.5.64 3 | optuna==3.0.4 4 | pandas==1.3.5 5 | pytorch-lightning==1.5.10 6 | timm==0.5.4 7 | wandb==0.12.16 8 | -------------------------------------------------------------------------------- /config/debug.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 128 2 | image_size: 3 | - 128 4 | - 128 5 | max_epochs: 5 6 | model_name: resnet18d 7 | out_indices: 8 | - 3 9 | - 4 10 | n_splits: 5 11 | n_data: 1000 12 | pseudo_label: 13 | pseudo_conf_threshold: 0.0 14 | -------------------------------------------------------------------------------- /config/efficientnet_b5.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | image_size: 3 | - 1024 4 | - 1024 5 | max_epochs: 30 6 | model_name: tf_efficientnet_b5_ns 7 | out_indices: 8 | - 3 9 | - 4 10 | n_splits: -1 11 | pseudo_label: pseudo_labels/round2.csv 12 | pseudo_conf_threshold: 0.6 13 | -------------------------------------------------------------------------------- /config/efficientnet_b6.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 6 2 | image_size: 3 | - 1024 4 | - 1024 5 | max_epochs: 30 6 | model_name: tf_efficientnet_b6_ns 7 | out_indices: 8 | - 3 9 | - 4 10 | n_splits: -1 11 | pseudo_label: pseudo_labels/round2.csv 12 | pseudo_conf_threshold: 0.6 13 | -------------------------------------------------------------------------------- /config/efficientnet_b7.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 4 2 | image_size: 3 | - 1024 4 | - 1024 5 | max_epochs: 30 6 | model_name: tf_efficientnet_b7_ns 7 | out_indices: 8 | - 3 9 | - 4 10 | n_splits: -1 11 | pseudo_label: pseudo_labels/round2.csv 12 | pseudo_conf_threshold: 0.6 13 | -------------------------------------------------------------------------------- /config/efficientnet_v2l.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 4 2 | image_size: 3 | - 1024 4 | - 1024 5 | max_epochs: 30 6 | model_name: tf_efficientnet_b7_ns 7 | out_indices: 8 | - 3 9 | - 4 10 | n_splits: -1 11 | pseudo_label: pseudo_labels/round2.csv 12 | pseudo_conf_threshold: 0.6 13 | -------------------------------------------------------------------------------- /config/efficientnet_v2m.yaml: -------------------------------------------------------------------------------- 1 | batch_size: 8 2 | image_size: 3 | - 1024 4 | - 1024 5 | max_epochs: 30 6 | model_name: tf_efficientnet_b7_ns 7 | out_indices: 8 | - 3 9 | - 4 10 | n_splits: -1 11 | pseudo_label: pseudo_labels/round2.csv 12 | pseudo_conf_threshold: 0.6 13 | -------------------------------------------------------------------------------- /config/config.py: -------------------------------------------------------------------------------- 1 | from typing import Optional 2 | 3 | import yaml 4 | 5 | 6 | class Config(dict): 7 | def __getattr__(self, key): 8 | try: 9 | val = self[key] 10 | except KeyError: 11 | return super().__getattr__(key) 12 | if isinstance(val, dict): 13 | return Config(val) 14 | return val 15 | 16 | 17 | def load_config(path: str, default_path: Optional[str]) -> Config: 18 | with open(path) as f: 19 | cfg = Config(yaml.full_load(f)) 20 | if default_path is not None: 21 | # set keys not included in `path` by default 22 | with open(default_path) as f: 23 | default_cfg = Config(yaml.full_load(f)) 24 | for key, val in default_cfg.items(): 25 | if key not in cfg: 26 | print(f"used default config {key}: {val}") 27 | cfg[key] = val 28 | return cfg 29 | -------------------------------------------------------------------------------- /input/README.md: -------------------------------------------------------------------------------- 1 | # Dataset 2 | ## Containing files 3 | 4 | ### fullbody_train_charm.csv/fullbody_test_charm.csv 5 | Made from [Jan Bre's notebook](https://www.kaggle.com/code/jpbremer/backfin-detection-with-yolov5) with DIM=1024, EPOCH=50. 6 | We used fullbody_annotations.csv from [Jan Bre's dataset](https://www.kaggle.com/datasets/jpbremer/fullbodywhaleannotations) as training data. 7 | 8 | ### species.npy/individual_id.npy 9 | Arrays of label encoders used in charmq's pipeline. 10 | 11 | ### pseudo_labels/ 12 | List of pseudo labels we created in each round. 13 | 14 | 15 | ## Files you need to download from Kaggle 16 | ### train.csv/sample_submission.csv/train_images/test_images 17 | [Competition dataset](https://www.kaggle.com/competitions/happy-whale-and-dolphin/data). 18 | 19 | ### fullbody_train.csv/fullbody_test.csv 20 | [Jan Bre's dataset](https://www.kaggle.com/datasets/jpbremer/fullbodywhaleannotations). 21 | 22 | ### train_backfin.csv/test_backfin.csv 23 | [Jan Bre's notebook](https://www.kaggle.com/code/jpbremer/backfin-detection-with-yolov5) (copy of train.csv/test.csv). 24 | 25 | ### train2.csv/test2.csv 26 | [phalanx's dataset](https://www.kaggle.com/datasets/phalanx/whale2-cropped-dataset). Bboxes of Detic. 27 | 28 | ### yolov5_train.csv/yolov5_test.csv 29 | [Awsaf's notebook](https://www.kaggle.com/code/awsaf49/happywhale-cropped-dataset-yolov5) (copy of train.csv/test.csv). 30 | -------------------------------------------------------------------------------- /config/default.yaml: -------------------------------------------------------------------------------- 1 | lr_backbone: 1.6e-3 2 | lr_head: 1.6e-2 3 | lr_decay_scale: 1.0e-2 4 | batch_size: 8 5 | image_size: 6 | - 768 7 | - 768 8 | max_epochs: 30 9 | model_name: 10 | out_indices: 11 | - 3 12 | - 4 13 | n_splits: 5 14 | num_classes: 15587 15 | num_species_classes: 26 16 | pretrained: true 17 | warmup_steps_ratio: 0.2 18 | val_bbox: fullbody 19 | test_bboxes: 20 | - fullbody 21 | - fullbody_charm 22 | bboxes: 23 | fullbody_charm: 0.15 24 | fullbody: 0.60 25 | backfin: 0.15 26 | detic: 0.05 27 | none: 0.05 28 | bbox_conf_threshold: 0.01 29 | n_data: -1 30 | global_pool: 31 | arch: GeM 32 | p: 3 33 | train: false 34 | normalization: batchnorm 35 | optimizer: AdamW 36 | loss_fn: CrossEntropy 37 | pseudo_label: 38 | pseudo_conf_threshold: 0.0 39 | 40 | loss_id_ratio: 0.437338 41 | margin_coef_id: 0.27126 42 | margin_coef_species: 0.226253 43 | margin_power_id: -0.364399 44 | margin_power_species: -0.720133 45 | s_id: 20.9588 46 | s_species: 33.1383 47 | 48 | margin_cons_id: 0.05 49 | margin_cons_species: 0.05 50 | n_center_id: 2 51 | n_center_species: 2 52 | 53 | aug: 54 | rotate: 15 55 | translate: 0.25 56 | shear: 3 57 | p_affine: 0.5 58 | crop_scale: 0.9 59 | crop_l: 0.75 60 | crop_r: 1.3333333333333333 61 | p_gray: 0.1 62 | p_blur: 0.05 63 | p_noise: 0.05 64 | p_downscale: 0.0 65 | p_shuffle: 0.3 66 | p_posterize: 0.2 67 | p_bright_contrast: 0.5 68 | p_cutout: 0.05 69 | p_snow: 0.1 70 | p_rain: 0.05 71 | -------------------------------------------------------------------------------- /src/utils.py: -------------------------------------------------------------------------------- 1 | import math 2 | from typing import Optional 3 | 4 | import torch 5 | 6 | 7 | class WarmupCosineLambda: 8 | def __init__(self, warmup_steps: int, cycle_steps: int, decay_scale: float, exponential_warmup: bool = False): 9 | self.warmup_steps = warmup_steps 10 | self.cycle_steps = cycle_steps 11 | self.decay_scale = decay_scale 12 | self.exponential_warmup = exponential_warmup 13 | 14 | def __call__(self, epoch: int): 15 | if epoch < self.warmup_steps: 16 | if self.exponential_warmup: 17 | return self.decay_scale * pow(self.decay_scale, -epoch / self.warmup_steps) 18 | ratio = epoch / self.warmup_steps 19 | else: 20 | ratio = (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.cycle_steps)) / 2 21 | return self.decay_scale + (1 - self.decay_scale) * ratio 22 | 23 | 24 | def topk_average_precision(output: torch.Tensor, y: torch.Tensor, k: int): 25 | score_array = torch.tensor([1.0 / i for i in range(1, k + 1)], device=output.device) 26 | topk = output.topk(k)[1] 27 | match_mat = topk == y[:, None].expand(topk.shape) 28 | return (match_mat * score_array).sum(dim=1) 29 | 30 | 31 | def calc_map5(output: torch.Tensor, y: torch.Tensor, threshold: Optional[float]): 32 | if threshold is not None: 33 | output = torch.cat([output, torch.full((output.shape[0], 1), threshold, device=output.device)], dim=1) 34 | return topk_average_precision(output, y, 5).mean().detach() 35 | 36 | 37 | def map_dict(output: torch.Tensor, y: torch.Tensor, prefix: str): 38 | d = {f"{prefix}/acc": topk_average_precision(output, y, 1).mean().detach()} 39 | for threshold in [None, 0.3, 0.4, 0.5, 0.6, 0.7]: 40 | d[f"{prefix}/map{threshold}"] = calc_map5(output, y, threshold) 41 | return d 42 | -------------------------------------------------------------------------------- /src/metric_learning.py: -------------------------------------------------------------------------------- 1 | import math 2 | 3 | import numpy as np 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | 9 | class GeM(nn.Module): 10 | def __init__(self, p=3, eps=1e-6, requires_grad=False): 11 | super().__init__() 12 | self.p = nn.Parameter(torch.ones(1) * p, requires_grad=requires_grad) 13 | self.eps = eps 14 | 15 | def forward(self, x: torch.Tensor): 16 | return x.clamp(min=self.eps).pow(self.p).mean((-2, -1)).pow(1.0 / self.p) 17 | 18 | 19 | # Copied and modified from 20 | # https://github.com/ChristofHenkel/kaggle-landmark-2021-1st-place/blob/034a7d8665bb4696981698348c9370f2d4e61e35/models/ch_mdl_dolg_efficientnet.py 21 | class ArcMarginProductSubcenter(nn.Module): 22 | def __init__(self, in_features: int, out_features: int, k: int = 3): 23 | super().__init__() 24 | self.weight = nn.Parameter(torch.FloatTensor(out_features * k, in_features)) 25 | self.reset_parameters() 26 | self.k = k 27 | self.out_features = out_features 28 | 29 | def reset_parameters(self): 30 | stdv = 1.0 / math.sqrt(self.weight.size(1)) 31 | self.weight.data.uniform_(-stdv, stdv) 32 | 33 | def forward(self, features: torch.Tensor) -> torch.Tensor: 34 | cosine_all = F.linear(F.normalize(features), F.normalize(self.weight)) 35 | cosine_all = cosine_all.view(-1, self.out_features, self.k) 36 | cosine, _ = torch.max(cosine_all, dim=2) 37 | return cosine 38 | 39 | 40 | class ArcFaceLossAdaptiveMargin(nn.modules.Module): 41 | def __init__(self, margins: np.ndarray, n_classes: int, s: float = 30.0): 42 | super().__init__() 43 | self.s = s 44 | self.margins = margins 45 | self.out_dim = n_classes 46 | 47 | def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: 48 | ms = self.margins[labels.cpu().numpy()] 49 | cos_m = torch.from_numpy(np.cos(ms)).float().cuda() 50 | sin_m = torch.from_numpy(np.sin(ms)).float().cuda() 51 | th = torch.from_numpy(np.cos(math.pi - ms)).float().cuda() 52 | mm = torch.from_numpy(np.sin(math.pi - ms) * ms).float().cuda() 53 | labels = F.one_hot(labels, self.out_dim).float() 54 | logits = logits.float() 55 | cosine = logits 56 | sine = torch.sqrt(1.0 - torch.pow(cosine, 2)) 57 | phi = cosine * cos_m.view(-1, 1) - sine * sin_m.view(-1, 1) 58 | phi = torch.where(cosine > th.view(-1, 1), phi, cosine - mm.view(-1, 1)) 59 | return ((labels * phi) + ((1.0 - labels) * cosine)) * self.s 60 | -------------------------------------------------------------------------------- /src/tune.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import copy 3 | 4 | import optuna 5 | 6 | from config.config import load_config 7 | from src.dataset import load_df 8 | from src.train import train 9 | 10 | 11 | def parse(): 12 | parser = argparse.ArgumentParser(description="Hyperparameter Tuning for HappyWhale") 13 | parser.add_argument("--out_base_dir", default="result") 14 | parser.add_argument("--in_base_dir", default="input") 15 | parser.add_argument("--exp_name", default="tune_tmp") 16 | parser.add_argument("--load_snapshot", action="store_true") 17 | parser.add_argument("--save_checkpoint", action="store_true") 18 | parser.add_argument("--wandb_logger", action="store_true") 19 | parser.add_argument("--config_path", default="config/debug.yaml") 20 | parser.add_argument("--rdb_url", default="sqlite:///tmp.db") 21 | return parser.parse_args() 22 | 23 | 24 | def main(): 25 | base_args = parse() 26 | base_cfg = load_config(base_args.config_path, "config/default.yaml") 27 | df = load_df(base_args.in_base_dir, base_cfg, "train.csv", True) 28 | 29 | def objective(trial: optuna.trial.Trial) -> float: 30 | args = copy.deepcopy(base_args) 31 | args.exp_name = f"{args.exp_name}/{trial.number}" 32 | 33 | cfg = copy.deepcopy(base_cfg) 34 | # dynamic arcface parameters 35 | cfg["s_id"] = trial.suggest_float("s_id", 10.0, 80.0) 36 | cfg["s_species"] = trial.suggest_float("s_species", 10.0, 80.0) 37 | cfg["loss_id_ratio"] = trial.suggest_float("loss_id_ratio", 0.2, 1.0) 38 | cfg["margin_power_id"] = trial.suggest_float("margin_power_id", -0.8, -0.05) 39 | cfg["margin_power_species"] = trial.suggest_float("margin_power_species", -0.8, -0.05) 40 | cfg["margin_coef_id"] = trial.suggest_float("margin_coef_id", 0.2, 1.0) 41 | cfg["margin_coef_species"] = trial.suggest_float("margin_coef_species", 0.2, 1.0) 42 | 43 | score = train(df, args, cfg, 0, optuna_trial=trial) 44 | assert score is not None 45 | return score 46 | 47 | storage = optuna.storages.RDBStorage( 48 | url=base_args.rdb_url, 49 | heartbeat_interval=60, 50 | grace_period=120, 51 | failed_trial_callback=optuna.storages.RetryFailedTrialCallback(), 52 | ) 53 | study = optuna.create_study( 54 | direction="maximize", 55 | storage=storage, 56 | study_name=base_args.exp_name, 57 | load_if_exists=True, 58 | pruner=optuna.pruners.NopPruner(), 59 | ) 60 | study.optimize(objective, callbacks=[optuna.study.MaxTrialsCallback(500)]) 61 | 62 | 63 | if __name__ == "__main__": 64 | main() 65 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # 1st Place Solution of Kaggle Happywhale Competition 2 | This is the knshnb's part of the Preferred Dolphin's solution for [Happywhale - Whale and Dolphin Identification](https://www.kaggle.com/competitions/happy-whale-and-dolphin). 3 | 4 | ## Dataset 5 | Please prepare dataset according to [input/README.md](input/README.md) and place under `input/`. 6 | ``` 7 | $ ls -F input 8 | fullbody_test_charm.csv pseudo_labels/ test_backfin.csv* train_images/ 9 | fullbody_test.csv README.md test_images/ yolov5_test.csv 10 | fullbody_train_charm.csv sample_submission.csv* train2.csv yolov5_train.csv 11 | fullbody_train.csv species.npy* train_backfin.csv 12 | individual_id.npy* test2.csv train.csv 13 | ``` 14 | 15 | ## Reproducing the winning score 16 | Before the final training round, we repeated 2 rounds of Step 1-2 for pseudo labeling. 17 | By default, `input/pseudo_labels/round2.csv` (the pseudo labels we created) is specified in the config file so that you can skip the first two rounds. 18 | You can train from scratch by setting `None` in `pseudo_label` field in config files. 19 | 20 | ### Step 1: Training and inference 21 | By `src/train.py`, we 22 | 1. train model by whole train data. 23 | 2. inference test data and save results under `result/{exp_name}/-1/`. 24 | 25 | Several examples of config files are located in `config/`. 26 | 27 | Example: Training and inference efficientnet_b6 and efficientnet_b7 28 | ``` 29 | python -m src.train --config_path config/efficientnet_b6.yaml --exp_name b6 30 | python -m src.train --config_path config/efficientnet_b7.yaml --exp_name b7 31 | ``` 32 | 33 | ### Step 2: Postprocess and ensemble 34 | By `src/ensemble.py`, we 35 | 1. calculate mean of the predictions by knn and logit for each model. 36 | 2. ensemble predictions of the models specified by `--model_dirs`. 37 | 3. save prediction as `submission/{out_prefix}-{new_ratio}-{threshold}.csv`. 38 | 4. save pseudo label as `submission/pseudo_label_{out_prefix}.csv`. 39 | 40 | Predictions generated by charmq's repository are saved in the same format, so you can ensemble them by just specifying paths to model directories. 41 | 42 | Example: Ensemble b6 and b7 43 | ``` 44 | python -m src.ensemble --model_dirs result/b6/-1 result/b7/-1 --out_prefix b6-b7 45 | ``` 46 | 47 | In our post submission, single model (efficientnet_b7) achieved a score that could rank 3rd place in the final leaderboard. 48 | We also confirmed that ensemble of only two models (efficientnet_b6 and efficientnet_b7) could win 1st place. 49 | Ensembling more backbones and charmq's modesl can achieve even better results. 50 | 51 | ## Citation 52 | ``` 53 | @article{patton2023deep, 54 | title={A deep learning approach to photo--identification demonstrates high performance on two dozen cetacean species}, 55 | author={Patton, Philip T and Cheeseman, Ted and Abe, Kenshin and Yamaguchi, Taiki and Reade, Walter and Southerland, Ken and Howard, Addison and Oleson, Erin M and Allen, Jason B and Ashe, Erin and others}, 56 | journal={Methods in ecology and evolution}, 57 | volume={14}, 58 | number={10}, 59 | pages={2611--2625}, 60 | year={2023}, 61 | publisher={Wiley Online Library} 62 | } 63 | ``` 64 | 65 | ## Links 66 | - For an overview of our key ideas and detailed explanation, please also refer to [1st Place Solution](https://www.kaggle.com/competitions/happy-whale-and-dolphin/discussion/320192) in Kaggle discussion. 67 | - My teammate [charmq's repository](https://github.com/tyamaguchi17/kaggle-happywhale-1st-place-solution-charmq). 68 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | result/** 2 | submission/** 3 | 4 | # Byte-compiled / optimized / DLL files 5 | __pycache__/ 6 | *.py[cod] 7 | *$py.class 8 | 9 | # C extensions 10 | *.so 11 | 12 | # Distribution / packaging 13 | .Python 14 | build/ 15 | develop-eggs/ 16 | dist/ 17 | downloads/ 18 | eggs/ 19 | .eggs/ 20 | lib/ 21 | lib64/ 22 | parts/ 23 | sdist/ 24 | var/ 25 | wheels/ 26 | share/python-wheels/ 27 | *.egg-info/ 28 | .installed.cfg 29 | *.egg 30 | MANIFEST 31 | 32 | # PyInstaller 33 | # Usually these files are written by a python script from a template 34 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 35 | *.manifest 36 | *.spec 37 | 38 | # Installer logs 39 | pip-log.txt 40 | pip-delete-this-directory.txt 41 | 42 | # Unit test / coverage reports 43 | htmlcov/ 44 | .tox/ 45 | .nox/ 46 | .coverage 47 | .coverage.* 48 | .cache 49 | nosetests.xml 50 | coverage.xml 51 | *.cover 52 | *.py,cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | cover/ 56 | 57 | # Translations 58 | *.mo 59 | *.pot 60 | 61 | # Django stuff: 62 | *.log 63 | local_settings.py 64 | db.sqlite3 65 | db.sqlite3-journal 66 | 67 | # Flask stuff: 68 | instance/ 69 | .webassets-cache 70 | 71 | # Scrapy stuff: 72 | .scrapy 73 | 74 | # Sphinx documentation 75 | docs/_build/ 76 | 77 | # PyBuilder 78 | .pybuilder/ 79 | target/ 80 | 81 | # Jupyter Notebook 82 | .ipynb_checkpoints 83 | 84 | # IPython 85 | profile_default/ 86 | ipython_config.py 87 | 88 | # pyenv 89 | # For a library or package, you might want to ignore these files since the code is 90 | # intended to run in multiple environments; otherwise, check them in: 91 | # .python-version 92 | 93 | # pipenv 94 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 95 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 96 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 97 | # install all needed dependencies. 98 | #Pipfile.lock 99 | 100 | # poetry 101 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 102 | # This is especially recommended for binary packages to ensure reproducibility, and is more 103 | # commonly ignored for libraries. 104 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 105 | #poetry.lock 106 | 107 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 108 | __pypackages__/ 109 | 110 | # Celery stuff 111 | celerybeat-schedule 112 | celerybeat.pid 113 | 114 | # SageMath parsed files 115 | *.sage.py 116 | 117 | # Environments 118 | .env 119 | .venv 120 | env/ 121 | venv/ 122 | ENV/ 123 | env.bak/ 124 | venv.bak/ 125 | 126 | # Spyder project settings 127 | .spyderproject 128 | .spyproject 129 | 130 | # Rope project settings 131 | .ropeproject 132 | 133 | # mkdocs documentation 134 | /site 135 | 136 | # mypy 137 | .mypy_cache/ 138 | .dmypy.json 139 | dmypy.json 140 | 141 | # Pyre type checker 142 | .pyre/ 143 | 144 | # pytype static type analyzer 145 | .pytype/ 146 | 147 | # Cython debug symbols 148 | cython_debug/ 149 | 150 | # PyCharm 151 | # JetBrains specific template is maintainted in a separate JetBrains.gitignore that can 152 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 153 | # and can be added to the global gitignore or merged into this file. For a more nuclear 154 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 155 | #.idea/ -------------------------------------------------------------------------------- /src/ensemble.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | from typing import List 4 | 5 | import numpy as np 6 | import pandas as pd 7 | import scipy.sparse 8 | import torch 9 | from sklearn import preprocessing 10 | from sklearn.neighbors import NearestNeighbors 11 | 12 | from config.config import Config, load_config 13 | 14 | 15 | def parse() -> argparse.Namespace: 16 | parser = argparse.ArgumentParser(description="Make submission csv for HappyWhale") 17 | parser.add_argument("--in_base_dir", default="input") 18 | parser.add_argument("--model_dirs", nargs="+", required=True) 19 | parser.add_argument("--out_prefix", default="test") 20 | return parser.parse_args() 21 | 22 | 23 | def load_results(results_path: str): 24 | results = np.load(results_path) 25 | # reorder all values so that ret["original_index"] = [0, 1, 2, ..., n_data - 1] 26 | n_data = results["original_index"].max() + 1 27 | ord = np.full(n_data, -1, dtype=int) 28 | ord[results["original_index"]] = np.arange(len(results["original_index"])) 29 | ret = {key: results[key][ord] for key in results.files if key != "file_name"} 30 | assert np.array_equal(ret["original_index"], np.arange(n_data)) 31 | return ret 32 | 33 | 34 | def restore_all_pred(n_class: int, pred: np.ndarray, pred_idx: np.ndarray): 35 | n_data = pred.shape[0] 36 | all_pred = np.zeros((n_data, n_class)) 37 | for i in range(n_data): 38 | all_pred[i][pred_idx[i]] = pred[i] 39 | return all_pred 40 | 41 | 42 | def knn_all_pred(n_class: int, n_train: int, test_feat: np.ndarray, train_feat: np.ndarray, train_label: np.ndarray): 43 | neigh = NearestNeighbors(n_neighbors=500, metric="cosine") 44 | neigh.fit(train_feat) 45 | test_dist, test_cosine_idx = neigh.kneighbors(test_feat, return_distance=True) # [n_val, 1000], [n_val, 1000] 46 | test_cosine = 1 - test_dist 47 | test_cosine_idx %= n_train 48 | test_all_knn = np.zeros((len(test_feat), n_class)) 49 | for i, (cosines, idx) in enumerate(zip(test_cosine, test_cosine_idx)): 50 | pred_ids = train_label[idx] 51 | for cosine, pred_id in zip(cosines, pred_ids): 52 | test_all_knn[i][pred_id] = max(test_all_knn[i][pred_id], cosine) 53 | return test_all_knn 54 | 55 | 56 | def knn_both_feat(n_class, train_feat1, train_feat2, test_feat1, test_feat2, train_label): 57 | # knn with both feats for train 58 | train_feat12 = np.concatenate([train_feat1, train_feat2], axis=0) 59 | knn1_both_mat = knn_all_pred(n_class, len(train_label), test_feat1, train_feat12, train_label) 60 | knn2_both_mat = knn_all_pred(n_class, len(train_label), test_feat2, train_feat12, train_label) 61 | return (knn1_both_mat + knn2_both_mat) / 2 62 | 63 | 64 | def binary_search_threshold(n_class: int, mat: torch.Tensor, new_ratio: float) -> float: 65 | ok, ng = 0.0, 1.0 66 | for _ in range(30): 67 | mid = (ok + ng) / 2 68 | out_new = torch.cat([mat, torch.full((mat.shape[0], 1), mid, device=mat.device)], dim=1) 69 | if (out_new.argmax(1) == n_class).to(float).mean() <= new_ratio: 70 | ok = mid 71 | else: 72 | ng = mid 73 | return ok 74 | 75 | 76 | def make_submission( 77 | train_paths: List[str], 78 | test_paths: List[str], 79 | args: argparse.Namespace, 80 | cfg: Config, 81 | new_ratios: List[float], 82 | knn_ratio: float = 0.5, 83 | ): 84 | os.makedirs("submission", exist_ok=True) 85 | # ensemble 86 | csr_sum = None 87 | species_mats = [] 88 | for train_path, test_path in zip(train_paths, test_paths): 89 | print(train_path, test_path) 90 | train_results, test_results = load_results(train_path), load_results(test_path) 91 | knn_csr = scipy.sparse.csr_matrix( 92 | knn_both_feat( 93 | cfg.num_classes, 94 | train_results["embed_features1"], 95 | train_results["embed_features2"], 96 | test_results["embed_features1"], 97 | test_results["embed_features2"], 98 | train_results["label"], 99 | ) 100 | ) 101 | logit_mat_csr = scipy.sparse.csr_matrix( 102 | restore_all_pred(cfg.num_classes, test_results["pred_logit"], test_results["pred_idx"]) 103 | ) 104 | mat_csr = knn_csr * knn_ratio + logit_mat_csr * (1 - knn_ratio) 105 | csr_sum = mat_csr if csr_sum is None else csr_sum + mat_csr 106 | species_mats.append(test_results["pred_species"]) 107 | ensembled_mat = (csr_sum / len(train_paths)).todense() 108 | species_mat = np.stack(species_mats).mean(0) 109 | out = torch.tensor(ensembled_mat) 110 | 111 | # make submission from ensembled prediction 112 | for new_ratio in new_ratios: 113 | threshold = binary_search_threshold(cfg.num_classes, out, new_ratio) 114 | print(f"new_ratio: {new_ratio}, selected threshold: {threshold}") 115 | out_new = torch.cat([out, torch.full((out.shape[0], 1), threshold, device=out.device)], dim=1) 116 | top5 = out_new.topk(5)[1] 117 | 118 | # make csv 119 | label_encoder = preprocessing.LabelEncoder() 120 | label_encoder.classes_ = np.load(f"{args.in_base_dir}/individual_id.npy", allow_pickle=True) 121 | assert cfg.num_classes == len(label_encoder.classes_) 122 | 123 | def make_str(id_list): 124 | return " ".join( 125 | "new_individual" if x == cfg.num_classes else label_encoder.inverse_transform([x])[0] for x in id_list 126 | ) 127 | 128 | df = pd.read_csv(f"{args.in_base_dir}/sample_submission.csv") 129 | df["predictions"] = [make_str(id_list) for id_list in top5] 130 | df.to_csv(f"submission/{args.out_prefix}-{new_ratio}-{threshold}.csv", index=False, columns=["image", "predictions"]) 131 | 132 | # generate pseudo label 133 | df = pd.read_csv(f"{args.in_base_dir}/sample_submission.csv") 134 | top1_conf, top1 = out.max(1) 135 | df["individual_id"] = label_encoder.inverse_transform(top1) 136 | df["conf"] = top1_conf 137 | le_species = preprocessing.LabelEncoder() 138 | le_species.classes_ = np.load(f"{args.in_base_dir}/species.npy", allow_pickle=True) 139 | df["species"] = le_species.inverse_transform(species_mat.argmax(1)) 140 | df.to_csv( 141 | f"submission/pseudo_label_{args.out_prefix}.csv", index=False, columns=["image", "individual_id", "conf", "species"] 142 | ) 143 | 144 | 145 | if __name__ == "__main__": 146 | args = parse() 147 | cfg = load_config("config/default.yaml", "config/default.yaml") 148 | train_paths, test_paths = [], [] 149 | for path in args.model_dirs: 150 | for bbox_name in cfg.test_bboxes: 151 | train_paths.append(f"{path}/train_{bbox_name}_results.npz") 152 | test_paths.append(f"{path}/test_{bbox_name}_results.npz") 153 | make_submission(train_paths, test_paths, args, cfg, [0.165], knn_ratio=0.5) 154 | -------------------------------------------------------------------------------- /src/dataset.py: -------------------------------------------------------------------------------- 1 | import albumentations as A 2 | import cv2 3 | import numpy as np 4 | import pandas as pd 5 | from albumentations.pytorch import ToTensorV2 6 | from sklearn import preprocessing 7 | from torch.utils.data import Dataset 8 | 9 | from config.config import Config 10 | 11 | 12 | class WhaleDataset(Dataset): 13 | def __init__( 14 | self, 15 | df: pd.DataFrame, 16 | cfg: Config, 17 | image_dir: str, 18 | val_bbox_name: str, 19 | data_aug: bool, 20 | ): 21 | super().__init__() 22 | self.index = df.index 23 | self.x_paths = np.array(df.image) 24 | self.ids = np.array(df.individual_id, dtype=int) if hasattr(df, "individual_id") else np.full(len(df), -1) 25 | self.species = np.array(df.species, dtype=int) if hasattr(df, "species") else np.full(len(df), -1) 26 | self.cfg = cfg 27 | self.image_dir = image_dir 28 | self.df = df 29 | self.val_bbox_name = val_bbox_name 30 | self.data_aug = data_aug 31 | augments = [] 32 | if data_aug: 33 | aug = cfg.aug 34 | augments = [ 35 | A.Affine( 36 | rotate=(-aug.rotate, aug.rotate), 37 | translate_percent=(0.0, aug.translate), 38 | shear=(-aug.shear, aug.shear), 39 | p=aug.p_affine, 40 | ), 41 | A.RandomResizedCrop( 42 | self.cfg.image_size[0], 43 | self.cfg.image_size[1], 44 | scale=(aug.crop_scale, 1.0), 45 | ratio=(aug.crop_l, aug.crop_r), 46 | ), 47 | A.ToGray(p=aug.p_gray), 48 | A.GaussianBlur(blur_limit=(3, 7), p=aug.p_blur), 49 | A.GaussNoise(p=aug.p_noise), 50 | A.Downscale(scale_min=0.5, scale_max=0.5, p=aug.p_downscale), 51 | A.RandomGridShuffle(grid=(2, 2), p=aug.p_shuffle), 52 | A.Posterize(p=aug.p_posterize), 53 | A.RandomBrightnessContrast(p=aug.p_bright_contrast), 54 | A.Cutout(p=aug.p_cutout), 55 | A.RandomSnow(p=aug.p_snow), 56 | A.RandomRain(p=aug.p_rain), 57 | A.HorizontalFlip(p=0.5), 58 | ] 59 | augments.append(A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])) 60 | augments.append(ToTensorV2()) # HWC to CHW 61 | self.transform = A.Compose(augments) 62 | 63 | def __len__(self): 64 | return len(self.ids) 65 | 66 | def get_original_image(self, i: int): 67 | bgr = cv2.imread(f"{self.image_dir}/{self.x_paths[i]}") 68 | rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) 69 | return rgb 70 | 71 | def __getitem__(self, i: int): 72 | image = self.get_original_image(i) 73 | # crop 74 | if self.data_aug: 75 | bbox_name = np.random.choice(list(self.cfg.bboxes.keys()), p=list(self.cfg.bboxes.values())) 76 | else: 77 | bbox_name = self.val_bbox_name 78 | bbox = None if bbox_name == "none" else self.df[bbox_name].iloc[i] 79 | if bbox is not None: 80 | xmin, ymin, xmax, ymax = bbox 81 | image = image[ymin:ymax, xmin:xmax] 82 | # resize 83 | image = cv2.resize(image, self.cfg.image_size, interpolation=cv2.INTER_CUBIC) 84 | # data augmentation 85 | augmented = self.transform(image=image)["image"] 86 | return { 87 | "original_index": self.index[i], 88 | "image": augmented, 89 | "label": self.ids[i], 90 | "label_species": self.species[i], 91 | } 92 | 93 | 94 | def load_bbox(cfg: Config, in_base_dir: str, bbox_name: str, is_train: bool) -> pd.Series: 95 | if bbox_name == "detic": 96 | filename = "train2.csv" if is_train else "test2.csv" 97 | tmp_df = pd.read_csv(f"{in_base_dir}/{filename}") 98 | low_conf = pd.Series([False for _ in range(len(tmp_df))]) 99 | bbox = tmp_df.box.map(lambda s: list(map(int, s.split())) if s == s else None) 100 | elif bbox_name == "fullbody": 101 | filename = "fullbody_train.csv" if is_train else "fullbody_test.csv" 102 | tmp_df = pd.read_csv(f"{in_base_dir}/{filename}") 103 | low_conf = tmp_df.conf.map(lambda s: float(s[1:-1]) if s == s else -1) < cfg.bbox_conf_threshold 104 | bbox = tmp_df.bbox.map(lambda s: list(map(int, s[2:-2].split()))) 105 | elif bbox_name == "fullbody_charm": 106 | filename = "fullbody_train_charm.csv" if is_train else "fullbody_test_charm.csv" 107 | tmp_df = pd.read_csv(f"{in_base_dir}/{filename}") 108 | low_conf = tmp_df.conf.map(lambda s: float(s[1:-1]) if s == s else -1) < cfg.bbox_conf_threshold 109 | bbox = tmp_df.bbox.map(lambda s: list(map(int, s[2:-2].split())) if s == s else None) 110 | elif bbox_name == "backfin": 111 | filename = "train_backfin.csv" if is_train else "test_backfin.csv" 112 | tmp_df = pd.read_csv(f"{in_base_dir}/{filename}") 113 | low_conf = tmp_df.conf.map(lambda s: float(s[1:-1]) if s == s else -1) < cfg.bbox_conf_threshold 114 | bbox = tmp_df.bbox.map(lambda s: list(map(int, s[2:-2].split())) if s == s else None) 115 | else: 116 | raise AssertionError() 117 | print(f"{bbox_name} low conf: {low_conf.sum()} / {len(tmp_df)}") 118 | bbox[low_conf] = None 119 | return bbox 120 | 121 | 122 | def load_df(in_base_dir: str, cfg: Config, filename: str, is_train: bool) -> pd.DataFrame: 123 | df = pd.read_csv(f"{in_base_dir}/{filename}") 124 | 125 | # bbox 126 | for bbox_name in ["detic", "fullbody", "fullbody_charm", "backfin"]: 127 | df[bbox_name] = load_bbox(cfg, in_base_dir, bbox_name, is_train) 128 | 129 | # label encoder 130 | if hasattr(df, "individual_id"): 131 | label_encoder = preprocessing.LabelEncoder() 132 | label_encoder.classes_ = np.load(f"{in_base_dir}/individual_id.npy", allow_pickle=True) 133 | df.individual_id = label_encoder.transform(df.individual_id) 134 | assert cfg.num_classes == len(label_encoder.classes_) 135 | if hasattr(df, "species"): 136 | df.species.replace( 137 | { 138 | "globis": "short_finned_pilot_whale", 139 | "pilot_whale": "short_finned_pilot_whale", 140 | "kiler_whale": "killer_whale", 141 | "bottlenose_dolpin": "bottlenose_dolphin", 142 | }, 143 | inplace=True, 144 | ) # https://www.kaggle.com/c/happy-whale-and-dolphin/discussion/305574 145 | label_encoder_species = preprocessing.LabelEncoder() 146 | label_encoder_species.classes_ = np.load(f"{in_base_dir}/species.npy", allow_pickle=True) 147 | df.species = label_encoder_species.transform(df.species) 148 | assert cfg.num_species_classes == len(label_encoder_species.classes_) 149 | return df 150 | -------------------------------------------------------------------------------- /src/train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import warnings 4 | from typing import Dict, List, Optional, Tuple 5 | 6 | import numpy as np 7 | import optuna 8 | import pandas as pd 9 | import timm 10 | import torch 11 | import wandb 12 | from optuna.integration import PyTorchLightningPruningCallback 13 | from pytorch_lightning import LightningDataModule, LightningModule, Trainer 14 | from pytorch_lightning import loggers as pl_loggers 15 | from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint 16 | from sklearn.model_selection import StratifiedKFold 17 | from torch.utils.data import ConcatDataset, DataLoader 18 | 19 | from config.config import Config, load_config 20 | from src.dataset import WhaleDataset, load_df 21 | from src.metric_learning import ArcFaceLossAdaptiveMargin, ArcMarginProductSubcenter, GeM 22 | from src.utils import WarmupCosineLambda, map_dict, topk_average_precision 23 | 24 | 25 | def parse(): 26 | parser = argparse.ArgumentParser(description="Training for HappyWhale") 27 | parser.add_argument("--out_base_dir", default="result") 28 | parser.add_argument("--in_base_dir", default="input") 29 | parser.add_argument("--exp_name", default="tmp") 30 | parser.add_argument("--load_snapshot", action="store_true") 31 | parser.add_argument("--save_checkpoint", action="store_true") 32 | parser.add_argument("--wandb_logger", action="store_true") 33 | parser.add_argument("--config_path", default="config/debug.yaml") 34 | return parser.parse_args() 35 | 36 | 37 | class WhaleDataModule(LightningDataModule): 38 | def __init__( 39 | self, 40 | df: pd.DataFrame, 41 | cfg: Config, 42 | image_dir: str, 43 | val_bbox_name: str, 44 | fold: int, 45 | additional_dataset: WhaleDataset = None, 46 | ): 47 | super().__init__() 48 | self.cfg = cfg 49 | self.image_dir = image_dir 50 | self.val_bbox_name = val_bbox_name 51 | self.additional_dataset = additional_dataset 52 | if cfg.n_data != -1: 53 | df = df.iloc[: cfg.n_data] 54 | self.all_df = df 55 | if fold == -1: 56 | self.train_df = df 57 | else: 58 | skf = StratifiedKFold(n_splits=cfg.n_splits, shuffle=True, random_state=0) 59 | train_idx, val_idx = list(skf.split(df, df.individual_id))[fold] 60 | self.train_df = df.iloc[train_idx].copy() 61 | self.val_df = df.iloc[val_idx].copy() 62 | # relabel ids not included in training data as "new individual" 63 | new_mask = ~self.val_df.individual_id.isin(self.train_df.individual_id) 64 | self.val_df.individual_id.mask(new_mask, cfg.num_classes, inplace=True) 65 | print(f"new: {(self.val_df.individual_id == cfg.num_classes).sum()} / {len(self.val_df)}") 66 | 67 | def get_dataset(self, df, data_aug): 68 | return WhaleDataset(df, self.cfg, self.image_dir, self.val_bbox_name, data_aug) 69 | 70 | def train_dataloader(self): 71 | dataset = self.get_dataset(self.train_df, True) 72 | if self.additional_dataset is not None: 73 | dataset = ConcatDataset([dataset, self.additional_dataset]) 74 | return DataLoader( 75 | dataset, 76 | batch_size=self.cfg.batch_size, 77 | shuffle=True, 78 | num_workers=2, 79 | pin_memory=True, 80 | drop_last=True, 81 | ) 82 | 83 | def val_dataloader(self): 84 | if self.cfg.n_splits == -1: 85 | return None 86 | return DataLoader( 87 | self.get_dataset(self.val_df, False), 88 | batch_size=self.cfg.batch_size, 89 | shuffle=False, 90 | num_workers=2, 91 | pin_memory=True, 92 | ) 93 | 94 | def all_dataloader(self): 95 | return DataLoader( 96 | self.get_dataset(self.all_df, False), 97 | batch_size=self.cfg.batch_size, 98 | shuffle=False, 99 | num_workers=2, 100 | pin_memory=True, 101 | ) 102 | 103 | 104 | class SphereClassifier(LightningModule): 105 | def __init__(self, cfg: dict, id_class_nums=None, species_class_nums=None): 106 | super().__init__() 107 | if not isinstance(cfg, Config): 108 | cfg = Config(cfg) 109 | self.save_hyperparameters(cfg, ignore=["id_class_nums", "species_class_nums"]) 110 | self.test_results_fp = None 111 | 112 | # NN architecture 113 | self.backbone = timm.create_model( 114 | cfg.model_name, 115 | in_chans=3, 116 | pretrained=cfg.pretrained, 117 | num_classes=0, 118 | features_only=True, 119 | out_indices=cfg.out_indices, 120 | ) 121 | feature_dims = self.backbone.feature_info.channels() 122 | print(f"feature dims: {feature_dims}") 123 | self.global_pools = torch.nn.ModuleList( 124 | [GeM(p=cfg.global_pool.p, requires_grad=cfg.global_pool.train) for _ in cfg.out_indices] 125 | ) 126 | self.mid_features = np.sum(feature_dims) 127 | if cfg.normalization == "batchnorm": 128 | self.neck = torch.nn.BatchNorm1d(self.mid_features) 129 | elif cfg.normalization == "layernorm": 130 | self.neck = torch.nn.LayerNorm(self.mid_features) 131 | self.head_id = ArcMarginProductSubcenter(self.mid_features, cfg.num_classes, cfg.n_center_id) 132 | self.head_species = ArcMarginProductSubcenter(self.mid_features, cfg.num_species_classes, cfg.n_center_species) 133 | if id_class_nums is not None and species_class_nums is not None: 134 | margins_id = np.power(id_class_nums, cfg.margin_power_id) * cfg.margin_coef_id + cfg.margin_cons_id 135 | margins_species = ( 136 | np.power(species_class_nums, cfg.margin_power_species) * cfg.margin_coef_species 137 | + cfg.margin_cons_species 138 | ) 139 | print("margins_id", margins_id) 140 | print("margins_species", margins_species) 141 | self.margin_fn_id = ArcFaceLossAdaptiveMargin(margins_id, cfg.num_classes, cfg.s_id) 142 | self.margin_fn_species = ArcFaceLossAdaptiveMargin(margins_species, cfg.num_species_classes, cfg.s_species) 143 | self.loss_fn_id = torch.nn.CrossEntropyLoss() 144 | self.loss_fn_species = torch.nn.CrossEntropyLoss() 145 | 146 | def get_feat(self, x: torch.Tensor) -> torch.Tensor: 147 | ms = self.backbone(x) 148 | h = torch.cat([global_pool(m) for m, global_pool in zip(ms, self.global_pools)], dim=1) 149 | return self.neck(h) 150 | 151 | def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: 152 | feat = self.get_feat(x) 153 | return self.head_id(feat), self.head_species(feat) 154 | 155 | def training_step(self, batch, batch_idx): 156 | x, ids, species = batch["image"], batch["label"], batch["label_species"] 157 | logits_ids, logits_species = self(x) 158 | margin_logits_ids = self.margin_fn_id(logits_ids, ids) 159 | loss_ids = self.loss_fn_id(margin_logits_ids, ids) 160 | loss_species = self.loss_fn_species(self.margin_fn_species(logits_species, species), species) 161 | self.log_dict({"train/loss_ids": loss_ids.detach()}, on_step=False, on_epoch=True) 162 | self.log_dict({"train/loss_species": loss_species.detach()}, on_step=False, on_epoch=True) 163 | with torch.no_grad(): 164 | self.log_dict(map_dict(logits_ids, ids, "train"), on_step=False, on_epoch=True) 165 | self.log_dict( 166 | {"train/acc_species": topk_average_precision(logits_species, species, 1).mean().detach()}, 167 | on_step=False, 168 | on_epoch=True, 169 | ) 170 | return loss_ids * self.hparams.loss_id_ratio + loss_species * (1 - self.hparams.loss_id_ratio) 171 | 172 | def validation_step(self, batch, batch_idx): 173 | x, ids, species = batch["image"], batch["label"], batch["label_species"] 174 | out1, out_species1 = self(x) 175 | out2, out_species2 = self(x.flip(3)) 176 | output, output_species = (out1 + out2) / 2, (out_species1 + out_species2) / 2 177 | self.log_dict(map_dict(output, ids, "val"), on_step=False, on_epoch=True) 178 | self.log_dict( 179 | {"val/acc_species": topk_average_precision(output_species, species, 1).mean().detach()}, 180 | on_step=False, 181 | on_epoch=True, 182 | ) 183 | 184 | def configure_optimizers(self): 185 | backbone_params = list(self.backbone.parameters()) + list(self.global_pools.parameters()) 186 | head_params = ( 187 | list(self.neck.parameters()) + list(self.head_id.parameters()) + list(self.head_species.parameters()) 188 | ) 189 | params = [ 190 | {"params": backbone_params, "lr": self.hparams.lr_backbone}, 191 | {"params": head_params, "lr": self.hparams.lr_head}, 192 | ] 193 | if self.hparams.optimizer == "Adam": 194 | optimizer = torch.optim.Adam(params) 195 | elif self.hparams.optimizer == "AdamW": 196 | optimizer = torch.optim.AdamW(params) 197 | elif self.hparams.optimizer == "RAdam": 198 | optimizer = torch.optim.RAdam(params) 199 | 200 | warmup_steps = self.hparams.max_epochs * self.hparams.warmup_steps_ratio 201 | cycle_steps = self.hparams.max_epochs - warmup_steps 202 | lr_lambda = WarmupCosineLambda(warmup_steps, cycle_steps, self.hparams.lr_decay_scale) 203 | scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) 204 | return [optimizer], [scheduler] 205 | 206 | def test_step(self, batch, batch_idx): 207 | x = batch["image"] 208 | feat1 = self.get_feat(x) 209 | out1, out_species1 = self.head_id(feat1), self.head_species(feat1) 210 | feat2 = self.get_feat(x.flip(3)) 211 | out2, out_species2 = self.head_id(feat2), self.head_species(feat2) 212 | pred_logit, pred_idx = ((out1 + out2) / 2).cpu().sort(descending=True) 213 | return { 214 | "original_index": batch["original_index"], 215 | "label": batch["label"], 216 | "label_species": batch["label_species"], 217 | "pred_logit": pred_logit[:, :1000], 218 | "pred_idx": pred_idx[:, :1000], 219 | "pred_species": ((out_species1 + out_species2) / 2).cpu(), 220 | "embed_features1": feat1.cpu(), 221 | "embed_features2": feat2.cpu(), 222 | } 223 | 224 | def test_epoch_end(self, outputs: List[Dict[str, torch.Tensor]]): 225 | outputs = self.all_gather(outputs) 226 | if self.trainer.global_rank == 0: 227 | epoch_results: Dict[str, np.ndarray] = {} 228 | for key in outputs[0].keys(): 229 | if torch.cuda.device_count() > 1: 230 | result = torch.cat([x[key] for x in outputs], dim=1).flatten(end_dim=1) 231 | else: 232 | result = torch.cat([x[key] for x in outputs], dim=0) 233 | epoch_results[key] = result.detach().cpu().numpy() 234 | np.savez_compressed(self.test_results_fp, **epoch_results) 235 | 236 | 237 | def train( 238 | df: pd.DataFrame, 239 | args: argparse.Namespace, 240 | cfg: Config, 241 | fold: int, 242 | do_inference: bool = False, 243 | additional_dataset: WhaleDataset = None, 244 | optuna_trial: Optional[optuna.Trial] = None, 245 | ) -> Optional[float]: 246 | out_dir = f"{args.out_base_dir}/{args.exp_name}/{fold}" 247 | id_class_nums = df.individual_id.value_counts().sort_index().values 248 | species_class_nums = df.species.value_counts().sort_index().values 249 | model = SphereClassifier(cfg, id_class_nums=id_class_nums, species_class_nums=species_class_nums) 250 | data_module = WhaleDataModule( 251 | df, cfg, f"{args.in_base_dir}/train_images", cfg.val_bbox, fold, additional_dataset=additional_dataset 252 | ) 253 | loggers = [pl_loggers.CSVLogger(out_dir)] 254 | if args.wandb_logger: 255 | loggers.append( 256 | pl_loggers.WandbLogger( 257 | project="kaggle-happywhale", group=args.exp_name, name=f"{args.exp_name}/{fold}", save_dir=out_dir 258 | ) 259 | ) 260 | callbacks = [LearningRateMonitor("epoch")] 261 | if optuna_trial is not None: 262 | callbacks.append(PyTorchLightningPruningCallback(optuna_trial, "val/mapNone")) 263 | if args.save_checkpoint: 264 | callbacks.append(ModelCheckpoint(out_dir, save_last=True, save_top_k=0)) 265 | trainer = Trainer( 266 | gpus=torch.cuda.device_count(), 267 | max_epochs=cfg["max_epochs"], 268 | logger=loggers, 269 | callbacks=callbacks, 270 | checkpoint_callback=args.save_checkpoint, 271 | precision=16, 272 | sync_batchnorm=True, 273 | ) 274 | ckpt_path = f"{out_dir}/last.ckpt" 275 | if not os.path.exists(ckpt_path) or not args.load_snapshot: 276 | ckpt_path = None 277 | trainer.fit(model, ckpt_path=ckpt_path, datamodule=data_module) 278 | if do_inference: 279 | for test_bbox in cfg.test_bboxes: 280 | # all train data 281 | model.test_results_fp = f"{out_dir}/train_{test_bbox}_results.npz" 282 | trainer.test(model, data_module.all_dataloader()) 283 | # test data 284 | model.test_results_fp = f"{out_dir}/test_{test_bbox}_results.npz" 285 | df_test = load_df(args.in_base_dir, cfg, "sample_submission.csv", False) 286 | test_data_module = WhaleDataModule(df_test, cfg, f"{args.in_base_dir}/test_images", test_bbox, -1) 287 | trainer.test(model, test_data_module.all_dataloader()) 288 | 289 | if args.wandb_logger: 290 | wandb.finish() 291 | if optuna_trial is not None: 292 | return trainer.callback_metrics["val/mapNone"].item() 293 | else: 294 | return None 295 | 296 | 297 | def main(): 298 | args = parse() 299 | warnings.filterwarnings("ignore", ".*does not have many workers.*") 300 | cfg = load_config(args.config_path, "config/default.yaml") 301 | print(cfg) 302 | df = load_df(args.in_base_dir, cfg, "train.csv", True) 303 | pseudo_dataset = None 304 | if cfg.pseudo_label is not None: 305 | pseudo_df = load_df(args.in_base_dir, cfg, cfg.pseudo_label, False) 306 | pseudo_dataset = WhaleDataset( 307 | pseudo_df[pseudo_df.conf > cfg.pseudo_conf_threshold], cfg, f"{args.in_base_dir}/test_images", "", True 308 | ) 309 | if cfg["n_splits"] == -1: 310 | train(df, args, cfg, -1, do_inference=True, additional_dataset=pseudo_dataset) 311 | else: 312 | train(df, args, cfg, 0, do_inference=True, additional_dataset=pseudo_dataset) 313 | 314 | 315 | if __name__ == "__main__": 316 | main() 317 | --------------------------------------------------------------------------------