├── docs ├── requirements.txt ├── figures │ ├── structure.png │ └── data-loading.png ├── dist_metric.rst ├── trainers.rst ├── evaluators.rst ├── datasets.rst ├── models.rst ├── Makefile ├── make.bat ├── index.rst ├── notes │ ├── overview.rst │ ├── evaluation_metrics.rst │ └── data_modules.rst ├── _static │ └── css │ │ └── openreid_theme.css ├── examples │ ├── benchmarks.rst │ └── training_id.rst └── conf.py ├── setup.cfg ├── reid ├── utils │ ├── data │ │ ├── __init__.py │ │ ├── preprocessor.py │ │ ├── sampler.py │ │ ├── transforms.py │ │ └── dataset.py │ ├── osutils.py │ ├── meters.py │ ├── __init__.py │ ├── logging.py │ └── serialization.py ├── evaluation_metrics │ ├── __init__.py │ ├── classification.py │ └── ranking.py ├── loss │ ├── __init__.py │ ├── triplet.py │ └── oim.py ├── feature_extraction │ ├── __init__.py │ ├── cnn.py │ └── database.py ├── __init__.py ├── metric_learning │ ├── euclidean.py │ ├── __init__.py │ └── kissme.py ├── dist_metric.py ├── datasets │ ├── __init__.py │ ├── cuhk01.py │ ├── viper.py │ ├── market1501.py │ ├── dukemtmc.py │ └── cuhk03.py ├── models │ ├── __init__.py │ ├── resnet.py │ └── inception.py ├── trainers.py └── evaluators.py ├── setup.py ├── test ├── models │ └── test_inception.py ├── loss │ └── test_oim.py ├── feature_extraction │ └── test_database.py ├── datasets │ ├── test_cuhk01.py │ ├── test_viper.py │ ├── test_cuhk03.py │ ├── test_dukemtmc.py │ └── test_market1501.py ├── utils │ └── data │ │ └── test_preprocessor.py └── evaluation_metrics │ └── test_cmc.py ├── README.md ├── LICENSE ├── .gitignore └── examples ├── softmax_loss.py ├── triplet_loss.py └── oim_loss.py /docs/requirements.txt: -------------------------------------------------------------------------------- 1 | sphinx 2 | sphinx-rtd-theme 3 | -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [metadata] 2 | description-file = README.md 3 | -------------------------------------------------------------------------------- /docs/figures/structure.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cysu/open-reid/HEAD/docs/figures/structure.png -------------------------------------------------------------------------------- /docs/figures/data-loading.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Cysu/open-reid/HEAD/docs/figures/data-loading.png -------------------------------------------------------------------------------- /reid/utils/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .dataset import Dataset 4 | from .preprocessor import Preprocessor 5 | -------------------------------------------------------------------------------- /docs/dist_metric.rst: -------------------------------------------------------------------------------- 1 | ================ 2 | reid.dist_metric 3 | ================ 4 | 5 | .. automodule:: reid.dist_metric 6 | .. currentmodule:: reid.dist_metric 7 | 8 | .. autoclass:: reid.dist_metric.DistanceMetric 9 | :members: -------------------------------------------------------------------------------- /reid/evaluation_metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .classification import accuracy 4 | from .ranking import cmc, mean_ap 5 | 6 | __all__ = [ 7 | 'accuracy', 8 | 'cmc', 9 | 'mean_ap', 10 | ] 11 | -------------------------------------------------------------------------------- /reid/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .oim import oim, OIM, OIMLoss 4 | from .triplet import TripletLoss 5 | 6 | __all__ = [ 7 | 'oim', 8 | 'OIM', 9 | 'OIMLoss', 10 | 'TripletLoss', 11 | ] 12 | -------------------------------------------------------------------------------- /docs/trainers.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | reid.trainers 3 | ============= 4 | 5 | .. automodule:: reid.trainers 6 | .. currentmodule:: reid.trainers 7 | 8 | .. autoclass:: BaseTrainer 9 | :members: 10 | 11 | .. autoclass:: Trainer 12 | :members: 13 | -------------------------------------------------------------------------------- /reid/feature_extraction/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .cnn import extract_cnn_feature 4 | from .database import FeatureDatabase 5 | 6 | __all__ = [ 7 | 'extract_cnn_feature', 8 | 'FeatureDatabase', 9 | ] 10 | -------------------------------------------------------------------------------- /reid/utils/osutils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import errno 4 | 5 | 6 | def mkdir_if_missing(dir_path): 7 | try: 8 | os.makedirs(dir_path) 9 | except OSError as e: 10 | if e.errno != errno.EEXIST: 11 | raise 12 | -------------------------------------------------------------------------------- /docs/evaluators.rst: -------------------------------------------------------------------------------- 1 | =============== 2 | reid.evaluators 3 | =============== 4 | 5 | .. automodule:: reid.evaluators 6 | .. currentmodule:: reid.evaluators 7 | 8 | .. autofunction:: extract_features 9 | .. autofunction:: pairwise_distance 10 | .. autofunction:: evaluate_all 11 | .. autoclass:: Evaluator 12 | :members: -------------------------------------------------------------------------------- /docs/datasets.rst: -------------------------------------------------------------------------------- 1 | ============= 2 | reid.datasets 3 | ============= 4 | 5 | .. automodule:: reid.datasets 6 | .. currentmodule:: reid.datasets 7 | 8 | .. autofunction:: create 9 | 10 | .. autoclass:: CUHK01 11 | .. autoclass:: CUHK03 12 | .. autoclass:: DukeMTMC 13 | .. autoclass:: Market1501 14 | .. autoclass:: VIPeR 15 | -------------------------------------------------------------------------------- /reid/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from . import datasets 4 | from . import evaluation_metrics 5 | from . import feature_extraction 6 | from . import loss 7 | from . import metric_learning 8 | from . import models 9 | from . import utils 10 | from . import dist_metric 11 | from . import evaluators 12 | from . import trainers 13 | 14 | __version__ = '0.2.0' 15 | -------------------------------------------------------------------------------- /docs/models.rst: -------------------------------------------------------------------------------- 1 | =========== 2 | reid.models 3 | =========== 4 | 5 | .. automodule:: reid.models 6 | .. currentmodule:: reid.models 7 | 8 | .. autofunction:: create 9 | .. autofunction:: inception 10 | .. autofunction:: resnet18 11 | .. autofunction:: resnet34 12 | .. autofunction:: resnet50 13 | .. autofunction:: resnet101 14 | .. autofunction:: resnet152 15 | 16 | .. autoclass:: InceptionNet 17 | .. autoclass:: ResNet 18 | -------------------------------------------------------------------------------- /reid/metric_learning/euclidean.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | class Euclidean(BaseMetricLearner): 8 | def __init__(self): 9 | self.M_ = None 10 | 11 | def metric(self): 12 | return self.M_ 13 | 14 | def fit(self, X): 15 | self.M_ = np.eye(X.shape[1]) 16 | self.X_ = X 17 | 18 | def transform(self, X=None): 19 | if X is None: 20 | return self.X_ 21 | return X 22 | -------------------------------------------------------------------------------- /reid/utils/meters.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | 4 | class AverageMeter(object): 5 | """Computes and stores the average and current value""" 6 | 7 | def __init__(self): 8 | self.val = 0 9 | self.avg = 0 10 | self.sum = 0 11 | self.count = 0 12 | 13 | def reset(self): 14 | self.val = 0 15 | self.avg = 0 16 | self.sum = 0 17 | self.count = 0 18 | 19 | def update(self, val, n=1): 20 | self.val = val 21 | self.sum += val * n 22 | self.count += n 23 | self.avg = self.sum / self.count 24 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/classification.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from ..utils import to_torch 4 | 5 | 6 | def accuracy(output, target, topk=(1,)): 7 | output, target = to_torch(output), to_torch(target) 8 | maxk = max(topk) 9 | batch_size = target.size(0) 10 | 11 | _, pred = output.topk(maxk, 1, True, True) 12 | pred = pred.t() 13 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 14 | 15 | ret = [] 16 | for k in topk: 17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True) 18 | ret.append(correct_k.mul_(1. / batch_size)) 19 | return ret 20 | -------------------------------------------------------------------------------- /docs/Makefile: -------------------------------------------------------------------------------- 1 | # Minimal makefile for Sphinx documentation 2 | # 3 | 4 | # You can set these variables from the command line. 5 | SPHINXOPTS = 6 | SPHINXBUILD = sphinx-build 7 | SPHINXPROJ = Open-ReID 8 | SOURCEDIR = . 9 | BUILDDIR = _build 10 | 11 | # Put it first so that "make" without argument is like "make help". 12 | help: 13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) 14 | 15 | .PHONY: help Makefile 16 | 17 | # Catch-all target: route all unknown targets to Sphinx using the new 18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). 19 | %: Makefile 20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) -------------------------------------------------------------------------------- /reid/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | 6 | def to_numpy(tensor): 7 | if torch.is_tensor(tensor): 8 | return tensor.cpu().numpy() 9 | elif type(tensor).__module__ != 'numpy': 10 | raise ValueError("Cannot convert {} to numpy array" 11 | .format(type(tensor))) 12 | return tensor 13 | 14 | 15 | def to_torch(ndarray): 16 | if type(ndarray).__module__ == 'numpy': 17 | return torch.from_numpy(ndarray) 18 | elif not torch.is_tensor(ndarray): 19 | raise ValueError("Cannot convert {} to torch tensor" 20 | .format(type(ndarray))) 21 | return ndarray 22 | -------------------------------------------------------------------------------- /reid/metric_learning/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised, 4 | SDML_Supervised, NCA, LFDA, RCA_Supervised) 5 | 6 | from .euclidean import Euclidean 7 | from .kissme import KISSME 8 | 9 | __factory = { 10 | 'euclidean': Euclidean, 11 | 'kissme': KISSME, 12 | 'itml': ITML_Supervised, 13 | 'lmnn': LMNN, 14 | 'lsml': LSML_Supervised, 15 | 'sdml': SDML_Supervised, 16 | 'nca': NCA, 17 | 'lfda': LFDA, 18 | 'rca': RCA_Supervised, 19 | } 20 | 21 | 22 | def get_metric(algorithm, *args, **kwargs): 23 | if algorithm not in __factory: 24 | raise KeyError("Unknown metric:", algorithm) 25 | return __factory[algorithm](*args, **kwargs) 26 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup, find_packages 2 | 3 | 4 | setup(name='open-reid', 5 | version='0.2.0', 6 | description='Deep Learning Library for Person Re-identification', 7 | author='Tong Xiao', 8 | author_email='st.cysu@gmail.com', 9 | url='https://github.com/Cysu/open-reid', 10 | license='MIT', 11 | install_requires=[ 12 | 'numpy', 'scipy', 'torch', 'torchvision', 13 | 'six', 'h5py', 'Pillow', 14 | 'scikit-learn', 'metric-learn'], 15 | extras_require={ 16 | 'docs': ['sphinx', 'sphinx_rtd_theme'], 17 | }, 18 | packages=find_packages(), 19 | keywords=[ 20 | 'Person Re-identification', 21 | 'Computer Vision', 22 | 'Deep Learning', 23 | ]) 24 | -------------------------------------------------------------------------------- /test/models/test_inception.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestInception(TestCase): 5 | def test_forward(self): 6 | import torch 7 | from torch.autograd import Variable 8 | from reid.models.inception import InceptionNet 9 | 10 | # model = Inception(num_classes=5, num_features=256, dropout=0.5) 11 | # x = Variable(torch.randn(10, 3, 144, 56), requires_grad=False) 12 | # y = model(x) 13 | # self.assertEquals(y.size(), (10, 5)) 14 | 15 | model = InceptionNet(num_features=8, norm=True, dropout=0) 16 | x = Variable(torch.randn(10, 3, 144, 56), requires_grad=False) 17 | y = model(x) 18 | self.assertEquals(y.size(), (10, 8)) 19 | self.assertEquals(y.norm(2, 1).max(), 1) 20 | self.assertEquals(y.norm(2, 1).min(), 1) 21 | -------------------------------------------------------------------------------- /test/loss/test_oim.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestOIMLoss(TestCase): 5 | def test_forward_backward(self): 6 | import torch 7 | import torch.nn.functional as F 8 | from torch.autograd import Variable 9 | from reid.loss import OIMLoss 10 | criterion = OIMLoss(3, 3, scalar=1.0, size_average=False) 11 | criterion.lut = torch.eye(3) 12 | x = Variable(torch.randn(3, 3), requires_grad=True) 13 | y = Variable(torch.range(0, 2).long()) 14 | loss = criterion(x, y) 15 | loss.backward() 16 | probs = F.softmax(x) 17 | grads = probs.data - torch.eye(3) 18 | abs_diff = torch.abs(grads - x.grad.data) 19 | self.assertEquals(torch.log(probs).diag().sum(), -loss) 20 | self.assertTrue(torch.max(abs_diff) < 1e-6) 21 | -------------------------------------------------------------------------------- /reid/feature_extraction/cnn.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import OrderedDict 3 | 4 | from torch.autograd import Variable 5 | 6 | from ..utils import to_torch 7 | 8 | 9 | def extract_cnn_feature(model, inputs, modules=None): 10 | model.eval() 11 | inputs = to_torch(inputs) 12 | inputs = Variable(inputs, volatile=True) 13 | if modules is None: 14 | outputs = model(inputs) 15 | outputs = outputs.data.cpu() 16 | return outputs 17 | # Register forward hook for each module 18 | outputs = OrderedDict() 19 | handles = [] 20 | for m in modules: 21 | outputs[id(m)] = None 22 | def func(m, i, o): outputs[id(m)] = o.data.cpu() 23 | handles.append(m.register_forward_hook(func)) 24 | model(inputs) 25 | for h in handles: 26 | h.remove() 27 | return list(outputs.values()) 28 | -------------------------------------------------------------------------------- /test/feature_extraction/test_database.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | import numpy as np 4 | 5 | from reid.feature_extraction.database import FeatureDatabase 6 | 7 | 8 | class TestFeatureDatabase(TestCase): 9 | def test_all(self): 10 | with FeatureDatabase('/tmp/open-reid/test.h5', 'w') as db: 11 | db['img1'] = np.random.rand(3, 8, 8).astype(np.float32) 12 | db['img2'] = np.arange(10) 13 | db['img2'] = np.arange(10).reshape(2, 5).astype(np.float32) 14 | with FeatureDatabase('/tmp/open-reid/test.h5', 'r') as db: 15 | self.assertTrue('img1' in db) 16 | self.assertTrue('img2' in db) 17 | self.assertEquals(db['img1'].shape, (3, 8, 8)) 18 | x = db['img2'] 19 | self.assertEquals(x.shape, (2, 5)) 20 | self.assertTrue(np.all(x == np.arange(10).reshape(2, 5))) 21 | 22 | -------------------------------------------------------------------------------- /test/datasets/test_cuhk01.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestCUHK01(TestCase): 5 | def test_init(self): 6 | import os.path as osp 7 | from reid.datasets import CUHK01 8 | from reid.utils.serialization import read_json 9 | 10 | root, split_id, num_val = '/tmp/open-reid/cuhk01', 0, 100 11 | dataset = CUHK01(root, split_id=split_id, num_val=num_val, download=True) 12 | 13 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json'))) 14 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json'))) 15 | meta = read_json(osp.join(root, 'meta.json')) 16 | self.assertEquals(len(meta['identities']), 971) 17 | splits = read_json(osp.join(root, 'splits.json')) 18 | self.assertEquals(len(splits), 10) 19 | 20 | self.assertDictEqual(meta, dataset.meta) 21 | self.assertDictEqual(splits[split_id], dataset.split) 22 | -------------------------------------------------------------------------------- /test/datasets/test_viper.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestVIPeR(TestCase): 5 | def test_init(self): 6 | import os.path as osp 7 | from reid.datasets.viper import VIPeR 8 | from reid.utils.serialization import read_json 9 | 10 | root, split_id, num_val = '/tmp/open-reid/viper', 0, 100 11 | dataset = VIPeR(root, split_id=split_id, num_val=num_val, download=True) 12 | 13 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json'))) 14 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json'))) 15 | meta = read_json(osp.join(root, 'meta.json')) 16 | self.assertEquals(len(meta['identities']), 632) 17 | splits = read_json(osp.join(root, 'splits.json')) 18 | self.assertEquals(len(splits), 10) 19 | 20 | self.assertDictEqual(meta, dataset.meta) 21 | self.assertDictEqual(splits[split_id], dataset.split) 22 | -------------------------------------------------------------------------------- /test/datasets/test_cuhk03.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestCUHK03(TestCase): 5 | def test_init(self): 6 | import os.path as osp 7 | from reid.datasets.cuhk03 import CUHK03 8 | from reid.utils.serialization import read_json 9 | 10 | root, split_id, num_val = '/tmp/open-reid/cuhk03', 0, 100 11 | dataset = CUHK03(root, split_id=split_id, num_val=num_val, download=True) 12 | 13 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json'))) 14 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json'))) 15 | meta = read_json(osp.join(root, 'meta.json')) 16 | self.assertEquals(len(meta['identities']), 1467) 17 | splits = read_json(osp.join(root, 'splits.json')) 18 | self.assertEquals(len(splits), 20) 19 | 20 | self.assertDictEqual(meta, dataset.meta) 21 | self.assertDictEqual(splits[split_id], dataset.split) 22 | -------------------------------------------------------------------------------- /test/datasets/test_dukemtmc.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestDukeMTMC(TestCase): 5 | def test_all(self): 6 | import os.path as osp 7 | from reid.datasets import DukeMTMC 8 | from reid.utils.serialization import read_json 9 | 10 | root, split_id, num_val = '/tmp/open-reid/dukemtmc', 0, 100 11 | dataset = DukeMTMC(root, split_id=split_id, num_val=num_val, 12 | download=True) 13 | 14 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json'))) 15 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json'))) 16 | meta = read_json(osp.join(root, 'meta.json')) 17 | self.assertEquals(len(meta['identities']), 1812) 18 | splits = read_json(osp.join(root, 'splits.json')) 19 | self.assertEquals(len(splits), 1) 20 | 21 | self.assertDictEqual(meta, dataset.meta) 22 | self.assertDictEqual(splits[split_id], dataset.split) -------------------------------------------------------------------------------- /docs/make.bat: -------------------------------------------------------------------------------- 1 | @ECHO OFF 2 | 3 | pushd %~dp0 4 | 5 | REM Command file for Sphinx documentation 6 | 7 | if "%SPHINXBUILD%" == "" ( 8 | set SPHINXBUILD=sphinx-build 9 | ) 10 | set SOURCEDIR=. 11 | set BUILDDIR=_build 12 | set SPHINXPROJ=Open-ReID 13 | 14 | if "%1" == "" goto help 15 | 16 | %SPHINXBUILD% >NUL 2>NUL 17 | if errorlevel 9009 ( 18 | echo. 19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx 20 | echo.installed, then set the SPHINXBUILD environment variable to point 21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you 22 | echo.may add the Sphinx directory to PATH. 23 | echo. 24 | echo.If you don't have Sphinx installed, grab it from 25 | echo.http://sphinx-doc.org/ 26 | exit /b 1 27 | ) 28 | 29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 30 | goto end 31 | 32 | :help 33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% 34 | 35 | :end 36 | popd 37 | -------------------------------------------------------------------------------- /test/datasets/test_market1501.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestMarket1501(TestCase): 5 | def test_init(self): 6 | import os.path as osp 7 | from reid.datasets.market1501 import Market1501 8 | from reid.utils.serialization import read_json 9 | 10 | root, split_id, num_val = '/tmp/open-reid/market1501', 0, 100 11 | dataset = Market1501(root, split_id=split_id, num_val=num_val, 12 | download=True) 13 | 14 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json'))) 15 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json'))) 16 | meta = read_json(osp.join(root, 'meta.json')) 17 | self.assertEquals(len(meta['identities']), 1502) 18 | splits = read_json(osp.join(root, 'splits.json')) 19 | self.assertEquals(len(splits), 1) 20 | 21 | self.assertDictEqual(meta, dataset.meta) 22 | self.assertDictEqual(splits[split_id], dataset.split) 23 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Open-ReID 2 | 3 | Open-ReID is a lightweight library of person re-identification for research 4 | purpose. It aims to provide a uniform interface for different datasets, a full 5 | set of models and evaluation metrics, as well as examples to reproduce (near) 6 | state-of-the-art results. 7 | 8 | ## Installation 9 | 10 | Install [PyTorch](http://pytorch.org/) (version >= 0.2.0). Although we support 11 | both python2 and python3, we recommend python3 for better performance. 12 | 13 | ```shell 14 | git clone https://github.com/Cysu/open-reid.git 15 | cd open-reid 16 | python setup.py install 17 | ``` 18 | 19 | ## Examples 20 | 21 | ```shell 22 | python examples/softmax_loss.py -d viper -b 64 -j 2 -a resnet50 --logs-dir logs/softmax-loss/viper-resnet50 23 | ``` 24 | 25 | This is just a quick example. VIPeR dataset may not be large enough to train a deep neural network. 26 | 27 | Check about more [examples](https://cysu.github.io/open-reid/examples/training_id.html) 28 | and [benchmarks](https://cysu.github.io/open-reid/examples/benchmarks.html). 29 | -------------------------------------------------------------------------------- /reid/utils/logging.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os 3 | import sys 4 | 5 | from .osutils import mkdir_if_missing 6 | 7 | 8 | class Logger(object): 9 | def __init__(self, fpath=None): 10 | self.console = sys.stdout 11 | self.file = None 12 | if fpath is not None: 13 | mkdir_if_missing(os.path.dirname(fpath)) 14 | self.file = open(fpath, 'w') 15 | 16 | def __del__(self): 17 | self.close() 18 | 19 | def __enter__(self): 20 | pass 21 | 22 | def __exit__(self, *args): 23 | self.close() 24 | 25 | def write(self, msg): 26 | self.console.write(msg) 27 | if self.file is not None: 28 | self.file.write(msg) 29 | 30 | def flush(self): 31 | self.console.flush() 32 | if self.file is not None: 33 | self.file.flush() 34 | os.fsync(self.file.fileno()) 35 | 36 | def close(self): 37 | self.console.close() 38 | if self.file is not None: 39 | self.file.close() 40 | -------------------------------------------------------------------------------- /reid/dist_metric.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | 5 | from .evaluators import extract_features 6 | from .metric_learning import get_metric 7 | 8 | 9 | class DistanceMetric(object): 10 | def __init__(self, algorithm='euclidean', *args, **kwargs): 11 | super(DistanceMetric, self).__init__() 12 | self.algorithm = algorithm 13 | self.metric = get_metric(algorithm, *args, **kwargs) 14 | 15 | def train(self, model, data_loader): 16 | if self.algorithm == 'euclidean': return 17 | features, labels = extract_features(model, data_loader) 18 | features = torch.stack(features.values()).numpy() 19 | labels = torch.Tensor(list(labels.values())).numpy() 20 | self.metric.fit(features, labels) 21 | 22 | def transform(self, X): 23 | if torch.is_tensor(X): 24 | X = X.numpy() 25 | X = self.metric.transform(X) 26 | X = torch.from_numpy(X) 27 | else: 28 | X = self.metric.transform(X) 29 | return X 30 | 31 | -------------------------------------------------------------------------------- /reid/utils/data/preprocessor.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import os.path as osp 3 | 4 | from PIL import Image 5 | 6 | 7 | class Preprocessor(object): 8 | def __init__(self, dataset, root=None, transform=None): 9 | super(Preprocessor, self).__init__() 10 | self.dataset = dataset 11 | self.root = root 12 | self.transform = transform 13 | 14 | def __len__(self): 15 | return len(self.dataset) 16 | 17 | def __getitem__(self, indices): 18 | if isinstance(indices, (tuple, list)): 19 | return [self._get_single_item(index) for index in indices] 20 | return self._get_single_item(indices) 21 | 22 | def _get_single_item(self, index): 23 | fname, pid, camid = self.dataset[index] 24 | fpath = fname 25 | if self.root is not None: 26 | fpath = osp.join(self.root, fname) 27 | img = Image.open(fpath).convert('RGB') 28 | if self.transform is not None: 29 | img = self.transform(img) 30 | return img, fname, pid, camid 31 | -------------------------------------------------------------------------------- /docs/index.rst: -------------------------------------------------------------------------------- 1 | .. Open-ReID documentation master file, created by 2 | sphinx-quickstart on Sun Mar 19 15:49:51 2017. 3 | You can adapt this file completely to your liking, but it should at least 4 | contain the root `toctree` directive. 5 | 6 | Open-ReID documentation 7 | ======================= 8 | 9 | Open-ReID is a deep learning library for person re-identification based on PyTorch. 10 | 11 | .. toctree:: 12 | :glob: 13 | :maxdepth: 1 14 | :caption: Notes 15 | 16 | notes/overview 17 | notes/data_modules 18 | notes/evaluation_metrics 19 | 20 | .. toctree:: 21 | :glob: 22 | :maxdepth: 1 23 | :caption: Examples 24 | 25 | examples/training_id 26 | examples/benchmarks 27 | 28 | .. toctree:: 29 | :maxdepth: 1 30 | :caption: SDK Level Reference 31 | 32 | trainers 33 | evaluators 34 | dist_metric 35 | 36 | .. toctree:: 37 | :maxdepth: 1 38 | :caption: API Level Reference 39 | 40 | datasets 41 | models 42 | 43 | Indices and tables 44 | ================== 45 | 46 | * :ref:`genindex` 47 | * :ref:`modindex` 48 | -------------------------------------------------------------------------------- /test/utils/data/test_preprocessor.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | 3 | 4 | class TestPreprocessor(TestCase): 5 | def test_getitem(self): 6 | import torchvision.transforms as t 7 | from reid.datasets.viper import VIPeR 8 | from reid.utils.data.preprocessor import Preprocessor 9 | 10 | root, split_id, num_val = '/tmp/open-reid/viper', 0, 100 11 | dataset = VIPeR(root, split_id=split_id, num_val=num_val, download=True) 12 | 13 | preproc = Preprocessor(dataset.train, root=dataset.images_dir, 14 | transform=t.Compose([ 15 | t.Scale(256), 16 | t.CenterCrop(224), 17 | t.ToTensor(), 18 | t.Normalize(mean=[0.485, 0.456, 0.406], 19 | std=[0.229, 0.224, 0.225]) 20 | ])) 21 | self.assertEquals(len(preproc), len(dataset.train)) 22 | img, pid, camid = preproc[0] 23 | self.assertEquals(img.size(), (3, 224, 224)) 24 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 Tong Xiao 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /reid/utils/data/sampler.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | import torch 6 | from torch.utils.data.sampler import ( 7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler, 8 | WeightedRandomSampler) 9 | 10 | 11 | class RandomIdentitySampler(Sampler): 12 | def __init__(self, data_source, num_instances=1): 13 | self.data_source = data_source 14 | self.num_instances = num_instances 15 | self.index_dic = defaultdict(list) 16 | for index, (_, pid, _) in enumerate(data_source): 17 | self.index_dic[pid].append(index) 18 | self.pids = list(self.index_dic.keys()) 19 | self.num_samples = len(self.pids) 20 | 21 | def __len__(self): 22 | return self.num_samples * self.num_instances 23 | 24 | def __iter__(self): 25 | indices = torch.randperm(self.num_samples) 26 | ret = [] 27 | for i in indices: 28 | pid = self.pids[i] 29 | t = self.index_dic[pid] 30 | if len(t) >= self.num_instances: 31 | t = np.random.choice(t, size=self.num_instances, replace=False) 32 | else: 33 | t = np.random.choice(t, size=self.num_instances, replace=True) 34 | ret.extend(t) 35 | return iter(ret) 36 | -------------------------------------------------------------------------------- /test/evaluation_metrics/test_cmc.py: -------------------------------------------------------------------------------- 1 | from unittest import TestCase 2 | import numpy as np 3 | 4 | from reid.evaluation_metrics import cmc 5 | 6 | 7 | class TestCMC(TestCase): 8 | def test_only_distmat(self): 9 | distmat = np.array([[0, 1, 2, 3, 4], 10 | [1, 0, 2, 3, 4], 11 | [0, 1, 2, 3, 4], 12 | [0, 1, 2, 3, 4], 13 | [1, 2, 3, 4, 0]]) 14 | ret = cmc(distmat) 15 | self.assertTrue(np.all(ret[:5] == [0.6, 0.6, 0.8, 1.0, 1.0])) 16 | 17 | def test_duplicate_ids(self): 18 | distmat = np.tile(np.arange(4), (4, 1)) 19 | query_ids = [0, 0, 1, 1] 20 | gallery_ids = [0, 0, 1, 1] 21 | ret = cmc(distmat, query_ids=query_ids, gallery_ids=gallery_ids, topk=4, 22 | separate_camera_set=False, single_gallery_shot=False) 23 | self.assertTrue(np.all(ret == [0.5, 0.5, 1, 1])) 24 | 25 | def test_duplicate_cams(self): 26 | distmat = np.tile(np.arange(5), (5, 1)) 27 | query_ids = [0,0,0,1,1] 28 | gallery_ids = [0,0,0,1,1] 29 | query_cams = [0,0,0,0,0] 30 | gallery_cams = [0,1,1,1,1] 31 | ret = cmc(distmat, query_ids=query_ids, gallery_ids=gallery_ids, 32 | query_cams=query_cams, gallery_cams=gallery_cams, topk=5, 33 | separate_camera_set=False, single_gallery_shot=False) 34 | self.assertTrue(np.all(ret == [0.6, 0.6, 0.6, 1, 1])) 35 | -------------------------------------------------------------------------------- /reid/loss/triplet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.autograd import Variable 6 | 7 | 8 | class TripletLoss(nn.Module): 9 | def __init__(self, margin=0): 10 | super(TripletLoss, self).__init__() 11 | self.margin = margin 12 | self.ranking_loss = nn.MarginRankingLoss(margin=margin) 13 | 14 | def forward(self, inputs, targets): 15 | n = inputs.size(0) 16 | # Compute pairwise distance, replace by the official when merged 17 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) 18 | dist = dist + dist.t() 19 | dist.addmm_(1, -2, inputs, inputs.t()) 20 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability 21 | # For each anchor, find the hardest positive and negative 22 | mask = targets.expand(n, n).eq(targets.expand(n, n).t()) 23 | dist_ap, dist_an = [], [] 24 | for i in range(n): 25 | dist_ap.append(dist[i][mask[i]].max()) 26 | dist_an.append(dist[i][mask[i] == 0].min()) 27 | dist_ap = torch.cat(dist_ap) 28 | dist_an = torch.cat(dist_an) 29 | # Compute ranking hinge loss 30 | y = dist_an.data.new() 31 | y.resize_as_(dist_an.data) 32 | y.fill_(1) 33 | y = Variable(y) 34 | loss = self.ranking_loss(dist_an, dist_ap, y) 35 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0) 36 | return loss, prec 37 | -------------------------------------------------------------------------------- /reid/feature_extraction/database.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import h5py 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | 7 | 8 | class FeatureDatabase(Dataset): 9 | def __init__(self, *args, **kwargs): 10 | super(FeatureDatabase, self).__init__() 11 | self.fid = h5py.File(*args, **kwargs) 12 | 13 | def __enter__(self): 14 | return self 15 | 16 | def __exit__(self, exc_type, exc_val, exc_tb): 17 | self.close() 18 | 19 | def __getitem__(self, keys): 20 | if isinstance(keys, (tuple, list)): 21 | return [self._get_single_item(k) for k in keys] 22 | return self._get_single_item(keys) 23 | 24 | def _get_single_item(self, key): 25 | return np.asarray(self.fid[key]) 26 | 27 | def __setitem__(self, key, value): 28 | if key in self.fid: 29 | if self.fid[key].shape == value.shape and \ 30 | self.fid[key].dtype == value.dtype: 31 | self.fid[key][...] = value 32 | else: 33 | del self.fid[key] 34 | self.fid.create_dataset(key, data=value) 35 | else: 36 | self.fid.create_dataset(key, data=value) 37 | 38 | def __delitem__(self, key): 39 | del self.fid[key] 40 | 41 | def __len__(self): 42 | return len(self.fid) 43 | 44 | def __iter__(self): 45 | return iter(self.fid) 46 | 47 | def flush(self): 48 | self.fid.flush() 49 | 50 | def close(self): 51 | self.fid.close() 52 | -------------------------------------------------------------------------------- /reid/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | import warnings 3 | 4 | from .cuhk01 import CUHK01 5 | from .cuhk03 import CUHK03 6 | from .dukemtmc import DukeMTMC 7 | from .market1501 import Market1501 8 | from .viper import VIPeR 9 | 10 | 11 | __factory = { 12 | 'viper': VIPeR, 13 | 'cuhk01': CUHK01, 14 | 'cuhk03': CUHK03, 15 | 'market1501': Market1501, 16 | 'dukemtmc': DukeMTMC, 17 | } 18 | 19 | 20 | def names(): 21 | return sorted(__factory.keys()) 22 | 23 | 24 | def create(name, root, *args, **kwargs): 25 | """ 26 | Create a dataset instance. 27 | 28 | Parameters 29 | ---------- 30 | name : str 31 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03', 32 | 'market1501', and 'dukemtmc'. 33 | root : str 34 | The path to the dataset directory. 35 | split_id : int, optional 36 | The index of data split. Default: 0 37 | num_val : int or float, optional 38 | When int, it means the number of validation identities. When float, 39 | it means the proportion of validation to all the trainval. Default: 100 40 | download : bool, optional 41 | If True, will download the dataset. Default: False 42 | """ 43 | if name not in __factory: 44 | raise KeyError("Unknown dataset:", name) 45 | return __factory[name](root, *args, **kwargs) 46 | 47 | 48 | def get_dataset(name, root, *args, **kwargs): 49 | warnings.warn("get_dataset is deprecated. Use create instead.") 50 | return create(name, root, *args, **kwargs) 51 | -------------------------------------------------------------------------------- /docs/notes/overview.rst: -------------------------------------------------------------------------------- 1 | ===================== 2 | Overview of Open-ReID 3 | ===================== 4 | 5 | Open Re-ID is a lightweight library of person re-identification for research 6 | purpose. It aims to provide a uniform interface for different datasets, a full 7 | set of models and evaluation metrics, as well as examples to reproduce (near) 8 | state-of-the-art results. Open-ReID is mainly based on `PyTorch 9 | `_. 10 | 11 | --------- 12 | Structure 13 | --------- 14 | 15 | Open-ReID is structured into three levels, as shown in the figure below. 16 | 17 | .. _fig-structure: 18 | .. figure:: /figures/structure.png 19 | 20 | API Level 21 | At bottom, there are decoupled modules each providing unit functions. For 22 | example, the ``datasets`` module has a uniform interface for many popular 23 | datasets, while commonly used evaluation metrics, such as CMC and mean AP are 24 | implemented in the ``evaluation_metrics`` module, which accept both 25 | ``torch.Tensor`` and ``numpy.ndarray`` as inputs. 26 | 27 | SDK Level 28 | In the middle, several classes interact with underlying APIs to provide 29 | routines for standard tasks. For example, the ``Trainer`` can be used to 30 | train a deep model on training set, and ``Evaluator`` can evaluate the model 31 | on validation and test sets. 32 | 33 | Application Level 34 | At top, we provide several examples using Open-ReID. For example, one can 35 | easily train a CNN with different kinds of loss functions on different 36 | datasets, to achieve certain baselines or state-of-the-art results. 37 | -------------------------------------------------------------------------------- /reid/utils/data/transforms.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torchvision.transforms import * 4 | from PIL import Image 5 | import random 6 | import math 7 | 8 | 9 | class RectScale(object): 10 | def __init__(self, height, width, interpolation=Image.BILINEAR): 11 | self.height = height 12 | self.width = width 13 | self.interpolation = interpolation 14 | 15 | def __call__(self, img): 16 | w, h = img.size 17 | if h == self.height and w == self.width: 18 | return img 19 | return img.resize((self.width, self.height), self.interpolation) 20 | 21 | 22 | class RandomSizedRectCrop(object): 23 | def __init__(self, height, width, interpolation=Image.BILINEAR): 24 | self.height = height 25 | self.width = width 26 | self.interpolation = interpolation 27 | 28 | def __call__(self, img): 29 | for attempt in range(10): 30 | area = img.size[0] * img.size[1] 31 | target_area = random.uniform(0.64, 1.0) * area 32 | aspect_ratio = random.uniform(2, 3) 33 | 34 | h = int(round(math.sqrt(target_area * aspect_ratio))) 35 | w = int(round(math.sqrt(target_area / aspect_ratio))) 36 | 37 | if w <= img.size[0] and h <= img.size[1]: 38 | x1 = random.randint(0, img.size[0] - w) 39 | y1 = random.randint(0, img.size[1] - h) 40 | 41 | img = img.crop((x1, y1, x1 + w, y1 + h)) 42 | assert(img.size == (w, h)) 43 | 44 | return img.resize((self.width, self.height), self.interpolation) 45 | 46 | # Fallback 47 | scale = RectScale(self.height, self.width, 48 | interpolation=self.interpolation) 49 | return scale(img) 50 | -------------------------------------------------------------------------------- /reid/metric_learning/kissme.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import numpy as np 4 | from metric_learn.base_metric import BaseMetricLearner 5 | 6 | 7 | def validate_cov_matrix(M): 8 | M = (M + M.T) * 0.5 9 | k = 0 10 | I = np.eye(M.shape[0]) 11 | while True: 12 | try: 13 | _ = np.linalg.cholesky(M) 14 | break 15 | except np.linalg.LinAlgError: 16 | # Find the nearest positive definite matrix for M. Modified from 17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd 18 | # Might take several minutes 19 | k += 1 20 | w, v = np.linalg.eig(M) 21 | min_eig = v.min() 22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I 23 | return M 24 | 25 | 26 | class KISSME(BaseMetricLearner): 27 | def __init__(self): 28 | self.M_ = None 29 | 30 | def metric(self): 31 | return self.M_ 32 | 33 | def fit(self, X, y=None): 34 | n = X.shape[0] 35 | if y is None: 36 | y = np.arange(n) 37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n)) 38 | X1, X2 = X1[X1 < X2], X2[X1 < X2] 39 | matches = (y[X1] == y[X2]) 40 | num_matches = matches.sum() 41 | num_non_matches = len(matches) - num_matches 42 | idxa = X1[matches] 43 | idxb = X2[matches] 44 | S = X[idxa] - X[idxb] 45 | C1 = S.transpose().dot(S) / num_matches 46 | p = np.random.choice(num_non_matches, num_matches, replace=False) 47 | idxa = X1[~matches] 48 | idxb = X2[~matches] 49 | idxa = idxa[p] 50 | idxb = idxb[p] 51 | S = X[idxa] - X[idxb] 52 | C0 = S.transpose().dot(S) / num_matches 53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0) 54 | self.M_ = validate_cov_matrix(self.M_) 55 | self.X_ = X 56 | -------------------------------------------------------------------------------- /reid/loss/oim.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch import nn, autograd 6 | 7 | 8 | class OIM(autograd.Function): 9 | def __init__(self, lut, momentum=0.5): 10 | super(OIM, self).__init__() 11 | self.lut = lut 12 | self.momentum = momentum 13 | 14 | def forward(self, inputs, targets): 15 | self.save_for_backward(inputs, targets) 16 | outputs = inputs.mm(self.lut.t()) 17 | return outputs 18 | 19 | def backward(self, grad_outputs): 20 | inputs, targets = self.saved_tensors 21 | grad_inputs = None 22 | if self.needs_input_grad[0]: 23 | grad_inputs = grad_outputs.mm(self.lut) 24 | for x, y in zip(inputs, targets): 25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x 26 | self.lut[y] /= self.lut[y].norm() 27 | return grad_inputs, None 28 | 29 | 30 | def oim(inputs, targets, lut, momentum=0.5): 31 | return OIM(lut, momentum=momentum)(inputs, targets) 32 | 33 | 34 | class OIMLoss(nn.Module): 35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5, 36 | weight=None, size_average=True): 37 | super(OIMLoss, self).__init__() 38 | self.num_features = num_features 39 | self.num_classes = num_classes 40 | self.momentum = momentum 41 | self.scalar = scalar 42 | self.weight = weight 43 | self.size_average = size_average 44 | 45 | self.register_buffer('lut', torch.zeros(num_classes, num_features)) 46 | 47 | def forward(self, inputs, targets): 48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum) 49 | inputs *= self.scalar 50 | loss = F.cross_entropy(inputs, targets, weight=self.weight, 51 | size_average=self.size_average) 52 | return loss, inputs 53 | -------------------------------------------------------------------------------- /reid/utils/serialization.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import json 3 | import os.path as osp 4 | import shutil 5 | 6 | import torch 7 | from torch.nn import Parameter 8 | 9 | from .osutils import mkdir_if_missing 10 | 11 | 12 | def read_json(fpath): 13 | with open(fpath, 'r') as f: 14 | obj = json.load(f) 15 | return obj 16 | 17 | 18 | def write_json(obj, fpath): 19 | mkdir_if_missing(osp.dirname(fpath)) 20 | with open(fpath, 'w') as f: 21 | json.dump(obj, f, indent=4, separators=(',', ': ')) 22 | 23 | 24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'): 25 | mkdir_if_missing(osp.dirname(fpath)) 26 | torch.save(state, fpath) 27 | if is_best: 28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar')) 29 | 30 | 31 | def load_checkpoint(fpath): 32 | if osp.isfile(fpath): 33 | checkpoint = torch.load(fpath) 34 | print("=> Loaded checkpoint '{}'".format(fpath)) 35 | return checkpoint 36 | else: 37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath)) 38 | 39 | 40 | def copy_state_dict(state_dict, model, strip=None): 41 | tgt_state = model.state_dict() 42 | copied_names = set() 43 | for name, param in state_dict.items(): 44 | if strip is not None and name.startswith(strip): 45 | name = name[len(strip):] 46 | if name not in tgt_state: 47 | continue 48 | if isinstance(param, Parameter): 49 | param = param.data 50 | if param.size() != tgt_state[name].size(): 51 | print('mismatch:', name, param.size(), tgt_state[name].size()) 52 | continue 53 | tgt_state[name].copy_(param) 54 | copied_names.add(name) 55 | 56 | missing = set(tgt_state.keys()) - copied_names 57 | if len(missing) > 0: 58 | print("missing keys in state_dict:", missing) 59 | 60 | return model 61 | -------------------------------------------------------------------------------- /reid/models/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from .inception import * 4 | from .resnet import * 5 | 6 | 7 | __factory = { 8 | 'inception': inception, 9 | 'resnet18': resnet18, 10 | 'resnet34': resnet34, 11 | 'resnet50': resnet50, 12 | 'resnet101': resnet101, 13 | 'resnet152': resnet152, 14 | } 15 | 16 | 17 | def names(): 18 | return sorted(__factory.keys()) 19 | 20 | 21 | def create(name, *args, **kwargs): 22 | """ 23 | Create a model instance. 24 | 25 | Parameters 26 | ---------- 27 | name : str 28 | Model name. Can be one of 'inception', 'resnet18', 'resnet34', 29 | 'resnet50', 'resnet101', and 'resnet152'. 30 | pretrained : bool, optional 31 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained 32 | model. Default: True 33 | cut_at_pooling : bool, optional 34 | If True, will cut the model before the last global pooling layer and 35 | ignore the remaining kwargs. Default: False 36 | num_features : int, optional 37 | If positive, will append a Linear layer after the global pooling layer, 38 | with this number of output units, followed by a BatchNorm layer. 39 | Otherwise these layers will not be appended. Default: 256 for 40 | 'inception', 0 for 'resnet*' 41 | norm : bool, optional 42 | If True, will normalize the feature to be unit L2-norm for each sample. 43 | Otherwise will append a ReLU layer after the above Linear layer if 44 | num_features > 0. Default: False 45 | dropout : float, optional 46 | If positive, will append a Dropout layer with this dropout rate. 47 | Default: 0 48 | num_classes : int, optional 49 | If positive, will append a Linear layer at the end as the classifier 50 | with this number of output units. Default: 0 51 | """ 52 | if name not in __factory: 53 | raise KeyError("Unknown model:", name) 54 | return __factory[name](*args, **kwargs) 55 | -------------------------------------------------------------------------------- /docs/notes/evaluation_metrics.rst: -------------------------------------------------------------------------------- 1 | .. _evaluation-metrics: 2 | 3 | ================== 4 | Evaluation Metrics 5 | ================== 6 | 7 | Cumulative Matching Characteristics (CMC) curves are the most popular evaluation 8 | metrics for person re-identification methods. Consider a simple 9 | *single-gallery-shot* setting, where each gallery identity has only one 10 | instance. For each query, an algorithm will rank all the gallery samples 11 | according to their distances to the query from small to large, and the CMC top-k 12 | accuracy is 13 | 14 | .. math:: 15 | Acc_k = \begin{cases} 16 | 1 & \text{if top-$k$ ranked gallery samples contain the query identity} \\ 17 | 0 & \text{otherwise} 18 | \end{cases}, 19 | 20 | which is a shifted `step function `_. 21 | The final CMC curve is computed by averaging the shifted step functions over all 22 | the queries. 23 | 24 | While the *single-gallery-shot* CMC is well defined, it does not have a common 25 | agreement when it comes to the *multi-gallery-shot* setting, where each gallery 26 | identity could have multiple instances. For example, 27 | `CUHK03 `_ and 28 | `Market-1501 `_ 29 | calculated the CMC curves and CMC top-k accuracy quite differently. To be 30 | specific, 31 | 32 | - `CUHK03 `_: 33 | Query and gallery sets are from different camera views. For each query, they 34 | randomly sample one instance for each gallery identity, and compute a CMC 35 | curve in the *single-gallery-shot* setting. The random sampling is repeated 36 | for :math:`N` times and the expected CMC curve is reported. 37 | 38 | - `Market-1501 `_: 39 | Query and gallery sets could have same camera views, but for each individual 40 | query identity, his/her gallery samples from the same camera are excluded. 41 | They do not randomly sample only one instance for each gallery identity. This 42 | means the query will always match the "easiest" positive sample in the gallery 43 | while does not care other harder positive samples when computing CMC. 44 | -------------------------------------------------------------------------------- /docs/_static/css/openreid_theme.css: -------------------------------------------------------------------------------- 1 | body { 2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 3 | } 4 | 5 | /* Default header fonts are ugly */ 6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption { 7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif; 8 | } 9 | 10 | /* Use white for docs background */ 11 | .wy-side-nav-search { 12 | background-color: #fff; 13 | } 14 | 15 | .wy-side-nav-search > a { 16 | color: #F05732; 17 | } 18 | 19 | .wy-nav-content-wrap, .wy-menu li.current > a { 20 | background-color: #fff; 21 | } 22 | 23 | @media screen and (min-width: 1400px) { 24 | .wy-nav-content-wrap { 25 | background-color: rgba(0, 0, 0, 0.0470588); 26 | } 27 | 28 | .wy-nav-content { 29 | background-color: #fff; 30 | } 31 | } 32 | 33 | 34 | /* Fixes for mobile */ 35 | .wy-nav-top { 36 | padding: 0; 37 | margin: 0.4045em 0.809em; 38 | color: #333; 39 | } 40 | 41 | .wy-nav-top > a { 42 | display: none; 43 | } 44 | 45 | @media screen and (max-width: 768px) { 46 | .wy-side-nav-search>a img.logo { 47 | height: 60px; 48 | } 49 | } 50 | 51 | /* This is needed to ensure that logo above search scales properly */ 52 | .wy-side-nav-search a { 53 | display: block; 54 | } 55 | 56 | /* This ensures that multiple constructors will remain in separate lines. */ 57 | .rst-content dl:not(.docutils) dt { 58 | display: table; 59 | } 60 | 61 | /* Use our red for literals (it's very similar to the original color) */ 62 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal { 63 | color: #F05732; 64 | } 65 | 66 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref, 67 | .rst-content code.xref, a .rst-content tt, a .rst-content code { 68 | color: #404040; 69 | } 70 | 71 | /* Change link colors (except for the menu) */ 72 | 73 | a { 74 | color: #F05732; 75 | } 76 | 77 | a:hover { 78 | color: #F05732; 79 | } 80 | 81 | 82 | a:visited { 83 | color: #D44D2C; 84 | } 85 | 86 | .wy-menu a { 87 | color: #b3b3b3; 88 | } 89 | 90 | .wy-menu a:hover { 91 | color: #b3b3b3; 92 | } 93 | 94 | /* Default footer text is quite big */ 95 | footer { 96 | font-size: 80%; 97 | } 98 | 99 | footer .rst-footer-buttons { 100 | font-size: 125%; /* revert footer settings - 1/80% = 125% */ 101 | } 102 | 103 | footer p { 104 | font-size: 100%; 105 | } 106 | 107 | /* For hidden headers that appear in TOC tree */ 108 | /* see http://stackoverflow.com/a/32363545/3343043 */ 109 | .rst-content .hidden-section { 110 | display: none; 111 | } 112 | 113 | nav .hidden-section { 114 | display: inherit; 115 | } 116 | 117 | .rst-content .citation ol { 118 | margin-bottom: 0; 119 | } -------------------------------------------------------------------------------- /reid/trainers.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | 4 | import torch 5 | from torch.autograd import Variable 6 | 7 | from .evaluation_metrics import accuracy 8 | from .loss import OIMLoss, TripletLoss 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | class BaseTrainer(object): 13 | def __init__(self, model, criterion): 14 | super(BaseTrainer, self).__init__() 15 | self.model = model 16 | self.criterion = criterion 17 | 18 | def train(self, epoch, data_loader, optimizer, print_freq=1): 19 | self.model.train() 20 | 21 | batch_time = AverageMeter() 22 | data_time = AverageMeter() 23 | losses = AverageMeter() 24 | precisions = AverageMeter() 25 | 26 | end = time.time() 27 | for i, inputs in enumerate(data_loader): 28 | data_time.update(time.time() - end) 29 | 30 | inputs, targets = self._parse_data(inputs) 31 | loss, prec1 = self._forward(inputs, targets) 32 | 33 | losses.update(loss.data[0], targets.size(0)) 34 | precisions.update(prec1, targets.size(0)) 35 | 36 | optimizer.zero_grad() 37 | loss.backward() 38 | optimizer.step() 39 | 40 | batch_time.update(time.time() - end) 41 | end = time.time() 42 | 43 | if (i + 1) % print_freq == 0: 44 | print('Epoch: [{}][{}/{}]\t' 45 | 'Time {:.3f} ({:.3f})\t' 46 | 'Data {:.3f} ({:.3f})\t' 47 | 'Loss {:.3f} ({:.3f})\t' 48 | 'Prec {:.2%} ({:.2%})\t' 49 | .format(epoch, i + 1, len(data_loader), 50 | batch_time.val, batch_time.avg, 51 | data_time.val, data_time.avg, 52 | losses.val, losses.avg, 53 | precisions.val, precisions.avg)) 54 | 55 | def _parse_data(self, inputs): 56 | raise NotImplementedError 57 | 58 | def _forward(self, inputs, targets): 59 | raise NotImplementedError 60 | 61 | 62 | class Trainer(BaseTrainer): 63 | def _parse_data(self, inputs): 64 | imgs, _, pids, _ = inputs 65 | inputs = [Variable(imgs)] 66 | targets = Variable(pids.cuda()) 67 | return inputs, targets 68 | 69 | def _forward(self, inputs, targets): 70 | outputs = self.model(*inputs) 71 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss): 72 | loss = self.criterion(outputs, targets) 73 | prec, = accuracy(outputs.data, targets.data) 74 | prec = prec[0] 75 | elif isinstance(self.criterion, OIMLoss): 76 | loss, outputs = self.criterion(outputs, targets) 77 | prec, = accuracy(outputs.data, targets.data) 78 | prec = prec[0] 79 | elif isinstance(self.criterion, TripletLoss): 80 | loss, prec = self.criterion(outputs, targets) 81 | else: 82 | raise ValueError("Unsupported loss:", self.criterion) 83 | return loss, prec 84 | -------------------------------------------------------------------------------- /docs/notes/data_modules.rst: -------------------------------------------------------------------------------- 1 | .. _data-modules: 2 | 3 | ============ 4 | Data Modules 5 | ============ 6 | 7 | This note will introduce the unified dataset interface defined by Open-ReID, and 8 | the whole data loader system that samples data efficiently. 9 | 10 | .. _unified-data-format: 11 | 12 | ------------------- 13 | Unified Data Format 14 | ------------------- 15 | 16 | There are many existing person re-identification datasets, each with its own 17 | data format and split protocols. This makes conducting experiments on these 18 | datasets a tedious and error-prone work. To solve this problem, Open-ReID 19 | defines a unified dataset interface. By converting the raw dataset into this 20 | unified format, we can significantly simplify the code for training and 21 | evaluation with the formatted data. 22 | 23 | Every dataset will be organized as a directory like 24 | 25 | .. code-block:: shell 26 | 27 | cuhk03 28 | ├── raw/ 29 | ├── images/ 30 | ├── meta.json 31 | └── splits.json 32 | 33 | where ``raw/`` stores the original dataset files, ``images/`` contains all the 34 | renamed image files in the format of:: 35 | 36 | '{:08d}_{:02d}_{:04d}.jpg'.format(person_id, camera_id, image_id) 37 | 38 | where all the ids are indexed from 0. 39 | 40 | ``meta.json`` contains all the person identities of the dataset, which is a list in the structure of:: 41 | 42 | "identities": [ 43 | [ # the first identity, person_id = 0 44 | [ # camera_id = 0 45 | "00000000_00_0000.jpg", 46 | "00000000_00_0001.jpg" 47 | ], 48 | [ # camera_id = 1 49 | "00000000_01_0000.jpg", 50 | "00000000_01_0001.jpg", 51 | "00000000_01_0002.jpg" 52 | ] 53 | ], 54 | [ # the second identity, person_id = 1 55 | [ # camera_id = 0 56 | "00000001_00_0000.jpg" 57 | ], 58 | [ # camera_id = 1 59 | "00000001_01_0000.jpg", 60 | "00000001_01_0001.jpg", 61 | ] 62 | ], 63 | ... 64 | ] 65 | 66 | Each dataset may define multiple training / test data splits. They are listed in 67 | ``splits.json``, where each split defines three subsets of person identities:: 68 | 69 | { 70 | "trainval": [0, 1, 3, ...], # person_ids for training and validation 71 | "gallery": [2, 4, 5, ...], # for test gallery, non-overlap with trainval 72 | "query": [2, 4, ...], # for test query, a subset of gallery 73 | } 74 | 75 | .. _data-loading-system: 76 | 77 | ------------------- 78 | Data Loading System 79 | ------------------- 80 | 81 | The objective of the data loading system is to sample mini-batches efficiently from the dataset. In our design, it consists of four components, namely ``Dataset``, ``Sampler``, ``Preprocessor``, and ``Data Loader``. Their relations are depicted in the figure below. 82 | 83 | .. _fig-data-loading: 84 | .. figure:: /figures/data-loading.png 85 | :figwidth: 80 % 86 | :align: center 87 | 88 | A ``Dataset`` is a list of items ``(filename, person_id, camera_id)``. A ``Sampler`` is an iterator that each time provides an index. At the top, we adopt the ``torch.utils.data.DataLoader`` to load mini-batches using multi-processing. It queries the data at given index from a ``Preprocessor``, which takes the index as input, loading the corresponding image (with transformations), and returns a tuple of ``(image, filename, person_id, camera_id)``. 89 | -------------------------------------------------------------------------------- /reid/datasets/cuhk01.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class CUHK01(Dataset): 12 | url = 'https://docs.google.com/spreadsheet/viewform?formkey=dF9pZ1BFZkNiMG1oZUdtTjZPalR0MGc6MA' 13 | md5 = 'e6d55c0da26d80cda210a2edeb448e98' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(CUHK01, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import hashlib 33 | import shutil 34 | from glob import glob 35 | from zipfile import ZipFile 36 | 37 | raw_dir = osp.join(self.root, 'raw') 38 | mkdir_if_missing(raw_dir) 39 | 40 | # Download the raw zip file 41 | fpath = osp.join(raw_dir, 'CUHK01.zip') 42 | if osp.isfile(fpath) and \ 43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 44 | print("Using downloaded file: " + fpath) 45 | else: 46 | raise RuntimeError("Please download the dataset manually from {} " 47 | "to {}".format(self.url, fpath)) 48 | 49 | # Extract the file 50 | exdir = osp.join(raw_dir, 'campus') 51 | if not osp.isdir(exdir): 52 | print("Extracting zip file") 53 | with ZipFile(fpath) as z: 54 | z.extractall(path=raw_dir) 55 | 56 | # Format 57 | images_dir = osp.join(self.root, 'images') 58 | mkdir_if_missing(images_dir) 59 | 60 | identities = [[[] for _ in range(2)] for _ in range(971)] 61 | 62 | files = sorted(glob(osp.join(exdir, '*.png'))) 63 | for fpath in files: 64 | fname = osp.basename(fpath) 65 | pid, cam = int(fname[:4]), int(fname[4:7]) 66 | assert 1 <= pid <= 971 67 | assert 1 <= cam <= 4 68 | pid, cam = pid - 1, (cam - 1) // 2 69 | fname = ('{:08d}_{:02d}_{:04d}.png' 70 | .format(pid, cam, len(identities[pid][cam]))) 71 | identities[pid][cam].append(fname) 72 | shutil.copy(fpath, osp.join(images_dir, fname)) 73 | 74 | # Save meta information into a json file 75 | meta = {'name': 'cuhk01', 'shot': 'multiple', 'num_cameras': 2, 76 | 'identities': identities} 77 | write_json(meta, osp.join(self.root, 'meta.json')) 78 | 79 | # Randomly create ten training and test split 80 | num = len(identities) 81 | splits = [] 82 | for _ in range(10): 83 | pids = np.random.permutation(num).tolist() 84 | trainval_pids = sorted(pids[:num // 2]) 85 | test_pids = sorted(pids[num // 2:]) 86 | split = {'trainval': trainval_pids, 87 | 'query': test_pids, 88 | 'gallery': test_pids} 89 | splits.append(split) 90 | write_json(splits, osp.join(self.root, 'splits.json')) 91 | -------------------------------------------------------------------------------- /reid/datasets/viper.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class VIPeR(Dataset): 12 | url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip' 13 | md5 = '1c2d9fc1cc800332567a0da25a1ce68c' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(VIPeR, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import hashlib 33 | from glob import glob 34 | from scipy.misc import imsave, imread 35 | from six.moves import urllib 36 | from zipfile import ZipFile 37 | 38 | raw_dir = osp.join(self.root, 'raw') 39 | mkdir_if_missing(raw_dir) 40 | 41 | # Download the raw zip file 42 | fpath = osp.join(raw_dir, 'VIPeR.v1.0.zip') 43 | if osp.isfile(fpath) and \ 44 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 45 | print("Using downloaded file: " + fpath) 46 | else: 47 | print("Downloading {} to {}".format(self.url, fpath)) 48 | urllib.request.urlretrieve(self.url, fpath) 49 | 50 | # Extract the file 51 | exdir = osp.join(raw_dir, 'VIPeR') 52 | if not osp.isdir(exdir): 53 | print("Extracting zip file") 54 | with ZipFile(fpath) as z: 55 | z.extractall(path=raw_dir) 56 | 57 | # Format 58 | images_dir = osp.join(self.root, 'images') 59 | mkdir_if_missing(images_dir) 60 | cameras = [sorted(glob(osp.join(exdir, 'cam_a', '*.bmp'))), 61 | sorted(glob(osp.join(exdir, 'cam_b', '*.bmp')))] 62 | assert len(cameras[0]) == len(cameras[1]) 63 | identities = [] 64 | for pid, (cam1, cam2) in enumerate(zip(*cameras)): 65 | images = [] 66 | # view-0 67 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, 0, 0) 68 | imsave(osp.join(images_dir, fname), imread(cam1)) 69 | images.append([fname]) 70 | # view-1 71 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, 1, 0) 72 | imsave(osp.join(images_dir, fname), imread(cam2)) 73 | images.append([fname]) 74 | identities.append(images) 75 | 76 | # Save meta information into a json file 77 | meta = {'name': 'VIPeR', 'shot': 'single', 'num_cameras': 2, 78 | 'identities': identities} 79 | write_json(meta, osp.join(self.root, 'meta.json')) 80 | 81 | # Randomly create ten training and test split 82 | num = len(identities) 83 | splits = [] 84 | for _ in range(10): 85 | pids = np.random.permutation(num).tolist() 86 | trainval_pids = sorted(pids[:num // 2]) 87 | test_pids = sorted(pids[num // 2:]) 88 | split = {'trainval': trainval_pids, 89 | 'query': test_pids, 90 | 'gallery': test_pids} 91 | splits.append(split) 92 | write_json(splits, osp.join(self.root, 'splits.json')) 93 | -------------------------------------------------------------------------------- /reid/utils/data/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..serialization import read_json 7 | 8 | 9 | def _pluck(identities, indices, relabel=False): 10 | ret = [] 11 | for index, pid in enumerate(indices): 12 | pid_images = identities[pid] 13 | for camid, cam_images in enumerate(pid_images): 14 | for fname in cam_images: 15 | name = osp.splitext(fname)[0] 16 | x, y, _ = map(int, name.split('_')) 17 | assert pid == x and camid == y 18 | if relabel: 19 | ret.append((fname, index, camid)) 20 | else: 21 | ret.append((fname, pid, camid)) 22 | return ret 23 | 24 | 25 | class Dataset(object): 26 | def __init__(self, root, split_id=0): 27 | self.root = root 28 | self.split_id = split_id 29 | self.meta = None 30 | self.split = None 31 | self.train, self.val, self.trainval = [], [], [] 32 | self.query, self.gallery = [], [] 33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0 34 | 35 | @property 36 | def images_dir(self): 37 | return osp.join(self.root, 'images') 38 | 39 | def load(self, num_val=0.3, verbose=True): 40 | splits = read_json(osp.join(self.root, 'splits.json')) 41 | if self.split_id >= len(splits): 42 | raise ValueError("split_id exceeds total splits {}" 43 | .format(len(splits))) 44 | self.split = splits[self.split_id] 45 | 46 | # Randomly split train / val 47 | trainval_pids = np.asarray(self.split['trainval']) 48 | np.random.shuffle(trainval_pids) 49 | num = len(trainval_pids) 50 | if isinstance(num_val, float): 51 | num_val = int(round(num * num_val)) 52 | if num_val >= num or num_val < 0: 53 | raise ValueError("num_val exceeds total identities {}" 54 | .format(num)) 55 | train_pids = sorted(trainval_pids[:-num_val]) 56 | val_pids = sorted(trainval_pids[-num_val:]) 57 | 58 | self.meta = read_json(osp.join(self.root, 'meta.json')) 59 | identities = self.meta['identities'] 60 | self.train = _pluck(identities, train_pids, relabel=True) 61 | self.val = _pluck(identities, val_pids, relabel=True) 62 | self.trainval = _pluck(identities, trainval_pids, relabel=True) 63 | self.query = _pluck(identities, self.split['query']) 64 | self.gallery = _pluck(identities, self.split['gallery']) 65 | self.num_train_ids = len(train_pids) 66 | self.num_val_ids = len(val_pids) 67 | self.num_trainval_ids = len(trainval_pids) 68 | 69 | if verbose: 70 | print(self.__class__.__name__, "dataset loaded") 71 | print(" subset | # ids | # images") 72 | print(" ---------------------------") 73 | print(" train | {:5d} | {:8d}" 74 | .format(self.num_train_ids, len(self.train))) 75 | print(" val | {:5d} | {:8d}" 76 | .format(self.num_val_ids, len(self.val))) 77 | print(" trainval | {:5d} | {:8d}" 78 | .format(self.num_trainval_ids, len(self.trainval))) 79 | print(" query | {:5d} | {:8d}" 80 | .format(len(self.split['query']), len(self.query))) 81 | print(" gallery | {:5d} | {:8d}" 82 | .format(len(self.split['gallery']), len(self.gallery))) 83 | 84 | def _check_integrity(self): 85 | return osp.isdir(osp.join(self.root, 'images')) and \ 86 | osp.isfile(osp.join(self.root, 'meta.json')) and \ 87 | osp.isfile(osp.join(self.root, 'splits.json')) 88 | -------------------------------------------------------------------------------- /reid/datasets/market1501.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | 8 | 9 | class Market1501(Dataset): 10 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view' 11 | md5 = '65005ab7d12ec1c44de4eeafe813e68a' 12 | 13 | def __init__(self, root, split_id=0, num_val=100, download=True): 14 | super(Market1501, self).__init__(root, split_id=split_id) 15 | 16 | if download: 17 | self.download() 18 | 19 | if not self._check_integrity(): 20 | raise RuntimeError("Dataset not found or corrupted. " + 21 | "You can use download=True to download it.") 22 | 23 | self.load(num_val) 24 | 25 | def download(self): 26 | if self._check_integrity(): 27 | print("Files already downloaded and verified") 28 | return 29 | 30 | import re 31 | import hashlib 32 | import shutil 33 | from glob import glob 34 | from zipfile import ZipFile 35 | 36 | raw_dir = osp.join(self.root, 'raw') 37 | mkdir_if_missing(raw_dir) 38 | 39 | # Download the raw zip file 40 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip') 41 | if osp.isfile(fpath) and \ 42 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 43 | print("Using downloaded file: " + fpath) 44 | else: 45 | raise RuntimeError("Please download the dataset manually from {} " 46 | "to {}".format(self.url, fpath)) 47 | 48 | # Extract the file 49 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15') 50 | if not osp.isdir(exdir): 51 | print("Extracting zip file") 52 | with ZipFile(fpath) as z: 53 | z.extractall(path=raw_dir) 54 | 55 | # Format 56 | images_dir = osp.join(self.root, 'images') 57 | mkdir_if_missing(images_dir) 58 | 59 | # 1501 identities (+1 for background) with 6 camera views each 60 | identities = [[[] for _ in range(6)] for _ in range(1502)] 61 | 62 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 63 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 64 | pids = set() 65 | for fpath in fpaths: 66 | fname = osp.basename(fpath) 67 | pid, cam = map(int, pattern.search(fname).groups()) 68 | if pid == -1: continue # junk images are just ignored 69 | assert 0 <= pid <= 1501 # pid == 0 means background 70 | assert 1 <= cam <= 6 71 | cam -= 1 72 | pids.add(pid) 73 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 74 | .format(pid, cam, len(identities[pid][cam]))) 75 | identities[pid][cam].append(fname) 76 | shutil.copy(fpath, osp.join(images_dir, fname)) 77 | return pids 78 | 79 | trainval_pids = register('bounding_box_train') 80 | gallery_pids = register('bounding_box_test') 81 | query_pids = register('query') 82 | assert query_pids <= gallery_pids 83 | assert trainval_pids.isdisjoint(gallery_pids) 84 | 85 | # Save meta information into a json file 86 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6, 87 | 'identities': identities} 88 | write_json(meta, osp.join(self.root, 'meta.json')) 89 | 90 | # Save the only training / test split 91 | splits = [{ 92 | 'trainval': sorted(list(trainval_pids)), 93 | 'query': sorted(list(query_pids)), 94 | 'gallery': sorted(list(gallery_pids))}] 95 | write_json(splits, osp.join(self.root, 'splits.json')) 96 | -------------------------------------------------------------------------------- /reid/datasets/dukemtmc.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | from ..utils.data import Dataset 5 | from ..utils.osutils import mkdir_if_missing 6 | from ..utils.serialization import write_json 7 | 8 | 9 | class DukeMTMC(Dataset): 10 | url = 'https://drive.google.com/uc?id=0B0VOCNYh8HeRdnBPa2ZWaVBYSVk' 11 | md5 = '2f93496f9b516d1ee5ef51c1d5e7d601' 12 | 13 | def __init__(self, root, split_id=0, num_val=100, download=True): 14 | super(DukeMTMC, self).__init__(root, split_id=split_id) 15 | 16 | if download: 17 | self.download() 18 | 19 | if not self._check_integrity(): 20 | raise RuntimeError("Dataset not found or corrupted. " + 21 | "You can use download=True to download it.") 22 | 23 | self.load(num_val) 24 | 25 | def download(self): 26 | if self._check_integrity(): 27 | print("Files already downloaded and verified") 28 | return 29 | 30 | import re 31 | import hashlib 32 | import shutil 33 | from glob import glob 34 | from zipfile import ZipFile 35 | 36 | raw_dir = osp.join(self.root, 'raw') 37 | mkdir_if_missing(raw_dir) 38 | 39 | # Download the raw zip file 40 | fpath = osp.join(raw_dir, 'DukeMTMC-reID.zip') 41 | if osp.isfile(fpath) and \ 42 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 43 | print("Using downloaded file: " + fpath) 44 | else: 45 | raise RuntimeError("Please download the dataset manually from {} " 46 | "to {}".format(self.url, fpath)) 47 | 48 | # Extract the file 49 | exdir = osp.join(raw_dir, 'DukeMTMC-reID') 50 | if not osp.isdir(exdir): 51 | print("Extracting zip file") 52 | with ZipFile(fpath) as z: 53 | z.extractall(path=raw_dir) 54 | 55 | # Format 56 | images_dir = osp.join(self.root, 'images') 57 | mkdir_if_missing(images_dir) 58 | 59 | identities = [] 60 | all_pids = {} 61 | 62 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')): 63 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg'))) 64 | pids = set() 65 | for fpath in fpaths: 66 | fname = osp.basename(fpath) 67 | pid, cam = map(int, pattern.search(fname).groups()) 68 | assert 1 <= cam <= 8 69 | cam -= 1 70 | if pid not in all_pids: 71 | all_pids[pid] = len(all_pids) 72 | pid = all_pids[pid] 73 | pids.add(pid) 74 | if pid >= len(identities): 75 | assert pid == len(identities) 76 | identities.append([[] for _ in range(8)]) # 8 camera views 77 | fname = ('{:08d}_{:02d}_{:04d}.jpg' 78 | .format(pid, cam, len(identities[pid][cam]))) 79 | identities[pid][cam].append(fname) 80 | shutil.copy(fpath, osp.join(images_dir, fname)) 81 | return pids 82 | 83 | trainval_pids = register('bounding_box_train') 84 | gallery_pids = register('bounding_box_test') 85 | query_pids = register('query') 86 | assert query_pids <= gallery_pids 87 | assert trainval_pids.isdisjoint(gallery_pids) 88 | 89 | # Save meta information into a json file 90 | meta = {'name': 'DukeMTMC', 'shot': 'multiple', 'num_cameras': 8, 91 | 'identities': identities} 92 | write_json(meta, osp.join(self.root, 'meta.json')) 93 | 94 | # Save the only training / test split 95 | splits = [{ 96 | 'trainval': sorted(list(trainval_pids)), 97 | 'query': sorted(list(query_pids)), 98 | 'gallery': sorted(list(gallery_pids))}] 99 | write_json(splits, osp.join(self.root, 'splits.json')) 100 | -------------------------------------------------------------------------------- /reid/models/resnet.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | from torch import nn 4 | from torch.nn import functional as F 5 | from torch.nn import init 6 | import torchvision 7 | 8 | 9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 10 | 'resnet152'] 11 | 12 | 13 | class ResNet(nn.Module): 14 | __factory = { 15 | 18: torchvision.models.resnet18, 16 | 34: torchvision.models.resnet34, 17 | 50: torchvision.models.resnet50, 18 | 101: torchvision.models.resnet101, 19 | 152: torchvision.models.resnet152, 20 | } 21 | 22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False, 23 | num_features=0, norm=False, dropout=0, num_classes=0): 24 | super(ResNet, self).__init__() 25 | 26 | self.depth = depth 27 | self.pretrained = pretrained 28 | self.cut_at_pooling = cut_at_pooling 29 | 30 | # Construct base (pretrained) resnet 31 | if depth not in ResNet.__factory: 32 | raise KeyError("Unsupported depth:", depth) 33 | self.base = ResNet.__factory[depth](pretrained=pretrained) 34 | 35 | if not self.cut_at_pooling: 36 | self.num_features = num_features 37 | self.norm = norm 38 | self.dropout = dropout 39 | self.has_embedding = num_features > 0 40 | self.num_classes = num_classes 41 | 42 | out_planes = self.base.fc.in_features 43 | 44 | # Append new layers 45 | if self.has_embedding: 46 | self.feat = nn.Linear(out_planes, self.num_features) 47 | self.feat_bn = nn.BatchNorm1d(self.num_features) 48 | init.kaiming_normal(self.feat.weight, mode='fan_out') 49 | init.constant(self.feat.bias, 0) 50 | init.constant(self.feat_bn.weight, 1) 51 | init.constant(self.feat_bn.bias, 0) 52 | else: 53 | # Change the num_features to CNN output channels 54 | self.num_features = out_planes 55 | if self.dropout > 0: 56 | self.drop = nn.Dropout(self.dropout) 57 | if self.num_classes > 0: 58 | self.classifier = nn.Linear(self.num_features, self.num_classes) 59 | init.normal(self.classifier.weight, std=0.001) 60 | init.constant(self.classifier.bias, 0) 61 | 62 | if not self.pretrained: 63 | self.reset_params() 64 | 65 | def forward(self, x): 66 | for name, module in self.base._modules.items(): 67 | if name == 'avgpool': 68 | break 69 | x = module(x) 70 | 71 | if self.cut_at_pooling: 72 | return x 73 | 74 | x = F.avg_pool2d(x, x.size()[2:]) 75 | x = x.view(x.size(0), -1) 76 | 77 | if self.has_embedding: 78 | x = self.feat(x) 79 | x = self.feat_bn(x) 80 | if self.norm: 81 | x = F.normalize(x) 82 | elif self.has_embedding: 83 | x = F.relu(x) 84 | if self.dropout > 0: 85 | x = self.drop(x) 86 | if self.num_classes > 0: 87 | x = self.classifier(x) 88 | return x 89 | 90 | def reset_params(self): 91 | for m in self.modules(): 92 | if isinstance(m, nn.Conv2d): 93 | init.kaiming_normal(m.weight, mode='fan_out') 94 | if m.bias is not None: 95 | init.constant(m.bias, 0) 96 | elif isinstance(m, nn.BatchNorm2d): 97 | init.constant(m.weight, 1) 98 | init.constant(m.bias, 0) 99 | elif isinstance(m, nn.Linear): 100 | init.normal(m.weight, std=0.001) 101 | if m.bias is not None: 102 | init.constant(m.bias, 0) 103 | 104 | 105 | def resnet18(**kwargs): 106 | return ResNet(18, **kwargs) 107 | 108 | 109 | def resnet34(**kwargs): 110 | return ResNet(34, **kwargs) 111 | 112 | 113 | def resnet50(**kwargs): 114 | return ResNet(50, **kwargs) 115 | 116 | 117 | def resnet101(**kwargs): 118 | return ResNet(101, **kwargs) 119 | 120 | 121 | def resnet152(**kwargs): 122 | return ResNet(152, **kwargs) 123 | -------------------------------------------------------------------------------- /reid/datasets/cuhk03.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import os.path as osp 3 | 4 | import numpy as np 5 | 6 | from ..utils.data import Dataset 7 | from ..utils.osutils import mkdir_if_missing 8 | from ..utils.serialization import write_json 9 | 10 | 11 | class CUHK03(Dataset): 12 | url = 'https://docs.google.com/spreadsheet/viewform?usp=drive_web&formkey=dHRkMkFVSUFvbTJIRkRDLWRwZWpONnc6MA#gid=0' 13 | md5 = '728939e58ad9f0ff53e521857dd8fb43' 14 | 15 | def __init__(self, root, split_id=0, num_val=100, download=True): 16 | super(CUHK03, self).__init__(root, split_id=split_id) 17 | 18 | if download: 19 | self.download() 20 | 21 | if not self._check_integrity(): 22 | raise RuntimeError("Dataset not found or corrupted. " + 23 | "You can use download=True to download it.") 24 | 25 | self.load(num_val) 26 | 27 | def download(self): 28 | if self._check_integrity(): 29 | print("Files already downloaded and verified") 30 | return 31 | 32 | import h5py 33 | import hashlib 34 | from scipy.misc import imsave 35 | from zipfile import ZipFile 36 | 37 | raw_dir = osp.join(self.root, 'raw') 38 | mkdir_if_missing(raw_dir) 39 | 40 | # Download the raw zip file 41 | fpath = osp.join(raw_dir, 'cuhk03_release.zip') 42 | if osp.isfile(fpath) and \ 43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5: 44 | print("Using downloaded file: " + fpath) 45 | else: 46 | raise RuntimeError("Please download the dataset manually from {} " 47 | "to {}".format(self.url, fpath)) 48 | 49 | # Extract the file 50 | exdir = osp.join(raw_dir, 'cuhk03_release') 51 | if not osp.isdir(exdir): 52 | print("Extracting zip file") 53 | with ZipFile(fpath) as z: 54 | z.extractall(path=raw_dir) 55 | 56 | # Format 57 | images_dir = osp.join(self.root, 'images') 58 | mkdir_if_missing(images_dir) 59 | matdata = h5py.File(osp.join(exdir, 'cuhk-03.mat'), 'r') 60 | 61 | def deref(ref): 62 | return matdata[ref][:].T 63 | 64 | def dump_(refs, pid, cam, fnames): 65 | for ref in refs: 66 | img = deref(ref) 67 | if img.size == 0 or img.ndim < 2: break 68 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, cam, len(fnames)) 69 | imsave(osp.join(images_dir, fname), img) 70 | fnames.append(fname) 71 | 72 | identities = [] 73 | for labeled, detected in zip( 74 | matdata['labeled'][0], matdata['detected'][0]): 75 | labeled, detected = deref(labeled), deref(detected) 76 | assert labeled.shape == detected.shape 77 | for i in range(labeled.shape[0]): 78 | pid = len(identities) 79 | images = [[], []] 80 | dump_(labeled[i, :5], pid, 0, images[0]) 81 | dump_(detected[i, :5], pid, 0, images[0]) 82 | dump_(labeled[i, 5:], pid, 1, images[1]) 83 | dump_(detected[i, 5:], pid, 1, images[1]) 84 | identities.append(images) 85 | 86 | # Save meta information into a json file 87 | meta = {'name': 'cuhk03', 'shot': 'multiple', 'num_cameras': 2, 88 | 'identities': identities} 89 | write_json(meta, osp.join(self.root, 'meta.json')) 90 | 91 | # Save training and test splits 92 | splits = [] 93 | view_counts = [deref(ref).shape[0] for ref in matdata['labeled'][0]] 94 | vid_offsets = np.r_[0, np.cumsum(view_counts)] 95 | for ref in matdata['testsets'][0]: 96 | test_info = deref(ref).astype(np.int32) 97 | test_pids = sorted( 98 | [int(vid_offsets[i-1] + j - 1) for i, j in test_info]) 99 | trainval_pids = list(set(range(vid_offsets[-1])) - set(test_pids)) 100 | split = {'trainval': trainval_pids, 101 | 'query': test_pids, 102 | 'gallery': test_pids} 103 | splits.append(split) 104 | write_json(splits, osp.join(self.root, 'splits.json')) 105 | -------------------------------------------------------------------------------- /reid/evaluation_metrics/ranking.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | from collections import defaultdict 3 | 4 | import numpy as np 5 | from sklearn.metrics import average_precision_score 6 | 7 | from ..utils import to_numpy 8 | 9 | 10 | def _unique_sample(ids_dict, num): 11 | mask = np.zeros(num, dtype=np.bool) 12 | for _, indices in ids_dict.items(): 13 | i = np.random.choice(indices) 14 | mask[i] = True 15 | return mask 16 | 17 | 18 | def cmc(distmat, query_ids=None, gallery_ids=None, 19 | query_cams=None, gallery_cams=None, topk=100, 20 | separate_camera_set=False, 21 | single_gallery_shot=False, 22 | first_match_break=False): 23 | distmat = to_numpy(distmat) 24 | m, n = distmat.shape 25 | # Fill up default values 26 | if query_ids is None: 27 | query_ids = np.arange(m) 28 | if gallery_ids is None: 29 | gallery_ids = np.arange(n) 30 | if query_cams is None: 31 | query_cams = np.zeros(m).astype(np.int32) 32 | if gallery_cams is None: 33 | gallery_cams = np.ones(n).astype(np.int32) 34 | # Ensure numpy array 35 | query_ids = np.asarray(query_ids) 36 | gallery_ids = np.asarray(gallery_ids) 37 | query_cams = np.asarray(query_cams) 38 | gallery_cams = np.asarray(gallery_cams) 39 | # Sort and find correct matches 40 | indices = np.argsort(distmat, axis=1) 41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 42 | # Compute CMC for each query 43 | ret = np.zeros(topk) 44 | num_valid_queries = 0 45 | for i in range(m): 46 | # Filter out the same id and same camera 47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 48 | (gallery_cams[indices[i]] != query_cams[i])) 49 | if separate_camera_set: 50 | # Filter out samples from same camera 51 | valid &= (gallery_cams[indices[i]] != query_cams[i]) 52 | if not np.any(matches[i, valid]): continue 53 | if single_gallery_shot: 54 | repeat = 10 55 | gids = gallery_ids[indices[i][valid]] 56 | inds = np.where(valid)[0] 57 | ids_dict = defaultdict(list) 58 | for j, x in zip(inds, gids): 59 | ids_dict[x].append(j) 60 | else: 61 | repeat = 1 62 | for _ in range(repeat): 63 | if single_gallery_shot: 64 | # Randomly choose one instance for each id 65 | sampled = (valid & _unique_sample(ids_dict, len(valid))) 66 | index = np.nonzero(matches[i, sampled])[0] 67 | else: 68 | index = np.nonzero(matches[i, valid])[0] 69 | delta = 1. / (len(index) * repeat) 70 | for j, k in enumerate(index): 71 | if k - j >= topk: break 72 | if first_match_break: 73 | ret[k - j] += 1 74 | break 75 | ret[k - j] += delta 76 | num_valid_queries += 1 77 | if num_valid_queries == 0: 78 | raise RuntimeError("No valid query") 79 | return ret.cumsum() / num_valid_queries 80 | 81 | 82 | def mean_ap(distmat, query_ids=None, gallery_ids=None, 83 | query_cams=None, gallery_cams=None): 84 | distmat = to_numpy(distmat) 85 | m, n = distmat.shape 86 | # Fill up default values 87 | if query_ids is None: 88 | query_ids = np.arange(m) 89 | if gallery_ids is None: 90 | gallery_ids = np.arange(n) 91 | if query_cams is None: 92 | query_cams = np.zeros(m).astype(np.int32) 93 | if gallery_cams is None: 94 | gallery_cams = np.ones(n).astype(np.int32) 95 | # Ensure numpy array 96 | query_ids = np.asarray(query_ids) 97 | gallery_ids = np.asarray(gallery_ids) 98 | query_cams = np.asarray(query_cams) 99 | gallery_cams = np.asarray(gallery_cams) 100 | # Sort and find correct matches 101 | indices = np.argsort(distmat, axis=1) 102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis]) 103 | # Compute AP for each query 104 | aps = [] 105 | for i in range(m): 106 | # Filter out the same id and same camera 107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) | 108 | (gallery_cams[indices[i]] != query_cams[i])) 109 | y_true = matches[i, valid] 110 | y_score = -distmat[i][indices[i]][valid] 111 | if not np.any(y_true): continue 112 | aps.append(average_precision_score(y_true, y_score)) 113 | if len(aps) == 0: 114 | raise RuntimeError("No valid query") 115 | return np.mean(aps) 116 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | 3 | # temporary files which can be created if a process still has a handle open of a deleted file 4 | .fuse_hidden* 5 | 6 | # KDE directory preferences 7 | .directory 8 | 9 | # Linux trash folder which might appear on any partition or disk 10 | .Trash-* 11 | 12 | # .nfs files are created when an open file is removed but is still being accessed 13 | .nfs* 14 | 15 | 16 | *.DS_Store 17 | .AppleDouble 18 | .LSOverride 19 | 20 | # Icon must end with two \r 21 | Icon 22 | 23 | 24 | # Thumbnails 25 | ._* 26 | 27 | # Files that might appear in the root of a volume 28 | .DocumentRevisions-V100 29 | .fseventsd 30 | .Spotlight-V100 31 | .TemporaryItems 32 | .Trashes 33 | .VolumeIcon.icns 34 | .com.apple.timemachine.donotpresent 35 | 36 | # Directories potentially created on remote AFP share 37 | .AppleDB 38 | .AppleDesktop 39 | Network Trash Folder 40 | Temporary Items 41 | .apdisk 42 | 43 | 44 | # swap 45 | [._]*.s[a-v][a-z] 46 | [._]*.sw[a-p] 47 | [._]s[a-v][a-z] 48 | [._]sw[a-p] 49 | # session 50 | Session.vim 51 | # temporary 52 | .netrwhist 53 | *~ 54 | # auto-generated tag files 55 | tags 56 | 57 | 58 | # cache files for sublime text 59 | *.tmlanguage.cache 60 | *.tmPreferences.cache 61 | *.stTheme.cache 62 | 63 | # workspace files are user-specific 64 | *.sublime-workspace 65 | 66 | # project files should be checked into the repository, unless a significant 67 | # proportion of contributors will probably not be using SublimeText 68 | # *.sublime-project 69 | 70 | # sftp configuration file 71 | sftp-config.json 72 | 73 | # Package control specific files 74 | Package Control.last-run 75 | Package Control.ca-list 76 | Package Control.ca-bundle 77 | Package Control.system-ca-bundle 78 | Package Control.cache/ 79 | Package Control.ca-certs/ 80 | Package Control.merged-ca-bundle 81 | Package Control.user-ca-bundle 82 | oscrypto-ca-bundle.crt 83 | bh_unicode_properties.cache 84 | 85 | # Sublime-github package stores a github token in this file 86 | # https://packagecontrol.io/packages/sublime-github 87 | GitHub.sublime-settings 88 | 89 | 90 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm 91 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 92 | 93 | # User-specific stuff: 94 | .idea 95 | .idea/**/workspace.xml 96 | .idea/**/tasks.xml 97 | 98 | # Sensitive or high-churn files: 99 | .idea/**/dataSources/ 100 | .idea/**/dataSources.ids 101 | .idea/**/dataSources.xml 102 | .idea/**/dataSources.local.xml 103 | .idea/**/sqlDataSources.xml 104 | .idea/**/dynamic.xml 105 | .idea/**/uiDesigner.xml 106 | 107 | # Gradle: 108 | .idea/**/gradle.xml 109 | .idea/**/libraries 110 | 111 | # Mongo Explorer plugin: 112 | .idea/**/mongoSettings.xml 113 | 114 | ## File-based project format: 115 | *.iws 116 | 117 | ## Plugin-specific files: 118 | 119 | # IntelliJ 120 | /out/ 121 | 122 | # mpeltonen/sbt-idea plugin 123 | .idea_modules/ 124 | 125 | # JIRA plugin 126 | atlassian-ide-plugin.xml 127 | 128 | # Crashlytics plugin (for Android Studio and IntelliJ) 129 | com_crashlytics_export_strings.xml 130 | crashlytics.properties 131 | crashlytics-build.properties 132 | fabric.properties 133 | 134 | 135 | # Byte-compiled / optimized / DLL files 136 | __pycache__/ 137 | *.py[cod] 138 | *$py.class 139 | 140 | # C extensions 141 | *.so 142 | 143 | # Distribution / packaging 144 | .Python 145 | env/ 146 | build/ 147 | develop-eggs/ 148 | dist/ 149 | downloads/ 150 | eggs/ 151 | .eggs/ 152 | lib/ 153 | lib64/ 154 | parts/ 155 | sdist/ 156 | var/ 157 | *.egg-info/ 158 | .installed.cfg 159 | *.egg 160 | 161 | # PyInstaller 162 | # Usually these files are written by a python script from a template 163 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 164 | *.manifest 165 | *.spec 166 | 167 | # Installer logs 168 | pip-log.txt 169 | pip-delete-this-directory.txt 170 | 171 | # Unit test / coverage reports 172 | htmlcov/ 173 | .tox/ 174 | .coverage 175 | .coverage.* 176 | .cache 177 | nosetests.xml 178 | coverage.xml 179 | *,cover 180 | .hypothesis/ 181 | 182 | # Translations 183 | *.mo 184 | *.pot 185 | 186 | # Django stuff: 187 | *.log 188 | local_settings.py 189 | 190 | # Flask stuff: 191 | instance/ 192 | .webassets-cache 193 | 194 | # Scrapy stuff: 195 | .scrapy 196 | 197 | # Sphinx documentation 198 | docs/_build/ 199 | 200 | # PyBuilder 201 | target/ 202 | 203 | # IPython Notebook 204 | .ipynb_checkpoints 205 | 206 | # pyenv 207 | .python-version 208 | 209 | # celery beat schedule file 210 | celerybeat-schedule 211 | 212 | # dotenv 213 | .env 214 | 215 | # virtualenv 216 | venv/ 217 | ENV/ 218 | 219 | # Spyder project settings 220 | .spyderproject 221 | 222 | # Rope project settings 223 | .ropeproject 224 | 225 | 226 | # Project specific 227 | examples/data 228 | examples/logs 229 | 230 | -------------------------------------------------------------------------------- /reid/evaluators.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import time 3 | from collections import OrderedDict 4 | 5 | import torch 6 | 7 | from .evaluation_metrics import cmc, mean_ap 8 | from .feature_extraction import extract_cnn_feature 9 | from .utils.meters import AverageMeter 10 | 11 | 12 | def extract_features(model, data_loader, print_freq=1, metric=None): 13 | model.eval() 14 | batch_time = AverageMeter() 15 | data_time = AverageMeter() 16 | 17 | features = OrderedDict() 18 | labels = OrderedDict() 19 | 20 | end = time.time() 21 | for i, (imgs, fnames, pids, _) in enumerate(data_loader): 22 | data_time.update(time.time() - end) 23 | 24 | outputs = extract_cnn_feature(model, imgs) 25 | for fname, output, pid in zip(fnames, outputs, pids): 26 | features[fname] = output 27 | labels[fname] = pid 28 | 29 | batch_time.update(time.time() - end) 30 | end = time.time() 31 | 32 | if (i + 1) % print_freq == 0: 33 | print('Extract Features: [{}/{}]\t' 34 | 'Time {:.3f} ({:.3f})\t' 35 | 'Data {:.3f} ({:.3f})\t' 36 | .format(i + 1, len(data_loader), 37 | batch_time.val, batch_time.avg, 38 | data_time.val, data_time.avg)) 39 | 40 | return features, labels 41 | 42 | 43 | def pairwise_distance(features, query=None, gallery=None, metric=None): 44 | if query is None and gallery is None: 45 | n = len(features) 46 | x = torch.cat(list(features.values())) 47 | x = x.view(n, -1) 48 | if metric is not None: 49 | x = metric.transform(x) 50 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2 51 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t()) 52 | return dist 53 | 54 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0) 55 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0) 56 | m, n = x.size(0), y.size(0) 57 | x = x.view(m, -1) 58 | y = y.view(n, -1) 59 | if metric is not None: 60 | x = metric.transform(x) 61 | y = metric.transform(y) 62 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \ 63 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t() 64 | dist.addmm_(1, -2, x, y.t()) 65 | return dist 66 | 67 | 68 | def evaluate_all(distmat, query=None, gallery=None, 69 | query_ids=None, gallery_ids=None, 70 | query_cams=None, gallery_cams=None, 71 | cmc_topk=(1, 5, 10)): 72 | if query is not None and gallery is not None: 73 | query_ids = [pid for _, pid, _ in query] 74 | gallery_ids = [pid for _, pid, _ in gallery] 75 | query_cams = [cam for _, _, cam in query] 76 | gallery_cams = [cam for _, _, cam in gallery] 77 | else: 78 | assert (query_ids is not None and gallery_ids is not None 79 | and query_cams is not None and gallery_cams is not None) 80 | 81 | # Compute mean AP 82 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams) 83 | print('Mean AP: {:4.1%}'.format(mAP)) 84 | 85 | # Compute all kinds of CMC scores 86 | cmc_configs = { 87 | 'allshots': dict(separate_camera_set=False, 88 | single_gallery_shot=False, 89 | first_match_break=False), 90 | 'cuhk03': dict(separate_camera_set=True, 91 | single_gallery_shot=True, 92 | first_match_break=False), 93 | 'market1501': dict(separate_camera_set=False, 94 | single_gallery_shot=False, 95 | first_match_break=True)} 96 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids, 97 | query_cams, gallery_cams, **params) 98 | for name, params in cmc_configs.items()} 99 | 100 | print('CMC Scores{:>12}{:>12}{:>12}' 101 | .format('allshots', 'cuhk03', 'market1501')) 102 | for k in cmc_topk: 103 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}' 104 | .format(k, cmc_scores['allshots'][k - 1], 105 | cmc_scores['cuhk03'][k - 1], 106 | cmc_scores['market1501'][k - 1])) 107 | 108 | # Use the allshots cmc top-1 score for validation criterion 109 | return cmc_scores['allshots'][0] 110 | 111 | 112 | class Evaluator(object): 113 | def __init__(self, model): 114 | super(Evaluator, self).__init__() 115 | self.model = model 116 | 117 | def evaluate(self, data_loader, query, gallery, metric=None): 118 | features, _ = extract_features(self.model, data_loader) 119 | distmat = pairwise_distance(features, query, gallery, metric=metric) 120 | return evaluate_all(distmat, query=query, gallery=gallery) 121 | -------------------------------------------------------------------------------- /reid/models/inception.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import 2 | 3 | import torch 4 | from torch import nn 5 | from torch.nn import functional as F 6 | from torch.nn import init 7 | 8 | 9 | __all__ = ['InceptionNet', 'inception'] 10 | 11 | 12 | def _make_conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, 13 | bias=False): 14 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, 15 | stride=stride, padding=padding, bias=bias) 16 | bn = nn.BatchNorm2d(out_planes) 17 | relu = nn.ReLU(inplace=True) 18 | return nn.Sequential(conv, bn, relu) 19 | 20 | 21 | class Block(nn.Module): 22 | def __init__(self, in_planes, out_planes, pool_method, stride): 23 | super(Block, self).__init__() 24 | self.branches = nn.ModuleList([ 25 | nn.Sequential( 26 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 27 | _make_conv(out_planes, out_planes, stride=stride) 28 | ), 29 | nn.Sequential( 30 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0), 31 | _make_conv(out_planes, out_planes), 32 | _make_conv(out_planes, out_planes, stride=stride)) 33 | ]) 34 | 35 | if pool_method == 'Avg': 36 | assert stride == 1 37 | self.branches.append( 38 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0)) 39 | self.branches.append(nn.Sequential( 40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1), 41 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0))) 42 | else: 43 | self.branches.append( 44 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1)) 45 | 46 | def forward(self, x): 47 | return torch.cat([b(x) for b in self.branches], 1) 48 | 49 | 50 | class InceptionNet(nn.Module): 51 | def __init__(self, cut_at_pooling=False, num_features=256, norm=False, 52 | dropout=0, num_classes=0): 53 | super(InceptionNet, self).__init__() 54 | self.cut_at_pooling = cut_at_pooling 55 | 56 | self.conv1 = _make_conv(3, 32) 57 | self.conv2 = _make_conv(32, 32) 58 | self.conv3 = _make_conv(32, 32) 59 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1) 60 | self.in_planes = 32 61 | self.inception4a = self._make_inception(64, 'Avg', 1) 62 | self.inception4b = self._make_inception(64, 'Max', 2) 63 | self.inception5a = self._make_inception(128, 'Avg', 1) 64 | self.inception5b = self._make_inception(128, 'Max', 2) 65 | self.inception6a = self._make_inception(256, 'Avg', 1) 66 | self.inception6b = self._make_inception(256, 'Max', 2) 67 | 68 | if not self.cut_at_pooling: 69 | self.num_features = num_features 70 | self.norm = norm 71 | self.dropout = dropout 72 | self.has_embedding = num_features > 0 73 | self.num_classes = num_classes 74 | 75 | self.avgpool = nn.AdaptiveAvgPool2d(1) 76 | 77 | if self.has_embedding: 78 | self.feat = nn.Linear(self.in_planes, self.num_features) 79 | self.feat_bn = nn.BatchNorm1d(self.num_features) 80 | else: 81 | # Change the num_features to CNN output channels 82 | self.num_features = self.in_planes 83 | if self.dropout > 0: 84 | self.drop = nn.Dropout(self.dropout) 85 | if self.num_classes > 0: 86 | self.classifier = nn.Linear(self.num_features, self.num_classes) 87 | 88 | self.reset_params() 89 | 90 | def forward(self, x): 91 | x = self.conv1(x) 92 | x = self.conv2(x) 93 | x = self.conv3(x) 94 | x = self.pool3(x) 95 | x = self.inception4a(x) 96 | x = self.inception4b(x) 97 | x = self.inception5a(x) 98 | x = self.inception5b(x) 99 | x = self.inception6a(x) 100 | x = self.inception6b(x) 101 | 102 | if self.cut_at_pooling: 103 | return x 104 | 105 | x = self.avgpool(x) 106 | x = x.view(x.size(0), -1) 107 | 108 | if self.has_embedding: 109 | x = self.feat(x) 110 | x = self.feat_bn(x) 111 | if self.norm: 112 | x = F.normalize(x) 113 | elif self.has_embedding: 114 | x = F.relu(x) 115 | if self.dropout > 0: 116 | x = self.drop(x) 117 | if self.num_classes > 0: 118 | x = self.classifier(x) 119 | return x 120 | 121 | def _make_inception(self, out_planes, pool_method, stride): 122 | block = Block(self.in_planes, out_planes, pool_method, stride) 123 | self.in_planes = (out_planes * 4 if pool_method == 'Avg' else 124 | out_planes * 2 + self.in_planes) 125 | return block 126 | 127 | def reset_params(self): 128 | for m in self.modules(): 129 | if isinstance(m, nn.Conv2d): 130 | init.kaiming_normal(m.weight, mode='fan_out') 131 | if m.bias is not None: 132 | init.constant(m.bias, 0) 133 | elif isinstance(m, nn.BatchNorm2d): 134 | init.constant(m.weight, 1) 135 | init.constant(m.bias, 0) 136 | elif isinstance(m, nn.Linear): 137 | init.normal(m.weight, std=0.001) 138 | if m.bias is not None: 139 | init.constant(m.bias, 0) 140 | 141 | 142 | def inception(**kwargs): 143 | return InceptionNet(**kwargs) 144 | -------------------------------------------------------------------------------- /docs/examples/benchmarks.rst: -------------------------------------------------------------------------------- 1 | ========== 2 | Benchmarks 3 | ========== 4 | 5 | Benchmarks for different models and loss functions on various datasets. 6 | 7 | All the experiments are conducted under the settings of: 8 | 9 | - 4 GPUs for training, meaning that ``CUDA_VISIBLE_DEVICES=0,1,2,3`` is set for the training scripts 10 | - Total effective batch size of 256. Consider reducing batch size and learning rate if you only have one GPU. See :ref:`gpu-options` for more details. 11 | - Use the default dataset split ``--split 0``, but combine training and validation sets for training models ``--combine-trainval`` 12 | - Use the default random seed ``--seed 1`` 13 | - Use Euclidean distance directly for evaluation 14 | - Use single-query and single-crop for evaluation 15 | - Full set of evaluation metrics are reported. See :ref:`evaluation-metrics` for more explanations. 16 | 17 | .. _cuhk03-benchmark: 18 | 19 | ^^^^^^ 20 | CUHK03 21 | ^^^^^^ 22 | 23 | ========= ============ ======== ============ ========== ============== =============== 24 | Net Loss Mean AP CMC allshots CMC cuhk03 CMC market1501 Training Script 25 | ========= ============ ======== ============ ========== ============== =============== 26 | Inception Triplet N/A N/A N/A N/A N/A 27 | Inception Softmax 65.8 48.6 73.2 71.0 ``python examples/softmax_loss.py -d cuhk03 -a inception --combine-trainval --epochs 70 --logs-dir examples/logs/softmax-loss/cuhk03-inception`` 28 | Inception OIM 71.4 56.0 77.7 76.5 ``python examples/oim_loss.py -d cuhk03 -a inception --combine-trainval --oim-scalar 20 --epochs 70 --logs-dir examples/logs/oim-loss/cuhk03-inception`` 29 | ResNet-50 Triplet **80.7** **67.9** **84.3** **85.0** ``python examples/triplet_loss.py -d cuhk03 -a resnet50 --combine-trainval --logs-dir examples/logs/triplet-loss/cuhk03-resnet50`` 30 | ResNet-50 Softmax 62.7 44.6 70.8 69.0 ``python examples/softmax_loss.py -d cuhk03 -a resnet50 --combine-trainval --logs-dir examples/logs/softmax-loss/cuhk03-resnet50`` 31 | ResNet-50 OIM 72.5 58.2 77.5 79.2 ``python examples/oim_loss.py -d cuhk03 -a resnet50 --combine-trainval --oim-scalar 30 --logs-dir examples/logs/oim-loss/cuhk03-resnet50`` 32 | ========= ============ ======== ============ ========== ============== =============== 33 | 34 | .. _market1501-benchmark: 35 | 36 | ^^^^^^^^^^ 37 | Market1501 38 | ^^^^^^^^^^ 39 | 40 | ========= ============ ======== ============ ========== ============== =============== 41 | Net Loss Mean AP CMC allshots CMC cuhk03 CMC market1501 Training Script 42 | ========= ============ ======== ============ ========== ============== =============== 43 | Inception Triplet N/A N/A N/A N/A N/A 44 | Inception Softmax 51.8 26.8 57.1 75.8 ``python examples/softmax_loss.py -d market1501 -a inception --combine-trainval --epochs 70 --logs-dir examples/logs/softmax-loss/market1501-inception`` 45 | Inception OIM 54.3 30.1 58.3 77.9 ``python examples/oim_loss.py -d market1501 -a inception --combine-trainval --oim-scalar 20 --epochs 70 --logs-dir examples/logs/oim-loss/market1501-inception`` 46 | ResNet-50 Triplet **67.9** **42.9** **70.5** **85.1** ``python examples/triplet_loss.py -d market1501 -a resnet50 --combine-trainval --logs-dir examples/logs/triplet-loss/market1501-resnet50`` 47 | ResNet-50 Softmax 59.8 35.5 62.8 81.4 ``python examples/softmax_loss.py -d market1501 -a resnet50 --combine-trainval --logs-dir examples/logs/softmax-loss/market1501-resnet50`` 48 | ResNet-50 OIM 60.9 37.3 63.6 82.1 ``python examples/oim_loss.py -d market1501 -a resnet50 --combine-trainval --oim-scalar 20 --logs-dir examples/logs/oim-loss/market1501-resnet50`` 49 | ========= ============ ======== ============ ========== ============== =============== 50 | 51 | .. _dukemtmc-benchmark: 52 | 53 | ^^^^^^^^ 54 | DukeMTMC 55 | ^^^^^^^^ 56 | 57 | ========= ============ ======== ============ ========== ============== =============== 58 | Net Loss Mean AP CMC allshots CMC cuhk03 CMC market1501 Training Script 59 | ========= ============ ======== ============ ========== ============== =============== 60 | Inception Triplet N/A N/A N/A N/A N/A 61 | Inception Softmax 34.0 17.4 39.2 54.4 ``python examples/softmax_loss.py -d dukemtmc -a inception --combine-trainval --epochs 70 --logs-dir examples/logs/softmax-loss/dukemtmc-inception`` 62 | Inception OIM 40.6 22.4 45.3 61.7 ``python examples/oim_loss.py -d dukemtmc -a inception --combine-trainval --oim-scalar 30 --epochs 70 --logs-dir examples/logs/oim-loss/dukemtmc-inception`` 63 | ResNet-50 Triplet **54.6** **34.6** **57.5** **73.1** ``python examples/triplet_loss.py -d dukemtmc -a resnet50 --combine-trainval --logs-dir examples/logs/triplet-loss/dukemtmc-resnet50`` 64 | ResNet-50 Softmax 40.7 23.7 44.3 62.5 ``python examples/softmax_loss.py -d dukemtmc -a resnet50 --combine-trainval --logs-dir examples/logs/softmax-loss/dukemtmc-resnet50`` 65 | ResNet-50 OIM 47.4 29.2 50.4 68.1 ``python examples/oim_loss.py -d dukemtmc -a resnet50 --combine-trainval --oim-scalar 30 --logs-dir examples/logs/oim-loss/dukemtmc-resnet50`` 66 | ========= ============ ======== ============ ========== ============== =============== 67 | 68 | .. ATTENTION:: 69 | No test-time augmentation is used. We have fixed a bug in the learning rate 70 | scheduler for the Triplet loss. Now the result of ResNet-50 with Triplet loss 71 | is slightly better than the "original" setting in [hermans2017in]_ (Table 4). 72 | -------------------------------------------------------------------------------- /docs/conf.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Open-ReID documentation build configuration file, created by 5 | # sphinx-quickstart on Sun Mar 19 15:49:51 2017. 6 | # 7 | # This file is execfile()d with the current directory set to its 8 | # containing dir. 9 | # 10 | # Note that not all possible configuration values are present in this 11 | # autogenerated file. 12 | # 13 | # All configuration values have a default; values that are commented out 14 | # serve to show the default. 15 | 16 | # If extensions (or modules to document with autodoc) are in another directory, 17 | # add these directories to sys.path here. If the directory is relative to the 18 | # documentation root, use os.path.abspath to make it absolute, like shown here. 19 | # 20 | import os 21 | import sys 22 | 23 | sys.path.insert(0, os.path.abspath('.')) 24 | sys.path.insert(0, os.path.abspath('..')) 25 | 26 | import sphinx_rtd_theme 27 | 28 | # -- General configuration ------------------------------------------------ 29 | 30 | # If your documentation needs a minimal Sphinx version, state it here. 31 | # 32 | # needs_sphinx = '1.0' 33 | 34 | # Add any Sphinx extension module names here, as strings. They can be 35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom 36 | # ones. 37 | extensions = [ 38 | 'sphinx.ext.autodoc', 39 | 'sphinx.ext.autosummary', 40 | 'sphinx.ext.doctest', 41 | 'sphinx.ext.intersphinx', 42 | 'sphinx.ext.todo', 43 | 'sphinx.ext.coverage', 44 | 'sphinx.ext.mathjax', 45 | 'sphinx.ext.napoleon', 46 | 'sphinx.ext.viewcode', 47 | 'sphinx.ext.githubpages', 48 | ] 49 | 50 | # Add any paths that contain templates here, relative to this directory. 51 | templates_path = ['_templates'] 52 | 53 | # The suffix(es) of source filenames. 54 | # You can specify multiple suffix as a list of string: 55 | # 56 | # source_suffix = ['.rst', '.md'] 57 | source_suffix = '.rst' 58 | 59 | # The master toctree document. 60 | master_doc = 'index' 61 | 62 | # General information about the project. 63 | project = 'Open-ReID' 64 | copyright = '2017, Tong Xiao' 65 | author = 'Tong Xiao' 66 | 67 | # The version info for the project you're documenting, acts as replacement for 68 | # |version| and |release|, also used in various other places throughout the 69 | # built documents. 70 | # 71 | # The short X.Y version. 72 | version = '' 73 | # The full version, including alpha/beta/rc tags. 74 | release = '' 75 | 76 | # The language for content autogenerated by Sphinx. Refer to documentation 77 | # for a list of supported languages. 78 | # 79 | # This is also used if you do content translation via gettext catalogs. 80 | # Usually you set "language" from the command line for these cases. 81 | language = None 82 | 83 | # List of patterns, relative to source directory, that match files and 84 | # directories to ignore when looking for source files. 85 | # This patterns also effect to html_static_path and html_extra_path 86 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] 87 | 88 | # The name of the Pygments (syntax highlighting) style to use. 89 | pygments_style = 'sphinx' 90 | 91 | # If true, `todo` and `todoList` produce output, else they produce nothing. 92 | todo_include_todos = True 93 | 94 | # -- Options for HTML output ---------------------------------------------- 95 | 96 | # The theme to use for HTML and HTML Help pages. See the documentation for 97 | # a list of builtin themes. 98 | # Use a custom css file. See https://blog.deimos.fr/2014/10/02/sphinxdoc-and-readthedocs-theme-tricks-2/ 99 | 100 | html_theme = 'sphinx_rtd_theme' 101 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] 102 | 103 | html_context = { 104 | 'css_files': [ 105 | 'https://fonts.googleapis.com/css?family=Lato', 106 | '_static/css/openreid_theme.css', 107 | ] 108 | } 109 | 110 | # Theme options are theme-specific and customize the look and feel of a theme 111 | # further. For a list of options available for each theme, see the 112 | # documentation. 113 | # 114 | html_theme_options = { 115 | 'collapse_navigation': False, 116 | 'display_version': False, 117 | } 118 | 119 | # Add any paths that contain custom static files (such as style sheets) here, 120 | # relative to this directory. They are copied after the builtin static files, 121 | # so a file named "default.css" will overwrite the builtin "default.css". 122 | html_static_path = ['_static'] 123 | 124 | # -- Options for HTMLHelp output ------------------------------------------ 125 | 126 | # Output file base name for HTML help builder. 127 | htmlhelp_basename = 'Open-ReIDdoc' 128 | 129 | # -- Options for LaTeX output --------------------------------------------- 130 | 131 | latex_elements = { 132 | # The paper size ('letterpaper' or 'a4paper'). 133 | # 134 | # 'papersize': 'letterpaper', 135 | 136 | # The font size ('10pt', '11pt' or '12pt'). 137 | # 138 | # 'pointsize': '10pt', 139 | 140 | # Additional stuff for the LaTeX preamble. 141 | # 142 | # 'preamble': '', 143 | 144 | # Latex figure (float) alignment 145 | # 146 | # 'figure_align': 'htbp', 147 | } 148 | 149 | # Grouping the document tree into LaTeX files. List of tuples 150 | # (source start file, target name, title, 151 | # author, documentclass [howto, manual, or own class]). 152 | latex_documents = [ 153 | (master_doc, 'Open-ReID.tex', 'Open-ReID Documentation', 154 | 'Tong Xiao', 'manual'), 155 | ] 156 | 157 | # -- Options for manual page output --------------------------------------- 158 | 159 | # One entry per manual page. List of tuples 160 | # (source start file, name, description, authors, manual section). 161 | man_pages = [ 162 | (master_doc, 'open-reid', 'Open-ReID Documentation', 163 | [author], 1) 164 | ] 165 | 166 | # -- Options for Texinfo output ------------------------------------------- 167 | 168 | # Grouping the document tree into Texinfo files. List of tuples 169 | # (source start file, target name, title, author, 170 | # dir menu entry, description, category) 171 | texinfo_documents = [ 172 | (master_doc, 'Open-ReID', 'Open-ReID Documentation', 173 | author, 'Open-ReID', 'One line description of project.', 174 | 'Miscellaneous'), 175 | ] 176 | 177 | # Example configuration for intersphinx: refer to the Python standard library. 178 | intersphinx_mapping = { 179 | 'python': ('https://docs.python.org/', None), 180 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None), 181 | 'pytorch': ('http://pytorch.org/docs/master/', None), 182 | } 183 | -------------------------------------------------------------------------------- /docs/examples/training_id.rst: -------------------------------------------------------------------------------- 1 | ============================ 2 | Training Identification Nets 3 | ============================ 4 | 5 | This example will present how to train nets with identification loss on popular 6 | datasets. 7 | 8 | The objective of training an identification net is to learn good feature 9 | representation for persons. If the features of the same person are similar, 10 | while the features of different people are dissimilar, then querying a target 11 | person from a gallery database would become easy. 12 | 13 | Different loss functions could be adopted for this purpose, for example, 14 | 15 | - Softmax cross entropy loss [zheng2016person]_ [xiao2016learning]_ 16 | - Triplet loss [hermans2017in]_ 17 | - Online instance matching (OIM) loss [xiaoli2017joint]_ 18 | 19 | 20 | .. _head-first-example: 21 | 22 | ------------------ 23 | Head First Example 24 | ------------------ 25 | 26 | After cloning the repository, we can start with training an Inception net on 27 | VIPeR with Softmax loss from scratch 28 | 29 | .. code-block:: shell 30 | 31 | python examples/softmax_loss.py -d viper -b 64 -j 2 -a inception --logs-dir logs/softmax-loss/viper-inception 32 | 33 | This script automatically downloads the VIPeR dataset and starts training, with 34 | batch size of 64 and two processes for data loading. Softmax cross entropy is 35 | used as the loss function. The training log should be print to screen as well as 36 | saved to ``logs/softmax-loss/viper-inception/log.txt``. When training ends, it 37 | will evaluate the best model (the one with best validation performance) on the 38 | test set, and report several commonly used metrics. 39 | 40 | 41 | .. _training-options: 42 | 43 | ---------------- 44 | Training Options 45 | ---------------- 46 | 47 | Many training options are available through command line arguments. See all the 48 | options by ``python examples/softmax_loss.py -h``. Here we elaborate on several 49 | commonly used options. 50 | 51 | .. _data-options: 52 | 53 | ^^^^^^^^ 54 | Datasets 55 | ^^^^^^^^ 56 | 57 | Specify the dataset by ``-d name``, where ``name`` can be one of ``cuhk03``, 58 | ``cuhk01``, ``market1501``, ``dukemtmc``, and ``viper`` currently. For some 59 | datasets that cannot be downloaded automatically, running the script will raise 60 | an error with a link to the dataset. One may need to manually download it and 61 | put it to the directory instructed by the error message. 62 | 63 | .. _model-options: 64 | 65 | ^^^^^^^^^^^^^^^^^^^ 66 | Model Architectures 67 | ^^^^^^^^^^^^^^^^^^^ 68 | 69 | Specify the model architecture by ``-a name``, where ``name`` can be one of 70 | ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, and 71 | ``inception`` currently. For ``resnet*``, running the scripts will download an 72 | ImageNet pretrained model automatically, and then finetune from it. For 73 | ``inception``, the scripts just train the net from scratch. 74 | 75 | .. _gpu-options: 76 | 77 | ^^^^^^^^^^^^^^^^^^^^^^^^ 78 | Multi-GPU and Batch Size 79 | ^^^^^^^^^^^^^^^^^^^^^^^^ 80 | 81 | All the examples support data parallel training on multiple GPUs. By default, 82 | the program will use all the GPUs listed in ``nvidia-smi``. To control which 83 | GPUs to be used, one need to specify the environment variable 84 | ``CUDA_VISIBLE_DEVICES`` before running the python script. For example, 85 | 86 | .. code-block:: shell 87 | 88 | # 4 GPUs, with effective batch size of 256 89 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/softmax_loss.py -d viper -b 256 --lr 0.1 ... 90 | 91 | # 1 GPU, reduce the batch size to 64, lr to 0.025 92 | CUDA_VISIBLE_DEVICES=0 python examples/softmax_loss.py -d viper -b 64 --lr 0.025 ... 93 | 94 | Note that the effective batch size specified by the ``-b`` option will be 95 | divided by the number of GPUs automatically for each GPU. For example, 4 GPUs 96 | with ``-b 256`` will have 64 minibatch samples on each one. 97 | 98 | In the second command above, we reduce the batch size and initial learning rate 99 | to 1/4, in order to adapt the original 4 GPUs setting to only 1 GPU. 100 | 101 | .. _resume-options: 102 | 103 | ^^^^^^^^^^^^^^^^^^^^^^^ 104 | Resume from Checkpoints 105 | ^^^^^^^^^^^^^^^^^^^^^^^ 106 | 107 | After each training epoch, the script would save a latest ``checkpoint.pth.tar`` 108 | in the specified logs directory, and update a ``model_best.pth.tar`` if the 109 | model achieves the best validation performance so far. To resume from this 110 | checkpoint, just run the script with ``--resume /path/to/checkpoint.pth.tar``. 111 | 112 | .. _eval-options: 113 | 114 | ^^^^^^^^^^^^^^^^^^^^^^^^ 115 | Evaluate a Trained Model 116 | ^^^^^^^^^^^^^^^^^^^^^^^^ 117 | 118 | To evaluate a trained model, just run the script with ``--resume 119 | /path/to/model_best.pth.tar --evaluate``. Different evaluation metrics, 120 | especially different versions of CMC could lead to drastically different 121 | numbers. 122 | 123 | 124 | .. _tips-and-tricks: 125 | 126 | --------------- 127 | Tips and Tricks 128 | --------------- 129 | 130 | Training a baseline network can be tricky. Many options and parameters could 131 | (significantly) affect the reported performance number. Here we list some tips 132 | and tricks for experiments. 133 | 134 | Combine train and val 135 | One can first use separate training and validation set to tune the 136 | hyperparameters, then fix the hyperparameters and combine both sets together 137 | to train a final model. This can be done by appending an option 138 | ``--combine-trainval``, and could lead to much better performance on the 139 | test set. 140 | 141 | Input size 142 | Larger input image size could benefit the performance. It depends on the 143 | network architecture. You may specify it by ``--height`` and ``--width``. By 144 | default, we use ``256x128`` for ``resnet*`` and ``144x56`` for ``inception``. 145 | 146 | Multi-scale multi-crop test 147 | Using multi-scale multi-crop for test normally guarantees performance gain. 148 | However, it sacrifices the running speed significantly. We have not 149 | implemented this yet. 150 | 151 | Classifier initialization for softmax cross entropy loss 152 | We found that initializing the softmax classifier weight with normal 153 | distribution ``std=0.001`` generally leads to better performance. It is also 154 | important to use larger learning rate for the classifier if underlying CNN is 155 | already pretrained. 156 | 157 | 158 | ---------- 159 | References 160 | ---------- 161 | 162 | .. [zheng2016person] L. Zheng, Y. Yang, and A.G. Hauptmann. Person Re-identification: Past, Present and Future. *arXiv:1610.02984*, 2014. 163 | .. [xiao2016learning] T. Xiao, H. Li, W. Ouyang, and X. Wang. Learning deep feature representations with domain guided dropout for person re-identification. In *CVPR*, 2016. 164 | .. [xiaoli2017joint] T. Xiao\*, S. Li\*, B. Wang, L. Lin, and X. Wang. Joint Detection and Identification Feature Learning for Person Search. In *CVPR*, 2017. 165 | .. [hermans2017in] A. Hermans, L. Beyer, and B. Leibe. In Defense of the Triplet Loss for Person Re-Identification. *arXiv:1703.07737*, 2017. -------------------------------------------------------------------------------- /examples/softmax_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | 12 | from reid import datasets 13 | from reid import models 14 | from reid.dist_metric import DistanceMetric 15 | from reid.trainers import Trainer 16 | from reid.evaluators import Evaluator 17 | from reid.utils.data import transforms as T 18 | from reid.utils.data.preprocessor import Preprocessor 19 | from reid.utils.logging import Logger 20 | from reid.utils.serialization import load_checkpoint, save_checkpoint 21 | 22 | 23 | def get_data(name, split_id, data_dir, height, width, batch_size, workers, 24 | combine_trainval): 25 | root = osp.join(data_dir, name) 26 | 27 | dataset = datasets.create(name, root, split_id=split_id) 28 | 29 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 30 | std=[0.229, 0.224, 0.225]) 31 | 32 | train_set = dataset.trainval if combine_trainval else dataset.train 33 | num_classes = (dataset.num_trainval_ids if combine_trainval 34 | else dataset.num_train_ids) 35 | 36 | train_transformer = T.Compose([ 37 | T.RandomSizedRectCrop(height, width), 38 | T.RandomHorizontalFlip(), 39 | T.ToTensor(), 40 | normalizer, 41 | ]) 42 | 43 | test_transformer = T.Compose([ 44 | T.RectScale(height, width), 45 | T.ToTensor(), 46 | normalizer, 47 | ]) 48 | 49 | train_loader = DataLoader( 50 | Preprocessor(train_set, root=dataset.images_dir, 51 | transform=train_transformer), 52 | batch_size=batch_size, num_workers=workers, 53 | shuffle=True, pin_memory=True, drop_last=True) 54 | 55 | val_loader = DataLoader( 56 | Preprocessor(dataset.val, root=dataset.images_dir, 57 | transform=test_transformer), 58 | batch_size=batch_size, num_workers=workers, 59 | shuffle=False, pin_memory=True) 60 | 61 | test_loader = DataLoader( 62 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 63 | root=dataset.images_dir, transform=test_transformer), 64 | batch_size=batch_size, num_workers=workers, 65 | shuffle=False, pin_memory=True) 66 | 67 | return dataset, num_classes, train_loader, val_loader, test_loader 68 | 69 | 70 | def main(args): 71 | np.random.seed(args.seed) 72 | torch.manual_seed(args.seed) 73 | cudnn.benchmark = True 74 | 75 | # Redirect print to both console and log file 76 | if not args.evaluate: 77 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 78 | 79 | # Create data loaders 80 | if args.height is None or args.width is None: 81 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 82 | (256, 128) 83 | dataset, num_classes, train_loader, val_loader, test_loader = \ 84 | get_data(args.dataset, args.split, args.data_dir, args.height, 85 | args.width, args.batch_size, args.workers, 86 | args.combine_trainval) 87 | 88 | # Create model 89 | model = models.create(args.arch, num_features=args.features, 90 | dropout=args.dropout, num_classes=num_classes) 91 | 92 | # Load from checkpoint 93 | start_epoch = best_top1 = 0 94 | if args.resume: 95 | checkpoint = load_checkpoint(args.resume) 96 | model.load_state_dict(checkpoint['state_dict']) 97 | start_epoch = checkpoint['epoch'] 98 | best_top1 = checkpoint['best_top1'] 99 | print("=> Start epoch {} best top1 {:.1%}" 100 | .format(start_epoch, best_top1)) 101 | model = nn.DataParallel(model).cuda() 102 | 103 | # Distance metric 104 | metric = DistanceMetric(algorithm=args.dist_metric) 105 | 106 | # Evaluator 107 | evaluator = Evaluator(model) 108 | if args.evaluate: 109 | metric.train(model, train_loader) 110 | print("Validation:") 111 | evaluator.evaluate(val_loader, dataset.val, dataset.val, metric) 112 | print("Test:") 113 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 114 | return 115 | 116 | # Criterion 117 | criterion = nn.CrossEntropyLoss().cuda() 118 | 119 | # Optimizer 120 | if hasattr(model.module, 'base'): 121 | base_param_ids = set(map(id, model.module.base.parameters())) 122 | new_params = [p for p in model.parameters() if 123 | id(p) not in base_param_ids] 124 | param_groups = [ 125 | {'params': model.module.base.parameters(), 'lr_mult': 0.1}, 126 | {'params': new_params, 'lr_mult': 1.0}] 127 | else: 128 | param_groups = model.parameters() 129 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 130 | momentum=args.momentum, 131 | weight_decay=args.weight_decay, 132 | nesterov=True) 133 | 134 | # Trainer 135 | trainer = Trainer(model, criterion) 136 | 137 | # Schedule learning rate 138 | def adjust_lr(epoch): 139 | step_size = 60 if args.arch == 'inception' else 40 140 | lr = args.lr * (0.1 ** (epoch // step_size)) 141 | for g in optimizer.param_groups: 142 | g['lr'] = lr * g.get('lr_mult', 1) 143 | 144 | # Start training 145 | for epoch in range(start_epoch, args.epochs): 146 | adjust_lr(epoch) 147 | trainer.train(epoch, train_loader, optimizer) 148 | if epoch < args.start_save: 149 | continue 150 | top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val) 151 | 152 | is_best = top1 > best_top1 153 | best_top1 = max(top1, best_top1) 154 | save_checkpoint({ 155 | 'state_dict': model.module.state_dict(), 156 | 'epoch': epoch + 1, 157 | 'best_top1': best_top1, 158 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 159 | 160 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. 161 | format(epoch, top1, best_top1, ' *' if is_best else '')) 162 | 163 | # Final test 164 | print('Test with best model:') 165 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 166 | model.module.load_state_dict(checkpoint['state_dict']) 167 | metric.train(model, train_loader) 168 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 169 | 170 | 171 | if __name__ == '__main__': 172 | parser = argparse.ArgumentParser(description="Softmax loss classification") 173 | # data 174 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 175 | choices=datasets.names()) 176 | parser.add_argument('-b', '--batch-size', type=int, default=256) 177 | parser.add_argument('-j', '--workers', type=int, default=4) 178 | parser.add_argument('--split', type=int, default=0) 179 | parser.add_argument('--height', type=int, 180 | help="input height, default: 256 for resnet*, " 181 | "144 for inception") 182 | parser.add_argument('--width', type=int, 183 | help="input width, default: 128 for resnet*, " 184 | "56 for inception") 185 | parser.add_argument('--combine-trainval', action='store_true', 186 | help="train and val sets together for training, " 187 | "val set alone for validation") 188 | # model 189 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 190 | choices=models.names()) 191 | parser.add_argument('--features', type=int, default=128) 192 | parser.add_argument('--dropout', type=float, default=0.5) 193 | # optimizer 194 | parser.add_argument('--lr', type=float, default=0.1, 195 | help="learning rate of new parameters, for pretrained " 196 | "parameters it is 10 times smaller than this") 197 | parser.add_argument('--momentum', type=float, default=0.9) 198 | parser.add_argument('--weight-decay', type=float, default=5e-4) 199 | # training configs 200 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 201 | parser.add_argument('--evaluate', action='store_true', 202 | help="evaluation only") 203 | parser.add_argument('--epochs', type=int, default=50) 204 | parser.add_argument('--start_save', type=int, default=0, 205 | help="start saving checkpoints after specific epoch") 206 | parser.add_argument('--seed', type=int, default=1) 207 | parser.add_argument('--print-freq', type=int, default=1) 208 | # metric learning 209 | parser.add_argument('--dist-metric', type=str, default='euclidean', 210 | choices=['euclidean', 'kissme']) 211 | # misc 212 | working_dir = osp.dirname(osp.abspath(__file__)) 213 | parser.add_argument('--data-dir', type=str, metavar='PATH', 214 | default=osp.join(working_dir, 'data')) 215 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 216 | default=osp.join(working_dir, 'logs')) 217 | main(parser.parse_args()) 218 | -------------------------------------------------------------------------------- /examples/triplet_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | 12 | from reid import datasets 13 | from reid import models 14 | from reid.dist_metric import DistanceMetric 15 | from reid.loss import TripletLoss 16 | from reid.trainers import Trainer 17 | from reid.evaluators import Evaluator 18 | from reid.utils.data import transforms as T 19 | from reid.utils.data.preprocessor import Preprocessor 20 | from reid.utils.data.sampler import RandomIdentitySampler 21 | from reid.utils.logging import Logger 22 | from reid.utils.serialization import load_checkpoint, save_checkpoint 23 | 24 | 25 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances, 26 | workers, combine_trainval): 27 | root = osp.join(data_dir, name) 28 | 29 | dataset = datasets.create(name, root, split_id=split_id) 30 | 31 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 32 | std=[0.229, 0.224, 0.225]) 33 | 34 | train_set = dataset.trainval if combine_trainval else dataset.train 35 | num_classes = (dataset.num_trainval_ids if combine_trainval 36 | else dataset.num_train_ids) 37 | 38 | train_transformer = T.Compose([ 39 | T.RandomSizedRectCrop(height, width), 40 | T.RandomHorizontalFlip(), 41 | T.ToTensor(), 42 | normalizer, 43 | ]) 44 | 45 | test_transformer = T.Compose([ 46 | T.RectScale(height, width), 47 | T.ToTensor(), 48 | normalizer, 49 | ]) 50 | 51 | train_loader = DataLoader( 52 | Preprocessor(train_set, root=dataset.images_dir, 53 | transform=train_transformer), 54 | batch_size=batch_size, num_workers=workers, 55 | sampler=RandomIdentitySampler(train_set, num_instances), 56 | pin_memory=True, drop_last=True) 57 | 58 | val_loader = DataLoader( 59 | Preprocessor(dataset.val, root=dataset.images_dir, 60 | transform=test_transformer), 61 | batch_size=batch_size, num_workers=workers, 62 | shuffle=False, pin_memory=True) 63 | 64 | test_loader = DataLoader( 65 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 66 | root=dataset.images_dir, transform=test_transformer), 67 | batch_size=batch_size, num_workers=workers, 68 | shuffle=False, pin_memory=True) 69 | 70 | return dataset, num_classes, train_loader, val_loader, test_loader 71 | 72 | 73 | def main(args): 74 | np.random.seed(args.seed) 75 | torch.manual_seed(args.seed) 76 | cudnn.benchmark = True 77 | 78 | # Redirect print to both console and log file 79 | if not args.evaluate: 80 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 81 | 82 | # Create data loaders 83 | assert args.num_instances > 1, "num_instances should be greater than 1" 84 | assert args.batch_size % args.num_instances == 0, \ 85 | 'num_instances should divide batch_size' 86 | if args.height is None or args.width is None: 87 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 88 | (256, 128) 89 | dataset, num_classes, train_loader, val_loader, test_loader = \ 90 | get_data(args.dataset, args.split, args.data_dir, args.height, 91 | args.width, args.batch_size, args.num_instances, args.workers, 92 | args.combine_trainval) 93 | 94 | # Create model 95 | # Hacking here to let the classifier be the last feature embedding layer 96 | # Net structure: avgpool -> FC(1024) -> FC(args.features) 97 | model = models.create(args.arch, num_features=1024, 98 | dropout=args.dropout, num_classes=args.features) 99 | 100 | # Load from checkpoint 101 | start_epoch = best_top1 = 0 102 | if args.resume: 103 | checkpoint = load_checkpoint(args.resume) 104 | model.load_state_dict(checkpoint['state_dict']) 105 | start_epoch = checkpoint['epoch'] 106 | best_top1 = checkpoint['best_top1'] 107 | print("=> Start epoch {} best top1 {:.1%}" 108 | .format(start_epoch, best_top1)) 109 | model = nn.DataParallel(model).cuda() 110 | 111 | # Distance metric 112 | metric = DistanceMetric(algorithm=args.dist_metric) 113 | 114 | # Evaluator 115 | evaluator = Evaluator(model) 116 | if args.evaluate: 117 | metric.train(model, train_loader) 118 | print("Validation:") 119 | evaluator.evaluate(val_loader, dataset.val, dataset.val, metric) 120 | print("Test:") 121 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 122 | return 123 | 124 | # Criterion 125 | criterion = TripletLoss(margin=args.margin).cuda() 126 | 127 | # Optimizer 128 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, 129 | weight_decay=args.weight_decay) 130 | 131 | # Trainer 132 | trainer = Trainer(model, criterion) 133 | 134 | # Schedule learning rate 135 | def adjust_lr(epoch): 136 | lr = args.lr if epoch <= 100 else \ 137 | args.lr * (0.001 ** ((epoch - 100) / 50.0)) 138 | for g in optimizer.param_groups: 139 | g['lr'] = lr * g.get('lr_mult', 1) 140 | 141 | # Start training 142 | for epoch in range(start_epoch, args.epochs): 143 | adjust_lr(epoch) 144 | trainer.train(epoch, train_loader, optimizer) 145 | if epoch < args.start_save: 146 | continue 147 | top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val) 148 | 149 | is_best = top1 > best_top1 150 | best_top1 = max(top1, best_top1) 151 | save_checkpoint({ 152 | 'state_dict': model.module.state_dict(), 153 | 'epoch': epoch + 1, 154 | 'best_top1': best_top1, 155 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 156 | 157 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. 158 | format(epoch, top1, best_top1, ' *' if is_best else '')) 159 | 160 | # Final test 161 | print('Test with best model:') 162 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 163 | model.module.load_state_dict(checkpoint['state_dict']) 164 | metric.train(model, train_loader) 165 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 166 | 167 | 168 | if __name__ == '__main__': 169 | parser = argparse.ArgumentParser(description="Triplet loss classification") 170 | # data 171 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 172 | choices=datasets.names()) 173 | parser.add_argument('-b', '--batch-size', type=int, default=256) 174 | parser.add_argument('-j', '--workers', type=int, default=4) 175 | parser.add_argument('--split', type=int, default=0) 176 | parser.add_argument('--height', type=int, 177 | help="input height, default: 256 for resnet*, " 178 | "144 for inception") 179 | parser.add_argument('--width', type=int, 180 | help="input width, default: 128 for resnet*, " 181 | "56 for inception") 182 | parser.add_argument('--combine-trainval', action='store_true', 183 | help="train and val sets together for training, " 184 | "val set alone for validation") 185 | parser.add_argument('--num-instances', type=int, default=4, 186 | help="each minibatch consist of " 187 | "(batch_size // num_instances) identities, and " 188 | "each identity has num_instances instances, " 189 | "default: 4") 190 | # model 191 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 192 | choices=models.names()) 193 | parser.add_argument('--features', type=int, default=128) 194 | parser.add_argument('--dropout', type=float, default=0) 195 | # loss 196 | parser.add_argument('--margin', type=float, default=0.5, 197 | help="margin of the triplet loss, default: 0.5") 198 | # optimizer 199 | parser.add_argument('--lr', type=float, default=0.0002, 200 | help="learning rate of all parameters") 201 | parser.add_argument('--weight-decay', type=float, default=5e-4) 202 | # training configs 203 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 204 | parser.add_argument('--evaluate', action='store_true', 205 | help="evaluation only") 206 | parser.add_argument('--epochs', type=int, default=150) 207 | parser.add_argument('--start_save', type=int, default=0, 208 | help="start saving checkpoints after specific epoch") 209 | parser.add_argument('--seed', type=int, default=1) 210 | parser.add_argument('--print-freq', type=int, default=1) 211 | # metric learning 212 | parser.add_argument('--dist-metric', type=str, default='euclidean', 213 | choices=['euclidean', 'kissme']) 214 | # misc 215 | working_dir = osp.dirname(osp.abspath(__file__)) 216 | parser.add_argument('--data-dir', type=str, metavar='PATH', 217 | default=osp.join(working_dir, 'data')) 218 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 219 | default=osp.join(working_dir, 'logs')) 220 | main(parser.parse_args()) 221 | -------------------------------------------------------------------------------- /examples/oim_loss.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, absolute_import 2 | import argparse 3 | import os.path as osp 4 | 5 | import numpy as np 6 | import sys 7 | import torch 8 | from torch import nn 9 | from torch.backends import cudnn 10 | from torch.utils.data import DataLoader 11 | 12 | from reid import datasets 13 | from reid import models 14 | from reid.dist_metric import DistanceMetric 15 | from reid.loss import OIMLoss 16 | from reid.trainers import Trainer 17 | from reid.evaluators import Evaluator 18 | from reid.utils.data import transforms as T 19 | from reid.utils.data.preprocessor import Preprocessor 20 | from reid.utils.logging import Logger 21 | from reid.utils.serialization import load_checkpoint, save_checkpoint 22 | 23 | 24 | def get_data(name, split_id, data_dir, height, width, batch_size, workers, 25 | combine_trainval): 26 | root = osp.join(data_dir, name) 27 | 28 | dataset = datasets.create(name, root, split_id=split_id) 29 | 30 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406], 31 | std=[0.229, 0.224, 0.225]) 32 | 33 | train_set = dataset.trainval if combine_trainval else dataset.train 34 | num_classes = (dataset.num_trainval_ids if combine_trainval 35 | else dataset.num_train_ids) 36 | 37 | train_transformer = T.Compose([ 38 | T.RandomSizedRectCrop(height, width), 39 | T.RandomHorizontalFlip(), 40 | T.ToTensor(), 41 | normalizer, 42 | ]) 43 | 44 | test_transformer = T.Compose([ 45 | T.RectScale(height, width), 46 | T.ToTensor(), 47 | normalizer, 48 | ]) 49 | 50 | train_loader = DataLoader( 51 | Preprocessor(train_set, root=dataset.images_dir, 52 | transform=train_transformer), 53 | batch_size=batch_size, num_workers=workers, 54 | shuffle=True, pin_memory=True, drop_last=True) 55 | 56 | val_loader = DataLoader( 57 | Preprocessor(dataset.val, root=dataset.images_dir, 58 | transform=test_transformer), 59 | batch_size=batch_size, num_workers=workers, 60 | shuffle=False, pin_memory=True) 61 | 62 | test_loader = DataLoader( 63 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)), 64 | root=dataset.images_dir, transform=test_transformer), 65 | batch_size=batch_size, num_workers=workers, 66 | shuffle=False, pin_memory=True) 67 | 68 | return dataset, num_classes, train_loader, val_loader, test_loader 69 | 70 | 71 | def main(args): 72 | np.random.seed(args.seed) 73 | torch.manual_seed(args.seed) 74 | cudnn.benchmark = True 75 | 76 | # Redirect print to both console and log file 77 | if not args.evaluate: 78 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt')) 79 | 80 | # Create data loaders 81 | if args.height is None or args.width is None: 82 | args.height, args.width = (144, 56) if args.arch == 'inception' else \ 83 | (256, 128) 84 | dataset, num_classes, train_loader, val_loader, test_loader = \ 85 | get_data(args.dataset, args.split, args.data_dir, args.height, 86 | args.width, args.batch_size, args.workers, 87 | args.combine_trainval) 88 | 89 | # Create model 90 | model = models.create(args.arch, num_features=args.features, norm=True, 91 | dropout=args.dropout) 92 | 93 | # Load from checkpoint 94 | start_epoch = best_top1 = 0 95 | if args.resume: 96 | checkpoint = load_checkpoint(args.resume) 97 | model.load_state_dict(checkpoint['state_dict']) 98 | start_epoch = checkpoint['epoch'] 99 | best_top1 = checkpoint['best_top1'] 100 | print("=> Start epoch {} best top1 {:.1%}" 101 | .format(start_epoch, best_top1)) 102 | model = nn.DataParallel(model).cuda() 103 | 104 | # Distance metric 105 | metric = DistanceMetric(algorithm=args.dist_metric) 106 | 107 | # Evaluator 108 | evaluator = Evaluator(model) 109 | if args.evaluate: 110 | metric.train(model, train_loader) 111 | print("Validation:") 112 | evaluator.evaluate(val_loader, dataset.val, dataset.val, metric) 113 | print("Test:") 114 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 115 | return 116 | 117 | # Criterion 118 | criterion = OIMLoss(model.module.num_features, num_classes, 119 | scalar=args.oim_scalar, 120 | momentum=args.oim_momentum).cuda() 121 | 122 | # Optimizer 123 | if hasattr(model.module, 'base'): 124 | base_param_ids = set(map(id, model.module.base.parameters())) 125 | new_params = [p for p in model.parameters() if 126 | id(p) not in base_param_ids] 127 | param_groups = [ 128 | {'params': model.module.base.parameters(), 'lr_mult': 0.1}, 129 | {'params': new_params, 'lr_mult': 1.0}] 130 | else: 131 | param_groups = model.parameters() 132 | optimizer = torch.optim.SGD(param_groups, lr=args.lr, 133 | momentum=args.momentum, 134 | weight_decay=args.weight_decay, 135 | nesterov=True) 136 | 137 | # Trainer 138 | trainer = Trainer(model, criterion) 139 | 140 | # Schedule learning rate 141 | def adjust_lr(epoch): 142 | step_size = 60 if args.arch == 'inception' else 40 143 | lr = args.lr * (0.1 ** (epoch // step_size)) 144 | for g in optimizer.param_groups: 145 | g['lr'] = lr * g.get('lr_mult', 1) 146 | 147 | # Start training 148 | for epoch in range(start_epoch, args.epochs): 149 | adjust_lr(epoch) 150 | trainer.train(epoch, train_loader, optimizer) 151 | if epoch < args.start_save: 152 | continue 153 | top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val) 154 | 155 | is_best = top1 > best_top1 156 | best_top1 = max(top1, best_top1) 157 | save_checkpoint({ 158 | 'state_dict': model.module.state_dict(), 159 | 'epoch': epoch + 1, 160 | 'best_top1': best_top1, 161 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar')) 162 | 163 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'. 164 | format(epoch, top1, best_top1, ' *' if is_best else '')) 165 | 166 | # Final test 167 | print('Test with best model:') 168 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar')) 169 | model.module.load_state_dict(checkpoint['state_dict']) 170 | metric.train(model, train_loader) 171 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric) 172 | 173 | 174 | if __name__ == '__main__': 175 | parser = argparse.ArgumentParser(description="OIM loss") 176 | # data 177 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03', 178 | choices=datasets.names()) 179 | parser.add_argument('-b', '--batch-size', type=int, default=256) 180 | parser.add_argument('-j', '--workers', type=int, default=4) 181 | parser.add_argument('--split', type=int, default=0) 182 | parser.add_argument('--height', type=int, 183 | help="input height, default: 256 for resnet*, " 184 | "144 for inception") 185 | parser.add_argument('--width', type=int, 186 | help="input width, default: 128 for resnet*, " 187 | "56 for inception") 188 | parser.add_argument('--combine-trainval', action='store_true', 189 | help="train and val sets together for training, " 190 | "val set alone for validation") 191 | # model 192 | parser.add_argument('-a', '--arch', type=str, default='resnet50', 193 | choices=models.names()) 194 | parser.add_argument('--features', type=int, default=128) 195 | parser.add_argument('--dropout', type=float, default=0.5) 196 | # loss 197 | parser.add_argument('--oim-scalar', type=float, default=30, 198 | help='reciprocal of the temperature in OIM loss') 199 | parser.add_argument('--oim-momentum', type=float, default=0.5, 200 | help='momentum for updating the LUT in OIM loss') 201 | # optimizer 202 | parser.add_argument('--lr', type=float, default=0.1, 203 | help="learning rate of new parameters, for pretrained " 204 | "parameters it is 10 times smaller than this") 205 | parser.add_argument('--momentum', type=float, default=0.9) 206 | parser.add_argument('--weight-decay', type=float, default=5e-4) 207 | # training configs 208 | parser.add_argument('--resume', type=str, default='', metavar='PATH') 209 | parser.add_argument('--evaluate', action='store_true', 210 | help="evaluation only") 211 | parser.add_argument('--epochs', type=int, default=50) 212 | parser.add_argument('--start_save', type=int, default=0, 213 | help="start saving checkpoints after specific epoch") 214 | parser.add_argument('--seed', type=int, default=1) 215 | parser.add_argument('--print-freq', type=int, default=1) 216 | # metric learning 217 | parser.add_argument('--dist-metric', type=str, default='euclidean', 218 | choices=['euclidean', 'kissme']) 219 | # misc 220 | working_dir = osp.dirname(osp.abspath(__file__)) 221 | parser.add_argument('--data-dir', type=str, metavar='PATH', 222 | default=osp.join(working_dir, 'data')) 223 | parser.add_argument('--logs-dir', type=str, metavar='PATH', 224 | default=osp.join(working_dir, 'logs')) 225 | main(parser.parse_args()) 226 | --------------------------------------------------------------------------------