├── .build └── .keep ├── .gitignore ├── imet ├── __init__.py ├── make_submission.py ├── make_folds.py ├── transforms.py ├── dataset.py ├── models.py ├── utils.py └── main.py ├── setup.py ├── README.rst ├── requirements.txt └── script_template.py /.build/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | /.build/script.py 2 | /data 3 | -------------------------------------------------------------------------------- /imet/__init__.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | 3 | 4 | cv2.setNumThreads(0) # fix potential pytorch worker issues 5 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | 3 | setup( 4 | name='imet', 5 | packages=['imet'], 6 | ) 7 | -------------------------------------------------------------------------------- /README.rst: -------------------------------------------------------------------------------- 1 | iMet Collection 2019 - FGVC6 2 | ============================ 3 | 4 | This is PyTorch baseline for https://www.kaggle.com/c/imet-2019-fgvc6/ 5 | 6 | License is MIT. 7 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | json-lines==0.5.0 2 | matplotlib==3.0.3 3 | numpy==1.16.2 4 | opencv-python==4.0.0.21 5 | pandas==0.23.4 6 | Pillow==5.1.0 7 | scikit-learn==0.20.3 8 | scipy==1.1.0 9 | tables==3.5.1 10 | torch 11 | torchvision==0.2.2 12 | tqdm==4.31.1 13 | -------------------------------------------------------------------------------- /script_template.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import base64 3 | import os 4 | from pathlib import Path 5 | from typing import Dict 6 | 7 | 8 | # this is base64 encoded source code 9 | file_data: Dict = {file_data} 10 | 11 | 12 | for path, encoded in file_data.items(): 13 | print(path) 14 | path = Path(path) 15 | path.parent.mkdir(exist_ok=True) 16 | path.write_bytes(gzip.decompress(base64.b64decode(encoded))) 17 | 18 | 19 | def run(command): 20 | os.system('export PYTHONPATH=${PYTHONPATH}:/kaggle/working && ' + command) 21 | 22 | 23 | run('python setup.py develop --install-dir /kaggle/working') 24 | run('python -m imet.make_folds') 25 | run('python -m imet.main train model_1 --n-epochs 25') 26 | run('python -m imet.main predict_test model_1') 27 | run('python -m imet.make_submission model_1/test.h5 submission.csv --threshold 0.1') 28 | -------------------------------------------------------------------------------- /imet/make_submission.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | import pandas as pd 4 | 5 | from .utils import mean_df 6 | from .dataset import DATA_ROOT 7 | from .main import binarize_prediction 8 | 9 | 10 | def main(): 11 | parser = argparse.ArgumentParser() 12 | arg = parser.add_argument 13 | arg('predictions', nargs='+') 14 | arg('output') 15 | arg('--threshold', type=float, default=0.2) 16 | args = parser.parse_args() 17 | sample_submission = pd.read_csv( 18 | DATA_ROOT / 'sample_submission.csv', index_col='id') 19 | dfs = [] 20 | for prediction in args.predictions: 21 | df = pd.read_hdf(prediction, index_col='id') 22 | df = df.reindex(sample_submission.index) 23 | dfs.append(df) 24 | df = pd.concat(dfs) 25 | df = mean_df(df) 26 | df[:] = binarize_prediction(df.values, threshold=args.threshold) 27 | df = df.apply(get_classes, axis=1) 28 | df.name = 'attribute_ids' 29 | df.to_csv(args.output, header=True) 30 | 31 | 32 | def get_classes(item): 33 | return ' '.join(cls for cls, is_present in item.items() if is_present) 34 | 35 | 36 | if __name__ == '__main__': 37 | main() 38 | -------------------------------------------------------------------------------- /imet/make_folds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from collections import defaultdict, Counter 3 | import random 4 | 5 | import pandas as pd 6 | import tqdm 7 | 8 | from .dataset import DATA_ROOT 9 | 10 | 11 | def make_folds(n_folds: int) -> pd.DataFrame: 12 | df = pd.read_csv(DATA_ROOT / 'train.csv') 13 | cls_counts = Counter(cls for classes in df['attribute_ids'].str.split() 14 | for cls in classes) 15 | fold_cls_counts = defaultdict(int) 16 | folds = [-1] * len(df) 17 | for item in tqdm.tqdm(df.sample(frac=1, random_state=42).itertuples(), 18 | total=len(df)): 19 | cls = min(item.attribute_ids.split(), key=lambda cls: cls_counts[cls]) 20 | fold_counts = [(f, fold_cls_counts[f, cls]) for f in range(n_folds)] 21 | min_count = min([count for _, count in fold_counts]) 22 | random.seed(item.Index) 23 | fold = random.choice([f for f, count in fold_counts 24 | if count == min_count]) 25 | folds[item.Index] = fold 26 | for cls in item.attribute_ids.split(): 27 | fold_cls_counts[fold, cls] += 1 28 | df['fold'] = folds 29 | return df 30 | 31 | 32 | def main(): 33 | parser = argparse.ArgumentParser() 34 | parser.add_argument('--n-folds', type=int, default=5) 35 | args = parser.parse_args() 36 | df = make_folds(n_folds=args.n_folds) 37 | df.to_csv('folds.csv', index=None) 38 | 39 | 40 | if __name__ == '__main__': 41 | main() 42 | -------------------------------------------------------------------------------- /imet/transforms.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | 4 | from PIL import Image 5 | from torchvision.transforms import ( 6 | ToTensor, Normalize, Compose, Resize, CenterCrop, RandomCrop, 7 | RandomHorizontalFlip) 8 | 9 | 10 | class RandomSizedCrop: 11 | """Random crop the given PIL.Image to a random size 12 | of the original size and and a random aspect ratio 13 | of the original aspect ratio. 14 | size: size of the smaller edge 15 | interpolation: Default: PIL.Image.BILINEAR 16 | """ 17 | 18 | def __init__(self, size, interpolation=Image.BILINEAR, 19 | min_aspect=4/5, max_aspect=5/4, 20 | min_area=0.25, max_area=1): 21 | self.size = size 22 | self.interpolation = interpolation 23 | self.min_aspect = min_aspect 24 | self.max_aspect = max_aspect 25 | self.min_area = min_area 26 | self.max_area = max_area 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(self.min_area, self.max_area) * area 32 | aspect_ratio = random.uniform(self.min_aspect, self.max_aspect) 33 | 34 | w = int(round(math.sqrt(target_area * aspect_ratio))) 35 | h = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if random.random() < 0.5: 38 | w, h = h, w 39 | 40 | if w <= img.size[0] and h <= img.size[1]: 41 | x1 = random.randint(0, img.size[0] - w) 42 | y1 = random.randint(0, img.size[1] - h) 43 | 44 | img = img.crop((x1, y1, x1 + w, y1 + h)) 45 | assert(img.size == (w, h)) 46 | 47 | return img.resize((self.size, self.size), self.interpolation) 48 | 49 | # Fallback 50 | scale = Resize(self.size, interpolation=self.interpolation) 51 | crop = CenterCrop(self.size) 52 | return crop(scale(img)) 53 | 54 | 55 | train_transform = Compose([ 56 | RandomCrop(288), 57 | RandomHorizontalFlip(), 58 | ]) 59 | 60 | 61 | test_transform = Compose([ 62 | RandomCrop(288), 63 | RandomHorizontalFlip(), 64 | ]) 65 | 66 | 67 | tensor_transform = Compose([ 68 | ToTensor(), 69 | Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), 70 | ]) 71 | -------------------------------------------------------------------------------- /imet/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Callable, List 3 | 4 | import cv2 5 | import pandas as pd 6 | from PIL import Image 7 | import torch 8 | from torch.utils.data import Dataset 9 | 10 | from .transforms import tensor_transform 11 | from .utils import ON_KAGGLE 12 | 13 | 14 | N_CLASSES = 1103 15 | DATA_ROOT = Path('../input/imet-2019-fgvc6' if ON_KAGGLE else './data') 16 | 17 | 18 | class TrainDataset(Dataset): 19 | def __init__(self, root: Path, df: pd.DataFrame, 20 | image_transform: Callable, debug: bool = True): 21 | super().__init__() 22 | self._root = root 23 | self._df = df 24 | self._image_transform = image_transform 25 | self._debug = debug 26 | 27 | def __len__(self): 28 | return len(self._df) 29 | 30 | def __getitem__(self, idx: int): 31 | item = self._df.iloc[idx] 32 | image = load_transform_image( 33 | item, self._root, self._image_transform, debug=self._debug) 34 | target = torch.zeros(N_CLASSES) 35 | for cls in item.attribute_ids.split(): 36 | target[int(cls)] = 1 37 | return image, target 38 | 39 | 40 | class TTADataset: 41 | def __init__(self, root: Path, df: pd.DataFrame, 42 | image_transform: Callable, tta: int): 43 | self._root = root 44 | self._df = df 45 | self._image_transform = image_transform 46 | self._tta = tta 47 | 48 | def __len__(self): 49 | return len(self._df) * self._tta 50 | 51 | def __getitem__(self, idx): 52 | item = self._df.iloc[idx % len(self._df)] 53 | image = load_transform_image(item, self._root, self._image_transform) 54 | return image, item.id 55 | 56 | 57 | def load_transform_image( 58 | item, root: Path, image_transform: Callable, debug: bool = False): 59 | image = load_image(item, root) 60 | image = image_transform(image) 61 | if debug: 62 | image.save('_debug.png') 63 | return tensor_transform(image) 64 | 65 | 66 | def load_image(item, root: Path) -> Image.Image: 67 | image = cv2.imread(str(root / f'{item.id}.png')) 68 | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) 69 | return Image.fromarray(image) 70 | 71 | 72 | def get_ids(root: Path) -> List[str]: 73 | return sorted({p.name.split('_')[0] for p in root.glob('*.png')}) 74 | -------------------------------------------------------------------------------- /imet/models.py: -------------------------------------------------------------------------------- 1 | from functools import partial 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | import torchvision.models as M 7 | 8 | from .utils import ON_KAGGLE 9 | 10 | 11 | class AvgPool(nn.Module): 12 | def forward(self, x): 13 | return F.avg_pool2d(x, x.shape[2:]) 14 | 15 | 16 | def create_net(net_cls, pretrained: bool): 17 | if ON_KAGGLE and pretrained: 18 | net = net_cls() 19 | model_name = net_cls.__name__ 20 | weights_path = f'../input/{model_name}/{model_name}.pth' 21 | net.load_state_dict(torch.load(weights_path)) 22 | else: 23 | net = net_cls(pretrained=pretrained) 24 | return net 25 | 26 | 27 | class ResNet(nn.Module): 28 | def __init__(self, num_classes, 29 | pretrained=False, net_cls=M.resnet50, dropout=False): 30 | super().__init__() 31 | self.net = create_net(net_cls, pretrained=pretrained) 32 | self.net.avgpool = AvgPool() 33 | if dropout: 34 | self.net.fc = nn.Sequential( 35 | nn.Dropout(), 36 | nn.Linear(self.net.fc.in_features, num_classes), 37 | ) 38 | else: 39 | self.net.fc = nn.Linear(self.net.fc.in_features, num_classes) 40 | 41 | def fresh_params(self): 42 | return self.net.fc.parameters() 43 | 44 | def forward(self, x): 45 | return self.net(x) 46 | 47 | 48 | class DenseNet(nn.Module): 49 | def __init__(self, num_classes, 50 | pretrained=False, net_cls=M.densenet121): 51 | super().__init__() 52 | self.net = create_net(net_cls, pretrained=pretrained) 53 | self.avg_pool = AvgPool() 54 | self.net.classifier = nn.Linear( 55 | self.net.classifier.in_features, num_classes) 56 | 57 | def fresh_params(self): 58 | return self.net.classifier.parameters() 59 | 60 | def forward(self, x): 61 | out = self.net.features(x) 62 | out = F.relu(out, inplace=True) 63 | out = self.avg_pool(out).view(out.size(0), -1) 64 | out = self.net.classifier(out) 65 | return out 66 | 67 | 68 | resnet18 = partial(ResNet, net_cls=M.resnet18) 69 | resnet34 = partial(ResNet, net_cls=M.resnet34) 70 | resnet50 = partial(ResNet, net_cls=M.resnet50) 71 | resnet101 = partial(ResNet, net_cls=M.resnet101) 72 | resnet152 = partial(ResNet, net_cls=M.resnet152) 73 | 74 | densenet121 = partial(DenseNet, net_cls=M.densenet121) 75 | densenet169 = partial(DenseNet, net_cls=M.densenet169) 76 | densenet201 = partial(DenseNet, net_cls=M.densenet201) 77 | densenet161 = partial(DenseNet, net_cls=M.densenet161) 78 | -------------------------------------------------------------------------------- /imet/utils.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import json 3 | import glob 4 | import os 5 | from pathlib import Path 6 | from multiprocessing.pool import ThreadPool 7 | from typing import Dict 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from scipy.stats.mstats import gmean 12 | import torch 13 | from torch import nn 14 | from torch.utils.data import DataLoader 15 | 16 | 17 | ON_KAGGLE: bool = 'KAGGLE_WORKING_DIR' in os.environ 18 | 19 | 20 | def gmean_df(df: pd.DataFrame) -> pd.DataFrame: 21 | return df.groupby(level=0).agg(lambda x: gmean(list(x))) 22 | 23 | 24 | def mean_df(df: pd.DataFrame) -> pd.DataFrame: 25 | return df.groupby(level=0).mean() 26 | 27 | 28 | def load_model(model: nn.Module, path: Path) -> Dict: 29 | state = torch.load(str(path)) 30 | model.load_state_dict(state['model']) 31 | print('Loaded model from epoch {epoch}, step {step:,}'.format(**state)) 32 | return state 33 | 34 | 35 | class ThreadingDataLoader(DataLoader): 36 | def __iter__(self): 37 | sample_iter = iter(self.batch_sampler) 38 | if self.num_workers == 0: 39 | for indices in sample_iter: 40 | yield self.collate_fn([self._get_item(i) for i in indices]) 41 | else: 42 | prefetch = 1 43 | with ThreadPool(processes=self.num_workers) as pool: 44 | futures = [] 45 | for indices in sample_iter: 46 | futures.append([pool.apply_async(self._get_item, args=(i,)) 47 | for i in indices]) 48 | if len(futures) > prefetch: 49 | yield self.collate_fn([f.get() for f in futures.pop(0)]) 50 | # items = pool.map(lambda i: self.dataset[i], indices) 51 | # yield self.collate_fn(items) 52 | for batch_futures in futures: 53 | yield self.collate_fn([f.get() for f in batch_futures]) 54 | 55 | def _get_item(self, i): 56 | return self.dataset[i] 57 | 58 | 59 | def write_event(log, step: int, **data): 60 | data['step'] = step 61 | data['dt'] = datetime.now().isoformat() 62 | log.write(json.dumps(data, sort_keys=True)) 63 | log.write('\n') 64 | log.flush() 65 | 66 | 67 | def plot(*args, ymin=None, ymax=None, xmin=None, xmax=None, params=False, 68 | max_points=200, legend=True, title=None, 69 | print_keys=False, print_paths=False, plt=None, newfigure=True, 70 | x_scale=1): 71 | """ 72 | Use in the notebook like this:: 73 | 74 | %matplotlib inline 75 | from imet.utils import plot 76 | plot('./runs/oc2', './runs/oc1', 'loss', 'valid_loss') 77 | 78 | """ 79 | import json_lines # no available on Kaggle 80 | 81 | if plt is None: 82 | from matplotlib import pyplot as plt 83 | paths, keys = [], [] 84 | for x in args: 85 | if x.startswith('.') or '/' in x: 86 | if '*' in x: 87 | paths.extend(glob.glob(x)) 88 | else: 89 | paths.append(x) 90 | else: 91 | keys.append(x) 92 | if print_paths: 93 | print('Found paths: {}'.format(' '.join(sorted(paths)))) 94 | if newfigure: 95 | plt.figure(figsize=(12, 8)) 96 | keys = keys or ['loss', 'valid_loss'] 97 | 98 | ylim_kw = {} 99 | if ymin is not None: 100 | ylim_kw['bottom'] = ymin 101 | if ymax is not None: 102 | ylim_kw['top'] = ymax 103 | if ylim_kw: 104 | plt.ylim(**ylim_kw) 105 | 106 | xlim_kw = {} 107 | if xmin is not None: 108 | xlim_kw['left'] = xmin 109 | if xmax is not None: 110 | xlim_kw['right'] = xmax 111 | if xlim_kw: 112 | plt.xlim(**xlim_kw) 113 | all_keys = set() 114 | for path in sorted(paths): 115 | path = Path(path) 116 | with json_lines.open(path / 'train.log', broken=True) as f: 117 | events = list(f) 118 | all_keys.update(k for e in events for k in e) 119 | for key in sorted(keys): 120 | xs, ys, ys_err = [], [], [] 121 | for e in events: 122 | if key in e: 123 | xs.append(e['step'] * x_scale) 124 | ys.append(e[key]) 125 | std_key = key + '_std' 126 | if std_key in e: 127 | ys_err.append(e[std_key]) 128 | if xs: 129 | if np.isnan(ys).any(): 130 | print('Warning: NaN {} for {}'.format(key, path)) 131 | if len(xs) > 2 * max_points: 132 | indices = (np.arange(0, len(xs) - 1, len(xs) / max_points) 133 | .astype(np.int32)) 134 | xs = np.array(xs)[indices[1:]] 135 | ys = _smooth(ys, indices) 136 | if ys_err: 137 | ys_err = _smooth(ys_err, indices) 138 | label = '{}: {}'.format(path, key) 139 | if label.startswith('_'): 140 | label = ' ' + label 141 | if ys_err: 142 | ys_err = 1.96 * np.array(ys_err) 143 | plt.errorbar(xs, ys, yerr=ys_err, 144 | fmt='-o', capsize=5, capthick=2, 145 | label=label) 146 | else: 147 | plt.plot(xs, ys, label=label) 148 | plt.legend() 149 | if newfigure: 150 | plt.grid() 151 | if legend: 152 | plt.legend() 153 | if title: 154 | plt.title(title) 155 | if print_keys: 156 | print('Found keys: {}' 157 | .format(', '.join(sorted(all_keys - {'step', 'dt'})))) 158 | 159 | 160 | def _smooth(ys, indices): 161 | return [np.mean(ys[idx: indices[i + 1]]) 162 | for i, idx in enumerate(indices[:-1])] 163 | -------------------------------------------------------------------------------- /imet/main.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from itertools import islice 3 | import json 4 | from pathlib import Path 5 | import shutil 6 | import warnings 7 | from typing import Dict 8 | 9 | import numpy as np 10 | import pandas as pd 11 | from sklearn.metrics import fbeta_score 12 | from sklearn.exceptions import UndefinedMetricWarning 13 | import torch 14 | from torch import nn, cuda 15 | from torch.optim import Adam 16 | import tqdm 17 | 18 | from . import models 19 | from .dataset import TrainDataset, TTADataset, get_ids, N_CLASSES, DATA_ROOT 20 | from .transforms import train_transform, test_transform 21 | from .utils import ( 22 | write_event, load_model, mean_df, ThreadingDataLoader as DataLoader, 23 | ON_KAGGLE) 24 | 25 | 26 | def main(): 27 | parser = argparse.ArgumentParser() 28 | arg = parser.add_argument 29 | arg('mode', choices=['train', 'validate', 'predict_valid', 'predict_test']) 30 | arg('run_root') 31 | arg('--model', default='resnet50') 32 | arg('--pretrained', type=int, default=1) 33 | arg('--batch-size', type=int, default=64) 34 | arg('--step', type=int, default=1) 35 | arg('--workers', type=int, default=2 if ON_KAGGLE else 4) 36 | arg('--lr', type=float, default=1e-4) 37 | arg('--patience', type=int, default=4) 38 | arg('--clean', action='store_true') 39 | arg('--n-epochs', type=int, default=100) 40 | arg('--epoch-size', type=int) 41 | arg('--tta', type=int, default=4) 42 | arg('--use-sample', action='store_true', help='use a sample of the dataset') 43 | arg('--debug', action='store_true') 44 | arg('--limit', type=int) 45 | arg('--fold', type=int, default=0) 46 | args = parser.parse_args() 47 | 48 | run_root = Path(args.run_root) 49 | folds = pd.read_csv('folds.csv') 50 | train_root = DATA_ROOT / ('train_sample' if args.use_sample else 'train') 51 | if args.use_sample: 52 | folds = folds[folds['Id'].isin(set(get_ids(train_root)))] 53 | train_fold = folds[folds['fold'] != args.fold] 54 | valid_fold = folds[folds['fold'] == args.fold] 55 | if args.limit: 56 | train_fold = train_fold[:args.limit] 57 | valid_fold = valid_fold[:args.limit] 58 | 59 | def make_loader(df: pd.DataFrame, image_transform) -> DataLoader: 60 | return DataLoader( 61 | TrainDataset(train_root, df, image_transform, debug=args.debug), 62 | shuffle=True, 63 | batch_size=args.batch_size, 64 | num_workers=args.workers, 65 | ) 66 | criterion = nn.BCEWithLogitsLoss(reduction='none') 67 | model = getattr(models, args.model)( 68 | num_classes=N_CLASSES, pretrained=args.pretrained) 69 | use_cuda = cuda.is_available() 70 | fresh_params = list(model.fresh_params()) 71 | all_params = list(model.parameters()) 72 | if use_cuda: 73 | model = model.cuda() 74 | 75 | if args.mode == 'train': 76 | if run_root.exists() and args.clean: 77 | shutil.rmtree(run_root) 78 | run_root.mkdir(exist_ok=True, parents=True) 79 | (run_root / 'params.json').write_text( 80 | json.dumps(vars(args), indent=4, sort_keys=True)) 81 | 82 | train_loader = make_loader(train_fold, train_transform) 83 | valid_loader = make_loader(valid_fold, test_transform) 84 | print(f'{len(train_loader.dataset):,} items in train, ' 85 | f'{len(valid_loader.dataset):,} in valid') 86 | 87 | train_kwargs = dict( 88 | args=args, 89 | model=model, 90 | criterion=criterion, 91 | train_loader=train_loader, 92 | valid_loader=valid_loader, 93 | patience=args.patience, 94 | init_optimizer=lambda params, lr: Adam(params, lr), 95 | use_cuda=use_cuda, 96 | ) 97 | 98 | if args.pretrained: 99 | if train(params=fresh_params, n_epochs=1, **train_kwargs): 100 | train(params=all_params, **train_kwargs) 101 | else: 102 | train(params=all_params, **train_kwargs) 103 | 104 | elif args.mode == 'validate': 105 | valid_loader = make_loader(valid_fold, test_transform) 106 | load_model(model, run_root / 'model.pt') 107 | validation(model, criterion, tqdm.tqdm(valid_loader, desc='Validation'), 108 | use_cuda=use_cuda) 109 | 110 | elif args.mode.startswith('predict'): 111 | load_model(model, run_root / 'best-model.pt') 112 | predict_kwargs = dict( 113 | batch_size=args.batch_size, 114 | tta=args.tta, 115 | use_cuda=use_cuda, 116 | workers=args.workers, 117 | ) 118 | if args.mode == 'predict_valid': 119 | predict(model, df=valid_fold, root=train_root, 120 | out_path=run_root / 'val.h5', 121 | **predict_kwargs) 122 | elif args.mode == 'predict_test': 123 | test_root = DATA_ROOT / ( 124 | 'test_sample' if args.use_sample else 'test') 125 | ss = pd.read_csv(DATA_ROOT / 'sample_submission.csv') 126 | if args.use_sample: 127 | ss = ss[ss['id'].isin(set(get_ids(test_root)))] 128 | if args.limit: 129 | ss = ss[:args.limit] 130 | predict(model, df=ss, root=test_root, 131 | out_path=run_root / 'test.h5', 132 | **predict_kwargs) 133 | 134 | 135 | def predict(model, root: Path, df: pd.DataFrame, out_path: Path, 136 | batch_size: int, tta: int, workers: int, use_cuda: bool): 137 | loader = DataLoader( 138 | dataset=TTADataset(root, df, test_transform, tta=tta), 139 | shuffle=False, 140 | batch_size=batch_size, 141 | num_workers=workers, 142 | ) 143 | model.eval() 144 | all_outputs, all_ids = [], [] 145 | with torch.no_grad(): 146 | for inputs, ids in tqdm.tqdm(loader, desc='Predict'): 147 | if use_cuda: 148 | inputs = inputs.cuda() 149 | outputs = torch.sigmoid(model(inputs)) 150 | all_outputs.append(outputs.data.cpu().numpy()) 151 | all_ids.extend(ids) 152 | df = pd.DataFrame( 153 | data=np.concatenate(all_outputs), 154 | index=all_ids, 155 | columns=map(str, range(N_CLASSES))) 156 | df = mean_df(df) 157 | df.to_hdf(out_path, 'prob', index_label='id') 158 | print(f'Saved predictions to {out_path}') 159 | 160 | 161 | def train(args, model: nn.Module, criterion, *, params, 162 | train_loader, valid_loader, init_optimizer, use_cuda, 163 | n_epochs=None, patience=2, max_lr_changes=2) -> bool: 164 | lr = args.lr 165 | n_epochs = n_epochs or args.n_epochs 166 | params = list(params) 167 | optimizer = init_optimizer(params, lr) 168 | 169 | run_root = Path(args.run_root) 170 | model_path = run_root / 'model.pt' 171 | best_model_path = run_root / 'best-model.pt' 172 | if model_path.exists(): 173 | state = load_model(model, model_path) 174 | epoch = state['epoch'] 175 | step = state['step'] 176 | best_valid_loss = state['best_valid_loss'] 177 | else: 178 | epoch = 1 179 | step = 0 180 | best_valid_loss = float('inf') 181 | lr_changes = 0 182 | 183 | save = lambda ep: torch.save({ 184 | 'model': model.state_dict(), 185 | 'epoch': ep, 186 | 'step': step, 187 | 'best_valid_loss': best_valid_loss 188 | }, str(model_path)) 189 | 190 | report_each = 10 191 | log = run_root.joinpath('train.log').open('at', encoding='utf8') 192 | valid_losses = [] 193 | lr_reset_epoch = epoch 194 | for epoch in range(epoch, n_epochs + 1): 195 | model.train() 196 | tq = tqdm.tqdm(total=(args.epoch_size or 197 | len(train_loader) * args.batch_size)) 198 | tq.set_description(f'Epoch {epoch}, lr {lr}') 199 | losses = [] 200 | tl = train_loader 201 | if args.epoch_size: 202 | tl = islice(tl, args.epoch_size // args.batch_size) 203 | try: 204 | mean_loss = 0 205 | for i, (inputs, targets) in enumerate(tl): 206 | if use_cuda: 207 | inputs, targets = inputs.cuda(), targets.cuda() 208 | outputs = model(inputs) 209 | loss = _reduce_loss(criterion(outputs, targets)) 210 | batch_size = inputs.size(0) 211 | (batch_size * loss).backward() 212 | if (i + 1) % args.step == 0: 213 | optimizer.step() 214 | optimizer.zero_grad() 215 | step += 1 216 | tq.update(batch_size) 217 | losses.append(loss.item()) 218 | mean_loss = np.mean(losses[-report_each:]) 219 | tq.set_postfix(loss=f'{mean_loss:.3f}') 220 | if i and i % report_each == 0: 221 | write_event(log, step, loss=mean_loss) 222 | write_event(log, step, loss=mean_loss) 223 | tq.close() 224 | save(epoch + 1) 225 | valid_metrics = validation(model, criterion, valid_loader, use_cuda) 226 | write_event(log, step, **valid_metrics) 227 | valid_loss = valid_metrics['valid_loss'] 228 | valid_losses.append(valid_loss) 229 | if valid_loss < best_valid_loss: 230 | best_valid_loss = valid_loss 231 | shutil.copy(str(model_path), str(best_model_path)) 232 | elif (patience and epoch - lr_reset_epoch > patience and 233 | min(valid_losses[-patience:]) > best_valid_loss): 234 | # "patience" epochs without improvement 235 | lr_changes +=1 236 | if lr_changes > max_lr_changes: 237 | break 238 | lr /= 5 239 | print(f'lr updated to {lr}') 240 | lr_reset_epoch = epoch 241 | optimizer = init_optimizer(params, lr) 242 | except KeyboardInterrupt: 243 | tq.close() 244 | print('Ctrl+C, saving snapshot') 245 | save(epoch) 246 | print('done.') 247 | return False 248 | return True 249 | 250 | 251 | def validation( 252 | model: nn.Module, criterion, valid_loader, use_cuda, 253 | ) -> Dict[str, float]: 254 | model.eval() 255 | all_losses, all_predictions, all_targets = [], [], [] 256 | with torch.no_grad(): 257 | for inputs, targets in valid_loader: 258 | all_targets.append(targets.numpy().copy()) 259 | if use_cuda: 260 | inputs, targets = inputs.cuda(), targets.cuda() 261 | outputs = model(inputs) 262 | loss = criterion(outputs, targets) 263 | all_losses.append(_reduce_loss(loss).item()) 264 | predictions = torch.sigmoid(outputs) 265 | all_predictions.append(predictions.cpu().numpy()) 266 | all_predictions = np.concatenate(all_predictions) 267 | all_targets = np.concatenate(all_targets) 268 | 269 | def get_score(y_pred): 270 | with warnings.catch_warnings(): 271 | warnings.simplefilter('ignore', category=UndefinedMetricWarning) 272 | return fbeta_score( 273 | all_targets, y_pred, beta=2, average='samples') 274 | 275 | metrics = {} 276 | argsorted = all_predictions.argsort(axis=1) 277 | for threshold in [0.05, 0.10, 0.15, 0.20]: 278 | metrics[f'valid_f2_th_{threshold:.2f}'] = get_score( 279 | binarize_prediction(all_predictions, threshold, argsorted)) 280 | metrics['valid_loss'] = np.mean(all_losses) 281 | print(' | '.join(f'{k} {v:.3f}' for k, v in sorted( 282 | metrics.items(), key=lambda kv: -kv[1]))) 283 | 284 | return metrics 285 | 286 | 287 | def binarize_prediction(probabilities, threshold: float, argsorted=None, 288 | min_labels=1, max_labels=10): 289 | """ Return matrix of 0/1 predictions, same shape as probabilities. 290 | """ 291 | assert probabilities.shape[1] == N_CLASSES 292 | if argsorted is None: 293 | argsorted = probabilities.argsort(axis=1) 294 | max_mask = _make_mask(argsorted, max_labels) 295 | min_mask = _make_mask(argsorted, min_labels) 296 | prob_mask = probabilities > threshold 297 | return (max_mask & prob_mask) | min_mask 298 | 299 | 300 | def _make_mask(argsorted, top_n: int): 301 | mask = np.zeros_like(argsorted, dtype=np.uint8) 302 | col_indices = argsorted[:, -top_n:].reshape(-1) 303 | row_indices = [i // top_n for i in range(len(col_indices))] 304 | mask[row_indices, col_indices] = 1 305 | return mask 306 | 307 | 308 | def _reduce_loss(loss): 309 | return loss.sum() / loss.shape[0] 310 | 311 | 312 | if __name__ == '__main__': 313 | main() 314 | --------------------------------------------------------------------------------