├── doc
├── test.png
├── graph.png
├── default.png
├── tl_model.png
├── base_model.png
├── tl_formular.png
├── loss_timeline.png
└── tl_algorithm.png
├── bases
├── __init__.py
├── trainer_base.py
├── data_loader_base.py
├── infer_base.py
└── model_base.py
├── utils
├── __init__.py
├── np_utils.py
├── utils.py
└── config_utils.py
├── models
├── __init__.py
└── triplet_model.py
├── trainers
├── __init__.py
└── triplet_trainer.py
├── data_loaders
├── __init__.py
└── triplet_dl.py
├── requirements.txt
├── configs
└── triplet_config.json
├── infers
├── __init__.py
└── triplet_infer.py
├── root_dir.py
├── main_test.py
├── main_train.py
├── .gitignore
└── README.md
/doc/test.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/test.png
--------------------------------------------------------------------------------
/doc/graph.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/graph.png
--------------------------------------------------------------------------------
/doc/default.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/default.png
--------------------------------------------------------------------------------
/doc/tl_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/tl_model.png
--------------------------------------------------------------------------------
/doc/base_model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/base_model.png
--------------------------------------------------------------------------------
/doc/tl_formular.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/tl_formular.png
--------------------------------------------------------------------------------
/doc/loss_timeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/loss_timeline.png
--------------------------------------------------------------------------------
/doc/tl_algorithm.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/tl_algorithm.png
--------------------------------------------------------------------------------
/bases/__init__.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
--------------------------------------------------------------------------------
/trainers/__init__.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
--------------------------------------------------------------------------------
/data_loaders/__init__.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.4.0
2 | keras==2.1.5
3 | bunch==1.0.1
4 | h5py==2.7.1
5 | pydot==1.2.4
6 | numpy==1.13.0
7 | scikit-learn==0.19.1
8 |
--------------------------------------------------------------------------------
/configs/triplet_config.json:
--------------------------------------------------------------------------------
1 | {
2 | "exp_name": "triplet_mnist",
3 | "lr": 0.001,
4 | "decay": 0.0,
5 | "num_epochs": 10,
6 | "batch_size": 64
7 | }
--------------------------------------------------------------------------------
/infers/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/18
6 | """
--------------------------------------------------------------------------------
/root_dir.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/23
6 | """
7 |
8 | import os
9 |
10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # 存储项目所在的绝对路径
11 |
--------------------------------------------------------------------------------
/bases/trainer_base.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 |
7 |
8 | class TrainerBase(object):
9 | """
10 | 训练器基类
11 | """
12 |
13 | def __init__(self, model, data, config):
14 | self.model = model # 模型
15 | self.data = data # 数据
16 | self.config = config # 配置
17 |
18 | def train(self):
19 | """
20 | 训练逻辑
21 | """
22 | raise NotImplementedError
23 |
--------------------------------------------------------------------------------
/bases/data_loader_base.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 |
7 |
8 | class DataLoaderBase(object):
9 | """
10 | 数据加载的基类
11 | """
12 |
13 | def __init__(self, config):
14 | self.config = config # 设置配置信息
15 |
16 | def get_train_data(self):
17 | """
18 | 获取训练数据
19 | """
20 | raise NotImplementedError
21 |
22 | def get_test_data(self):
23 | """
24 | 获取测试数据
25 | """
26 | raise NotImplementedError
27 |
--------------------------------------------------------------------------------
/bases/infer_base.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/18
6 | """
7 |
8 |
9 | class InferBase(object):
10 | """
11 | 推断基类
12 | """
13 |
14 | def __init__(self, config):
15 | self.config = config # 配置
16 |
17 | def load_model(self, name):
18 | """
19 | 加载模型
20 | """
21 | raise NotImplementedError
22 |
23 | def predict(self, data):
24 | """
25 | 预测结果
26 | """
27 | raise NotImplementedError
28 |
--------------------------------------------------------------------------------
/utils/np_utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/18
6 |
7 | 参考: https://juejin.im/post/5acfef976fb9a028db5918b5
8 | """
9 |
10 | import numpy as np
11 |
12 |
13 | def prp_2_oh_array(arr):
14 | """
15 | 概率矩阵转换为OH矩阵
16 | arr = np.array([[0.1, 0.5, 0.4], [0.2, 0.1, 0.6]])
17 | :param arr: 概率矩阵
18 | :return: OH矩阵
19 | """
20 | arr_size = arr.shape[1] # 类别数
21 | arr_max = np.argmax(arr, axis=1) # 最大值位置
22 | oh_arr = np.eye(arr_size)[arr_max] # OH矩阵
23 | return oh_arr
24 |
--------------------------------------------------------------------------------
/utils/utils.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/18
6 | """
7 | import os
8 | import shutil
9 |
10 |
11 | def mkdir_if_not_exist(dir, is_delete=False):
12 | """
13 | 创建文件夹
14 | :param dirs: 文件夹列表
15 | :param is_delete: 是否删除
16 | :return: 是否成功
17 | """
18 | try:
19 | if is_delete:
20 | if os.path.exists(dir):
21 | shutil.rmtree(dir)
22 | print u'[INFO] 文件夹 "%s" 存在, 删除文件夹.' % dir
23 |
24 | if not os.path.exists(dir):
25 | os.makedirs(dir)
26 | print u'[INFO] 文件夹 "%s" 不存在, 创建文件夹.' % dir
27 | return True
28 | except Exception as e:
29 | print '[Exception] %s' % e
30 | return False
31 |
--------------------------------------------------------------------------------
/main_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/18
6 | """
7 | from time import time
8 |
9 | import datetime
10 | import numpy as np
11 |
12 | from data_loaders.triplet_dl import TripletDL
13 | from infers.triplet_infer import TripletInfer
14 | from utils.config_utils import process_config, get_test_args
15 |
16 |
17 | def main_test():
18 | print '[INFO] 解析配置...'
19 | parser = None
20 | config = None
21 |
22 | try:
23 | args, parser = get_test_args()
24 | config = process_config(args.config)
25 | except Exception as e:
26 | print '[Exception] 配置无效, %s' % e
27 | if parser:
28 | parser.print_help()
29 | print '[Exception] 参考: python main_test.py -c configs/triplet_config.json'
30 | exit(0)
31 |
32 | print '[INFO] 加载数据...'
33 |
34 | print '[INFO] 预测数据...'
35 | infer = TripletInfer(config=config)
36 | infer.default_dist()
37 | infer.test_dist()
38 |
39 | print '[INFO] 预测完成...'
40 |
41 |
42 | if __name__ == '__main__':
43 | main_test()
44 |
--------------------------------------------------------------------------------
/bases/model_base.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 |
7 |
8 | class ModelBase(object):
9 | """
10 | 模型基类
11 | """
12 |
13 | def __init__(self, config):
14 | self.config = config # 配置
15 | self.model = None # 模型
16 |
17 | def save(self, checkpoint_path):
18 | """
19 | 存储checkpoint, 路径定义于配置文件中
20 | """
21 | if self.model is None:
22 | raise Exception("[Exception] You have to build the model first.")
23 |
24 | print("[INFO] Saving model...")
25 | self.model.save_weights(checkpoint_path)
26 | print("[INFO] Model saved")
27 |
28 | def load(self, checkpoint_path):
29 | """
30 | 加载checkpoint, 路径定义于配置文件中
31 | """
32 | if self.model is None:
33 | raise Exception("[Exception] You have to build the model first.")
34 |
35 | print("[INFO] Loading model checkpoint {} ...\n".format(checkpoint_path))
36 | self.model.load_weights(checkpoint_path)
37 | print("[INFO] Model loaded")
38 |
39 | def build_model(self):
40 | """
41 | 构建模型
42 | """
43 | raise NotImplementedError
44 |
--------------------------------------------------------------------------------
/data_loaders/triplet_dl.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 | from keras.datasets import mnist
7 | from keras.utils import to_categorical
8 |
9 | from bases.data_loader_base import DataLoaderBase
10 |
11 |
12 | class TripletDL(DataLoaderBase):
13 | def __init__(self, config=None):
14 | super(TripletDL, self).__init__(config)
15 | (self.X_train, self.y_train), (self.X_test, self.y_test) = mnist.load_data()
16 |
17 | # 展平数据,for RNN,感知机
18 | # self.X_train = self.X_train.reshape((-1, 28 * 28))
19 | # self.X_test = self.X_test.reshape((-1, 28 * 28))
20 |
21 | # 图片数据,for CNN
22 | self.X_train = self.X_train.reshape(self.X_train.shape[0], 28, 28, 1)
23 | self.X_test = self.X_test.reshape(self.X_test.shape[0], 28, 28, 1)
24 |
25 | self.y_train = to_categorical(self.y_train)
26 | self.y_test = to_categorical(self.y_test)
27 |
28 | print "[INFO] X_train.shape: %s, y_train.shape: %s" \
29 | % (str(self.X_train.shape), str(self.y_train.shape))
30 | print "[INFO] X_test.shape: %s, y_test.shape: %s" \
31 | % (str(self.X_test.shape), str(self.y_test.shape))
32 |
33 | def get_train_data(self):
34 | return self.X_train, self.y_train
35 |
36 | def get_test_data(self):
37 | return self.X_test, self.y_test
38 |
--------------------------------------------------------------------------------
/main_train.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -- coding: utf-8 --
3 | """
4 | Copyright (c) 2018. All rights reserved.
5 | Created by C. L. Wang on 2018/4/18
6 |
7 | 参考:
8 | NumPy FutureWarning
9 | https://stackoverflow.com/questions/48340392/futurewarning-conversion-of-the-second-argument-of-issubdtype-from-float-to
10 | """
11 |
12 | from data_loaders.triplet_dl import TripletDL
13 | from infers.triplet_infer import TripletInfer
14 | from models.triplet_model import TripletModel
15 | from trainers.triplet_trainer import TripletTrainer
16 | from utils.config_utils import process_config, get_train_args
17 | import numpy as np
18 |
19 |
20 | def main_train():
21 | """
22 | 训练模型
23 |
24 | :return:
25 | """
26 | print '[INFO] 解析配置...'
27 |
28 | parser = None
29 | config = None
30 |
31 | try:
32 | args, parser = get_train_args()
33 | config = process_config(args.config)
34 | except Exception as e:
35 | print '[Exception] 配置无效, %s' % e
36 | if parser:
37 | parser.print_help()
38 | print '[Exception] 参考: python main_train.py -c configs/triplet_config.json'
39 | exit(0)
40 | # config = process_config('configs/triplet_config.json')
41 |
42 | print '[INFO] 加载数据...'
43 | dl = TripletDL(config=config)
44 |
45 | print '[INFO] 构造网络...'
46 | model = TripletModel(config=config)
47 |
48 | print '[INFO] 训练网络...'
49 | trainer = TripletTrainer(
50 | model=model.model,
51 | data=[dl.get_train_data(), dl.get_test_data()],
52 | config=config)
53 | trainer.train()
54 | print '[INFO] 训练完成...'
55 |
56 |
57 | if __name__ == '__main__':
58 | main_train()
59 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Created by .ignore support plugin (hsz.mobi)
2 | ### Python template
3 | # Byte-compiled / optimized / DLL files
4 | __pycache__/
5 | *.py[cod]
6 | *$py.class
7 |
8 | # C extensions
9 | *.so
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .coverage
44 | .coverage.*
45 | .cache
46 | nosetests.xml
47 | coverage.xml
48 | *.cover
49 | .hypothesis/
50 |
51 | # Translations
52 | *.mo
53 | *.pot
54 |
55 | # Django stuff:
56 | *.log
57 | .static_storage/
58 | .media/
59 | local_settings.py
60 |
61 | # Flask stuff:
62 | instance/
63 | .webassets-cache
64 |
65 | # Scrapy stuff:
66 | .scrapy
67 |
68 | # Sphinx documentation
69 | docs/_build/
70 |
71 | # PyBuilder
72 | target/
73 |
74 | # Jupyter Notebook
75 | .ipynb_checkpoints
76 |
77 | # pyenv
78 | .python-version
79 |
80 | # celery beat schedule file
81 | celerybeat-schedule
82 |
83 | # SageMath parsed files
84 | *.sage.py
85 |
86 | # Environments
87 | .env
88 | .venv
89 | env/
90 | venv/
91 | ENV/
92 | env.bak/
93 | venv.bak/
94 |
95 | # Spyder project settings
96 | .spyderproject
97 | .spyproject
98 |
99 | # Rope project settings
100 | .ropeproject
101 |
102 | # mkdocs documentation
103 | /site
104 |
105 | # mypy
106 | .mypy_cache/
107 |
108 | .idea/
109 |
110 | # data
111 | experiments/
112 |
113 |
--------------------------------------------------------------------------------
/utils/config_utils.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 | import argparse
7 | import json
8 |
9 | import os
10 | from bunch import Bunch
11 |
12 | from root_dir import ROOT_DIR
13 | from utils import mkdir_if_not_exist
14 |
15 |
16 | def get_config_from_json(json_file):
17 | """
18 | 将配置文件转换为配置类
19 | """
20 | with open(json_file, 'r') as config_file:
21 | config_dict = json.load(config_file) # 配置字典
22 |
23 | config = Bunch(config_dict) # 将配置字典转换为类
24 |
25 | return config, config_dict
26 |
27 |
28 | def process_config(json_file):
29 | """
30 | 解析Json文件
31 | :param json_file: 配置文件
32 | :return: 配置类
33 | """
34 | config, _ = get_config_from_json(json_file)
35 |
36 | exp_dir = os.path.join(ROOT_DIR, "experiments")
37 | mkdir_if_not_exist(exp_dir) # 创建文件夹
38 |
39 | config.tb_dir = os.path.join(exp_dir, config.exp_name, "logs/") # 日志
40 | config.cp_dir = os.path.join(exp_dir, config.exp_name, "checkpoints/") # 模型
41 | config.img_dir = os.path.join(exp_dir, config.exp_name, "images/") # 网络
42 |
43 | mkdir_if_not_exist(config.tb_dir) # 创建文件夹
44 | mkdir_if_not_exist(config.cp_dir) # 创建文件夹
45 | mkdir_if_not_exist(config.img_dir) # 创建文件夹
46 | return config
47 |
48 |
49 | def get_train_args():
50 | """
51 | 添加训练参数
52 | """
53 | parser = argparse.ArgumentParser(description=__doc__)
54 | parser.add_argument(
55 | '-c', '--cfg',
56 | dest='config',
57 | metavar='path',
58 | default='None',
59 | help='add a configuration file')
60 | args = parser.parse_args()
61 | return args, parser
62 |
63 |
64 | def get_test_args():
65 | """
66 | 添加测试路径
67 | :return: 参数
68 | """
69 | parser = argparse.ArgumentParser(description=__doc__)
70 | parser.add_argument(
71 | '-c', '--cfg',
72 | dest='config',
73 | metavar='C',
74 | default='None',
75 | help='add a configuration file')
76 | args = parser.parse_args()
77 | return args, parser
78 |
--------------------------------------------------------------------------------
/models/triplet_model.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 | import os
7 | from keras import Input, Model
8 | from keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten, K, merge, Concatenate
9 | from keras.optimizers import Adam
10 | from keras.utils import plot_model
11 |
12 | from bases.model_base import ModelBase
13 |
14 |
15 | class TripletModel(ModelBase):
16 | """
17 | TripletLoss模型
18 | """
19 |
20 | MARGIN = 1.0 # 超参
21 |
22 | def __init__(self, config):
23 | super(TripletModel, self).__init__(config)
24 | self.build_model()
25 |
26 | def build_model(self):
27 | self.model = self.triplet_loss_model() # 使用Triplet Loss训练Model
28 |
29 | def triplet_loss_model(self):
30 | anc_input = Input(shape=(28, 28, 1), name='anc_input') # anchor
31 | pos_input = Input(shape=(28, 28, 1), name='pos_input') # positive
32 | neg_input = Input(shape=(28, 28, 1), name='neg_input') # negative
33 |
34 | shared_model = self.base_model() # 共享模型
35 |
36 | std_out = shared_model(anc_input)
37 | pos_out = shared_model(pos_input)
38 | neg_out = shared_model(neg_input)
39 |
40 | print "[INFO] model - 锚shape: %s" % str(std_out.get_shape())
41 | print "[INFO] model - 正shape: %s" % str(pos_out.get_shape())
42 | print "[INFO] model - 负shape: %s" % str(neg_out.get_shape())
43 |
44 | output = Concatenate()([std_out, pos_out, neg_out]) # 连接
45 | model = Model(inputs=[anc_input, pos_input, neg_input], outputs=output)
46 |
47 | plot_model(model, to_file=os.path.join(self.config.img_dir, "triplet_loss_model.png"),
48 | show_shapes=True) # 绘制模型图
49 | model.compile(loss=self.triplet_loss, optimizer=Adam())
50 |
51 | return model
52 |
53 | @staticmethod
54 | def triplet_loss(y_true, y_pred):
55 | """
56 | Triplet Loss的损失函数
57 | """
58 |
59 | anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:]
60 |
61 | # 欧式距离
62 | pos_dist = K.sum(K.square(anc - pos), axis=-1, keepdims=True)
63 | neg_dist = K.sum(K.square(anc - neg), axis=-1, keepdims=True)
64 | basic_loss = pos_dist - neg_dist + TripletModel.MARGIN
65 |
66 | loss = K.maximum(basic_loss, 0.0)
67 |
68 | print "[INFO] model - triplet_loss shape: %s" % str(loss.shape)
69 | return loss
70 |
71 | def base_model(self):
72 | """
73 | Triplet Loss的基础网络,可以替换其他网络结构
74 | """
75 | ins_input = Input(shape=(28, 28, 1))
76 | x = Conv2D(32, kernel_size=(3, 3), kernel_initializer='random_uniform', activation='relu')(ins_input)
77 | x = MaxPooling2D(pool_size=(2, 2))(x)
78 | x = Conv2D(32, kernel_size=(3, 3), kernel_initializer='random_uniform', activation='relu')(x)
79 | x = MaxPooling2D(pool_size=(2, 2))(x)
80 | x = Dropout(0.25)(x)
81 | x = Flatten()(x)
82 | out = Dense(128, activation='relu')(x)
83 |
84 | model = Model(ins_input, out)
85 | plot_model(model, to_file=os.path.join(self.config.img_dir, "base_model.png"), show_shapes=True) # 绘制模型图
86 | return model
87 |
--------------------------------------------------------------------------------
/infers/triplet_infer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 | import os
7 |
8 | import datetime
9 | from keras.datasets import mnist
10 | from tensorflow.contrib.tensorboard.plugins import projector
11 | import tensorflow as tf
12 | import numpy as np
13 |
14 | from bases.infer_base import InferBase
15 | from keras.models import load_model
16 |
17 | from models.triplet_model import TripletModel
18 | from root_dir import ROOT_DIR
19 | from utils.utils import mkdir_if_not_exist
20 |
21 |
22 | class TripletInfer(InferBase):
23 | def __init__(self, name=None, config=None):
24 | super(TripletInfer, self).__init__(config)
25 | if name:
26 | self.model = self.load_model(name)
27 |
28 | def load_model(self, name):
29 | model = os.path.join(self.config.cp_dir, name)
30 | return load_model(model)
31 |
32 | def predict(self, data):
33 | return self.model.predict(data)
34 |
35 | def test_dist(self):
36 | model_path = os.path.join(self.config.cp_dir, "triplet_loss_model.h5")
37 |
38 | model = load_model(model_path, custom_objects={'triplet_loss': TripletModel.triplet_loss})
39 | model.summary()
40 |
41 | (_, _), (X_test, y_test) = mnist.load_data()
42 | anchor = np.reshape(X_test, (-1, 28, 28, 1))
43 | X = {
44 | 'anc_input': anchor,
45 | 'pos_input': np.zeros(anchor.shape),
46 | 'neg_input': np.zeros(anchor.shape)
47 | }
48 | s_time = datetime.datetime.now()
49 | res = model.predict(X)
50 | e_time = datetime.datetime.now() - s_time
51 | micro = float(e_time.microseconds) / float(len(y_test))
52 | print "TPS: %s (%s ms)" % ((1000000.0 / micro), (micro / 1000.0))
53 | data = res[:, :128]
54 | print "验证结果结构: %s" % str(data.shape)
55 | log_dir = os.path.join(ROOT_DIR, self.config.tb_dir, "test")
56 | mkdir_if_not_exist(log_dir)
57 | self.tb_projector(data, y_test, log_dir)
58 |
59 | def default_dist(self):
60 | (_, _), (X_test, y_test) = mnist.load_data()
61 | X_test = np.reshape(X_test, (-1, 28 * 28))
62 | log_dir = os.path.join(ROOT_DIR, self.config.tb_dir, 'default')
63 | mkdir_if_not_exist(log_dir)
64 | self.tb_projector(X_test, y_test, log_dir)
65 |
66 | @staticmethod
67 | def tb_projector(X_test, y_test, log_dir):
68 | """
69 | TB的映射器
70 | :param X_test: 数据
71 | :param y_test: 标签, 数值型
72 | :param log_dir: 文件夹
73 | :return: 写入日志
74 | """
75 | print "展示数据: %s" % str(X_test.shape)
76 | print "展示标签: %s" % str(y_test.shape)
77 | print "日志目录: %s" % str(log_dir)
78 |
79 | metadata = os.path.join(log_dir, 'metadata.tsv')
80 |
81 | images = tf.Variable(X_test)
82 |
83 | # 把标签写入metadata
84 | with open(metadata, 'w') as metadata_file:
85 | for row in y_test:
86 | metadata_file.write('%d\n' % row)
87 |
88 | with tf.Session() as sess:
89 | saver = tf.train.Saver([images]) # 把数据存储为矩阵
90 |
91 | sess.run(images.initializer) # 图像初始化
92 | saver.save(sess, os.path.join(log_dir, 'images.ckpt')) # 图像存储于images.ckpt
93 |
94 | config = projector.ProjectorConfig() # 配置
95 | # One can add multiple embeddings.
96 | embedding = config.embeddings.add() # 嵌入向量添加
97 | embedding.tensor_name = images.name # Tensor名称
98 | # Link this tensor to its metadata file (e.g. labels).
99 | embedding.metadata_path = metadata # Metadata的路径
100 | # Saves a config file that TensorBoard will read during startup.
101 | projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config) # 可视化嵌入向量
102 |
--------------------------------------------------------------------------------
/trainers/triplet_trainer.py:
--------------------------------------------------------------------------------
1 | # -- coding: utf-8 --
2 | """
3 | Copyright (c) 2018. All rights reserved.
4 | Created by C. L. Wang on 2018/4/18
5 | """
6 | import os
7 | import random
8 | import warnings
9 |
10 | import numpy as np
11 | from keras.callbacks import TensorBoard, ModelCheckpoint, Callback
12 | from sklearn.exceptions import UndefinedMetricWarning
13 | from sklearn.metrics import precision_recall_fscore_support
14 |
15 | from bases.trainer_base import TrainerBase
16 | from root_dir import ROOT_DIR
17 | from utils.np_utils import prp_2_oh_array
18 | from utils.utils import mkdir_if_not_exist
19 |
20 |
21 | class TripletTrainer(TrainerBase):
22 | def __init__(self, model, data, config):
23 | super(TripletTrainer, self).__init__(model, data, config)
24 | self.callbacks = []
25 | self.loss = []
26 | self.acc = []
27 | self.val_loss = []
28 | self.val_acc = []
29 | self.init_callbacks()
30 |
31 | def init_callbacks(self):
32 | train_dir = os.path.join(ROOT_DIR, self.config.tb_dir, "train")
33 | mkdir_if_not_exist(train_dir)
34 | self.callbacks.append(
35 | TensorBoard(
36 | log_dir=train_dir,
37 | write_images=True,
38 | write_graph=True,
39 | )
40 | )
41 |
42 | # self.callbacks.append(FPRMetric())
43 | # self.callbacks.append(FPRMetricDetail())
44 |
45 | def train(self):
46 | x_train = self.data[0][0]
47 | y_train = np.argmax(self.data[0][1], axis=1)
48 | x_test = self.data[1][0]
49 | y_test = np.argmax(self.data[1][1], axis=1)
50 |
51 | clz_size = len(np.unique(y_train))
52 | print "[INFO] trainer - 类别数: %s" % clz_size
53 | digit_indices = [np.where(y_train == i)[0] for i in range(10)]
54 | tr_pairs = self.create_pairs(x_train, digit_indices, clz_size)
55 |
56 | digit_indices = [np.where(y_test == i)[0] for i in range(10)]
57 | te_pairs = self.create_pairs(x_test, digit_indices, clz_size)
58 |
59 | anc_ins = tr_pairs[:, 0]
60 | pos_ins = tr_pairs[:, 1]
61 | neg_ins = tr_pairs[:, 2]
62 |
63 | X = {
64 | 'anc_input': anc_ins,
65 | 'pos_input': pos_ins,
66 | 'neg_input': neg_ins
67 | }
68 |
69 | anc_ins_te = te_pairs[:, 0]
70 | pos_ins_te = te_pairs[:, 1]
71 | neg_ins_te = te_pairs[:, 2]
72 |
73 | X_te = {
74 | 'anc_input': anc_ins_te,
75 | 'pos_input': pos_ins_te,
76 | 'neg_input': neg_ins_te
77 | }
78 |
79 | self.model.fit(
80 | X, np.ones(len(anc_ins)),
81 | batch_size=32,
82 | epochs=2,
83 | validation_data=[X_te, np.ones(len(anc_ins_te))],
84 | verbose=1,
85 | callbacks=self.callbacks)
86 |
87 | self.model.save(os.path.join(self.config.cp_dir, "triplet_loss_model.h5")) # 存储模型
88 |
89 | y_pred = self.model.predict(X_te) # 验证模型
90 | self.show_acc_facets(y_pred, y_pred.shape[0] / clz_size, clz_size)
91 |
92 | @staticmethod
93 | def show_acc_facets(y_pred, n, clz_size):
94 | """
95 | 展示模型的准确率
96 | :param y_pred: 测试结果数据组
97 | :param n: 数据长度
98 | :param clz_size: 类别数
99 | :return: 打印数据
100 | """
101 | print "[INFO] trainer - n_clz: %s" % n
102 | for i in range(clz_size):
103 | print "[INFO] trainer - clz %s" % i
104 | final = y_pred[n * i:n * (i + 1), :]
105 | anchor, positive, negative = final[:, 0:128], final[:, 128:256], final[:, 256:]
106 |
107 | pos_dist = np.sum(np.square(anchor - positive), axis=-1, keepdims=True)
108 | neg_dist = np.sum(np.square(anchor - negative), axis=-1, keepdims=True)
109 | basic_loss = pos_dist - neg_dist
110 | r_count = basic_loss[np.where(basic_loss < 0)].shape[0]
111 | print "[INFO] trainer - distance - min: %s, max: %s, avg: %s" % (
112 | np.min(basic_loss), np.max(basic_loss), np.average(basic_loss))
113 | print "[INFO] acc: %s" % (float(r_count) / float(n))
114 | print ""
115 |
116 | @staticmethod
117 | def create_pairs(x, digit_indices, num_classes):
118 | """
119 | 创建正例和负例的Pairs
120 | :param x: 数据
121 | :param digit_indices: 不同类别的索引列表
122 | :param num_classes: 类别
123 | :return: Triplet Loss 的 Feed 数据
124 | """
125 |
126 | pairs = []
127 | n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1 # 最小类别数
128 | for d in range(num_classes):
129 | for i in range(n):
130 | z1, z2 = digit_indices[d][i], digit_indices[d][i + 1]
131 | inc = random.randrange(1, num_classes)
132 | dn = (d + inc) % num_classes
133 | z3 = digit_indices[dn][i]
134 | pairs += [[x[z1], x[z2], x[z3]]]
135 | return np.array(pairs)
136 |
137 |
138 | class FPRMetric(Callback):
139 | """
140 | 输出F, P, R
141 | """
142 |
143 | def on_epoch_end(self, batch, logs=None):
144 | val_x = self.validation_data[0]
145 | val_y = self.validation_data[1]
146 |
147 | prd_y = prp_2_oh_array(np.asarray(self.model.predict(val_x)))
148 |
149 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
150 | precision, recall, f_score, _ = precision_recall_fscore_support(
151 | val_y, prd_y, average='macro')
152 | print " — val_f1: % 0.4f — val_pre: % 0.4f — val_rec % 0.4f" % (f_score, precision, recall)
153 |
154 |
155 | class FPRMetricDetail(Callback):
156 | """
157 | 输出F, P, R
158 | """
159 |
160 | def on_epoch_end(self, batch, logs=None):
161 | val_x = self.validation_data[0]
162 | val_y = self.validation_data[1]
163 |
164 | prd_y = prp_2_oh_array(np.asarray(self.model.predict(val_x)))
165 |
166 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning)
167 | precision, recall, f_score, support = precision_recall_fscore_support(val_y, prd_y)
168 |
169 | for p, r, f, s in zip(precision, recall, f_score, support):
170 | print " — val_f1: % 0.4f — val_pre: % 0.4f — val_rec % 0.4f - ins %s" % (f, p, r, s)
171 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Triplet Loss 损失函数
2 |
3 | [Triplet Loss](https://arxiv.org/pdf/1503.03832.pdf)是深度学习中的一种损失函数,用于训练**差异性较小**的样本,如人脸等, Feed数据包括锚(Anchor)示例、正(Positive)示例、负(Negative)示例,通过优化锚示例与正示例的距离**小于**锚示例与负示例的距离,实现样本的相似性计算。
4 |
5 |
6 |
7 | 数据集:[MNIST](http://yann.lecun.com/exdb/mnist/)
8 |
9 | 框架:[DL-Project-Template](https://github.com/SpikeKing/DL-Project-Template)
10 |
11 | 目标:通过Triplet Loss训练模型,实现手写图像的相似性计算。
12 |
13 | [工程](https://github.com/SpikeKing/triplet-loss-mnist):https://github.com/SpikeKing/triplet-loss-mnist
14 |
15 | ---
16 |
17 | ## 模型
18 |
19 | Triplet Loss的核心是锚示例、正示例、负示例共享模型,通过模型,将锚示例与正示例聚类,远离负示例。
20 |
21 | **Triplet Loss Model**的结构如下:
22 |
23 |
24 |
25 | - 输入:三个输入,即锚示例、正示例、负示例,不同示例的**结构**相同;
26 | - 模型:一个共享模型,支持替换为**任意**网络结构;
27 | - 输出:一个输出,即三个模型输出的拼接。
28 |
29 | **Shared Model**选择常用的卷积模型,输出为全连接的128维数据:
30 |
31 |
32 |
33 | Triplet Loss **损失函数**的计算公式如下:
34 |
35 |
36 |
37 | ---
38 |
39 | ## 训练
40 |
41 | 模型参数:
42 |
43 | - batch_size:32
44 | - epochs:2
45 |
46 | 超参数:
47 |
48 | - 边界Margin的值设置为``1``。
49 |
50 | 训练命令:
51 |
52 | ```text
53 | python main_train.py -c configs/triplet_config.json
54 | ```
55 |
56 | 训练日志:
57 |
58 | ```text
59 | Using TensorFlow backend.
60 | [INFO] 解析配置...
61 | [INFO] 加载数据...
62 | [INFO] X_train.shape: (60000, 28, 28, 1), y_train.shape: (60000, 10)
63 | [INFO] X_test.shape: (10000, 28, 28, 1), y_test.shape: (10000, 10)
64 | [INFO] 构造网络...
65 | [INFO] model - 锚shape: (?, 128)
66 | [INFO] model - 正shape: (?, 128)
67 | [INFO] model - 负shape: (?, 128)
68 | [INFO] model - triplet_loss shape: (?, 1)
69 | [INFO] 训练网络...
70 | [INFO] trainer - 类别数: 10
71 | Train on 54200 samples, validate on 8910 samples
72 | 2018-04-24 10:59:06.130952: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
73 | Epoch 1/2
74 | 54200/54200 [==============================] - 84s 2ms/step - loss: 0.6083 - val_loss: 0.0515
75 | Epoch 2/2
76 | 54200/54200 [==============================] - 83s 2ms/step - loss: 0.0869 - val_loss: 0.0314
77 | ```
78 |
79 | 算法收敛较好,Loss线性下降:
80 |
81 |
82 |
83 | TF Graph:
84 |
85 |
86 |
87 | ---
88 |
89 | ## 验证
90 |
91 | **算法效率**(TPS): 每秒48163次 (0.0207625 ms/t)
92 |
93 | 测试命令:
94 |
95 | ```text
96 | python main_test.py -c configs/triplet_config.json
97 | ```
98 |
99 | 测试日志:
100 |
101 | ```text
102 | Using TensorFlow backend.
103 | [INFO] 解析配置...
104 | [INFO] 加载数据...
105 | [INFO] 预测数据...
106 | 展示数据: (10000, 784)
107 | 展示标签: (10000,)
108 | 日志目录: /Users/wang/workspace/triplet-loss-mnist/experiments/triplet_mnist/logs/default
109 | 2018-04-24 11:02:07.874682: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
110 | [INFO] model - triplet_loss shape: (?, 1)
111 | __________________________________________________________________________________________________
112 | Layer (type) Output Shape Param # Connected to
113 | ==================================================================================================
114 | anc_input (InputLayer) (None, 28, 28, 1) 0
115 | __________________________________________________________________________________________________
116 | pos_input (InputLayer) (None, 28, 28, 1) 0
117 | __________________________________________________________________________________________________
118 | neg_input (InputLayer) (None, 28, 28, 1) 0
119 | __________________________________________________________________________________________________
120 | model_1 (Model) (None, 128) 112096 anc_input[0][0]
121 | pos_input[0][0]
122 | neg_input[0][0]
123 | __________________________________________________________________________________________________
124 | concatenate_1 (Concatenate) (None, 384) 0 model_1[1][0]
125 | model_1[2][0]
126 | model_1[3][0]
127 | ==================================================================================================
128 | Total params: 112,096
129 | Trainable params: 112,096
130 | Non-trainable params: 0
131 | __________________________________________________________________________________________________
132 | TPS: 272553.829381 (0.003669 ms)
133 | 验证结果结构: (10000, 128)
134 | 展示数据: (10000, 128)
135 | 展示标签: (10000,)
136 | 日志目录: /Users/wang/workspace/triplet-loss-mnist/experiments/triplet_mnist/logs/test
137 | [INFO] 预测完成...
138 | ```
139 |
140 | MNIST验证集的效果:
141 |
142 | ``` bash
143 | [INFO] trainer - clz 0
144 | [INFO] trainer - distance - min: -15.4567, max: 1.98611, avg: -6.50481
145 | [INFO] acc: 0.996632996633
146 |
147 | [INFO] trainer - clz 1
148 | [INFO] trainer - distance - min: -13.09, max: 3.43779, avg: -6.66867
149 | [INFO] acc: 0.99214365881
150 |
151 | [INFO] trainer - clz 2
152 | [INFO] trainer - distance - min: -14.2524, max: 2.49437, avg: -5.60508
153 | [INFO] acc: 0.991021324355
154 |
155 | [INFO] trainer - clz 3
156 | [INFO] trainer - distance - min: -16.6555, max: 1.21776, avg: -6.32161
157 | [INFO] acc: 0.995510662177
158 |
159 | [INFO] trainer - clz 4
160 | [INFO] trainer - distance - min: -14.193, max: 1.65427, avg: -5.90896
161 | [INFO] acc: 0.991021324355
162 |
163 | [INFO] trainer - clz 5
164 | [INFO] trainer - distance - min: -14.1007, max: 2.01843, avg: -6.36086
165 | [INFO] acc: 0.994388327722
166 |
167 | [INFO] trainer - clz 6
168 | [INFO] trainer - distance - min: -16.8953, max: 2.84421, avg: -8.43978
169 | [INFO] acc: 0.995510662177
170 |
171 | [INFO] trainer - clz 7
172 | [INFO] trainer - distance - min: -16.6177, max: 3.49675, avg: -5.99822
173 | [INFO] acc: 0.989898989899
174 |
175 | [INFO] trainer - clz 8
176 | [INFO] trainer - distance - min: -14.937, max: 3.38141, avg: -5.4424
177 | [INFO] acc: 0.979797979798
178 |
179 | [INFO] trainer - clz 9
180 | [INFO] trainer - distance - min: -16.9519, max: 2.39112, avg: -5.93581
181 | [INFO] acc: 0.985409652076
182 | ```
183 |
184 | 测试的MNIST分布:
185 |
186 |
187 |
188 | 输出的Triplet Loss MNIST分布:
189 |
190 |
191 |
192 | 本例仅仅使用2个Epoch,也没有特殊设置超参,实际效果仍有提升空间。
193 |
194 | ---
195 |
196 | By C. L. Wang
197 |
--------------------------------------------------------------------------------