├── docs
├── requirements.txt
├── figures
│ ├── structure.png
│ └── data-loading.png
├── dist_metric.rst
├── trainers.rst
├── evaluators.rst
├── datasets.rst
├── models.rst
├── Makefile
├── make.bat
├── index.rst
├── notes
│ ├── overview.rst
│ ├── evaluation_metrics.rst
│ └── data_modules.rst
├── _static
│ └── css
│ │ └── openreid_theme.css
├── examples
│ ├── benchmarks.rst
│ └── training_id.rst
└── conf.py
├── setup.cfg
├── reid
├── utils
│ ├── data
│ │ ├── __init__.py
│ │ ├── preprocessor.py
│ │ ├── sampler.py
│ │ ├── transforms.py
│ │ └── dataset.py
│ ├── osutils.py
│ ├── meters.py
│ ├── __init__.py
│ ├── logging.py
│ └── serialization.py
├── evaluation_metrics
│ ├── __init__.py
│ ├── classification.py
│ └── ranking.py
├── loss
│ ├── __init__.py
│ ├── triplet.py
│ └── oim.py
├── feature_extraction
│ ├── __init__.py
│ ├── cnn.py
│ └── database.py
├── __init__.py
├── metric_learning
│ ├── euclidean.py
│ ├── __init__.py
│ └── kissme.py
├── dist_metric.py
├── datasets
│ ├── __init__.py
│ ├── cuhk01.py
│ ├── viper.py
│ ├── market1501.py
│ ├── dukemtmc.py
│ └── cuhk03.py
├── models
│ ├── __init__.py
│ ├── resnet.py
│ └── inception.py
├── trainers.py
└── evaluators.py
├── setup.py
├── test
├── models
│ └── test_inception.py
├── loss
│ └── test_oim.py
├── feature_extraction
│ └── test_database.py
├── datasets
│ ├── test_cuhk01.py
│ ├── test_viper.py
│ ├── test_cuhk03.py
│ ├── test_dukemtmc.py
│ └── test_market1501.py
├── utils
│ └── data
│ │ └── test_preprocessor.py
└── evaluation_metrics
│ └── test_cmc.py
├── README.md
├── LICENSE
├── .gitignore
└── examples
├── softmax_loss.py
├── triplet_loss.py
└── oim_loss.py
/docs/requirements.txt:
--------------------------------------------------------------------------------
1 | sphinx
2 | sphinx-rtd-theme
3 |
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [metadata]
2 | description-file = README.md
3 |
--------------------------------------------------------------------------------
/docs/figures/structure.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cysu/open-reid/HEAD/docs/figures/structure.png
--------------------------------------------------------------------------------
/docs/figures/data-loading.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Cysu/open-reid/HEAD/docs/figures/data-loading.png
--------------------------------------------------------------------------------
/reid/utils/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .dataset import Dataset
4 | from .preprocessor import Preprocessor
5 |
--------------------------------------------------------------------------------
/docs/dist_metric.rst:
--------------------------------------------------------------------------------
1 | ================
2 | reid.dist_metric
3 | ================
4 |
5 | .. automodule:: reid.dist_metric
6 | .. currentmodule:: reid.dist_metric
7 |
8 | .. autoclass:: reid.dist_metric.DistanceMetric
9 | :members:
--------------------------------------------------------------------------------
/reid/evaluation_metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .classification import accuracy
4 | from .ranking import cmc, mean_ap
5 |
6 | __all__ = [
7 | 'accuracy',
8 | 'cmc',
9 | 'mean_ap',
10 | ]
11 |
--------------------------------------------------------------------------------
/reid/loss/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .oim import oim, OIM, OIMLoss
4 | from .triplet import TripletLoss
5 |
6 | __all__ = [
7 | 'oim',
8 | 'OIM',
9 | 'OIMLoss',
10 | 'TripletLoss',
11 | ]
12 |
--------------------------------------------------------------------------------
/docs/trainers.rst:
--------------------------------------------------------------------------------
1 | =============
2 | reid.trainers
3 | =============
4 |
5 | .. automodule:: reid.trainers
6 | .. currentmodule:: reid.trainers
7 |
8 | .. autoclass:: BaseTrainer
9 | :members:
10 |
11 | .. autoclass:: Trainer
12 | :members:
13 |
--------------------------------------------------------------------------------
/reid/feature_extraction/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .cnn import extract_cnn_feature
4 | from .database import FeatureDatabase
5 |
6 | __all__ = [
7 | 'extract_cnn_feature',
8 | 'FeatureDatabase',
9 | ]
10 |
--------------------------------------------------------------------------------
/reid/utils/osutils.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import errno
4 |
5 |
6 | def mkdir_if_missing(dir_path):
7 | try:
8 | os.makedirs(dir_path)
9 | except OSError as e:
10 | if e.errno != errno.EEXIST:
11 | raise
12 |
--------------------------------------------------------------------------------
/docs/evaluators.rst:
--------------------------------------------------------------------------------
1 | ===============
2 | reid.evaluators
3 | ===============
4 |
5 | .. automodule:: reid.evaluators
6 | .. currentmodule:: reid.evaluators
7 |
8 | .. autofunction:: extract_features
9 | .. autofunction:: pairwise_distance
10 | .. autofunction:: evaluate_all
11 | .. autoclass:: Evaluator
12 | :members:
--------------------------------------------------------------------------------
/docs/datasets.rst:
--------------------------------------------------------------------------------
1 | =============
2 | reid.datasets
3 | =============
4 |
5 | .. automodule:: reid.datasets
6 | .. currentmodule:: reid.datasets
7 |
8 | .. autofunction:: create
9 |
10 | .. autoclass:: CUHK01
11 | .. autoclass:: CUHK03
12 | .. autoclass:: DukeMTMC
13 | .. autoclass:: Market1501
14 | .. autoclass:: VIPeR
15 |
--------------------------------------------------------------------------------
/reid/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from . import datasets
4 | from . import evaluation_metrics
5 | from . import feature_extraction
6 | from . import loss
7 | from . import metric_learning
8 | from . import models
9 | from . import utils
10 | from . import dist_metric
11 | from . import evaluators
12 | from . import trainers
13 |
14 | __version__ = '0.2.0'
15 |
--------------------------------------------------------------------------------
/docs/models.rst:
--------------------------------------------------------------------------------
1 | ===========
2 | reid.models
3 | ===========
4 |
5 | .. automodule:: reid.models
6 | .. currentmodule:: reid.models
7 |
8 | .. autofunction:: create
9 | .. autofunction:: inception
10 | .. autofunction:: resnet18
11 | .. autofunction:: resnet34
12 | .. autofunction:: resnet50
13 | .. autofunction:: resnet101
14 | .. autofunction:: resnet152
15 |
16 | .. autoclass:: InceptionNet
17 | .. autoclass:: ResNet
18 |
--------------------------------------------------------------------------------
/reid/metric_learning/euclidean.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import numpy as np
4 | from metric_learn.base_metric import BaseMetricLearner
5 |
6 |
7 | class Euclidean(BaseMetricLearner):
8 | def __init__(self):
9 | self.M_ = None
10 |
11 | def metric(self):
12 | return self.M_
13 |
14 | def fit(self, X):
15 | self.M_ = np.eye(X.shape[1])
16 | self.X_ = X
17 |
18 | def transform(self, X=None):
19 | if X is None:
20 | return self.X_
21 | return X
22 |
--------------------------------------------------------------------------------
/reid/utils/meters.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 |
4 | class AverageMeter(object):
5 | """Computes and stores the average and current value"""
6 |
7 | def __init__(self):
8 | self.val = 0
9 | self.avg = 0
10 | self.sum = 0
11 | self.count = 0
12 |
13 | def reset(self):
14 | self.val = 0
15 | self.avg = 0
16 | self.sum = 0
17 | self.count = 0
18 |
19 | def update(self, val, n=1):
20 | self.val = val
21 | self.sum += val * n
22 | self.count += n
23 | self.avg = self.sum / self.count
24 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/classification.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from ..utils import to_torch
4 |
5 |
6 | def accuracy(output, target, topk=(1,)):
7 | output, target = to_torch(output), to_torch(target)
8 | maxk = max(topk)
9 | batch_size = target.size(0)
10 |
11 | _, pred = output.topk(maxk, 1, True, True)
12 | pred = pred.t()
13 | correct = pred.eq(target.view(1, -1).expand_as(pred))
14 |
15 | ret = []
16 | for k in topk:
17 | correct_k = correct[:k].view(-1).float().sum(dim=0, keepdim=True)
18 | ret.append(correct_k.mul_(1. / batch_size))
19 | return ret
20 |
--------------------------------------------------------------------------------
/docs/Makefile:
--------------------------------------------------------------------------------
1 | # Minimal makefile for Sphinx documentation
2 | #
3 |
4 | # You can set these variables from the command line.
5 | SPHINXOPTS =
6 | SPHINXBUILD = sphinx-build
7 | SPHINXPROJ = Open-ReID
8 | SOURCEDIR = .
9 | BUILDDIR = _build
10 |
11 | # Put it first so that "make" without argument is like "make help".
12 | help:
13 | @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
14 |
15 | .PHONY: help Makefile
16 |
17 | # Catch-all target: route all unknown targets to Sphinx using the new
18 | # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
19 | %: Makefile
20 | @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
--------------------------------------------------------------------------------
/reid/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 |
6 | def to_numpy(tensor):
7 | if torch.is_tensor(tensor):
8 | return tensor.cpu().numpy()
9 | elif type(tensor).__module__ != 'numpy':
10 | raise ValueError("Cannot convert {} to numpy array"
11 | .format(type(tensor)))
12 | return tensor
13 |
14 |
15 | def to_torch(ndarray):
16 | if type(ndarray).__module__ == 'numpy':
17 | return torch.from_numpy(ndarray)
18 | elif not torch.is_tensor(ndarray):
19 | raise ValueError("Cannot convert {} to torch tensor"
20 | .format(type(ndarray)))
21 | return ndarray
22 |
--------------------------------------------------------------------------------
/reid/metric_learning/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from metric_learn import (ITML_Supervised, LMNN, LSML_Supervised,
4 | SDML_Supervised, NCA, LFDA, RCA_Supervised)
5 |
6 | from .euclidean import Euclidean
7 | from .kissme import KISSME
8 |
9 | __factory = {
10 | 'euclidean': Euclidean,
11 | 'kissme': KISSME,
12 | 'itml': ITML_Supervised,
13 | 'lmnn': LMNN,
14 | 'lsml': LSML_Supervised,
15 | 'sdml': SDML_Supervised,
16 | 'nca': NCA,
17 | 'lfda': LFDA,
18 | 'rca': RCA_Supervised,
19 | }
20 |
21 |
22 | def get_metric(algorithm, *args, **kwargs):
23 | if algorithm not in __factory:
24 | raise KeyError("Unknown metric:", algorithm)
25 | return __factory[algorithm](*args, **kwargs)
26 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup, find_packages
2 |
3 |
4 | setup(name='open-reid',
5 | version='0.2.0',
6 | description='Deep Learning Library for Person Re-identification',
7 | author='Tong Xiao',
8 | author_email='st.cysu@gmail.com',
9 | url='https://github.com/Cysu/open-reid',
10 | license='MIT',
11 | install_requires=[
12 | 'numpy', 'scipy', 'torch', 'torchvision',
13 | 'six', 'h5py', 'Pillow',
14 | 'scikit-learn', 'metric-learn'],
15 | extras_require={
16 | 'docs': ['sphinx', 'sphinx_rtd_theme'],
17 | },
18 | packages=find_packages(),
19 | keywords=[
20 | 'Person Re-identification',
21 | 'Computer Vision',
22 | 'Deep Learning',
23 | ])
24 |
--------------------------------------------------------------------------------
/test/models/test_inception.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestInception(TestCase):
5 | def test_forward(self):
6 | import torch
7 | from torch.autograd import Variable
8 | from reid.models.inception import InceptionNet
9 |
10 | # model = Inception(num_classes=5, num_features=256, dropout=0.5)
11 | # x = Variable(torch.randn(10, 3, 144, 56), requires_grad=False)
12 | # y = model(x)
13 | # self.assertEquals(y.size(), (10, 5))
14 |
15 | model = InceptionNet(num_features=8, norm=True, dropout=0)
16 | x = Variable(torch.randn(10, 3, 144, 56), requires_grad=False)
17 | y = model(x)
18 | self.assertEquals(y.size(), (10, 8))
19 | self.assertEquals(y.norm(2, 1).max(), 1)
20 | self.assertEquals(y.norm(2, 1).min(), 1)
21 |
--------------------------------------------------------------------------------
/test/loss/test_oim.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestOIMLoss(TestCase):
5 | def test_forward_backward(self):
6 | import torch
7 | import torch.nn.functional as F
8 | from torch.autograd import Variable
9 | from reid.loss import OIMLoss
10 | criterion = OIMLoss(3, 3, scalar=1.0, size_average=False)
11 | criterion.lut = torch.eye(3)
12 | x = Variable(torch.randn(3, 3), requires_grad=True)
13 | y = Variable(torch.range(0, 2).long())
14 | loss = criterion(x, y)
15 | loss.backward()
16 | probs = F.softmax(x)
17 | grads = probs.data - torch.eye(3)
18 | abs_diff = torch.abs(grads - x.grad.data)
19 | self.assertEquals(torch.log(probs).diag().sum(), -loss)
20 | self.assertTrue(torch.max(abs_diff) < 1e-6)
21 |
--------------------------------------------------------------------------------
/reid/feature_extraction/cnn.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import OrderedDict
3 |
4 | from torch.autograd import Variable
5 |
6 | from ..utils import to_torch
7 |
8 |
9 | def extract_cnn_feature(model, inputs, modules=None):
10 | model.eval()
11 | inputs = to_torch(inputs)
12 | inputs = Variable(inputs, volatile=True)
13 | if modules is None:
14 | outputs = model(inputs)
15 | outputs = outputs.data.cpu()
16 | return outputs
17 | # Register forward hook for each module
18 | outputs = OrderedDict()
19 | handles = []
20 | for m in modules:
21 | outputs[id(m)] = None
22 | def func(m, i, o): outputs[id(m)] = o.data.cpu()
23 | handles.append(m.register_forward_hook(func))
24 | model(inputs)
25 | for h in handles:
26 | h.remove()
27 | return list(outputs.values())
28 |
--------------------------------------------------------------------------------
/test/feature_extraction/test_database.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 | import numpy as np
4 |
5 | from reid.feature_extraction.database import FeatureDatabase
6 |
7 |
8 | class TestFeatureDatabase(TestCase):
9 | def test_all(self):
10 | with FeatureDatabase('/tmp/open-reid/test.h5', 'w') as db:
11 | db['img1'] = np.random.rand(3, 8, 8).astype(np.float32)
12 | db['img2'] = np.arange(10)
13 | db['img2'] = np.arange(10).reshape(2, 5).astype(np.float32)
14 | with FeatureDatabase('/tmp/open-reid/test.h5', 'r') as db:
15 | self.assertTrue('img1' in db)
16 | self.assertTrue('img2' in db)
17 | self.assertEquals(db['img1'].shape, (3, 8, 8))
18 | x = db['img2']
19 | self.assertEquals(x.shape, (2, 5))
20 | self.assertTrue(np.all(x == np.arange(10).reshape(2, 5)))
21 |
22 |
--------------------------------------------------------------------------------
/test/datasets/test_cuhk01.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestCUHK01(TestCase):
5 | def test_init(self):
6 | import os.path as osp
7 | from reid.datasets import CUHK01
8 | from reid.utils.serialization import read_json
9 |
10 | root, split_id, num_val = '/tmp/open-reid/cuhk01', 0, 100
11 | dataset = CUHK01(root, split_id=split_id, num_val=num_val, download=True)
12 |
13 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json')))
14 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json')))
15 | meta = read_json(osp.join(root, 'meta.json'))
16 | self.assertEquals(len(meta['identities']), 971)
17 | splits = read_json(osp.join(root, 'splits.json'))
18 | self.assertEquals(len(splits), 10)
19 |
20 | self.assertDictEqual(meta, dataset.meta)
21 | self.assertDictEqual(splits[split_id], dataset.split)
22 |
--------------------------------------------------------------------------------
/test/datasets/test_viper.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestVIPeR(TestCase):
5 | def test_init(self):
6 | import os.path as osp
7 | from reid.datasets.viper import VIPeR
8 | from reid.utils.serialization import read_json
9 |
10 | root, split_id, num_val = '/tmp/open-reid/viper', 0, 100
11 | dataset = VIPeR(root, split_id=split_id, num_val=num_val, download=True)
12 |
13 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json')))
14 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json')))
15 | meta = read_json(osp.join(root, 'meta.json'))
16 | self.assertEquals(len(meta['identities']), 632)
17 | splits = read_json(osp.join(root, 'splits.json'))
18 | self.assertEquals(len(splits), 10)
19 |
20 | self.assertDictEqual(meta, dataset.meta)
21 | self.assertDictEqual(splits[split_id], dataset.split)
22 |
--------------------------------------------------------------------------------
/test/datasets/test_cuhk03.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestCUHK03(TestCase):
5 | def test_init(self):
6 | import os.path as osp
7 | from reid.datasets.cuhk03 import CUHK03
8 | from reid.utils.serialization import read_json
9 |
10 | root, split_id, num_val = '/tmp/open-reid/cuhk03', 0, 100
11 | dataset = CUHK03(root, split_id=split_id, num_val=num_val, download=True)
12 |
13 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json')))
14 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json')))
15 | meta = read_json(osp.join(root, 'meta.json'))
16 | self.assertEquals(len(meta['identities']), 1467)
17 | splits = read_json(osp.join(root, 'splits.json'))
18 | self.assertEquals(len(splits), 20)
19 |
20 | self.assertDictEqual(meta, dataset.meta)
21 | self.assertDictEqual(splits[split_id], dataset.split)
22 |
--------------------------------------------------------------------------------
/test/datasets/test_dukemtmc.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestDukeMTMC(TestCase):
5 | def test_all(self):
6 | import os.path as osp
7 | from reid.datasets import DukeMTMC
8 | from reid.utils.serialization import read_json
9 |
10 | root, split_id, num_val = '/tmp/open-reid/dukemtmc', 0, 100
11 | dataset = DukeMTMC(root, split_id=split_id, num_val=num_val,
12 | download=True)
13 |
14 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json')))
15 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json')))
16 | meta = read_json(osp.join(root, 'meta.json'))
17 | self.assertEquals(len(meta['identities']), 1812)
18 | splits = read_json(osp.join(root, 'splits.json'))
19 | self.assertEquals(len(splits), 1)
20 |
21 | self.assertDictEqual(meta, dataset.meta)
22 | self.assertDictEqual(splits[split_id], dataset.split)
--------------------------------------------------------------------------------
/docs/make.bat:
--------------------------------------------------------------------------------
1 | @ECHO OFF
2 |
3 | pushd %~dp0
4 |
5 | REM Command file for Sphinx documentation
6 |
7 | if "%SPHINXBUILD%" == "" (
8 | set SPHINXBUILD=sphinx-build
9 | )
10 | set SOURCEDIR=.
11 | set BUILDDIR=_build
12 | set SPHINXPROJ=Open-ReID
13 |
14 | if "%1" == "" goto help
15 |
16 | %SPHINXBUILD% >NUL 2>NUL
17 | if errorlevel 9009 (
18 | echo.
19 | echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
20 | echo.installed, then set the SPHINXBUILD environment variable to point
21 | echo.to the full path of the 'sphinx-build' executable. Alternatively you
22 | echo.may add the Sphinx directory to PATH.
23 | echo.
24 | echo.If you don't have Sphinx installed, grab it from
25 | echo.http://sphinx-doc.org/
26 | exit /b 1
27 | )
28 |
29 | %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
30 | goto end
31 |
32 | :help
33 | %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
34 |
35 | :end
36 | popd
37 |
--------------------------------------------------------------------------------
/test/datasets/test_market1501.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestMarket1501(TestCase):
5 | def test_init(self):
6 | import os.path as osp
7 | from reid.datasets.market1501 import Market1501
8 | from reid.utils.serialization import read_json
9 |
10 | root, split_id, num_val = '/tmp/open-reid/market1501', 0, 100
11 | dataset = Market1501(root, split_id=split_id, num_val=num_val,
12 | download=True)
13 |
14 | self.assertTrue(osp.isfile(osp.join(root, 'meta.json')))
15 | self.assertTrue(osp.isfile(osp.join(root, 'splits.json')))
16 | meta = read_json(osp.join(root, 'meta.json'))
17 | self.assertEquals(len(meta['identities']), 1502)
18 | splits = read_json(osp.join(root, 'splits.json'))
19 | self.assertEquals(len(splits), 1)
20 |
21 | self.assertDictEqual(meta, dataset.meta)
22 | self.assertDictEqual(splits[split_id], dataset.split)
23 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Open-ReID
2 |
3 | Open-ReID is a lightweight library of person re-identification for research
4 | purpose. It aims to provide a uniform interface for different datasets, a full
5 | set of models and evaluation metrics, as well as examples to reproduce (near)
6 | state-of-the-art results.
7 |
8 | ## Installation
9 |
10 | Install [PyTorch](http://pytorch.org/) (version >= 0.2.0). Although we support
11 | both python2 and python3, we recommend python3 for better performance.
12 |
13 | ```shell
14 | git clone https://github.com/Cysu/open-reid.git
15 | cd open-reid
16 | python setup.py install
17 | ```
18 |
19 | ## Examples
20 |
21 | ```shell
22 | python examples/softmax_loss.py -d viper -b 64 -j 2 -a resnet50 --logs-dir logs/softmax-loss/viper-resnet50
23 | ```
24 |
25 | This is just a quick example. VIPeR dataset may not be large enough to train a deep neural network.
26 |
27 | Check about more [examples](https://cysu.github.io/open-reid/examples/training_id.html)
28 | and [benchmarks](https://cysu.github.io/open-reid/examples/benchmarks.html).
29 |
--------------------------------------------------------------------------------
/reid/utils/logging.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os
3 | import sys
4 |
5 | from .osutils import mkdir_if_missing
6 |
7 |
8 | class Logger(object):
9 | def __init__(self, fpath=None):
10 | self.console = sys.stdout
11 | self.file = None
12 | if fpath is not None:
13 | mkdir_if_missing(os.path.dirname(fpath))
14 | self.file = open(fpath, 'w')
15 |
16 | def __del__(self):
17 | self.close()
18 |
19 | def __enter__(self):
20 | pass
21 |
22 | def __exit__(self, *args):
23 | self.close()
24 |
25 | def write(self, msg):
26 | self.console.write(msg)
27 | if self.file is not None:
28 | self.file.write(msg)
29 |
30 | def flush(self):
31 | self.console.flush()
32 | if self.file is not None:
33 | self.file.flush()
34 | os.fsync(self.file.fileno())
35 |
36 | def close(self):
37 | self.console.close()
38 | if self.file is not None:
39 | self.file.close()
40 |
--------------------------------------------------------------------------------
/reid/dist_metric.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 |
5 | from .evaluators import extract_features
6 | from .metric_learning import get_metric
7 |
8 |
9 | class DistanceMetric(object):
10 | def __init__(self, algorithm='euclidean', *args, **kwargs):
11 | super(DistanceMetric, self).__init__()
12 | self.algorithm = algorithm
13 | self.metric = get_metric(algorithm, *args, **kwargs)
14 |
15 | def train(self, model, data_loader):
16 | if self.algorithm == 'euclidean': return
17 | features, labels = extract_features(model, data_loader)
18 | features = torch.stack(features.values()).numpy()
19 | labels = torch.Tensor(list(labels.values())).numpy()
20 | self.metric.fit(features, labels)
21 |
22 | def transform(self, X):
23 | if torch.is_tensor(X):
24 | X = X.numpy()
25 | X = self.metric.transform(X)
26 | X = torch.from_numpy(X)
27 | else:
28 | X = self.metric.transform(X)
29 | return X
30 |
31 |
--------------------------------------------------------------------------------
/reid/utils/data/preprocessor.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import os.path as osp
3 |
4 | from PIL import Image
5 |
6 |
7 | class Preprocessor(object):
8 | def __init__(self, dataset, root=None, transform=None):
9 | super(Preprocessor, self).__init__()
10 | self.dataset = dataset
11 | self.root = root
12 | self.transform = transform
13 |
14 | def __len__(self):
15 | return len(self.dataset)
16 |
17 | def __getitem__(self, indices):
18 | if isinstance(indices, (tuple, list)):
19 | return [self._get_single_item(index) for index in indices]
20 | return self._get_single_item(indices)
21 |
22 | def _get_single_item(self, index):
23 | fname, pid, camid = self.dataset[index]
24 | fpath = fname
25 | if self.root is not None:
26 | fpath = osp.join(self.root, fname)
27 | img = Image.open(fpath).convert('RGB')
28 | if self.transform is not None:
29 | img = self.transform(img)
30 | return img, fname, pid, camid
31 |
--------------------------------------------------------------------------------
/docs/index.rst:
--------------------------------------------------------------------------------
1 | .. Open-ReID documentation master file, created by
2 | sphinx-quickstart on Sun Mar 19 15:49:51 2017.
3 | You can adapt this file completely to your liking, but it should at least
4 | contain the root `toctree` directive.
5 |
6 | Open-ReID documentation
7 | =======================
8 |
9 | Open-ReID is a deep learning library for person re-identification based on PyTorch.
10 |
11 | .. toctree::
12 | :glob:
13 | :maxdepth: 1
14 | :caption: Notes
15 |
16 | notes/overview
17 | notes/data_modules
18 | notes/evaluation_metrics
19 |
20 | .. toctree::
21 | :glob:
22 | :maxdepth: 1
23 | :caption: Examples
24 |
25 | examples/training_id
26 | examples/benchmarks
27 |
28 | .. toctree::
29 | :maxdepth: 1
30 | :caption: SDK Level Reference
31 |
32 | trainers
33 | evaluators
34 | dist_metric
35 |
36 | .. toctree::
37 | :maxdepth: 1
38 | :caption: API Level Reference
39 |
40 | datasets
41 | models
42 |
43 | Indices and tables
44 | ==================
45 |
46 | * :ref:`genindex`
47 | * :ref:`modindex`
48 |
--------------------------------------------------------------------------------
/test/utils/data/test_preprocessor.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 |
3 |
4 | class TestPreprocessor(TestCase):
5 | def test_getitem(self):
6 | import torchvision.transforms as t
7 | from reid.datasets.viper import VIPeR
8 | from reid.utils.data.preprocessor import Preprocessor
9 |
10 | root, split_id, num_val = '/tmp/open-reid/viper', 0, 100
11 | dataset = VIPeR(root, split_id=split_id, num_val=num_val, download=True)
12 |
13 | preproc = Preprocessor(dataset.train, root=dataset.images_dir,
14 | transform=t.Compose([
15 | t.Scale(256),
16 | t.CenterCrop(224),
17 | t.ToTensor(),
18 | t.Normalize(mean=[0.485, 0.456, 0.406],
19 | std=[0.229, 0.224, 0.225])
20 | ]))
21 | self.assertEquals(len(preproc), len(dataset.train))
22 | img, pid, camid = preproc[0]
23 | self.assertEquals(img.size(), (3, 224, 224))
24 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2017 Tong Xiao
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/reid/utils/data/sampler.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | import torch
6 | from torch.utils.data.sampler import (
7 | Sampler, SequentialSampler, RandomSampler, SubsetRandomSampler,
8 | WeightedRandomSampler)
9 |
10 |
11 | class RandomIdentitySampler(Sampler):
12 | def __init__(self, data_source, num_instances=1):
13 | self.data_source = data_source
14 | self.num_instances = num_instances
15 | self.index_dic = defaultdict(list)
16 | for index, (_, pid, _) in enumerate(data_source):
17 | self.index_dic[pid].append(index)
18 | self.pids = list(self.index_dic.keys())
19 | self.num_samples = len(self.pids)
20 |
21 | def __len__(self):
22 | return self.num_samples * self.num_instances
23 |
24 | def __iter__(self):
25 | indices = torch.randperm(self.num_samples)
26 | ret = []
27 | for i in indices:
28 | pid = self.pids[i]
29 | t = self.index_dic[pid]
30 | if len(t) >= self.num_instances:
31 | t = np.random.choice(t, size=self.num_instances, replace=False)
32 | else:
33 | t = np.random.choice(t, size=self.num_instances, replace=True)
34 | ret.extend(t)
35 | return iter(ret)
36 |
--------------------------------------------------------------------------------
/test/evaluation_metrics/test_cmc.py:
--------------------------------------------------------------------------------
1 | from unittest import TestCase
2 | import numpy as np
3 |
4 | from reid.evaluation_metrics import cmc
5 |
6 |
7 | class TestCMC(TestCase):
8 | def test_only_distmat(self):
9 | distmat = np.array([[0, 1, 2, 3, 4],
10 | [1, 0, 2, 3, 4],
11 | [0, 1, 2, 3, 4],
12 | [0, 1, 2, 3, 4],
13 | [1, 2, 3, 4, 0]])
14 | ret = cmc(distmat)
15 | self.assertTrue(np.all(ret[:5] == [0.6, 0.6, 0.8, 1.0, 1.0]))
16 |
17 | def test_duplicate_ids(self):
18 | distmat = np.tile(np.arange(4), (4, 1))
19 | query_ids = [0, 0, 1, 1]
20 | gallery_ids = [0, 0, 1, 1]
21 | ret = cmc(distmat, query_ids=query_ids, gallery_ids=gallery_ids, topk=4,
22 | separate_camera_set=False, single_gallery_shot=False)
23 | self.assertTrue(np.all(ret == [0.5, 0.5, 1, 1]))
24 |
25 | def test_duplicate_cams(self):
26 | distmat = np.tile(np.arange(5), (5, 1))
27 | query_ids = [0,0,0,1,1]
28 | gallery_ids = [0,0,0,1,1]
29 | query_cams = [0,0,0,0,0]
30 | gallery_cams = [0,1,1,1,1]
31 | ret = cmc(distmat, query_ids=query_ids, gallery_ids=gallery_ids,
32 | query_cams=query_cams, gallery_cams=gallery_cams, topk=5,
33 | separate_camera_set=False, single_gallery_shot=False)
34 | self.assertTrue(np.all(ret == [0.6, 0.6, 0.6, 1, 1]))
35 |
--------------------------------------------------------------------------------
/reid/loss/triplet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.autograd import Variable
6 |
7 |
8 | class TripletLoss(nn.Module):
9 | def __init__(self, margin=0):
10 | super(TripletLoss, self).__init__()
11 | self.margin = margin
12 | self.ranking_loss = nn.MarginRankingLoss(margin=margin)
13 |
14 | def forward(self, inputs, targets):
15 | n = inputs.size(0)
16 | # Compute pairwise distance, replace by the official when merged
17 | dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
18 | dist = dist + dist.t()
19 | dist.addmm_(1, -2, inputs, inputs.t())
20 | dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
21 | # For each anchor, find the hardest positive and negative
22 | mask = targets.expand(n, n).eq(targets.expand(n, n).t())
23 | dist_ap, dist_an = [], []
24 | for i in range(n):
25 | dist_ap.append(dist[i][mask[i]].max())
26 | dist_an.append(dist[i][mask[i] == 0].min())
27 | dist_ap = torch.cat(dist_ap)
28 | dist_an = torch.cat(dist_an)
29 | # Compute ranking hinge loss
30 | y = dist_an.data.new()
31 | y.resize_as_(dist_an.data)
32 | y.fill_(1)
33 | y = Variable(y)
34 | loss = self.ranking_loss(dist_an, dist_ap, y)
35 | prec = (dist_an.data > dist_ap.data).sum() * 1. / y.size(0)
36 | return loss, prec
37 |
--------------------------------------------------------------------------------
/reid/feature_extraction/database.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import h5py
4 | import numpy as np
5 | from torch.utils.data import Dataset
6 |
7 |
8 | class FeatureDatabase(Dataset):
9 | def __init__(self, *args, **kwargs):
10 | super(FeatureDatabase, self).__init__()
11 | self.fid = h5py.File(*args, **kwargs)
12 |
13 | def __enter__(self):
14 | return self
15 |
16 | def __exit__(self, exc_type, exc_val, exc_tb):
17 | self.close()
18 |
19 | def __getitem__(self, keys):
20 | if isinstance(keys, (tuple, list)):
21 | return [self._get_single_item(k) for k in keys]
22 | return self._get_single_item(keys)
23 |
24 | def _get_single_item(self, key):
25 | return np.asarray(self.fid[key])
26 |
27 | def __setitem__(self, key, value):
28 | if key in self.fid:
29 | if self.fid[key].shape == value.shape and \
30 | self.fid[key].dtype == value.dtype:
31 | self.fid[key][...] = value
32 | else:
33 | del self.fid[key]
34 | self.fid.create_dataset(key, data=value)
35 | else:
36 | self.fid.create_dataset(key, data=value)
37 |
38 | def __delitem__(self, key):
39 | del self.fid[key]
40 |
41 | def __len__(self):
42 | return len(self.fid)
43 |
44 | def __iter__(self):
45 | return iter(self.fid)
46 |
47 | def flush(self):
48 | self.fid.flush()
49 |
50 | def close(self):
51 | self.fid.close()
52 |
--------------------------------------------------------------------------------
/reid/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | import warnings
3 |
4 | from .cuhk01 import CUHK01
5 | from .cuhk03 import CUHK03
6 | from .dukemtmc import DukeMTMC
7 | from .market1501 import Market1501
8 | from .viper import VIPeR
9 |
10 |
11 | __factory = {
12 | 'viper': VIPeR,
13 | 'cuhk01': CUHK01,
14 | 'cuhk03': CUHK03,
15 | 'market1501': Market1501,
16 | 'dukemtmc': DukeMTMC,
17 | }
18 |
19 |
20 | def names():
21 | return sorted(__factory.keys())
22 |
23 |
24 | def create(name, root, *args, **kwargs):
25 | """
26 | Create a dataset instance.
27 |
28 | Parameters
29 | ----------
30 | name : str
31 | The dataset name. Can be one of 'viper', 'cuhk01', 'cuhk03',
32 | 'market1501', and 'dukemtmc'.
33 | root : str
34 | The path to the dataset directory.
35 | split_id : int, optional
36 | The index of data split. Default: 0
37 | num_val : int or float, optional
38 | When int, it means the number of validation identities. When float,
39 | it means the proportion of validation to all the trainval. Default: 100
40 | download : bool, optional
41 | If True, will download the dataset. Default: False
42 | """
43 | if name not in __factory:
44 | raise KeyError("Unknown dataset:", name)
45 | return __factory[name](root, *args, **kwargs)
46 |
47 |
48 | def get_dataset(name, root, *args, **kwargs):
49 | warnings.warn("get_dataset is deprecated. Use create instead.")
50 | return create(name, root, *args, **kwargs)
51 |
--------------------------------------------------------------------------------
/docs/notes/overview.rst:
--------------------------------------------------------------------------------
1 | =====================
2 | Overview of Open-ReID
3 | =====================
4 |
5 | Open Re-ID is a lightweight library of person re-identification for research
6 | purpose. It aims to provide a uniform interface for different datasets, a full
7 | set of models and evaluation metrics, as well as examples to reproduce (near)
8 | state-of-the-art results. Open-ReID is mainly based on `PyTorch
9 | `_.
10 |
11 | ---------
12 | Structure
13 | ---------
14 |
15 | Open-ReID is structured into three levels, as shown in the figure below.
16 |
17 | .. _fig-structure:
18 | .. figure:: /figures/structure.png
19 |
20 | API Level
21 | At bottom, there are decoupled modules each providing unit functions. For
22 | example, the ``datasets`` module has a uniform interface for many popular
23 | datasets, while commonly used evaluation metrics, such as CMC and mean AP are
24 | implemented in the ``evaluation_metrics`` module, which accept both
25 | ``torch.Tensor`` and ``numpy.ndarray`` as inputs.
26 |
27 | SDK Level
28 | In the middle, several classes interact with underlying APIs to provide
29 | routines for standard tasks. For example, the ``Trainer`` can be used to
30 | train a deep model on training set, and ``Evaluator`` can evaluate the model
31 | on validation and test sets.
32 |
33 | Application Level
34 | At top, we provide several examples using Open-ReID. For example, one can
35 | easily train a CNN with different kinds of loss functions on different
36 | datasets, to achieve certain baselines or state-of-the-art results.
37 |
--------------------------------------------------------------------------------
/reid/utils/data/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torchvision.transforms import *
4 | from PIL import Image
5 | import random
6 | import math
7 |
8 |
9 | class RectScale(object):
10 | def __init__(self, height, width, interpolation=Image.BILINEAR):
11 | self.height = height
12 | self.width = width
13 | self.interpolation = interpolation
14 |
15 | def __call__(self, img):
16 | w, h = img.size
17 | if h == self.height and w == self.width:
18 | return img
19 | return img.resize((self.width, self.height), self.interpolation)
20 |
21 |
22 | class RandomSizedRectCrop(object):
23 | def __init__(self, height, width, interpolation=Image.BILINEAR):
24 | self.height = height
25 | self.width = width
26 | self.interpolation = interpolation
27 |
28 | def __call__(self, img):
29 | for attempt in range(10):
30 | area = img.size[0] * img.size[1]
31 | target_area = random.uniform(0.64, 1.0) * area
32 | aspect_ratio = random.uniform(2, 3)
33 |
34 | h = int(round(math.sqrt(target_area * aspect_ratio)))
35 | w = int(round(math.sqrt(target_area / aspect_ratio)))
36 |
37 | if w <= img.size[0] and h <= img.size[1]:
38 | x1 = random.randint(0, img.size[0] - w)
39 | y1 = random.randint(0, img.size[1] - h)
40 |
41 | img = img.crop((x1, y1, x1 + w, y1 + h))
42 | assert(img.size == (w, h))
43 |
44 | return img.resize((self.width, self.height), self.interpolation)
45 |
46 | # Fallback
47 | scale = RectScale(self.height, self.width,
48 | interpolation=self.interpolation)
49 | return scale(img)
50 |
--------------------------------------------------------------------------------
/reid/metric_learning/kissme.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import numpy as np
4 | from metric_learn.base_metric import BaseMetricLearner
5 |
6 |
7 | def validate_cov_matrix(M):
8 | M = (M + M.T) * 0.5
9 | k = 0
10 | I = np.eye(M.shape[0])
11 | while True:
12 | try:
13 | _ = np.linalg.cholesky(M)
14 | break
15 | except np.linalg.LinAlgError:
16 | # Find the nearest positive definite matrix for M. Modified from
17 | # http://www.mathworks.com/matlabcentral/fileexchange/42885-nearestspd
18 | # Might take several minutes
19 | k += 1
20 | w, v = np.linalg.eig(M)
21 | min_eig = v.min()
22 | M += (-min_eig * k * k + np.spacing(min_eig)) * I
23 | return M
24 |
25 |
26 | class KISSME(BaseMetricLearner):
27 | def __init__(self):
28 | self.M_ = None
29 |
30 | def metric(self):
31 | return self.M_
32 |
33 | def fit(self, X, y=None):
34 | n = X.shape[0]
35 | if y is None:
36 | y = np.arange(n)
37 | X1, X2 = np.meshgrid(np.arange(n), np.arange(n))
38 | X1, X2 = X1[X1 < X2], X2[X1 < X2]
39 | matches = (y[X1] == y[X2])
40 | num_matches = matches.sum()
41 | num_non_matches = len(matches) - num_matches
42 | idxa = X1[matches]
43 | idxb = X2[matches]
44 | S = X[idxa] - X[idxb]
45 | C1 = S.transpose().dot(S) / num_matches
46 | p = np.random.choice(num_non_matches, num_matches, replace=False)
47 | idxa = X1[~matches]
48 | idxb = X2[~matches]
49 | idxa = idxa[p]
50 | idxb = idxb[p]
51 | S = X[idxa] - X[idxb]
52 | C0 = S.transpose().dot(S) / num_matches
53 | self.M_ = np.linalg.inv(C1) - np.linalg.inv(C0)
54 | self.M_ = validate_cov_matrix(self.M_)
55 | self.X_ = X
56 |
--------------------------------------------------------------------------------
/reid/loss/oim.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch import nn, autograd
6 |
7 |
8 | class OIM(autograd.Function):
9 | def __init__(self, lut, momentum=0.5):
10 | super(OIM, self).__init__()
11 | self.lut = lut
12 | self.momentum = momentum
13 |
14 | def forward(self, inputs, targets):
15 | self.save_for_backward(inputs, targets)
16 | outputs = inputs.mm(self.lut.t())
17 | return outputs
18 |
19 | def backward(self, grad_outputs):
20 | inputs, targets = self.saved_tensors
21 | grad_inputs = None
22 | if self.needs_input_grad[0]:
23 | grad_inputs = grad_outputs.mm(self.lut)
24 | for x, y in zip(inputs, targets):
25 | self.lut[y] = self.momentum * self.lut[y] + (1. - self.momentum) * x
26 | self.lut[y] /= self.lut[y].norm()
27 | return grad_inputs, None
28 |
29 |
30 | def oim(inputs, targets, lut, momentum=0.5):
31 | return OIM(lut, momentum=momentum)(inputs, targets)
32 |
33 |
34 | class OIMLoss(nn.Module):
35 | def __init__(self, num_features, num_classes, scalar=1.0, momentum=0.5,
36 | weight=None, size_average=True):
37 | super(OIMLoss, self).__init__()
38 | self.num_features = num_features
39 | self.num_classes = num_classes
40 | self.momentum = momentum
41 | self.scalar = scalar
42 | self.weight = weight
43 | self.size_average = size_average
44 |
45 | self.register_buffer('lut', torch.zeros(num_classes, num_features))
46 |
47 | def forward(self, inputs, targets):
48 | inputs = oim(inputs, targets, self.lut, momentum=self.momentum)
49 | inputs *= self.scalar
50 | loss = F.cross_entropy(inputs, targets, weight=self.weight,
51 | size_average=self.size_average)
52 | return loss, inputs
53 |
--------------------------------------------------------------------------------
/reid/utils/serialization.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import json
3 | import os.path as osp
4 | import shutil
5 |
6 | import torch
7 | from torch.nn import Parameter
8 |
9 | from .osutils import mkdir_if_missing
10 |
11 |
12 | def read_json(fpath):
13 | with open(fpath, 'r') as f:
14 | obj = json.load(f)
15 | return obj
16 |
17 |
18 | def write_json(obj, fpath):
19 | mkdir_if_missing(osp.dirname(fpath))
20 | with open(fpath, 'w') as f:
21 | json.dump(obj, f, indent=4, separators=(',', ': '))
22 |
23 |
24 | def save_checkpoint(state, is_best, fpath='checkpoint.pth.tar'):
25 | mkdir_if_missing(osp.dirname(fpath))
26 | torch.save(state, fpath)
27 | if is_best:
28 | shutil.copy(fpath, osp.join(osp.dirname(fpath), 'model_best.pth.tar'))
29 |
30 |
31 | def load_checkpoint(fpath):
32 | if osp.isfile(fpath):
33 | checkpoint = torch.load(fpath)
34 | print("=> Loaded checkpoint '{}'".format(fpath))
35 | return checkpoint
36 | else:
37 | raise ValueError("=> No checkpoint found at '{}'".format(fpath))
38 |
39 |
40 | def copy_state_dict(state_dict, model, strip=None):
41 | tgt_state = model.state_dict()
42 | copied_names = set()
43 | for name, param in state_dict.items():
44 | if strip is not None and name.startswith(strip):
45 | name = name[len(strip):]
46 | if name not in tgt_state:
47 | continue
48 | if isinstance(param, Parameter):
49 | param = param.data
50 | if param.size() != tgt_state[name].size():
51 | print('mismatch:', name, param.size(), tgt_state[name].size())
52 | continue
53 | tgt_state[name].copy_(param)
54 | copied_names.add(name)
55 |
56 | missing = set(tgt_state.keys()) - copied_names
57 | if len(missing) > 0:
58 | print("missing keys in state_dict:", missing)
59 |
60 | return model
61 |
--------------------------------------------------------------------------------
/reid/models/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from .inception import *
4 | from .resnet import *
5 |
6 |
7 | __factory = {
8 | 'inception': inception,
9 | 'resnet18': resnet18,
10 | 'resnet34': resnet34,
11 | 'resnet50': resnet50,
12 | 'resnet101': resnet101,
13 | 'resnet152': resnet152,
14 | }
15 |
16 |
17 | def names():
18 | return sorted(__factory.keys())
19 |
20 |
21 | def create(name, *args, **kwargs):
22 | """
23 | Create a model instance.
24 |
25 | Parameters
26 | ----------
27 | name : str
28 | Model name. Can be one of 'inception', 'resnet18', 'resnet34',
29 | 'resnet50', 'resnet101', and 'resnet152'.
30 | pretrained : bool, optional
31 | Only applied for 'resnet*' models. If True, will use ImageNet pretrained
32 | model. Default: True
33 | cut_at_pooling : bool, optional
34 | If True, will cut the model before the last global pooling layer and
35 | ignore the remaining kwargs. Default: False
36 | num_features : int, optional
37 | If positive, will append a Linear layer after the global pooling layer,
38 | with this number of output units, followed by a BatchNorm layer.
39 | Otherwise these layers will not be appended. Default: 256 for
40 | 'inception', 0 for 'resnet*'
41 | norm : bool, optional
42 | If True, will normalize the feature to be unit L2-norm for each sample.
43 | Otherwise will append a ReLU layer after the above Linear layer if
44 | num_features > 0. Default: False
45 | dropout : float, optional
46 | If positive, will append a Dropout layer with this dropout rate.
47 | Default: 0
48 | num_classes : int, optional
49 | If positive, will append a Linear layer at the end as the classifier
50 | with this number of output units. Default: 0
51 | """
52 | if name not in __factory:
53 | raise KeyError("Unknown model:", name)
54 | return __factory[name](*args, **kwargs)
55 |
--------------------------------------------------------------------------------
/docs/notes/evaluation_metrics.rst:
--------------------------------------------------------------------------------
1 | .. _evaluation-metrics:
2 |
3 | ==================
4 | Evaluation Metrics
5 | ==================
6 |
7 | Cumulative Matching Characteristics (CMC) curves are the most popular evaluation
8 | metrics for person re-identification methods. Consider a simple
9 | *single-gallery-shot* setting, where each gallery identity has only one
10 | instance. For each query, an algorithm will rank all the gallery samples
11 | according to their distances to the query from small to large, and the CMC top-k
12 | accuracy is
13 |
14 | .. math::
15 | Acc_k = \begin{cases}
16 | 1 & \text{if top-$k$ ranked gallery samples contain the query identity} \\
17 | 0 & \text{otherwise}
18 | \end{cases},
19 |
20 | which is a shifted `step function `_.
21 | The final CMC curve is computed by averaging the shifted step functions over all
22 | the queries.
23 |
24 | While the *single-gallery-shot* CMC is well defined, it does not have a common
25 | agreement when it comes to the *multi-gallery-shot* setting, where each gallery
26 | identity could have multiple instances. For example,
27 | `CUHK03 `_ and
28 | `Market-1501 `_
29 | calculated the CMC curves and CMC top-k accuracy quite differently. To be
30 | specific,
31 |
32 | - `CUHK03 `_:
33 | Query and gallery sets are from different camera views. For each query, they
34 | randomly sample one instance for each gallery identity, and compute a CMC
35 | curve in the *single-gallery-shot* setting. The random sampling is repeated
36 | for :math:`N` times and the expected CMC curve is reported.
37 |
38 | - `Market-1501 `_:
39 | Query and gallery sets could have same camera views, but for each individual
40 | query identity, his/her gallery samples from the same camera are excluded.
41 | They do not randomly sample only one instance for each gallery identity. This
42 | means the query will always match the "easiest" positive sample in the gallery
43 | while does not care other harder positive samples when computing CMC.
44 |
--------------------------------------------------------------------------------
/docs/_static/css/openreid_theme.css:
--------------------------------------------------------------------------------
1 | body {
2 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
3 | }
4 |
5 | /* Default header fonts are ugly */
6 | h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend, p.caption {
7 | font-family: "Lato","proxima-nova","Helvetica Neue",Arial,sans-serif;
8 | }
9 |
10 | /* Use white for docs background */
11 | .wy-side-nav-search {
12 | background-color: #fff;
13 | }
14 |
15 | .wy-side-nav-search > a {
16 | color: #F05732;
17 | }
18 |
19 | .wy-nav-content-wrap, .wy-menu li.current > a {
20 | background-color: #fff;
21 | }
22 |
23 | @media screen and (min-width: 1400px) {
24 | .wy-nav-content-wrap {
25 | background-color: rgba(0, 0, 0, 0.0470588);
26 | }
27 |
28 | .wy-nav-content {
29 | background-color: #fff;
30 | }
31 | }
32 |
33 |
34 | /* Fixes for mobile */
35 | .wy-nav-top {
36 | padding: 0;
37 | margin: 0.4045em 0.809em;
38 | color: #333;
39 | }
40 |
41 | .wy-nav-top > a {
42 | display: none;
43 | }
44 |
45 | @media screen and (max-width: 768px) {
46 | .wy-side-nav-search>a img.logo {
47 | height: 60px;
48 | }
49 | }
50 |
51 | /* This is needed to ensure that logo above search scales properly */
52 | .wy-side-nav-search a {
53 | display: block;
54 | }
55 |
56 | /* This ensures that multiple constructors will remain in separate lines. */
57 | .rst-content dl:not(.docutils) dt {
58 | display: table;
59 | }
60 |
61 | /* Use our red for literals (it's very similar to the original color) */
62 | .rst-content tt.literal, .rst-content tt.literal, .rst-content code.literal {
63 | color: #F05732;
64 | }
65 |
66 | .rst-content tt.xref, a .rst-content tt, .rst-content tt.xref,
67 | .rst-content code.xref, a .rst-content tt, a .rst-content code {
68 | color: #404040;
69 | }
70 |
71 | /* Change link colors (except for the menu) */
72 |
73 | a {
74 | color: #F05732;
75 | }
76 |
77 | a:hover {
78 | color: #F05732;
79 | }
80 |
81 |
82 | a:visited {
83 | color: #D44D2C;
84 | }
85 |
86 | .wy-menu a {
87 | color: #b3b3b3;
88 | }
89 |
90 | .wy-menu a:hover {
91 | color: #b3b3b3;
92 | }
93 |
94 | /* Default footer text is quite big */
95 | footer {
96 | font-size: 80%;
97 | }
98 |
99 | footer .rst-footer-buttons {
100 | font-size: 125%; /* revert footer settings - 1/80% = 125% */
101 | }
102 |
103 | footer p {
104 | font-size: 100%;
105 | }
106 |
107 | /* For hidden headers that appear in TOC tree */
108 | /* see http://stackoverflow.com/a/32363545/3343043 */
109 | .rst-content .hidden-section {
110 | display: none;
111 | }
112 |
113 | nav .hidden-section {
114 | display: inherit;
115 | }
116 |
117 | .rst-content .citation ol {
118 | margin-bottom: 0;
119 | }
--------------------------------------------------------------------------------
/reid/trainers.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 |
4 | import torch
5 | from torch.autograd import Variable
6 |
7 | from .evaluation_metrics import accuracy
8 | from .loss import OIMLoss, TripletLoss
9 | from .utils.meters import AverageMeter
10 |
11 |
12 | class BaseTrainer(object):
13 | def __init__(self, model, criterion):
14 | super(BaseTrainer, self).__init__()
15 | self.model = model
16 | self.criterion = criterion
17 |
18 | def train(self, epoch, data_loader, optimizer, print_freq=1):
19 | self.model.train()
20 |
21 | batch_time = AverageMeter()
22 | data_time = AverageMeter()
23 | losses = AverageMeter()
24 | precisions = AverageMeter()
25 |
26 | end = time.time()
27 | for i, inputs in enumerate(data_loader):
28 | data_time.update(time.time() - end)
29 |
30 | inputs, targets = self._parse_data(inputs)
31 | loss, prec1 = self._forward(inputs, targets)
32 |
33 | losses.update(loss.data[0], targets.size(0))
34 | precisions.update(prec1, targets.size(0))
35 |
36 | optimizer.zero_grad()
37 | loss.backward()
38 | optimizer.step()
39 |
40 | batch_time.update(time.time() - end)
41 | end = time.time()
42 |
43 | if (i + 1) % print_freq == 0:
44 | print('Epoch: [{}][{}/{}]\t'
45 | 'Time {:.3f} ({:.3f})\t'
46 | 'Data {:.3f} ({:.3f})\t'
47 | 'Loss {:.3f} ({:.3f})\t'
48 | 'Prec {:.2%} ({:.2%})\t'
49 | .format(epoch, i + 1, len(data_loader),
50 | batch_time.val, batch_time.avg,
51 | data_time.val, data_time.avg,
52 | losses.val, losses.avg,
53 | precisions.val, precisions.avg))
54 |
55 | def _parse_data(self, inputs):
56 | raise NotImplementedError
57 |
58 | def _forward(self, inputs, targets):
59 | raise NotImplementedError
60 |
61 |
62 | class Trainer(BaseTrainer):
63 | def _parse_data(self, inputs):
64 | imgs, _, pids, _ = inputs
65 | inputs = [Variable(imgs)]
66 | targets = Variable(pids.cuda())
67 | return inputs, targets
68 |
69 | def _forward(self, inputs, targets):
70 | outputs = self.model(*inputs)
71 | if isinstance(self.criterion, torch.nn.CrossEntropyLoss):
72 | loss = self.criterion(outputs, targets)
73 | prec, = accuracy(outputs.data, targets.data)
74 | prec = prec[0]
75 | elif isinstance(self.criterion, OIMLoss):
76 | loss, outputs = self.criterion(outputs, targets)
77 | prec, = accuracy(outputs.data, targets.data)
78 | prec = prec[0]
79 | elif isinstance(self.criterion, TripletLoss):
80 | loss, prec = self.criterion(outputs, targets)
81 | else:
82 | raise ValueError("Unsupported loss:", self.criterion)
83 | return loss, prec
84 |
--------------------------------------------------------------------------------
/docs/notes/data_modules.rst:
--------------------------------------------------------------------------------
1 | .. _data-modules:
2 |
3 | ============
4 | Data Modules
5 | ============
6 |
7 | This note will introduce the unified dataset interface defined by Open-ReID, and
8 | the whole data loader system that samples data efficiently.
9 |
10 | .. _unified-data-format:
11 |
12 | -------------------
13 | Unified Data Format
14 | -------------------
15 |
16 | There are many existing person re-identification datasets, each with its own
17 | data format and split protocols. This makes conducting experiments on these
18 | datasets a tedious and error-prone work. To solve this problem, Open-ReID
19 | defines a unified dataset interface. By converting the raw dataset into this
20 | unified format, we can significantly simplify the code for training and
21 | evaluation with the formatted data.
22 |
23 | Every dataset will be organized as a directory like
24 |
25 | .. code-block:: shell
26 |
27 | cuhk03
28 | ├── raw/
29 | ├── images/
30 | ├── meta.json
31 | └── splits.json
32 |
33 | where ``raw/`` stores the original dataset files, ``images/`` contains all the
34 | renamed image files in the format of::
35 |
36 | '{:08d}_{:02d}_{:04d}.jpg'.format(person_id, camera_id, image_id)
37 |
38 | where all the ids are indexed from 0.
39 |
40 | ``meta.json`` contains all the person identities of the dataset, which is a list in the structure of::
41 |
42 | "identities": [
43 | [ # the first identity, person_id = 0
44 | [ # camera_id = 0
45 | "00000000_00_0000.jpg",
46 | "00000000_00_0001.jpg"
47 | ],
48 | [ # camera_id = 1
49 | "00000000_01_0000.jpg",
50 | "00000000_01_0001.jpg",
51 | "00000000_01_0002.jpg"
52 | ]
53 | ],
54 | [ # the second identity, person_id = 1
55 | [ # camera_id = 0
56 | "00000001_00_0000.jpg"
57 | ],
58 | [ # camera_id = 1
59 | "00000001_01_0000.jpg",
60 | "00000001_01_0001.jpg",
61 | ]
62 | ],
63 | ...
64 | ]
65 |
66 | Each dataset may define multiple training / test data splits. They are listed in
67 | ``splits.json``, where each split defines three subsets of person identities::
68 |
69 | {
70 | "trainval": [0, 1, 3, ...], # person_ids for training and validation
71 | "gallery": [2, 4, 5, ...], # for test gallery, non-overlap with trainval
72 | "query": [2, 4, ...], # for test query, a subset of gallery
73 | }
74 |
75 | .. _data-loading-system:
76 |
77 | -------------------
78 | Data Loading System
79 | -------------------
80 |
81 | The objective of the data loading system is to sample mini-batches efficiently from the dataset. In our design, it consists of four components, namely ``Dataset``, ``Sampler``, ``Preprocessor``, and ``Data Loader``. Their relations are depicted in the figure below.
82 |
83 | .. _fig-data-loading:
84 | .. figure:: /figures/data-loading.png
85 | :figwidth: 80 %
86 | :align: center
87 |
88 | A ``Dataset`` is a list of items ``(filename, person_id, camera_id)``. A ``Sampler`` is an iterator that each time provides an index. At the top, we adopt the ``torch.utils.data.DataLoader`` to load mini-batches using multi-processing. It queries the data at given index from a ``Preprocessor``, which takes the index as input, loading the corresponding image (with transformations), and returns a tuple of ``(image, filename, person_id, camera_id)``.
89 |
--------------------------------------------------------------------------------
/reid/datasets/cuhk01.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..utils.data import Dataset
7 | from ..utils.osutils import mkdir_if_missing
8 | from ..utils.serialization import write_json
9 |
10 |
11 | class CUHK01(Dataset):
12 | url = 'https://docs.google.com/spreadsheet/viewform?formkey=dF9pZ1BFZkNiMG1oZUdtTjZPalR0MGc6MA'
13 | md5 = 'e6d55c0da26d80cda210a2edeb448e98'
14 |
15 | def __init__(self, root, split_id=0, num_val=100, download=True):
16 | super(CUHK01, self).__init__(root, split_id=split_id)
17 |
18 | if download:
19 | self.download()
20 |
21 | if not self._check_integrity():
22 | raise RuntimeError("Dataset not found or corrupted. " +
23 | "You can use download=True to download it.")
24 |
25 | self.load(num_val)
26 |
27 | def download(self):
28 | if self._check_integrity():
29 | print("Files already downloaded and verified")
30 | return
31 |
32 | import hashlib
33 | import shutil
34 | from glob import glob
35 | from zipfile import ZipFile
36 |
37 | raw_dir = osp.join(self.root, 'raw')
38 | mkdir_if_missing(raw_dir)
39 |
40 | # Download the raw zip file
41 | fpath = osp.join(raw_dir, 'CUHK01.zip')
42 | if osp.isfile(fpath) and \
43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
44 | print("Using downloaded file: " + fpath)
45 | else:
46 | raise RuntimeError("Please download the dataset manually from {} "
47 | "to {}".format(self.url, fpath))
48 |
49 | # Extract the file
50 | exdir = osp.join(raw_dir, 'campus')
51 | if not osp.isdir(exdir):
52 | print("Extracting zip file")
53 | with ZipFile(fpath) as z:
54 | z.extractall(path=raw_dir)
55 |
56 | # Format
57 | images_dir = osp.join(self.root, 'images')
58 | mkdir_if_missing(images_dir)
59 |
60 | identities = [[[] for _ in range(2)] for _ in range(971)]
61 |
62 | files = sorted(glob(osp.join(exdir, '*.png')))
63 | for fpath in files:
64 | fname = osp.basename(fpath)
65 | pid, cam = int(fname[:4]), int(fname[4:7])
66 | assert 1 <= pid <= 971
67 | assert 1 <= cam <= 4
68 | pid, cam = pid - 1, (cam - 1) // 2
69 | fname = ('{:08d}_{:02d}_{:04d}.png'
70 | .format(pid, cam, len(identities[pid][cam])))
71 | identities[pid][cam].append(fname)
72 | shutil.copy(fpath, osp.join(images_dir, fname))
73 |
74 | # Save meta information into a json file
75 | meta = {'name': 'cuhk01', 'shot': 'multiple', 'num_cameras': 2,
76 | 'identities': identities}
77 | write_json(meta, osp.join(self.root, 'meta.json'))
78 |
79 | # Randomly create ten training and test split
80 | num = len(identities)
81 | splits = []
82 | for _ in range(10):
83 | pids = np.random.permutation(num).tolist()
84 | trainval_pids = sorted(pids[:num // 2])
85 | test_pids = sorted(pids[num // 2:])
86 | split = {'trainval': trainval_pids,
87 | 'query': test_pids,
88 | 'gallery': test_pids}
89 | splits.append(split)
90 | write_json(splits, osp.join(self.root, 'splits.json'))
91 |
--------------------------------------------------------------------------------
/reid/datasets/viper.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..utils.data import Dataset
7 | from ..utils.osutils import mkdir_if_missing
8 | from ..utils.serialization import write_json
9 |
10 |
11 | class VIPeR(Dataset):
12 | url = 'http://users.soe.ucsc.edu/~manduchi/VIPeR.v1.0.zip'
13 | md5 = '1c2d9fc1cc800332567a0da25a1ce68c'
14 |
15 | def __init__(self, root, split_id=0, num_val=100, download=True):
16 | super(VIPeR, self).__init__(root, split_id=split_id)
17 |
18 | if download:
19 | self.download()
20 |
21 | if not self._check_integrity():
22 | raise RuntimeError("Dataset not found or corrupted. " +
23 | "You can use download=True to download it.")
24 |
25 | self.load(num_val)
26 |
27 | def download(self):
28 | if self._check_integrity():
29 | print("Files already downloaded and verified")
30 | return
31 |
32 | import hashlib
33 | from glob import glob
34 | from scipy.misc import imsave, imread
35 | from six.moves import urllib
36 | from zipfile import ZipFile
37 |
38 | raw_dir = osp.join(self.root, 'raw')
39 | mkdir_if_missing(raw_dir)
40 |
41 | # Download the raw zip file
42 | fpath = osp.join(raw_dir, 'VIPeR.v1.0.zip')
43 | if osp.isfile(fpath) and \
44 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
45 | print("Using downloaded file: " + fpath)
46 | else:
47 | print("Downloading {} to {}".format(self.url, fpath))
48 | urllib.request.urlretrieve(self.url, fpath)
49 |
50 | # Extract the file
51 | exdir = osp.join(raw_dir, 'VIPeR')
52 | if not osp.isdir(exdir):
53 | print("Extracting zip file")
54 | with ZipFile(fpath) as z:
55 | z.extractall(path=raw_dir)
56 |
57 | # Format
58 | images_dir = osp.join(self.root, 'images')
59 | mkdir_if_missing(images_dir)
60 | cameras = [sorted(glob(osp.join(exdir, 'cam_a', '*.bmp'))),
61 | sorted(glob(osp.join(exdir, 'cam_b', '*.bmp')))]
62 | assert len(cameras[0]) == len(cameras[1])
63 | identities = []
64 | for pid, (cam1, cam2) in enumerate(zip(*cameras)):
65 | images = []
66 | # view-0
67 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, 0, 0)
68 | imsave(osp.join(images_dir, fname), imread(cam1))
69 | images.append([fname])
70 | # view-1
71 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, 1, 0)
72 | imsave(osp.join(images_dir, fname), imread(cam2))
73 | images.append([fname])
74 | identities.append(images)
75 |
76 | # Save meta information into a json file
77 | meta = {'name': 'VIPeR', 'shot': 'single', 'num_cameras': 2,
78 | 'identities': identities}
79 | write_json(meta, osp.join(self.root, 'meta.json'))
80 |
81 | # Randomly create ten training and test split
82 | num = len(identities)
83 | splits = []
84 | for _ in range(10):
85 | pids = np.random.permutation(num).tolist()
86 | trainval_pids = sorted(pids[:num // 2])
87 | test_pids = sorted(pids[num // 2:])
88 | split = {'trainval': trainval_pids,
89 | 'query': test_pids,
90 | 'gallery': test_pids}
91 | splits.append(split)
92 | write_json(splits, osp.join(self.root, 'splits.json'))
93 |
--------------------------------------------------------------------------------
/reid/utils/data/dataset.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..serialization import read_json
7 |
8 |
9 | def _pluck(identities, indices, relabel=False):
10 | ret = []
11 | for index, pid in enumerate(indices):
12 | pid_images = identities[pid]
13 | for camid, cam_images in enumerate(pid_images):
14 | for fname in cam_images:
15 | name = osp.splitext(fname)[0]
16 | x, y, _ = map(int, name.split('_'))
17 | assert pid == x and camid == y
18 | if relabel:
19 | ret.append((fname, index, camid))
20 | else:
21 | ret.append((fname, pid, camid))
22 | return ret
23 |
24 |
25 | class Dataset(object):
26 | def __init__(self, root, split_id=0):
27 | self.root = root
28 | self.split_id = split_id
29 | self.meta = None
30 | self.split = None
31 | self.train, self.val, self.trainval = [], [], []
32 | self.query, self.gallery = [], []
33 | self.num_train_ids, self.num_val_ids, self.num_trainval_ids = 0, 0, 0
34 |
35 | @property
36 | def images_dir(self):
37 | return osp.join(self.root, 'images')
38 |
39 | def load(self, num_val=0.3, verbose=True):
40 | splits = read_json(osp.join(self.root, 'splits.json'))
41 | if self.split_id >= len(splits):
42 | raise ValueError("split_id exceeds total splits {}"
43 | .format(len(splits)))
44 | self.split = splits[self.split_id]
45 |
46 | # Randomly split train / val
47 | trainval_pids = np.asarray(self.split['trainval'])
48 | np.random.shuffle(trainval_pids)
49 | num = len(trainval_pids)
50 | if isinstance(num_val, float):
51 | num_val = int(round(num * num_val))
52 | if num_val >= num or num_val < 0:
53 | raise ValueError("num_val exceeds total identities {}"
54 | .format(num))
55 | train_pids = sorted(trainval_pids[:-num_val])
56 | val_pids = sorted(trainval_pids[-num_val:])
57 |
58 | self.meta = read_json(osp.join(self.root, 'meta.json'))
59 | identities = self.meta['identities']
60 | self.train = _pluck(identities, train_pids, relabel=True)
61 | self.val = _pluck(identities, val_pids, relabel=True)
62 | self.trainval = _pluck(identities, trainval_pids, relabel=True)
63 | self.query = _pluck(identities, self.split['query'])
64 | self.gallery = _pluck(identities, self.split['gallery'])
65 | self.num_train_ids = len(train_pids)
66 | self.num_val_ids = len(val_pids)
67 | self.num_trainval_ids = len(trainval_pids)
68 |
69 | if verbose:
70 | print(self.__class__.__name__, "dataset loaded")
71 | print(" subset | # ids | # images")
72 | print(" ---------------------------")
73 | print(" train | {:5d} | {:8d}"
74 | .format(self.num_train_ids, len(self.train)))
75 | print(" val | {:5d} | {:8d}"
76 | .format(self.num_val_ids, len(self.val)))
77 | print(" trainval | {:5d} | {:8d}"
78 | .format(self.num_trainval_ids, len(self.trainval)))
79 | print(" query | {:5d} | {:8d}"
80 | .format(len(self.split['query']), len(self.query)))
81 | print(" gallery | {:5d} | {:8d}"
82 | .format(len(self.split['gallery']), len(self.gallery)))
83 |
84 | def _check_integrity(self):
85 | return osp.isdir(osp.join(self.root, 'images')) and \
86 | osp.isfile(osp.join(self.root, 'meta.json')) and \
87 | osp.isfile(osp.join(self.root, 'splits.json'))
88 |
--------------------------------------------------------------------------------
/reid/datasets/market1501.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | from ..utils.data import Dataset
5 | from ..utils.osutils import mkdir_if_missing
6 | from ..utils.serialization import write_json
7 |
8 |
9 | class Market1501(Dataset):
10 | url = 'https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view'
11 | md5 = '65005ab7d12ec1c44de4eeafe813e68a'
12 |
13 | def __init__(self, root, split_id=0, num_val=100, download=True):
14 | super(Market1501, self).__init__(root, split_id=split_id)
15 |
16 | if download:
17 | self.download()
18 |
19 | if not self._check_integrity():
20 | raise RuntimeError("Dataset not found or corrupted. " +
21 | "You can use download=True to download it.")
22 |
23 | self.load(num_val)
24 |
25 | def download(self):
26 | if self._check_integrity():
27 | print("Files already downloaded and verified")
28 | return
29 |
30 | import re
31 | import hashlib
32 | import shutil
33 | from glob import glob
34 | from zipfile import ZipFile
35 |
36 | raw_dir = osp.join(self.root, 'raw')
37 | mkdir_if_missing(raw_dir)
38 |
39 | # Download the raw zip file
40 | fpath = osp.join(raw_dir, 'Market-1501-v15.09.15.zip')
41 | if osp.isfile(fpath) and \
42 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
43 | print("Using downloaded file: " + fpath)
44 | else:
45 | raise RuntimeError("Please download the dataset manually from {} "
46 | "to {}".format(self.url, fpath))
47 |
48 | # Extract the file
49 | exdir = osp.join(raw_dir, 'Market-1501-v15.09.15')
50 | if not osp.isdir(exdir):
51 | print("Extracting zip file")
52 | with ZipFile(fpath) as z:
53 | z.extractall(path=raw_dir)
54 |
55 | # Format
56 | images_dir = osp.join(self.root, 'images')
57 | mkdir_if_missing(images_dir)
58 |
59 | # 1501 identities (+1 for background) with 6 camera views each
60 | identities = [[[] for _ in range(6)] for _ in range(1502)]
61 |
62 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')):
63 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg')))
64 | pids = set()
65 | for fpath in fpaths:
66 | fname = osp.basename(fpath)
67 | pid, cam = map(int, pattern.search(fname).groups())
68 | if pid == -1: continue # junk images are just ignored
69 | assert 0 <= pid <= 1501 # pid == 0 means background
70 | assert 1 <= cam <= 6
71 | cam -= 1
72 | pids.add(pid)
73 | fname = ('{:08d}_{:02d}_{:04d}.jpg'
74 | .format(pid, cam, len(identities[pid][cam])))
75 | identities[pid][cam].append(fname)
76 | shutil.copy(fpath, osp.join(images_dir, fname))
77 | return pids
78 |
79 | trainval_pids = register('bounding_box_train')
80 | gallery_pids = register('bounding_box_test')
81 | query_pids = register('query')
82 | assert query_pids <= gallery_pids
83 | assert trainval_pids.isdisjoint(gallery_pids)
84 |
85 | # Save meta information into a json file
86 | meta = {'name': 'Market1501', 'shot': 'multiple', 'num_cameras': 6,
87 | 'identities': identities}
88 | write_json(meta, osp.join(self.root, 'meta.json'))
89 |
90 | # Save the only training / test split
91 | splits = [{
92 | 'trainval': sorted(list(trainval_pids)),
93 | 'query': sorted(list(query_pids)),
94 | 'gallery': sorted(list(gallery_pids))}]
95 | write_json(splits, osp.join(self.root, 'splits.json'))
96 |
--------------------------------------------------------------------------------
/reid/datasets/dukemtmc.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | from ..utils.data import Dataset
5 | from ..utils.osutils import mkdir_if_missing
6 | from ..utils.serialization import write_json
7 |
8 |
9 | class DukeMTMC(Dataset):
10 | url = 'https://drive.google.com/uc?id=0B0VOCNYh8HeRdnBPa2ZWaVBYSVk'
11 | md5 = '2f93496f9b516d1ee5ef51c1d5e7d601'
12 |
13 | def __init__(self, root, split_id=0, num_val=100, download=True):
14 | super(DukeMTMC, self).__init__(root, split_id=split_id)
15 |
16 | if download:
17 | self.download()
18 |
19 | if not self._check_integrity():
20 | raise RuntimeError("Dataset not found or corrupted. " +
21 | "You can use download=True to download it.")
22 |
23 | self.load(num_val)
24 |
25 | def download(self):
26 | if self._check_integrity():
27 | print("Files already downloaded and verified")
28 | return
29 |
30 | import re
31 | import hashlib
32 | import shutil
33 | from glob import glob
34 | from zipfile import ZipFile
35 |
36 | raw_dir = osp.join(self.root, 'raw')
37 | mkdir_if_missing(raw_dir)
38 |
39 | # Download the raw zip file
40 | fpath = osp.join(raw_dir, 'DukeMTMC-reID.zip')
41 | if osp.isfile(fpath) and \
42 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
43 | print("Using downloaded file: " + fpath)
44 | else:
45 | raise RuntimeError("Please download the dataset manually from {} "
46 | "to {}".format(self.url, fpath))
47 |
48 | # Extract the file
49 | exdir = osp.join(raw_dir, 'DukeMTMC-reID')
50 | if not osp.isdir(exdir):
51 | print("Extracting zip file")
52 | with ZipFile(fpath) as z:
53 | z.extractall(path=raw_dir)
54 |
55 | # Format
56 | images_dir = osp.join(self.root, 'images')
57 | mkdir_if_missing(images_dir)
58 |
59 | identities = []
60 | all_pids = {}
61 |
62 | def register(subdir, pattern=re.compile(r'([-\d]+)_c(\d)')):
63 | fpaths = sorted(glob(osp.join(exdir, subdir, '*.jpg')))
64 | pids = set()
65 | for fpath in fpaths:
66 | fname = osp.basename(fpath)
67 | pid, cam = map(int, pattern.search(fname).groups())
68 | assert 1 <= cam <= 8
69 | cam -= 1
70 | if pid not in all_pids:
71 | all_pids[pid] = len(all_pids)
72 | pid = all_pids[pid]
73 | pids.add(pid)
74 | if pid >= len(identities):
75 | assert pid == len(identities)
76 | identities.append([[] for _ in range(8)]) # 8 camera views
77 | fname = ('{:08d}_{:02d}_{:04d}.jpg'
78 | .format(pid, cam, len(identities[pid][cam])))
79 | identities[pid][cam].append(fname)
80 | shutil.copy(fpath, osp.join(images_dir, fname))
81 | return pids
82 |
83 | trainval_pids = register('bounding_box_train')
84 | gallery_pids = register('bounding_box_test')
85 | query_pids = register('query')
86 | assert query_pids <= gallery_pids
87 | assert trainval_pids.isdisjoint(gallery_pids)
88 |
89 | # Save meta information into a json file
90 | meta = {'name': 'DukeMTMC', 'shot': 'multiple', 'num_cameras': 8,
91 | 'identities': identities}
92 | write_json(meta, osp.join(self.root, 'meta.json'))
93 |
94 | # Save the only training / test split
95 | splits = [{
96 | 'trainval': sorted(list(trainval_pids)),
97 | 'query': sorted(list(query_pids)),
98 | 'gallery': sorted(list(gallery_pids))}]
99 | write_json(splits, osp.join(self.root, 'splits.json'))
100 |
--------------------------------------------------------------------------------
/reid/models/resnet.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | from torch import nn
4 | from torch.nn import functional as F
5 | from torch.nn import init
6 | import torchvision
7 |
8 |
9 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
10 | 'resnet152']
11 |
12 |
13 | class ResNet(nn.Module):
14 | __factory = {
15 | 18: torchvision.models.resnet18,
16 | 34: torchvision.models.resnet34,
17 | 50: torchvision.models.resnet50,
18 | 101: torchvision.models.resnet101,
19 | 152: torchvision.models.resnet152,
20 | }
21 |
22 | def __init__(self, depth, pretrained=True, cut_at_pooling=False,
23 | num_features=0, norm=False, dropout=0, num_classes=0):
24 | super(ResNet, self).__init__()
25 |
26 | self.depth = depth
27 | self.pretrained = pretrained
28 | self.cut_at_pooling = cut_at_pooling
29 |
30 | # Construct base (pretrained) resnet
31 | if depth not in ResNet.__factory:
32 | raise KeyError("Unsupported depth:", depth)
33 | self.base = ResNet.__factory[depth](pretrained=pretrained)
34 |
35 | if not self.cut_at_pooling:
36 | self.num_features = num_features
37 | self.norm = norm
38 | self.dropout = dropout
39 | self.has_embedding = num_features > 0
40 | self.num_classes = num_classes
41 |
42 | out_planes = self.base.fc.in_features
43 |
44 | # Append new layers
45 | if self.has_embedding:
46 | self.feat = nn.Linear(out_planes, self.num_features)
47 | self.feat_bn = nn.BatchNorm1d(self.num_features)
48 | init.kaiming_normal(self.feat.weight, mode='fan_out')
49 | init.constant(self.feat.bias, 0)
50 | init.constant(self.feat_bn.weight, 1)
51 | init.constant(self.feat_bn.bias, 0)
52 | else:
53 | # Change the num_features to CNN output channels
54 | self.num_features = out_planes
55 | if self.dropout > 0:
56 | self.drop = nn.Dropout(self.dropout)
57 | if self.num_classes > 0:
58 | self.classifier = nn.Linear(self.num_features, self.num_classes)
59 | init.normal(self.classifier.weight, std=0.001)
60 | init.constant(self.classifier.bias, 0)
61 |
62 | if not self.pretrained:
63 | self.reset_params()
64 |
65 | def forward(self, x):
66 | for name, module in self.base._modules.items():
67 | if name == 'avgpool':
68 | break
69 | x = module(x)
70 |
71 | if self.cut_at_pooling:
72 | return x
73 |
74 | x = F.avg_pool2d(x, x.size()[2:])
75 | x = x.view(x.size(0), -1)
76 |
77 | if self.has_embedding:
78 | x = self.feat(x)
79 | x = self.feat_bn(x)
80 | if self.norm:
81 | x = F.normalize(x)
82 | elif self.has_embedding:
83 | x = F.relu(x)
84 | if self.dropout > 0:
85 | x = self.drop(x)
86 | if self.num_classes > 0:
87 | x = self.classifier(x)
88 | return x
89 |
90 | def reset_params(self):
91 | for m in self.modules():
92 | if isinstance(m, nn.Conv2d):
93 | init.kaiming_normal(m.weight, mode='fan_out')
94 | if m.bias is not None:
95 | init.constant(m.bias, 0)
96 | elif isinstance(m, nn.BatchNorm2d):
97 | init.constant(m.weight, 1)
98 | init.constant(m.bias, 0)
99 | elif isinstance(m, nn.Linear):
100 | init.normal(m.weight, std=0.001)
101 | if m.bias is not None:
102 | init.constant(m.bias, 0)
103 |
104 |
105 | def resnet18(**kwargs):
106 | return ResNet(18, **kwargs)
107 |
108 |
109 | def resnet34(**kwargs):
110 | return ResNet(34, **kwargs)
111 |
112 |
113 | def resnet50(**kwargs):
114 | return ResNet(50, **kwargs)
115 |
116 |
117 | def resnet101(**kwargs):
118 | return ResNet(101, **kwargs)
119 |
120 |
121 | def resnet152(**kwargs):
122 | return ResNet(152, **kwargs)
123 |
--------------------------------------------------------------------------------
/reid/datasets/cuhk03.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import os.path as osp
3 |
4 | import numpy as np
5 |
6 | from ..utils.data import Dataset
7 | from ..utils.osutils import mkdir_if_missing
8 | from ..utils.serialization import write_json
9 |
10 |
11 | class CUHK03(Dataset):
12 | url = 'https://docs.google.com/spreadsheet/viewform?usp=drive_web&formkey=dHRkMkFVSUFvbTJIRkRDLWRwZWpONnc6MA#gid=0'
13 | md5 = '728939e58ad9f0ff53e521857dd8fb43'
14 |
15 | def __init__(self, root, split_id=0, num_val=100, download=True):
16 | super(CUHK03, self).__init__(root, split_id=split_id)
17 |
18 | if download:
19 | self.download()
20 |
21 | if not self._check_integrity():
22 | raise RuntimeError("Dataset not found or corrupted. " +
23 | "You can use download=True to download it.")
24 |
25 | self.load(num_val)
26 |
27 | def download(self):
28 | if self._check_integrity():
29 | print("Files already downloaded and verified")
30 | return
31 |
32 | import h5py
33 | import hashlib
34 | from scipy.misc import imsave
35 | from zipfile import ZipFile
36 |
37 | raw_dir = osp.join(self.root, 'raw')
38 | mkdir_if_missing(raw_dir)
39 |
40 | # Download the raw zip file
41 | fpath = osp.join(raw_dir, 'cuhk03_release.zip')
42 | if osp.isfile(fpath) and \
43 | hashlib.md5(open(fpath, 'rb').read()).hexdigest() == self.md5:
44 | print("Using downloaded file: " + fpath)
45 | else:
46 | raise RuntimeError("Please download the dataset manually from {} "
47 | "to {}".format(self.url, fpath))
48 |
49 | # Extract the file
50 | exdir = osp.join(raw_dir, 'cuhk03_release')
51 | if not osp.isdir(exdir):
52 | print("Extracting zip file")
53 | with ZipFile(fpath) as z:
54 | z.extractall(path=raw_dir)
55 |
56 | # Format
57 | images_dir = osp.join(self.root, 'images')
58 | mkdir_if_missing(images_dir)
59 | matdata = h5py.File(osp.join(exdir, 'cuhk-03.mat'), 'r')
60 |
61 | def deref(ref):
62 | return matdata[ref][:].T
63 |
64 | def dump_(refs, pid, cam, fnames):
65 | for ref in refs:
66 | img = deref(ref)
67 | if img.size == 0 or img.ndim < 2: break
68 | fname = '{:08d}_{:02d}_{:04d}.jpg'.format(pid, cam, len(fnames))
69 | imsave(osp.join(images_dir, fname), img)
70 | fnames.append(fname)
71 |
72 | identities = []
73 | for labeled, detected in zip(
74 | matdata['labeled'][0], matdata['detected'][0]):
75 | labeled, detected = deref(labeled), deref(detected)
76 | assert labeled.shape == detected.shape
77 | for i in range(labeled.shape[0]):
78 | pid = len(identities)
79 | images = [[], []]
80 | dump_(labeled[i, :5], pid, 0, images[0])
81 | dump_(detected[i, :5], pid, 0, images[0])
82 | dump_(labeled[i, 5:], pid, 1, images[1])
83 | dump_(detected[i, 5:], pid, 1, images[1])
84 | identities.append(images)
85 |
86 | # Save meta information into a json file
87 | meta = {'name': 'cuhk03', 'shot': 'multiple', 'num_cameras': 2,
88 | 'identities': identities}
89 | write_json(meta, osp.join(self.root, 'meta.json'))
90 |
91 | # Save training and test splits
92 | splits = []
93 | view_counts = [deref(ref).shape[0] for ref in matdata['labeled'][0]]
94 | vid_offsets = np.r_[0, np.cumsum(view_counts)]
95 | for ref in matdata['testsets'][0]:
96 | test_info = deref(ref).astype(np.int32)
97 | test_pids = sorted(
98 | [int(vid_offsets[i-1] + j - 1) for i, j in test_info])
99 | trainval_pids = list(set(range(vid_offsets[-1])) - set(test_pids))
100 | split = {'trainval': trainval_pids,
101 | 'query': test_pids,
102 | 'gallery': test_pids}
103 | splits.append(split)
104 | write_json(splits, osp.join(self.root, 'splits.json'))
105 |
--------------------------------------------------------------------------------
/reid/evaluation_metrics/ranking.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 | from collections import defaultdict
3 |
4 | import numpy as np
5 | from sklearn.metrics import average_precision_score
6 |
7 | from ..utils import to_numpy
8 |
9 |
10 | def _unique_sample(ids_dict, num):
11 | mask = np.zeros(num, dtype=np.bool)
12 | for _, indices in ids_dict.items():
13 | i = np.random.choice(indices)
14 | mask[i] = True
15 | return mask
16 |
17 |
18 | def cmc(distmat, query_ids=None, gallery_ids=None,
19 | query_cams=None, gallery_cams=None, topk=100,
20 | separate_camera_set=False,
21 | single_gallery_shot=False,
22 | first_match_break=False):
23 | distmat = to_numpy(distmat)
24 | m, n = distmat.shape
25 | # Fill up default values
26 | if query_ids is None:
27 | query_ids = np.arange(m)
28 | if gallery_ids is None:
29 | gallery_ids = np.arange(n)
30 | if query_cams is None:
31 | query_cams = np.zeros(m).astype(np.int32)
32 | if gallery_cams is None:
33 | gallery_cams = np.ones(n).astype(np.int32)
34 | # Ensure numpy array
35 | query_ids = np.asarray(query_ids)
36 | gallery_ids = np.asarray(gallery_ids)
37 | query_cams = np.asarray(query_cams)
38 | gallery_cams = np.asarray(gallery_cams)
39 | # Sort and find correct matches
40 | indices = np.argsort(distmat, axis=1)
41 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
42 | # Compute CMC for each query
43 | ret = np.zeros(topk)
44 | num_valid_queries = 0
45 | for i in range(m):
46 | # Filter out the same id and same camera
47 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
48 | (gallery_cams[indices[i]] != query_cams[i]))
49 | if separate_camera_set:
50 | # Filter out samples from same camera
51 | valid &= (gallery_cams[indices[i]] != query_cams[i])
52 | if not np.any(matches[i, valid]): continue
53 | if single_gallery_shot:
54 | repeat = 10
55 | gids = gallery_ids[indices[i][valid]]
56 | inds = np.where(valid)[0]
57 | ids_dict = defaultdict(list)
58 | for j, x in zip(inds, gids):
59 | ids_dict[x].append(j)
60 | else:
61 | repeat = 1
62 | for _ in range(repeat):
63 | if single_gallery_shot:
64 | # Randomly choose one instance for each id
65 | sampled = (valid & _unique_sample(ids_dict, len(valid)))
66 | index = np.nonzero(matches[i, sampled])[0]
67 | else:
68 | index = np.nonzero(matches[i, valid])[0]
69 | delta = 1. / (len(index) * repeat)
70 | for j, k in enumerate(index):
71 | if k - j >= topk: break
72 | if first_match_break:
73 | ret[k - j] += 1
74 | break
75 | ret[k - j] += delta
76 | num_valid_queries += 1
77 | if num_valid_queries == 0:
78 | raise RuntimeError("No valid query")
79 | return ret.cumsum() / num_valid_queries
80 |
81 |
82 | def mean_ap(distmat, query_ids=None, gallery_ids=None,
83 | query_cams=None, gallery_cams=None):
84 | distmat = to_numpy(distmat)
85 | m, n = distmat.shape
86 | # Fill up default values
87 | if query_ids is None:
88 | query_ids = np.arange(m)
89 | if gallery_ids is None:
90 | gallery_ids = np.arange(n)
91 | if query_cams is None:
92 | query_cams = np.zeros(m).astype(np.int32)
93 | if gallery_cams is None:
94 | gallery_cams = np.ones(n).astype(np.int32)
95 | # Ensure numpy array
96 | query_ids = np.asarray(query_ids)
97 | gallery_ids = np.asarray(gallery_ids)
98 | query_cams = np.asarray(query_cams)
99 | gallery_cams = np.asarray(gallery_cams)
100 | # Sort and find correct matches
101 | indices = np.argsort(distmat, axis=1)
102 | matches = (gallery_ids[indices] == query_ids[:, np.newaxis])
103 | # Compute AP for each query
104 | aps = []
105 | for i in range(m):
106 | # Filter out the same id and same camera
107 | valid = ((gallery_ids[indices[i]] != query_ids[i]) |
108 | (gallery_cams[indices[i]] != query_cams[i]))
109 | y_true = matches[i, valid]
110 | y_score = -distmat[i][indices[i]][valid]
111 | if not np.any(y_true): continue
112 | aps.append(average_precision_score(y_true, y_score))
113 | if len(aps) == 0:
114 | raise RuntimeError("No valid query")
115 | return np.mean(aps)
116 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *~
2 |
3 | # temporary files which can be created if a process still has a handle open of a deleted file
4 | .fuse_hidden*
5 |
6 | # KDE directory preferences
7 | .directory
8 |
9 | # Linux trash folder which might appear on any partition or disk
10 | .Trash-*
11 |
12 | # .nfs files are created when an open file is removed but is still being accessed
13 | .nfs*
14 |
15 |
16 | *.DS_Store
17 | .AppleDouble
18 | .LSOverride
19 |
20 | # Icon must end with two \r
21 | Icon
22 |
23 |
24 | # Thumbnails
25 | ._*
26 |
27 | # Files that might appear in the root of a volume
28 | .DocumentRevisions-V100
29 | .fseventsd
30 | .Spotlight-V100
31 | .TemporaryItems
32 | .Trashes
33 | .VolumeIcon.icns
34 | .com.apple.timemachine.donotpresent
35 |
36 | # Directories potentially created on remote AFP share
37 | .AppleDB
38 | .AppleDesktop
39 | Network Trash Folder
40 | Temporary Items
41 | .apdisk
42 |
43 |
44 | # swap
45 | [._]*.s[a-v][a-z]
46 | [._]*.sw[a-p]
47 | [._]s[a-v][a-z]
48 | [._]sw[a-p]
49 | # session
50 | Session.vim
51 | # temporary
52 | .netrwhist
53 | *~
54 | # auto-generated tag files
55 | tags
56 |
57 |
58 | # cache files for sublime text
59 | *.tmlanguage.cache
60 | *.tmPreferences.cache
61 | *.stTheme.cache
62 |
63 | # workspace files are user-specific
64 | *.sublime-workspace
65 |
66 | # project files should be checked into the repository, unless a significant
67 | # proportion of contributors will probably not be using SublimeText
68 | # *.sublime-project
69 |
70 | # sftp configuration file
71 | sftp-config.json
72 |
73 | # Package control specific files
74 | Package Control.last-run
75 | Package Control.ca-list
76 | Package Control.ca-bundle
77 | Package Control.system-ca-bundle
78 | Package Control.cache/
79 | Package Control.ca-certs/
80 | Package Control.merged-ca-bundle
81 | Package Control.user-ca-bundle
82 | oscrypto-ca-bundle.crt
83 | bh_unicode_properties.cache
84 |
85 | # Sublime-github package stores a github token in this file
86 | # https://packagecontrol.io/packages/sublime-github
87 | GitHub.sublime-settings
88 |
89 |
90 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and Webstorm
91 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
92 |
93 | # User-specific stuff:
94 | .idea
95 | .idea/**/workspace.xml
96 | .idea/**/tasks.xml
97 |
98 | # Sensitive or high-churn files:
99 | .idea/**/dataSources/
100 | .idea/**/dataSources.ids
101 | .idea/**/dataSources.xml
102 | .idea/**/dataSources.local.xml
103 | .idea/**/sqlDataSources.xml
104 | .idea/**/dynamic.xml
105 | .idea/**/uiDesigner.xml
106 |
107 | # Gradle:
108 | .idea/**/gradle.xml
109 | .idea/**/libraries
110 |
111 | # Mongo Explorer plugin:
112 | .idea/**/mongoSettings.xml
113 |
114 | ## File-based project format:
115 | *.iws
116 |
117 | ## Plugin-specific files:
118 |
119 | # IntelliJ
120 | /out/
121 |
122 | # mpeltonen/sbt-idea plugin
123 | .idea_modules/
124 |
125 | # JIRA plugin
126 | atlassian-ide-plugin.xml
127 |
128 | # Crashlytics plugin (for Android Studio and IntelliJ)
129 | com_crashlytics_export_strings.xml
130 | crashlytics.properties
131 | crashlytics-build.properties
132 | fabric.properties
133 |
134 |
135 | # Byte-compiled / optimized / DLL files
136 | __pycache__/
137 | *.py[cod]
138 | *$py.class
139 |
140 | # C extensions
141 | *.so
142 |
143 | # Distribution / packaging
144 | .Python
145 | env/
146 | build/
147 | develop-eggs/
148 | dist/
149 | downloads/
150 | eggs/
151 | .eggs/
152 | lib/
153 | lib64/
154 | parts/
155 | sdist/
156 | var/
157 | *.egg-info/
158 | .installed.cfg
159 | *.egg
160 |
161 | # PyInstaller
162 | # Usually these files are written by a python script from a template
163 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
164 | *.manifest
165 | *.spec
166 |
167 | # Installer logs
168 | pip-log.txt
169 | pip-delete-this-directory.txt
170 |
171 | # Unit test / coverage reports
172 | htmlcov/
173 | .tox/
174 | .coverage
175 | .coverage.*
176 | .cache
177 | nosetests.xml
178 | coverage.xml
179 | *,cover
180 | .hypothesis/
181 |
182 | # Translations
183 | *.mo
184 | *.pot
185 |
186 | # Django stuff:
187 | *.log
188 | local_settings.py
189 |
190 | # Flask stuff:
191 | instance/
192 | .webassets-cache
193 |
194 | # Scrapy stuff:
195 | .scrapy
196 |
197 | # Sphinx documentation
198 | docs/_build/
199 |
200 | # PyBuilder
201 | target/
202 |
203 | # IPython Notebook
204 | .ipynb_checkpoints
205 |
206 | # pyenv
207 | .python-version
208 |
209 | # celery beat schedule file
210 | celerybeat-schedule
211 |
212 | # dotenv
213 | .env
214 |
215 | # virtualenv
216 | venv/
217 | ENV/
218 |
219 | # Spyder project settings
220 | .spyderproject
221 |
222 | # Rope project settings
223 | .ropeproject
224 |
225 |
226 | # Project specific
227 | examples/data
228 | examples/logs
229 |
230 |
--------------------------------------------------------------------------------
/reid/evaluators.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import time
3 | from collections import OrderedDict
4 |
5 | import torch
6 |
7 | from .evaluation_metrics import cmc, mean_ap
8 | from .feature_extraction import extract_cnn_feature
9 | from .utils.meters import AverageMeter
10 |
11 |
12 | def extract_features(model, data_loader, print_freq=1, metric=None):
13 | model.eval()
14 | batch_time = AverageMeter()
15 | data_time = AverageMeter()
16 |
17 | features = OrderedDict()
18 | labels = OrderedDict()
19 |
20 | end = time.time()
21 | for i, (imgs, fnames, pids, _) in enumerate(data_loader):
22 | data_time.update(time.time() - end)
23 |
24 | outputs = extract_cnn_feature(model, imgs)
25 | for fname, output, pid in zip(fnames, outputs, pids):
26 | features[fname] = output
27 | labels[fname] = pid
28 |
29 | batch_time.update(time.time() - end)
30 | end = time.time()
31 |
32 | if (i + 1) % print_freq == 0:
33 | print('Extract Features: [{}/{}]\t'
34 | 'Time {:.3f} ({:.3f})\t'
35 | 'Data {:.3f} ({:.3f})\t'
36 | .format(i + 1, len(data_loader),
37 | batch_time.val, batch_time.avg,
38 | data_time.val, data_time.avg))
39 |
40 | return features, labels
41 |
42 |
43 | def pairwise_distance(features, query=None, gallery=None, metric=None):
44 | if query is None and gallery is None:
45 | n = len(features)
46 | x = torch.cat(list(features.values()))
47 | x = x.view(n, -1)
48 | if metric is not None:
49 | x = metric.transform(x)
50 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True) * 2
51 | dist = dist.expand(n, n) - 2 * torch.mm(x, x.t())
52 | return dist
53 |
54 | x = torch.cat([features[f].unsqueeze(0) for f, _, _ in query], 0)
55 | y = torch.cat([features[f].unsqueeze(0) for f, _, _ in gallery], 0)
56 | m, n = x.size(0), y.size(0)
57 | x = x.view(m, -1)
58 | y = y.view(n, -1)
59 | if metric is not None:
60 | x = metric.transform(x)
61 | y = metric.transform(y)
62 | dist = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n) + \
63 | torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
64 | dist.addmm_(1, -2, x, y.t())
65 | return dist
66 |
67 |
68 | def evaluate_all(distmat, query=None, gallery=None,
69 | query_ids=None, gallery_ids=None,
70 | query_cams=None, gallery_cams=None,
71 | cmc_topk=(1, 5, 10)):
72 | if query is not None and gallery is not None:
73 | query_ids = [pid for _, pid, _ in query]
74 | gallery_ids = [pid for _, pid, _ in gallery]
75 | query_cams = [cam for _, _, cam in query]
76 | gallery_cams = [cam for _, _, cam in gallery]
77 | else:
78 | assert (query_ids is not None and gallery_ids is not None
79 | and query_cams is not None and gallery_cams is not None)
80 |
81 | # Compute mean AP
82 | mAP = mean_ap(distmat, query_ids, gallery_ids, query_cams, gallery_cams)
83 | print('Mean AP: {:4.1%}'.format(mAP))
84 |
85 | # Compute all kinds of CMC scores
86 | cmc_configs = {
87 | 'allshots': dict(separate_camera_set=False,
88 | single_gallery_shot=False,
89 | first_match_break=False),
90 | 'cuhk03': dict(separate_camera_set=True,
91 | single_gallery_shot=True,
92 | first_match_break=False),
93 | 'market1501': dict(separate_camera_set=False,
94 | single_gallery_shot=False,
95 | first_match_break=True)}
96 | cmc_scores = {name: cmc(distmat, query_ids, gallery_ids,
97 | query_cams, gallery_cams, **params)
98 | for name, params in cmc_configs.items()}
99 |
100 | print('CMC Scores{:>12}{:>12}{:>12}'
101 | .format('allshots', 'cuhk03', 'market1501'))
102 | for k in cmc_topk:
103 | print(' top-{:<4}{:12.1%}{:12.1%}{:12.1%}'
104 | .format(k, cmc_scores['allshots'][k - 1],
105 | cmc_scores['cuhk03'][k - 1],
106 | cmc_scores['market1501'][k - 1]))
107 |
108 | # Use the allshots cmc top-1 score for validation criterion
109 | return cmc_scores['allshots'][0]
110 |
111 |
112 | class Evaluator(object):
113 | def __init__(self, model):
114 | super(Evaluator, self).__init__()
115 | self.model = model
116 |
117 | def evaluate(self, data_loader, query, gallery, metric=None):
118 | features, _ = extract_features(self.model, data_loader)
119 | distmat = pairwise_distance(features, query, gallery, metric=metric)
120 | return evaluate_all(distmat, query=query, gallery=gallery)
121 |
--------------------------------------------------------------------------------
/reid/models/inception.py:
--------------------------------------------------------------------------------
1 | from __future__ import absolute_import
2 |
3 | import torch
4 | from torch import nn
5 | from torch.nn import functional as F
6 | from torch.nn import init
7 |
8 |
9 | __all__ = ['InceptionNet', 'inception']
10 |
11 |
12 | def _make_conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1,
13 | bias=False):
14 | conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size,
15 | stride=stride, padding=padding, bias=bias)
16 | bn = nn.BatchNorm2d(out_planes)
17 | relu = nn.ReLU(inplace=True)
18 | return nn.Sequential(conv, bn, relu)
19 |
20 |
21 | class Block(nn.Module):
22 | def __init__(self, in_planes, out_planes, pool_method, stride):
23 | super(Block, self).__init__()
24 | self.branches = nn.ModuleList([
25 | nn.Sequential(
26 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0),
27 | _make_conv(out_planes, out_planes, stride=stride)
28 | ),
29 | nn.Sequential(
30 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0),
31 | _make_conv(out_planes, out_planes),
32 | _make_conv(out_planes, out_planes, stride=stride))
33 | ])
34 |
35 | if pool_method == 'Avg':
36 | assert stride == 1
37 | self.branches.append(
38 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0))
39 | self.branches.append(nn.Sequential(
40 | nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
41 | _make_conv(in_planes, out_planes, kernel_size=1, padding=0)))
42 | else:
43 | self.branches.append(
44 | nn.MaxPool2d(kernel_size=3, stride=stride, padding=1))
45 |
46 | def forward(self, x):
47 | return torch.cat([b(x) for b in self.branches], 1)
48 |
49 |
50 | class InceptionNet(nn.Module):
51 | def __init__(self, cut_at_pooling=False, num_features=256, norm=False,
52 | dropout=0, num_classes=0):
53 | super(InceptionNet, self).__init__()
54 | self.cut_at_pooling = cut_at_pooling
55 |
56 | self.conv1 = _make_conv(3, 32)
57 | self.conv2 = _make_conv(32, 32)
58 | self.conv3 = _make_conv(32, 32)
59 | self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2, padding=1)
60 | self.in_planes = 32
61 | self.inception4a = self._make_inception(64, 'Avg', 1)
62 | self.inception4b = self._make_inception(64, 'Max', 2)
63 | self.inception5a = self._make_inception(128, 'Avg', 1)
64 | self.inception5b = self._make_inception(128, 'Max', 2)
65 | self.inception6a = self._make_inception(256, 'Avg', 1)
66 | self.inception6b = self._make_inception(256, 'Max', 2)
67 |
68 | if not self.cut_at_pooling:
69 | self.num_features = num_features
70 | self.norm = norm
71 | self.dropout = dropout
72 | self.has_embedding = num_features > 0
73 | self.num_classes = num_classes
74 |
75 | self.avgpool = nn.AdaptiveAvgPool2d(1)
76 |
77 | if self.has_embedding:
78 | self.feat = nn.Linear(self.in_planes, self.num_features)
79 | self.feat_bn = nn.BatchNorm1d(self.num_features)
80 | else:
81 | # Change the num_features to CNN output channels
82 | self.num_features = self.in_planes
83 | if self.dropout > 0:
84 | self.drop = nn.Dropout(self.dropout)
85 | if self.num_classes > 0:
86 | self.classifier = nn.Linear(self.num_features, self.num_classes)
87 |
88 | self.reset_params()
89 |
90 | def forward(self, x):
91 | x = self.conv1(x)
92 | x = self.conv2(x)
93 | x = self.conv3(x)
94 | x = self.pool3(x)
95 | x = self.inception4a(x)
96 | x = self.inception4b(x)
97 | x = self.inception5a(x)
98 | x = self.inception5b(x)
99 | x = self.inception6a(x)
100 | x = self.inception6b(x)
101 |
102 | if self.cut_at_pooling:
103 | return x
104 |
105 | x = self.avgpool(x)
106 | x = x.view(x.size(0), -1)
107 |
108 | if self.has_embedding:
109 | x = self.feat(x)
110 | x = self.feat_bn(x)
111 | if self.norm:
112 | x = F.normalize(x)
113 | elif self.has_embedding:
114 | x = F.relu(x)
115 | if self.dropout > 0:
116 | x = self.drop(x)
117 | if self.num_classes > 0:
118 | x = self.classifier(x)
119 | return x
120 |
121 | def _make_inception(self, out_planes, pool_method, stride):
122 | block = Block(self.in_planes, out_planes, pool_method, stride)
123 | self.in_planes = (out_planes * 4 if pool_method == 'Avg' else
124 | out_planes * 2 + self.in_planes)
125 | return block
126 |
127 | def reset_params(self):
128 | for m in self.modules():
129 | if isinstance(m, nn.Conv2d):
130 | init.kaiming_normal(m.weight, mode='fan_out')
131 | if m.bias is not None:
132 | init.constant(m.bias, 0)
133 | elif isinstance(m, nn.BatchNorm2d):
134 | init.constant(m.weight, 1)
135 | init.constant(m.bias, 0)
136 | elif isinstance(m, nn.Linear):
137 | init.normal(m.weight, std=0.001)
138 | if m.bias is not None:
139 | init.constant(m.bias, 0)
140 |
141 |
142 | def inception(**kwargs):
143 | return InceptionNet(**kwargs)
144 |
--------------------------------------------------------------------------------
/docs/examples/benchmarks.rst:
--------------------------------------------------------------------------------
1 | ==========
2 | Benchmarks
3 | ==========
4 |
5 | Benchmarks for different models and loss functions on various datasets.
6 |
7 | All the experiments are conducted under the settings of:
8 |
9 | - 4 GPUs for training, meaning that ``CUDA_VISIBLE_DEVICES=0,1,2,3`` is set for the training scripts
10 | - Total effective batch size of 256. Consider reducing batch size and learning rate if you only have one GPU. See :ref:`gpu-options` for more details.
11 | - Use the default dataset split ``--split 0``, but combine training and validation sets for training models ``--combine-trainval``
12 | - Use the default random seed ``--seed 1``
13 | - Use Euclidean distance directly for evaluation
14 | - Use single-query and single-crop for evaluation
15 | - Full set of evaluation metrics are reported. See :ref:`evaluation-metrics` for more explanations.
16 |
17 | .. _cuhk03-benchmark:
18 |
19 | ^^^^^^
20 | CUHK03
21 | ^^^^^^
22 |
23 | ========= ============ ======== ============ ========== ============== ===============
24 | Net Loss Mean AP CMC allshots CMC cuhk03 CMC market1501 Training Script
25 | ========= ============ ======== ============ ========== ============== ===============
26 | Inception Triplet N/A N/A N/A N/A N/A
27 | Inception Softmax 65.8 48.6 73.2 71.0 ``python examples/softmax_loss.py -d cuhk03 -a inception --combine-trainval --epochs 70 --logs-dir examples/logs/softmax-loss/cuhk03-inception``
28 | Inception OIM 71.4 56.0 77.7 76.5 ``python examples/oim_loss.py -d cuhk03 -a inception --combine-trainval --oim-scalar 20 --epochs 70 --logs-dir examples/logs/oim-loss/cuhk03-inception``
29 | ResNet-50 Triplet **80.7** **67.9** **84.3** **85.0** ``python examples/triplet_loss.py -d cuhk03 -a resnet50 --combine-trainval --logs-dir examples/logs/triplet-loss/cuhk03-resnet50``
30 | ResNet-50 Softmax 62.7 44.6 70.8 69.0 ``python examples/softmax_loss.py -d cuhk03 -a resnet50 --combine-trainval --logs-dir examples/logs/softmax-loss/cuhk03-resnet50``
31 | ResNet-50 OIM 72.5 58.2 77.5 79.2 ``python examples/oim_loss.py -d cuhk03 -a resnet50 --combine-trainval --oim-scalar 30 --logs-dir examples/logs/oim-loss/cuhk03-resnet50``
32 | ========= ============ ======== ============ ========== ============== ===============
33 |
34 | .. _market1501-benchmark:
35 |
36 | ^^^^^^^^^^
37 | Market1501
38 | ^^^^^^^^^^
39 |
40 | ========= ============ ======== ============ ========== ============== ===============
41 | Net Loss Mean AP CMC allshots CMC cuhk03 CMC market1501 Training Script
42 | ========= ============ ======== ============ ========== ============== ===============
43 | Inception Triplet N/A N/A N/A N/A N/A
44 | Inception Softmax 51.8 26.8 57.1 75.8 ``python examples/softmax_loss.py -d market1501 -a inception --combine-trainval --epochs 70 --logs-dir examples/logs/softmax-loss/market1501-inception``
45 | Inception OIM 54.3 30.1 58.3 77.9 ``python examples/oim_loss.py -d market1501 -a inception --combine-trainval --oim-scalar 20 --epochs 70 --logs-dir examples/logs/oim-loss/market1501-inception``
46 | ResNet-50 Triplet **67.9** **42.9** **70.5** **85.1** ``python examples/triplet_loss.py -d market1501 -a resnet50 --combine-trainval --logs-dir examples/logs/triplet-loss/market1501-resnet50``
47 | ResNet-50 Softmax 59.8 35.5 62.8 81.4 ``python examples/softmax_loss.py -d market1501 -a resnet50 --combine-trainval --logs-dir examples/logs/softmax-loss/market1501-resnet50``
48 | ResNet-50 OIM 60.9 37.3 63.6 82.1 ``python examples/oim_loss.py -d market1501 -a resnet50 --combine-trainval --oim-scalar 20 --logs-dir examples/logs/oim-loss/market1501-resnet50``
49 | ========= ============ ======== ============ ========== ============== ===============
50 |
51 | .. _dukemtmc-benchmark:
52 |
53 | ^^^^^^^^
54 | DukeMTMC
55 | ^^^^^^^^
56 |
57 | ========= ============ ======== ============ ========== ============== ===============
58 | Net Loss Mean AP CMC allshots CMC cuhk03 CMC market1501 Training Script
59 | ========= ============ ======== ============ ========== ============== ===============
60 | Inception Triplet N/A N/A N/A N/A N/A
61 | Inception Softmax 34.0 17.4 39.2 54.4 ``python examples/softmax_loss.py -d dukemtmc -a inception --combine-trainval --epochs 70 --logs-dir examples/logs/softmax-loss/dukemtmc-inception``
62 | Inception OIM 40.6 22.4 45.3 61.7 ``python examples/oim_loss.py -d dukemtmc -a inception --combine-trainval --oim-scalar 30 --epochs 70 --logs-dir examples/logs/oim-loss/dukemtmc-inception``
63 | ResNet-50 Triplet **54.6** **34.6** **57.5** **73.1** ``python examples/triplet_loss.py -d dukemtmc -a resnet50 --combine-trainval --logs-dir examples/logs/triplet-loss/dukemtmc-resnet50``
64 | ResNet-50 Softmax 40.7 23.7 44.3 62.5 ``python examples/softmax_loss.py -d dukemtmc -a resnet50 --combine-trainval --logs-dir examples/logs/softmax-loss/dukemtmc-resnet50``
65 | ResNet-50 OIM 47.4 29.2 50.4 68.1 ``python examples/oim_loss.py -d dukemtmc -a resnet50 --combine-trainval --oim-scalar 30 --logs-dir examples/logs/oim-loss/dukemtmc-resnet50``
66 | ========= ============ ======== ============ ========== ============== ===============
67 |
68 | .. ATTENTION::
69 | No test-time augmentation is used. We have fixed a bug in the learning rate
70 | scheduler for the Triplet loss. Now the result of ResNet-50 with Triplet loss
71 | is slightly better than the "original" setting in [hermans2017in]_ (Table 4).
72 |
--------------------------------------------------------------------------------
/docs/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Open-ReID documentation build configuration file, created by
5 | # sphinx-quickstart on Sun Mar 19 15:49:51 2017.
6 | #
7 | # This file is execfile()d with the current directory set to its
8 | # containing dir.
9 | #
10 | # Note that not all possible configuration values are present in this
11 | # autogenerated file.
12 | #
13 | # All configuration values have a default; values that are commented out
14 | # serve to show the default.
15 |
16 | # If extensions (or modules to document with autodoc) are in another directory,
17 | # add these directories to sys.path here. If the directory is relative to the
18 | # documentation root, use os.path.abspath to make it absolute, like shown here.
19 | #
20 | import os
21 | import sys
22 |
23 | sys.path.insert(0, os.path.abspath('.'))
24 | sys.path.insert(0, os.path.abspath('..'))
25 |
26 | import sphinx_rtd_theme
27 |
28 | # -- General configuration ------------------------------------------------
29 |
30 | # If your documentation needs a minimal Sphinx version, state it here.
31 | #
32 | # needs_sphinx = '1.0'
33 |
34 | # Add any Sphinx extension module names here, as strings. They can be
35 | # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
36 | # ones.
37 | extensions = [
38 | 'sphinx.ext.autodoc',
39 | 'sphinx.ext.autosummary',
40 | 'sphinx.ext.doctest',
41 | 'sphinx.ext.intersphinx',
42 | 'sphinx.ext.todo',
43 | 'sphinx.ext.coverage',
44 | 'sphinx.ext.mathjax',
45 | 'sphinx.ext.napoleon',
46 | 'sphinx.ext.viewcode',
47 | 'sphinx.ext.githubpages',
48 | ]
49 |
50 | # Add any paths that contain templates here, relative to this directory.
51 | templates_path = ['_templates']
52 |
53 | # The suffix(es) of source filenames.
54 | # You can specify multiple suffix as a list of string:
55 | #
56 | # source_suffix = ['.rst', '.md']
57 | source_suffix = '.rst'
58 |
59 | # The master toctree document.
60 | master_doc = 'index'
61 |
62 | # General information about the project.
63 | project = 'Open-ReID'
64 | copyright = '2017, Tong Xiao'
65 | author = 'Tong Xiao'
66 |
67 | # The version info for the project you're documenting, acts as replacement for
68 | # |version| and |release|, also used in various other places throughout the
69 | # built documents.
70 | #
71 | # The short X.Y version.
72 | version = ''
73 | # The full version, including alpha/beta/rc tags.
74 | release = ''
75 |
76 | # The language for content autogenerated by Sphinx. Refer to documentation
77 | # for a list of supported languages.
78 | #
79 | # This is also used if you do content translation via gettext catalogs.
80 | # Usually you set "language" from the command line for these cases.
81 | language = None
82 |
83 | # List of patterns, relative to source directory, that match files and
84 | # directories to ignore when looking for source files.
85 | # This patterns also effect to html_static_path and html_extra_path
86 | exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
87 |
88 | # The name of the Pygments (syntax highlighting) style to use.
89 | pygments_style = 'sphinx'
90 |
91 | # If true, `todo` and `todoList` produce output, else they produce nothing.
92 | todo_include_todos = True
93 |
94 | # -- Options for HTML output ----------------------------------------------
95 |
96 | # The theme to use for HTML and HTML Help pages. See the documentation for
97 | # a list of builtin themes.
98 | # Use a custom css file. See https://blog.deimos.fr/2014/10/02/sphinxdoc-and-readthedocs-theme-tricks-2/
99 |
100 | html_theme = 'sphinx_rtd_theme'
101 | html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
102 |
103 | html_context = {
104 | 'css_files': [
105 | 'https://fonts.googleapis.com/css?family=Lato',
106 | '_static/css/openreid_theme.css',
107 | ]
108 | }
109 |
110 | # Theme options are theme-specific and customize the look and feel of a theme
111 | # further. For a list of options available for each theme, see the
112 | # documentation.
113 | #
114 | html_theme_options = {
115 | 'collapse_navigation': False,
116 | 'display_version': False,
117 | }
118 |
119 | # Add any paths that contain custom static files (such as style sheets) here,
120 | # relative to this directory. They are copied after the builtin static files,
121 | # so a file named "default.css" will overwrite the builtin "default.css".
122 | html_static_path = ['_static']
123 |
124 | # -- Options for HTMLHelp output ------------------------------------------
125 |
126 | # Output file base name for HTML help builder.
127 | htmlhelp_basename = 'Open-ReIDdoc'
128 |
129 | # -- Options for LaTeX output ---------------------------------------------
130 |
131 | latex_elements = {
132 | # The paper size ('letterpaper' or 'a4paper').
133 | #
134 | # 'papersize': 'letterpaper',
135 |
136 | # The font size ('10pt', '11pt' or '12pt').
137 | #
138 | # 'pointsize': '10pt',
139 |
140 | # Additional stuff for the LaTeX preamble.
141 | #
142 | # 'preamble': '',
143 |
144 | # Latex figure (float) alignment
145 | #
146 | # 'figure_align': 'htbp',
147 | }
148 |
149 | # Grouping the document tree into LaTeX files. List of tuples
150 | # (source start file, target name, title,
151 | # author, documentclass [howto, manual, or own class]).
152 | latex_documents = [
153 | (master_doc, 'Open-ReID.tex', 'Open-ReID Documentation',
154 | 'Tong Xiao', 'manual'),
155 | ]
156 |
157 | # -- Options for manual page output ---------------------------------------
158 |
159 | # One entry per manual page. List of tuples
160 | # (source start file, name, description, authors, manual section).
161 | man_pages = [
162 | (master_doc, 'open-reid', 'Open-ReID Documentation',
163 | [author], 1)
164 | ]
165 |
166 | # -- Options for Texinfo output -------------------------------------------
167 |
168 | # Grouping the document tree into Texinfo files. List of tuples
169 | # (source start file, target name, title, author,
170 | # dir menu entry, description, category)
171 | texinfo_documents = [
172 | (master_doc, 'Open-ReID', 'Open-ReID Documentation',
173 | author, 'Open-ReID', 'One line description of project.',
174 | 'Miscellaneous'),
175 | ]
176 |
177 | # Example configuration for intersphinx: refer to the Python standard library.
178 | intersphinx_mapping = {
179 | 'python': ('https://docs.python.org/', None),
180 | 'numpy': ('http://docs.scipy.org/doc/numpy/', None),
181 | 'pytorch': ('http://pytorch.org/docs/master/', None),
182 | }
183 |
--------------------------------------------------------------------------------
/docs/examples/training_id.rst:
--------------------------------------------------------------------------------
1 | ============================
2 | Training Identification Nets
3 | ============================
4 |
5 | This example will present how to train nets with identification loss on popular
6 | datasets.
7 |
8 | The objective of training an identification net is to learn good feature
9 | representation for persons. If the features of the same person are similar,
10 | while the features of different people are dissimilar, then querying a target
11 | person from a gallery database would become easy.
12 |
13 | Different loss functions could be adopted for this purpose, for example,
14 |
15 | - Softmax cross entropy loss [zheng2016person]_ [xiao2016learning]_
16 | - Triplet loss [hermans2017in]_
17 | - Online instance matching (OIM) loss [xiaoli2017joint]_
18 |
19 |
20 | .. _head-first-example:
21 |
22 | ------------------
23 | Head First Example
24 | ------------------
25 |
26 | After cloning the repository, we can start with training an Inception net on
27 | VIPeR with Softmax loss from scratch
28 |
29 | .. code-block:: shell
30 |
31 | python examples/softmax_loss.py -d viper -b 64 -j 2 -a inception --logs-dir logs/softmax-loss/viper-inception
32 |
33 | This script automatically downloads the VIPeR dataset and starts training, with
34 | batch size of 64 and two processes for data loading. Softmax cross entropy is
35 | used as the loss function. The training log should be print to screen as well as
36 | saved to ``logs/softmax-loss/viper-inception/log.txt``. When training ends, it
37 | will evaluate the best model (the one with best validation performance) on the
38 | test set, and report several commonly used metrics.
39 |
40 |
41 | .. _training-options:
42 |
43 | ----------------
44 | Training Options
45 | ----------------
46 |
47 | Many training options are available through command line arguments. See all the
48 | options by ``python examples/softmax_loss.py -h``. Here we elaborate on several
49 | commonly used options.
50 |
51 | .. _data-options:
52 |
53 | ^^^^^^^^
54 | Datasets
55 | ^^^^^^^^
56 |
57 | Specify the dataset by ``-d name``, where ``name`` can be one of ``cuhk03``,
58 | ``cuhk01``, ``market1501``, ``dukemtmc``, and ``viper`` currently. For some
59 | datasets that cannot be downloaded automatically, running the script will raise
60 | an error with a link to the dataset. One may need to manually download it and
61 | put it to the directory instructed by the error message.
62 |
63 | .. _model-options:
64 |
65 | ^^^^^^^^^^^^^^^^^^^
66 | Model Architectures
67 | ^^^^^^^^^^^^^^^^^^^
68 |
69 | Specify the model architecture by ``-a name``, where ``name`` can be one of
70 | ``resnet18``, ``resnet34``, ``resnet50``, ``resnet101``, ``resnet152``, and
71 | ``inception`` currently. For ``resnet*``, running the scripts will download an
72 | ImageNet pretrained model automatically, and then finetune from it. For
73 | ``inception``, the scripts just train the net from scratch.
74 |
75 | .. _gpu-options:
76 |
77 | ^^^^^^^^^^^^^^^^^^^^^^^^
78 | Multi-GPU and Batch Size
79 | ^^^^^^^^^^^^^^^^^^^^^^^^
80 |
81 | All the examples support data parallel training on multiple GPUs. By default,
82 | the program will use all the GPUs listed in ``nvidia-smi``. To control which
83 | GPUs to be used, one need to specify the environment variable
84 | ``CUDA_VISIBLE_DEVICES`` before running the python script. For example,
85 |
86 | .. code-block:: shell
87 |
88 | # 4 GPUs, with effective batch size of 256
89 | CUDA_VISIBLE_DEVICES=0,1,2,3 python examples/softmax_loss.py -d viper -b 256 --lr 0.1 ...
90 |
91 | # 1 GPU, reduce the batch size to 64, lr to 0.025
92 | CUDA_VISIBLE_DEVICES=0 python examples/softmax_loss.py -d viper -b 64 --lr 0.025 ...
93 |
94 | Note that the effective batch size specified by the ``-b`` option will be
95 | divided by the number of GPUs automatically for each GPU. For example, 4 GPUs
96 | with ``-b 256`` will have 64 minibatch samples on each one.
97 |
98 | In the second command above, we reduce the batch size and initial learning rate
99 | to 1/4, in order to adapt the original 4 GPUs setting to only 1 GPU.
100 |
101 | .. _resume-options:
102 |
103 | ^^^^^^^^^^^^^^^^^^^^^^^
104 | Resume from Checkpoints
105 | ^^^^^^^^^^^^^^^^^^^^^^^
106 |
107 | After each training epoch, the script would save a latest ``checkpoint.pth.tar``
108 | in the specified logs directory, and update a ``model_best.pth.tar`` if the
109 | model achieves the best validation performance so far. To resume from this
110 | checkpoint, just run the script with ``--resume /path/to/checkpoint.pth.tar``.
111 |
112 | .. _eval-options:
113 |
114 | ^^^^^^^^^^^^^^^^^^^^^^^^
115 | Evaluate a Trained Model
116 | ^^^^^^^^^^^^^^^^^^^^^^^^
117 |
118 | To evaluate a trained model, just run the script with ``--resume
119 | /path/to/model_best.pth.tar --evaluate``. Different evaluation metrics,
120 | especially different versions of CMC could lead to drastically different
121 | numbers.
122 |
123 |
124 | .. _tips-and-tricks:
125 |
126 | ---------------
127 | Tips and Tricks
128 | ---------------
129 |
130 | Training a baseline network can be tricky. Many options and parameters could
131 | (significantly) affect the reported performance number. Here we list some tips
132 | and tricks for experiments.
133 |
134 | Combine train and val
135 | One can first use separate training and validation set to tune the
136 | hyperparameters, then fix the hyperparameters and combine both sets together
137 | to train a final model. This can be done by appending an option
138 | ``--combine-trainval``, and could lead to much better performance on the
139 | test set.
140 |
141 | Input size
142 | Larger input image size could benefit the performance. It depends on the
143 | network architecture. You may specify it by ``--height`` and ``--width``. By
144 | default, we use ``256x128`` for ``resnet*`` and ``144x56`` for ``inception``.
145 |
146 | Multi-scale multi-crop test
147 | Using multi-scale multi-crop for test normally guarantees performance gain.
148 | However, it sacrifices the running speed significantly. We have not
149 | implemented this yet.
150 |
151 | Classifier initialization for softmax cross entropy loss
152 | We found that initializing the softmax classifier weight with normal
153 | distribution ``std=0.001`` generally leads to better performance. It is also
154 | important to use larger learning rate for the classifier if underlying CNN is
155 | already pretrained.
156 |
157 |
158 | ----------
159 | References
160 | ----------
161 |
162 | .. [zheng2016person] L. Zheng, Y. Yang, and A.G. Hauptmann. Person Re-identification: Past, Present and Future. *arXiv:1610.02984*, 2014.
163 | .. [xiao2016learning] T. Xiao, H. Li, W. Ouyang, and X. Wang. Learning deep feature representations with domain guided dropout for person re-identification. In *CVPR*, 2016.
164 | .. [xiaoli2017joint] T. Xiao\*, S. Li\*, B. Wang, L. Lin, and X. Wang. Joint Detection and Identification Feature Learning for Person Search. In *CVPR*, 2017.
165 | .. [hermans2017in] A. Hermans, L. Beyer, and B. Leibe. In Defense of the Triplet Loss for Person Re-Identification. *arXiv:1703.07737*, 2017.
--------------------------------------------------------------------------------
/examples/softmax_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os.path as osp
4 |
5 | import numpy as np
6 | import sys
7 | import torch
8 | from torch import nn
9 | from torch.backends import cudnn
10 | from torch.utils.data import DataLoader
11 |
12 | from reid import datasets
13 | from reid import models
14 | from reid.dist_metric import DistanceMetric
15 | from reid.trainers import Trainer
16 | from reid.evaluators import Evaluator
17 | from reid.utils.data import transforms as T
18 | from reid.utils.data.preprocessor import Preprocessor
19 | from reid.utils.logging import Logger
20 | from reid.utils.serialization import load_checkpoint, save_checkpoint
21 |
22 |
23 | def get_data(name, split_id, data_dir, height, width, batch_size, workers,
24 | combine_trainval):
25 | root = osp.join(data_dir, name)
26 |
27 | dataset = datasets.create(name, root, split_id=split_id)
28 |
29 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
30 | std=[0.229, 0.224, 0.225])
31 |
32 | train_set = dataset.trainval if combine_trainval else dataset.train
33 | num_classes = (dataset.num_trainval_ids if combine_trainval
34 | else dataset.num_train_ids)
35 |
36 | train_transformer = T.Compose([
37 | T.RandomSizedRectCrop(height, width),
38 | T.RandomHorizontalFlip(),
39 | T.ToTensor(),
40 | normalizer,
41 | ])
42 |
43 | test_transformer = T.Compose([
44 | T.RectScale(height, width),
45 | T.ToTensor(),
46 | normalizer,
47 | ])
48 |
49 | train_loader = DataLoader(
50 | Preprocessor(train_set, root=dataset.images_dir,
51 | transform=train_transformer),
52 | batch_size=batch_size, num_workers=workers,
53 | shuffle=True, pin_memory=True, drop_last=True)
54 |
55 | val_loader = DataLoader(
56 | Preprocessor(dataset.val, root=dataset.images_dir,
57 | transform=test_transformer),
58 | batch_size=batch_size, num_workers=workers,
59 | shuffle=False, pin_memory=True)
60 |
61 | test_loader = DataLoader(
62 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
63 | root=dataset.images_dir, transform=test_transformer),
64 | batch_size=batch_size, num_workers=workers,
65 | shuffle=False, pin_memory=True)
66 |
67 | return dataset, num_classes, train_loader, val_loader, test_loader
68 |
69 |
70 | def main(args):
71 | np.random.seed(args.seed)
72 | torch.manual_seed(args.seed)
73 | cudnn.benchmark = True
74 |
75 | # Redirect print to both console and log file
76 | if not args.evaluate:
77 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
78 |
79 | # Create data loaders
80 | if args.height is None or args.width is None:
81 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
82 | (256, 128)
83 | dataset, num_classes, train_loader, val_loader, test_loader = \
84 | get_data(args.dataset, args.split, args.data_dir, args.height,
85 | args.width, args.batch_size, args.workers,
86 | args.combine_trainval)
87 |
88 | # Create model
89 | model = models.create(args.arch, num_features=args.features,
90 | dropout=args.dropout, num_classes=num_classes)
91 |
92 | # Load from checkpoint
93 | start_epoch = best_top1 = 0
94 | if args.resume:
95 | checkpoint = load_checkpoint(args.resume)
96 | model.load_state_dict(checkpoint['state_dict'])
97 | start_epoch = checkpoint['epoch']
98 | best_top1 = checkpoint['best_top1']
99 | print("=> Start epoch {} best top1 {:.1%}"
100 | .format(start_epoch, best_top1))
101 | model = nn.DataParallel(model).cuda()
102 |
103 | # Distance metric
104 | metric = DistanceMetric(algorithm=args.dist_metric)
105 |
106 | # Evaluator
107 | evaluator = Evaluator(model)
108 | if args.evaluate:
109 | metric.train(model, train_loader)
110 | print("Validation:")
111 | evaluator.evaluate(val_loader, dataset.val, dataset.val, metric)
112 | print("Test:")
113 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
114 | return
115 |
116 | # Criterion
117 | criterion = nn.CrossEntropyLoss().cuda()
118 |
119 | # Optimizer
120 | if hasattr(model.module, 'base'):
121 | base_param_ids = set(map(id, model.module.base.parameters()))
122 | new_params = [p for p in model.parameters() if
123 | id(p) not in base_param_ids]
124 | param_groups = [
125 | {'params': model.module.base.parameters(), 'lr_mult': 0.1},
126 | {'params': new_params, 'lr_mult': 1.0}]
127 | else:
128 | param_groups = model.parameters()
129 | optimizer = torch.optim.SGD(param_groups, lr=args.lr,
130 | momentum=args.momentum,
131 | weight_decay=args.weight_decay,
132 | nesterov=True)
133 |
134 | # Trainer
135 | trainer = Trainer(model, criterion)
136 |
137 | # Schedule learning rate
138 | def adjust_lr(epoch):
139 | step_size = 60 if args.arch == 'inception' else 40
140 | lr = args.lr * (0.1 ** (epoch // step_size))
141 | for g in optimizer.param_groups:
142 | g['lr'] = lr * g.get('lr_mult', 1)
143 |
144 | # Start training
145 | for epoch in range(start_epoch, args.epochs):
146 | adjust_lr(epoch)
147 | trainer.train(epoch, train_loader, optimizer)
148 | if epoch < args.start_save:
149 | continue
150 | top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val)
151 |
152 | is_best = top1 > best_top1
153 | best_top1 = max(top1, best_top1)
154 | save_checkpoint({
155 | 'state_dict': model.module.state_dict(),
156 | 'epoch': epoch + 1,
157 | 'best_top1': best_top1,
158 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))
159 |
160 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'.
161 | format(epoch, top1, best_top1, ' *' if is_best else ''))
162 |
163 | # Final test
164 | print('Test with best model:')
165 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
166 | model.module.load_state_dict(checkpoint['state_dict'])
167 | metric.train(model, train_loader)
168 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
169 |
170 |
171 | if __name__ == '__main__':
172 | parser = argparse.ArgumentParser(description="Softmax loss classification")
173 | # data
174 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03',
175 | choices=datasets.names())
176 | parser.add_argument('-b', '--batch-size', type=int, default=256)
177 | parser.add_argument('-j', '--workers', type=int, default=4)
178 | parser.add_argument('--split', type=int, default=0)
179 | parser.add_argument('--height', type=int,
180 | help="input height, default: 256 for resnet*, "
181 | "144 for inception")
182 | parser.add_argument('--width', type=int,
183 | help="input width, default: 128 for resnet*, "
184 | "56 for inception")
185 | parser.add_argument('--combine-trainval', action='store_true',
186 | help="train and val sets together for training, "
187 | "val set alone for validation")
188 | # model
189 | parser.add_argument('-a', '--arch', type=str, default='resnet50',
190 | choices=models.names())
191 | parser.add_argument('--features', type=int, default=128)
192 | parser.add_argument('--dropout', type=float, default=0.5)
193 | # optimizer
194 | parser.add_argument('--lr', type=float, default=0.1,
195 | help="learning rate of new parameters, for pretrained "
196 | "parameters it is 10 times smaller than this")
197 | parser.add_argument('--momentum', type=float, default=0.9)
198 | parser.add_argument('--weight-decay', type=float, default=5e-4)
199 | # training configs
200 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
201 | parser.add_argument('--evaluate', action='store_true',
202 | help="evaluation only")
203 | parser.add_argument('--epochs', type=int, default=50)
204 | parser.add_argument('--start_save', type=int, default=0,
205 | help="start saving checkpoints after specific epoch")
206 | parser.add_argument('--seed', type=int, default=1)
207 | parser.add_argument('--print-freq', type=int, default=1)
208 | # metric learning
209 | parser.add_argument('--dist-metric', type=str, default='euclidean',
210 | choices=['euclidean', 'kissme'])
211 | # misc
212 | working_dir = osp.dirname(osp.abspath(__file__))
213 | parser.add_argument('--data-dir', type=str, metavar='PATH',
214 | default=osp.join(working_dir, 'data'))
215 | parser.add_argument('--logs-dir', type=str, metavar='PATH',
216 | default=osp.join(working_dir, 'logs'))
217 | main(parser.parse_args())
218 |
--------------------------------------------------------------------------------
/examples/triplet_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os.path as osp
4 |
5 | import numpy as np
6 | import sys
7 | import torch
8 | from torch import nn
9 | from torch.backends import cudnn
10 | from torch.utils.data import DataLoader
11 |
12 | from reid import datasets
13 | from reid import models
14 | from reid.dist_metric import DistanceMetric
15 | from reid.loss import TripletLoss
16 | from reid.trainers import Trainer
17 | from reid.evaluators import Evaluator
18 | from reid.utils.data import transforms as T
19 | from reid.utils.data.preprocessor import Preprocessor
20 | from reid.utils.data.sampler import RandomIdentitySampler
21 | from reid.utils.logging import Logger
22 | from reid.utils.serialization import load_checkpoint, save_checkpoint
23 |
24 |
25 | def get_data(name, split_id, data_dir, height, width, batch_size, num_instances,
26 | workers, combine_trainval):
27 | root = osp.join(data_dir, name)
28 |
29 | dataset = datasets.create(name, root, split_id=split_id)
30 |
31 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
32 | std=[0.229, 0.224, 0.225])
33 |
34 | train_set = dataset.trainval if combine_trainval else dataset.train
35 | num_classes = (dataset.num_trainval_ids if combine_trainval
36 | else dataset.num_train_ids)
37 |
38 | train_transformer = T.Compose([
39 | T.RandomSizedRectCrop(height, width),
40 | T.RandomHorizontalFlip(),
41 | T.ToTensor(),
42 | normalizer,
43 | ])
44 |
45 | test_transformer = T.Compose([
46 | T.RectScale(height, width),
47 | T.ToTensor(),
48 | normalizer,
49 | ])
50 |
51 | train_loader = DataLoader(
52 | Preprocessor(train_set, root=dataset.images_dir,
53 | transform=train_transformer),
54 | batch_size=batch_size, num_workers=workers,
55 | sampler=RandomIdentitySampler(train_set, num_instances),
56 | pin_memory=True, drop_last=True)
57 |
58 | val_loader = DataLoader(
59 | Preprocessor(dataset.val, root=dataset.images_dir,
60 | transform=test_transformer),
61 | batch_size=batch_size, num_workers=workers,
62 | shuffle=False, pin_memory=True)
63 |
64 | test_loader = DataLoader(
65 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
66 | root=dataset.images_dir, transform=test_transformer),
67 | batch_size=batch_size, num_workers=workers,
68 | shuffle=False, pin_memory=True)
69 |
70 | return dataset, num_classes, train_loader, val_loader, test_loader
71 |
72 |
73 | def main(args):
74 | np.random.seed(args.seed)
75 | torch.manual_seed(args.seed)
76 | cudnn.benchmark = True
77 |
78 | # Redirect print to both console and log file
79 | if not args.evaluate:
80 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
81 |
82 | # Create data loaders
83 | assert args.num_instances > 1, "num_instances should be greater than 1"
84 | assert args.batch_size % args.num_instances == 0, \
85 | 'num_instances should divide batch_size'
86 | if args.height is None or args.width is None:
87 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
88 | (256, 128)
89 | dataset, num_classes, train_loader, val_loader, test_loader = \
90 | get_data(args.dataset, args.split, args.data_dir, args.height,
91 | args.width, args.batch_size, args.num_instances, args.workers,
92 | args.combine_trainval)
93 |
94 | # Create model
95 | # Hacking here to let the classifier be the last feature embedding layer
96 | # Net structure: avgpool -> FC(1024) -> FC(args.features)
97 | model = models.create(args.arch, num_features=1024,
98 | dropout=args.dropout, num_classes=args.features)
99 |
100 | # Load from checkpoint
101 | start_epoch = best_top1 = 0
102 | if args.resume:
103 | checkpoint = load_checkpoint(args.resume)
104 | model.load_state_dict(checkpoint['state_dict'])
105 | start_epoch = checkpoint['epoch']
106 | best_top1 = checkpoint['best_top1']
107 | print("=> Start epoch {} best top1 {:.1%}"
108 | .format(start_epoch, best_top1))
109 | model = nn.DataParallel(model).cuda()
110 |
111 | # Distance metric
112 | metric = DistanceMetric(algorithm=args.dist_metric)
113 |
114 | # Evaluator
115 | evaluator = Evaluator(model)
116 | if args.evaluate:
117 | metric.train(model, train_loader)
118 | print("Validation:")
119 | evaluator.evaluate(val_loader, dataset.val, dataset.val, metric)
120 | print("Test:")
121 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
122 | return
123 |
124 | # Criterion
125 | criterion = TripletLoss(margin=args.margin).cuda()
126 |
127 | # Optimizer
128 | optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
129 | weight_decay=args.weight_decay)
130 |
131 | # Trainer
132 | trainer = Trainer(model, criterion)
133 |
134 | # Schedule learning rate
135 | def adjust_lr(epoch):
136 | lr = args.lr if epoch <= 100 else \
137 | args.lr * (0.001 ** ((epoch - 100) / 50.0))
138 | for g in optimizer.param_groups:
139 | g['lr'] = lr * g.get('lr_mult', 1)
140 |
141 | # Start training
142 | for epoch in range(start_epoch, args.epochs):
143 | adjust_lr(epoch)
144 | trainer.train(epoch, train_loader, optimizer)
145 | if epoch < args.start_save:
146 | continue
147 | top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val)
148 |
149 | is_best = top1 > best_top1
150 | best_top1 = max(top1, best_top1)
151 | save_checkpoint({
152 | 'state_dict': model.module.state_dict(),
153 | 'epoch': epoch + 1,
154 | 'best_top1': best_top1,
155 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))
156 |
157 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'.
158 | format(epoch, top1, best_top1, ' *' if is_best else ''))
159 |
160 | # Final test
161 | print('Test with best model:')
162 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
163 | model.module.load_state_dict(checkpoint['state_dict'])
164 | metric.train(model, train_loader)
165 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
166 |
167 |
168 | if __name__ == '__main__':
169 | parser = argparse.ArgumentParser(description="Triplet loss classification")
170 | # data
171 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03',
172 | choices=datasets.names())
173 | parser.add_argument('-b', '--batch-size', type=int, default=256)
174 | parser.add_argument('-j', '--workers', type=int, default=4)
175 | parser.add_argument('--split', type=int, default=0)
176 | parser.add_argument('--height', type=int,
177 | help="input height, default: 256 for resnet*, "
178 | "144 for inception")
179 | parser.add_argument('--width', type=int,
180 | help="input width, default: 128 for resnet*, "
181 | "56 for inception")
182 | parser.add_argument('--combine-trainval', action='store_true',
183 | help="train and val sets together for training, "
184 | "val set alone for validation")
185 | parser.add_argument('--num-instances', type=int, default=4,
186 | help="each minibatch consist of "
187 | "(batch_size // num_instances) identities, and "
188 | "each identity has num_instances instances, "
189 | "default: 4")
190 | # model
191 | parser.add_argument('-a', '--arch', type=str, default='resnet50',
192 | choices=models.names())
193 | parser.add_argument('--features', type=int, default=128)
194 | parser.add_argument('--dropout', type=float, default=0)
195 | # loss
196 | parser.add_argument('--margin', type=float, default=0.5,
197 | help="margin of the triplet loss, default: 0.5")
198 | # optimizer
199 | parser.add_argument('--lr', type=float, default=0.0002,
200 | help="learning rate of all parameters")
201 | parser.add_argument('--weight-decay', type=float, default=5e-4)
202 | # training configs
203 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
204 | parser.add_argument('--evaluate', action='store_true',
205 | help="evaluation only")
206 | parser.add_argument('--epochs', type=int, default=150)
207 | parser.add_argument('--start_save', type=int, default=0,
208 | help="start saving checkpoints after specific epoch")
209 | parser.add_argument('--seed', type=int, default=1)
210 | parser.add_argument('--print-freq', type=int, default=1)
211 | # metric learning
212 | parser.add_argument('--dist-metric', type=str, default='euclidean',
213 | choices=['euclidean', 'kissme'])
214 | # misc
215 | working_dir = osp.dirname(osp.abspath(__file__))
216 | parser.add_argument('--data-dir', type=str, metavar='PATH',
217 | default=osp.join(working_dir, 'data'))
218 | parser.add_argument('--logs-dir', type=str, metavar='PATH',
219 | default=osp.join(working_dir, 'logs'))
220 | main(parser.parse_args())
221 |
--------------------------------------------------------------------------------
/examples/oim_loss.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, absolute_import
2 | import argparse
3 | import os.path as osp
4 |
5 | import numpy as np
6 | import sys
7 | import torch
8 | from torch import nn
9 | from torch.backends import cudnn
10 | from torch.utils.data import DataLoader
11 |
12 | from reid import datasets
13 | from reid import models
14 | from reid.dist_metric import DistanceMetric
15 | from reid.loss import OIMLoss
16 | from reid.trainers import Trainer
17 | from reid.evaluators import Evaluator
18 | from reid.utils.data import transforms as T
19 | from reid.utils.data.preprocessor import Preprocessor
20 | from reid.utils.logging import Logger
21 | from reid.utils.serialization import load_checkpoint, save_checkpoint
22 |
23 |
24 | def get_data(name, split_id, data_dir, height, width, batch_size, workers,
25 | combine_trainval):
26 | root = osp.join(data_dir, name)
27 |
28 | dataset = datasets.create(name, root, split_id=split_id)
29 |
30 | normalizer = T.Normalize(mean=[0.485, 0.456, 0.406],
31 | std=[0.229, 0.224, 0.225])
32 |
33 | train_set = dataset.trainval if combine_trainval else dataset.train
34 | num_classes = (dataset.num_trainval_ids if combine_trainval
35 | else dataset.num_train_ids)
36 |
37 | train_transformer = T.Compose([
38 | T.RandomSizedRectCrop(height, width),
39 | T.RandomHorizontalFlip(),
40 | T.ToTensor(),
41 | normalizer,
42 | ])
43 |
44 | test_transformer = T.Compose([
45 | T.RectScale(height, width),
46 | T.ToTensor(),
47 | normalizer,
48 | ])
49 |
50 | train_loader = DataLoader(
51 | Preprocessor(train_set, root=dataset.images_dir,
52 | transform=train_transformer),
53 | batch_size=batch_size, num_workers=workers,
54 | shuffle=True, pin_memory=True, drop_last=True)
55 |
56 | val_loader = DataLoader(
57 | Preprocessor(dataset.val, root=dataset.images_dir,
58 | transform=test_transformer),
59 | batch_size=batch_size, num_workers=workers,
60 | shuffle=False, pin_memory=True)
61 |
62 | test_loader = DataLoader(
63 | Preprocessor(list(set(dataset.query) | set(dataset.gallery)),
64 | root=dataset.images_dir, transform=test_transformer),
65 | batch_size=batch_size, num_workers=workers,
66 | shuffle=False, pin_memory=True)
67 |
68 | return dataset, num_classes, train_loader, val_loader, test_loader
69 |
70 |
71 | def main(args):
72 | np.random.seed(args.seed)
73 | torch.manual_seed(args.seed)
74 | cudnn.benchmark = True
75 |
76 | # Redirect print to both console and log file
77 | if not args.evaluate:
78 | sys.stdout = Logger(osp.join(args.logs_dir, 'log.txt'))
79 |
80 | # Create data loaders
81 | if args.height is None or args.width is None:
82 | args.height, args.width = (144, 56) if args.arch == 'inception' else \
83 | (256, 128)
84 | dataset, num_classes, train_loader, val_loader, test_loader = \
85 | get_data(args.dataset, args.split, args.data_dir, args.height,
86 | args.width, args.batch_size, args.workers,
87 | args.combine_trainval)
88 |
89 | # Create model
90 | model = models.create(args.arch, num_features=args.features, norm=True,
91 | dropout=args.dropout)
92 |
93 | # Load from checkpoint
94 | start_epoch = best_top1 = 0
95 | if args.resume:
96 | checkpoint = load_checkpoint(args.resume)
97 | model.load_state_dict(checkpoint['state_dict'])
98 | start_epoch = checkpoint['epoch']
99 | best_top1 = checkpoint['best_top1']
100 | print("=> Start epoch {} best top1 {:.1%}"
101 | .format(start_epoch, best_top1))
102 | model = nn.DataParallel(model).cuda()
103 |
104 | # Distance metric
105 | metric = DistanceMetric(algorithm=args.dist_metric)
106 |
107 | # Evaluator
108 | evaluator = Evaluator(model)
109 | if args.evaluate:
110 | metric.train(model, train_loader)
111 | print("Validation:")
112 | evaluator.evaluate(val_loader, dataset.val, dataset.val, metric)
113 | print("Test:")
114 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
115 | return
116 |
117 | # Criterion
118 | criterion = OIMLoss(model.module.num_features, num_classes,
119 | scalar=args.oim_scalar,
120 | momentum=args.oim_momentum).cuda()
121 |
122 | # Optimizer
123 | if hasattr(model.module, 'base'):
124 | base_param_ids = set(map(id, model.module.base.parameters()))
125 | new_params = [p for p in model.parameters() if
126 | id(p) not in base_param_ids]
127 | param_groups = [
128 | {'params': model.module.base.parameters(), 'lr_mult': 0.1},
129 | {'params': new_params, 'lr_mult': 1.0}]
130 | else:
131 | param_groups = model.parameters()
132 | optimizer = torch.optim.SGD(param_groups, lr=args.lr,
133 | momentum=args.momentum,
134 | weight_decay=args.weight_decay,
135 | nesterov=True)
136 |
137 | # Trainer
138 | trainer = Trainer(model, criterion)
139 |
140 | # Schedule learning rate
141 | def adjust_lr(epoch):
142 | step_size = 60 if args.arch == 'inception' else 40
143 | lr = args.lr * (0.1 ** (epoch // step_size))
144 | for g in optimizer.param_groups:
145 | g['lr'] = lr * g.get('lr_mult', 1)
146 |
147 | # Start training
148 | for epoch in range(start_epoch, args.epochs):
149 | adjust_lr(epoch)
150 | trainer.train(epoch, train_loader, optimizer)
151 | if epoch < args.start_save:
152 | continue
153 | top1 = evaluator.evaluate(val_loader, dataset.val, dataset.val)
154 |
155 | is_best = top1 > best_top1
156 | best_top1 = max(top1, best_top1)
157 | save_checkpoint({
158 | 'state_dict': model.module.state_dict(),
159 | 'epoch': epoch + 1,
160 | 'best_top1': best_top1,
161 | }, is_best, fpath=osp.join(args.logs_dir, 'checkpoint.pth.tar'))
162 |
163 | print('\n * Finished epoch {:3d} top1: {:5.1%} best: {:5.1%}{}\n'.
164 | format(epoch, top1, best_top1, ' *' if is_best else ''))
165 |
166 | # Final test
167 | print('Test with best model:')
168 | checkpoint = load_checkpoint(osp.join(args.logs_dir, 'model_best.pth.tar'))
169 | model.module.load_state_dict(checkpoint['state_dict'])
170 | metric.train(model, train_loader)
171 | evaluator.evaluate(test_loader, dataset.query, dataset.gallery, metric)
172 |
173 |
174 | if __name__ == '__main__':
175 | parser = argparse.ArgumentParser(description="OIM loss")
176 | # data
177 | parser.add_argument('-d', '--dataset', type=str, default='cuhk03',
178 | choices=datasets.names())
179 | parser.add_argument('-b', '--batch-size', type=int, default=256)
180 | parser.add_argument('-j', '--workers', type=int, default=4)
181 | parser.add_argument('--split', type=int, default=0)
182 | parser.add_argument('--height', type=int,
183 | help="input height, default: 256 for resnet*, "
184 | "144 for inception")
185 | parser.add_argument('--width', type=int,
186 | help="input width, default: 128 for resnet*, "
187 | "56 for inception")
188 | parser.add_argument('--combine-trainval', action='store_true',
189 | help="train and val sets together for training, "
190 | "val set alone for validation")
191 | # model
192 | parser.add_argument('-a', '--arch', type=str, default='resnet50',
193 | choices=models.names())
194 | parser.add_argument('--features', type=int, default=128)
195 | parser.add_argument('--dropout', type=float, default=0.5)
196 | # loss
197 | parser.add_argument('--oim-scalar', type=float, default=30,
198 | help='reciprocal of the temperature in OIM loss')
199 | parser.add_argument('--oim-momentum', type=float, default=0.5,
200 | help='momentum for updating the LUT in OIM loss')
201 | # optimizer
202 | parser.add_argument('--lr', type=float, default=0.1,
203 | help="learning rate of new parameters, for pretrained "
204 | "parameters it is 10 times smaller than this")
205 | parser.add_argument('--momentum', type=float, default=0.9)
206 | parser.add_argument('--weight-decay', type=float, default=5e-4)
207 | # training configs
208 | parser.add_argument('--resume', type=str, default='', metavar='PATH')
209 | parser.add_argument('--evaluate', action='store_true',
210 | help="evaluation only")
211 | parser.add_argument('--epochs', type=int, default=50)
212 | parser.add_argument('--start_save', type=int, default=0,
213 | help="start saving checkpoints after specific epoch")
214 | parser.add_argument('--seed', type=int, default=1)
215 | parser.add_argument('--print-freq', type=int, default=1)
216 | # metric learning
217 | parser.add_argument('--dist-metric', type=str, default='euclidean',
218 | choices=['euclidean', 'kissme'])
219 | # misc
220 | working_dir = osp.dirname(osp.abspath(__file__))
221 | parser.add_argument('--data-dir', type=str, metavar='PATH',
222 | default=osp.join(working_dir, 'data'))
223 | parser.add_argument('--logs-dir', type=str, metavar='PATH',
224 | default=osp.join(working_dir, 'logs'))
225 | main(parser.parse_args())
226 |
--------------------------------------------------------------------------------