├── doc ├── test.png ├── graph.png ├── default.png ├── tl_model.png ├── base_model.png ├── tl_formular.png ├── loss_timeline.png └── tl_algorithm.png ├── bases ├── __init__.py ├── trainer_base.py ├── data_loader_base.py ├── infer_base.py └── model_base.py ├── utils ├── __init__.py ├── np_utils.py ├── utils.py └── config_utils.py ├── models ├── __init__.py └── triplet_model.py ├── trainers ├── __init__.py └── triplet_trainer.py ├── data_loaders ├── __init__.py └── triplet_dl.py ├── requirements.txt ├── configs └── triplet_config.json ├── infers ├── __init__.py └── triplet_infer.py ├── root_dir.py ├── main_test.py ├── main_train.py ├── .gitignore └── README.md /doc/test.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/test.png -------------------------------------------------------------------------------- /doc/graph.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/graph.png -------------------------------------------------------------------------------- /doc/default.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/default.png -------------------------------------------------------------------------------- /doc/tl_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/tl_model.png -------------------------------------------------------------------------------- /doc/base_model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/base_model.png -------------------------------------------------------------------------------- /doc/tl_formular.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/tl_formular.png -------------------------------------------------------------------------------- /doc/loss_timeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/loss_timeline.png -------------------------------------------------------------------------------- /doc/tl_algorithm.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/SpikeKing/triplet-loss-mnist/HEAD/doc/tl_algorithm.png -------------------------------------------------------------------------------- /bases/__init__.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ -------------------------------------------------------------------------------- /trainers/__init__.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ -------------------------------------------------------------------------------- /data_loaders/__init__.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tensorflow==1.4.0 2 | keras==2.1.5 3 | bunch==1.0.1 4 | h5py==2.7.1 5 | pydot==1.2.4 6 | numpy==1.13.0 7 | scikit-learn==0.19.1 8 | -------------------------------------------------------------------------------- /configs/triplet_config.json: -------------------------------------------------------------------------------- 1 | { 2 | "exp_name": "triplet_mnist", 3 | "lr": 0.001, 4 | "decay": 0.0, 5 | "num_epochs": 10, 6 | "batch_size": 64 7 | } -------------------------------------------------------------------------------- /infers/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/18 6 | """ -------------------------------------------------------------------------------- /root_dir.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/23 6 | """ 7 | 8 | import os 9 | 10 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) # 存储项目所在的绝对路径 11 | -------------------------------------------------------------------------------- /bases/trainer_base.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | 7 | 8 | class TrainerBase(object): 9 | """ 10 | 训练器基类 11 | """ 12 | 13 | def __init__(self, model, data, config): 14 | self.model = model # 模型 15 | self.data = data # 数据 16 | self.config = config # 配置 17 | 18 | def train(self): 19 | """ 20 | 训练逻辑 21 | """ 22 | raise NotImplementedError 23 | -------------------------------------------------------------------------------- /bases/data_loader_base.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | 7 | 8 | class DataLoaderBase(object): 9 | """ 10 | 数据加载的基类 11 | """ 12 | 13 | def __init__(self, config): 14 | self.config = config # 设置配置信息 15 | 16 | def get_train_data(self): 17 | """ 18 | 获取训练数据 19 | """ 20 | raise NotImplementedError 21 | 22 | def get_test_data(self): 23 | """ 24 | 获取测试数据 25 | """ 26 | raise NotImplementedError 27 | -------------------------------------------------------------------------------- /bases/infer_base.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/18 6 | """ 7 | 8 | 9 | class InferBase(object): 10 | """ 11 | 推断基类 12 | """ 13 | 14 | def __init__(self, config): 15 | self.config = config # 配置 16 | 17 | def load_model(self, name): 18 | """ 19 | 加载模型 20 | """ 21 | raise NotImplementedError 22 | 23 | def predict(self, data): 24 | """ 25 | 预测结果 26 | """ 27 | raise NotImplementedError 28 | -------------------------------------------------------------------------------- /utils/np_utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/18 6 | 7 | 参考: https://juejin.im/post/5acfef976fb9a028db5918b5 8 | """ 9 | 10 | import numpy as np 11 | 12 | 13 | def prp_2_oh_array(arr): 14 | """ 15 | 概率矩阵转换为OH矩阵 16 | arr = np.array([[0.1, 0.5, 0.4], [0.2, 0.1, 0.6]]) 17 | :param arr: 概率矩阵 18 | :return: OH矩阵 19 | """ 20 | arr_size = arr.shape[1] # 类别数 21 | arr_max = np.argmax(arr, axis=1) # 最大值位置 22 | oh_arr = np.eye(arr_size)[arr_max] # OH矩阵 23 | return oh_arr 24 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/18 6 | """ 7 | import os 8 | import shutil 9 | 10 | 11 | def mkdir_if_not_exist(dir, is_delete=False): 12 | """ 13 | 创建文件夹 14 | :param dirs: 文件夹列表 15 | :param is_delete: 是否删除 16 | :return: 是否成功 17 | """ 18 | try: 19 | if is_delete: 20 | if os.path.exists(dir): 21 | shutil.rmtree(dir) 22 | print u'[INFO] 文件夹 "%s" 存在, 删除文件夹.' % dir 23 | 24 | if not os.path.exists(dir): 25 | os.makedirs(dir) 26 | print u'[INFO] 文件夹 "%s" 不存在, 创建文件夹.' % dir 27 | return True 28 | except Exception as e: 29 | print '[Exception] %s' % e 30 | return False 31 | -------------------------------------------------------------------------------- /main_test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/18 6 | """ 7 | from time import time 8 | 9 | import datetime 10 | import numpy as np 11 | 12 | from data_loaders.triplet_dl import TripletDL 13 | from infers.triplet_infer import TripletInfer 14 | from utils.config_utils import process_config, get_test_args 15 | 16 | 17 | def main_test(): 18 | print '[INFO] 解析配置...' 19 | parser = None 20 | config = None 21 | 22 | try: 23 | args, parser = get_test_args() 24 | config = process_config(args.config) 25 | except Exception as e: 26 | print '[Exception] 配置无效, %s' % e 27 | if parser: 28 | parser.print_help() 29 | print '[Exception] 参考: python main_test.py -c configs/triplet_config.json' 30 | exit(0) 31 | 32 | print '[INFO] 加载数据...' 33 | 34 | print '[INFO] 预测数据...' 35 | infer = TripletInfer(config=config) 36 | infer.default_dist() 37 | infer.test_dist() 38 | 39 | print '[INFO] 预测完成...' 40 | 41 | 42 | if __name__ == '__main__': 43 | main_test() 44 | -------------------------------------------------------------------------------- /bases/model_base.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | 7 | 8 | class ModelBase(object): 9 | """ 10 | 模型基类 11 | """ 12 | 13 | def __init__(self, config): 14 | self.config = config # 配置 15 | self.model = None # 模型 16 | 17 | def save(self, checkpoint_path): 18 | """ 19 | 存储checkpoint, 路径定义于配置文件中 20 | """ 21 | if self.model is None: 22 | raise Exception("[Exception] You have to build the model first.") 23 | 24 | print("[INFO] Saving model...") 25 | self.model.save_weights(checkpoint_path) 26 | print("[INFO] Model saved") 27 | 28 | def load(self, checkpoint_path): 29 | """ 30 | 加载checkpoint, 路径定义于配置文件中 31 | """ 32 | if self.model is None: 33 | raise Exception("[Exception] You have to build the model first.") 34 | 35 | print("[INFO] Loading model checkpoint {} ...\n".format(checkpoint_path)) 36 | self.model.load_weights(checkpoint_path) 37 | print("[INFO] Model loaded") 38 | 39 | def build_model(self): 40 | """ 41 | 构建模型 42 | """ 43 | raise NotImplementedError 44 | -------------------------------------------------------------------------------- /data_loaders/triplet_dl.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | from keras.datasets import mnist 7 | from keras.utils import to_categorical 8 | 9 | from bases.data_loader_base import DataLoaderBase 10 | 11 | 12 | class TripletDL(DataLoaderBase): 13 | def __init__(self, config=None): 14 | super(TripletDL, self).__init__(config) 15 | (self.X_train, self.y_train), (self.X_test, self.y_test) = mnist.load_data() 16 | 17 | # 展平数据,for RNN,感知机 18 | # self.X_train = self.X_train.reshape((-1, 28 * 28)) 19 | # self.X_test = self.X_test.reshape((-1, 28 * 28)) 20 | 21 | # 图片数据,for CNN 22 | self.X_train = self.X_train.reshape(self.X_train.shape[0], 28, 28, 1) 23 | self.X_test = self.X_test.reshape(self.X_test.shape[0], 28, 28, 1) 24 | 25 | self.y_train = to_categorical(self.y_train) 26 | self.y_test = to_categorical(self.y_test) 27 | 28 | print "[INFO] X_train.shape: %s, y_train.shape: %s" \ 29 | % (str(self.X_train.shape), str(self.y_train.shape)) 30 | print "[INFO] X_test.shape: %s, y_test.shape: %s" \ 31 | % (str(self.X_test.shape), str(self.y_test.shape)) 32 | 33 | def get_train_data(self): 34 | return self.X_train, self.y_train 35 | 36 | def get_test_data(self): 37 | return self.X_test, self.y_test 38 | -------------------------------------------------------------------------------- /main_train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -- coding: utf-8 -- 3 | """ 4 | Copyright (c) 2018. All rights reserved. 5 | Created by C. L. Wang on 2018/4/18 6 | 7 | 参考: 8 | NumPy FutureWarning 9 | https://stackoverflow.com/questions/48340392/futurewarning-conversion-of-the-second-argument-of-issubdtype-from-float-to 10 | """ 11 | 12 | from data_loaders.triplet_dl import TripletDL 13 | from infers.triplet_infer import TripletInfer 14 | from models.triplet_model import TripletModel 15 | from trainers.triplet_trainer import TripletTrainer 16 | from utils.config_utils import process_config, get_train_args 17 | import numpy as np 18 | 19 | 20 | def main_train(): 21 | """ 22 | 训练模型 23 | 24 | :return: 25 | """ 26 | print '[INFO] 解析配置...' 27 | 28 | parser = None 29 | config = None 30 | 31 | try: 32 | args, parser = get_train_args() 33 | config = process_config(args.config) 34 | except Exception as e: 35 | print '[Exception] 配置无效, %s' % e 36 | if parser: 37 | parser.print_help() 38 | print '[Exception] 参考: python main_train.py -c configs/triplet_config.json' 39 | exit(0) 40 | # config = process_config('configs/triplet_config.json') 41 | 42 | print '[INFO] 加载数据...' 43 | dl = TripletDL(config=config) 44 | 45 | print '[INFO] 构造网络...' 46 | model = TripletModel(config=config) 47 | 48 | print '[INFO] 训练网络...' 49 | trainer = TripletTrainer( 50 | model=model.model, 51 | data=[dl.get_train_data(), dl.get_test_data()], 52 | config=config) 53 | trainer.train() 54 | print '[INFO] 训练完成...' 55 | 56 | 57 | if __name__ == '__main__': 58 | main_train() 59 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | .idea/ 109 | 110 | # data 111 | experiments/ 112 | 113 | -------------------------------------------------------------------------------- /utils/config_utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | import argparse 7 | import json 8 | 9 | import os 10 | from bunch import Bunch 11 | 12 | from root_dir import ROOT_DIR 13 | from utils import mkdir_if_not_exist 14 | 15 | 16 | def get_config_from_json(json_file): 17 | """ 18 | 将配置文件转换为配置类 19 | """ 20 | with open(json_file, 'r') as config_file: 21 | config_dict = json.load(config_file) # 配置字典 22 | 23 | config = Bunch(config_dict) # 将配置字典转换为类 24 | 25 | return config, config_dict 26 | 27 | 28 | def process_config(json_file): 29 | """ 30 | 解析Json文件 31 | :param json_file: 配置文件 32 | :return: 配置类 33 | """ 34 | config, _ = get_config_from_json(json_file) 35 | 36 | exp_dir = os.path.join(ROOT_DIR, "experiments") 37 | mkdir_if_not_exist(exp_dir) # 创建文件夹 38 | 39 | config.tb_dir = os.path.join(exp_dir, config.exp_name, "logs/") # 日志 40 | config.cp_dir = os.path.join(exp_dir, config.exp_name, "checkpoints/") # 模型 41 | config.img_dir = os.path.join(exp_dir, config.exp_name, "images/") # 网络 42 | 43 | mkdir_if_not_exist(config.tb_dir) # 创建文件夹 44 | mkdir_if_not_exist(config.cp_dir) # 创建文件夹 45 | mkdir_if_not_exist(config.img_dir) # 创建文件夹 46 | return config 47 | 48 | 49 | def get_train_args(): 50 | """ 51 | 添加训练参数 52 | """ 53 | parser = argparse.ArgumentParser(description=__doc__) 54 | parser.add_argument( 55 | '-c', '--cfg', 56 | dest='config', 57 | metavar='path', 58 | default='None', 59 | help='add a configuration file') 60 | args = parser.parse_args() 61 | return args, parser 62 | 63 | 64 | def get_test_args(): 65 | """ 66 | 添加测试路径 67 | :return: 参数 68 | """ 69 | parser = argparse.ArgumentParser(description=__doc__) 70 | parser.add_argument( 71 | '-c', '--cfg', 72 | dest='config', 73 | metavar='C', 74 | default='None', 75 | help='add a configuration file') 76 | args = parser.parse_args() 77 | return args, parser 78 | -------------------------------------------------------------------------------- /models/triplet_model.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | import os 7 | from keras import Input, Model 8 | from keras.layers import Dense, Conv2D, MaxPooling2D, Dropout, Flatten, K, merge, Concatenate 9 | from keras.optimizers import Adam 10 | from keras.utils import plot_model 11 | 12 | from bases.model_base import ModelBase 13 | 14 | 15 | class TripletModel(ModelBase): 16 | """ 17 | TripletLoss模型 18 | """ 19 | 20 | MARGIN = 1.0 # 超参 21 | 22 | def __init__(self, config): 23 | super(TripletModel, self).__init__(config) 24 | self.build_model() 25 | 26 | def build_model(self): 27 | self.model = self.triplet_loss_model() # 使用Triplet Loss训练Model 28 | 29 | def triplet_loss_model(self): 30 | anc_input = Input(shape=(28, 28, 1), name='anc_input') # anchor 31 | pos_input = Input(shape=(28, 28, 1), name='pos_input') # positive 32 | neg_input = Input(shape=(28, 28, 1), name='neg_input') # negative 33 | 34 | shared_model = self.base_model() # 共享模型 35 | 36 | std_out = shared_model(anc_input) 37 | pos_out = shared_model(pos_input) 38 | neg_out = shared_model(neg_input) 39 | 40 | print "[INFO] model - 锚shape: %s" % str(std_out.get_shape()) 41 | print "[INFO] model - 正shape: %s" % str(pos_out.get_shape()) 42 | print "[INFO] model - 负shape: %s" % str(neg_out.get_shape()) 43 | 44 | output = Concatenate()([std_out, pos_out, neg_out]) # 连接 45 | model = Model(inputs=[anc_input, pos_input, neg_input], outputs=output) 46 | 47 | plot_model(model, to_file=os.path.join(self.config.img_dir, "triplet_loss_model.png"), 48 | show_shapes=True) # 绘制模型图 49 | model.compile(loss=self.triplet_loss, optimizer=Adam()) 50 | 51 | return model 52 | 53 | @staticmethod 54 | def triplet_loss(y_true, y_pred): 55 | """ 56 | Triplet Loss的损失函数 57 | """ 58 | 59 | anc, pos, neg = y_pred[:, 0:128], y_pred[:, 128:256], y_pred[:, 256:] 60 | 61 | # 欧式距离 62 | pos_dist = K.sum(K.square(anc - pos), axis=-1, keepdims=True) 63 | neg_dist = K.sum(K.square(anc - neg), axis=-1, keepdims=True) 64 | basic_loss = pos_dist - neg_dist + TripletModel.MARGIN 65 | 66 | loss = K.maximum(basic_loss, 0.0) 67 | 68 | print "[INFO] model - triplet_loss shape: %s" % str(loss.shape) 69 | return loss 70 | 71 | def base_model(self): 72 | """ 73 | Triplet Loss的基础网络,可以替换其他网络结构 74 | """ 75 | ins_input = Input(shape=(28, 28, 1)) 76 | x = Conv2D(32, kernel_size=(3, 3), kernel_initializer='random_uniform', activation='relu')(ins_input) 77 | x = MaxPooling2D(pool_size=(2, 2))(x) 78 | x = Conv2D(32, kernel_size=(3, 3), kernel_initializer='random_uniform', activation='relu')(x) 79 | x = MaxPooling2D(pool_size=(2, 2))(x) 80 | x = Dropout(0.25)(x) 81 | x = Flatten()(x) 82 | out = Dense(128, activation='relu')(x) 83 | 84 | model = Model(ins_input, out) 85 | plot_model(model, to_file=os.path.join(self.config.img_dir, "base_model.png"), show_shapes=True) # 绘制模型图 86 | return model 87 | -------------------------------------------------------------------------------- /infers/triplet_infer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | import os 7 | 8 | import datetime 9 | from keras.datasets import mnist 10 | from tensorflow.contrib.tensorboard.plugins import projector 11 | import tensorflow as tf 12 | import numpy as np 13 | 14 | from bases.infer_base import InferBase 15 | from keras.models import load_model 16 | 17 | from models.triplet_model import TripletModel 18 | from root_dir import ROOT_DIR 19 | from utils.utils import mkdir_if_not_exist 20 | 21 | 22 | class TripletInfer(InferBase): 23 | def __init__(self, name=None, config=None): 24 | super(TripletInfer, self).__init__(config) 25 | if name: 26 | self.model = self.load_model(name) 27 | 28 | def load_model(self, name): 29 | model = os.path.join(self.config.cp_dir, name) 30 | return load_model(model) 31 | 32 | def predict(self, data): 33 | return self.model.predict(data) 34 | 35 | def test_dist(self): 36 | model_path = os.path.join(self.config.cp_dir, "triplet_loss_model.h5") 37 | 38 | model = load_model(model_path, custom_objects={'triplet_loss': TripletModel.triplet_loss}) 39 | model.summary() 40 | 41 | (_, _), (X_test, y_test) = mnist.load_data() 42 | anchor = np.reshape(X_test, (-1, 28, 28, 1)) 43 | X = { 44 | 'anc_input': anchor, 45 | 'pos_input': np.zeros(anchor.shape), 46 | 'neg_input': np.zeros(anchor.shape) 47 | } 48 | s_time = datetime.datetime.now() 49 | res = model.predict(X) 50 | e_time = datetime.datetime.now() - s_time 51 | micro = float(e_time.microseconds) / float(len(y_test)) 52 | print "TPS: %s (%s ms)" % ((1000000.0 / micro), (micro / 1000.0)) 53 | data = res[:, :128] 54 | print "验证结果结构: %s" % str(data.shape) 55 | log_dir = os.path.join(ROOT_DIR, self.config.tb_dir, "test") 56 | mkdir_if_not_exist(log_dir) 57 | self.tb_projector(data, y_test, log_dir) 58 | 59 | def default_dist(self): 60 | (_, _), (X_test, y_test) = mnist.load_data() 61 | X_test = np.reshape(X_test, (-1, 28 * 28)) 62 | log_dir = os.path.join(ROOT_DIR, self.config.tb_dir, 'default') 63 | mkdir_if_not_exist(log_dir) 64 | self.tb_projector(X_test, y_test, log_dir) 65 | 66 | @staticmethod 67 | def tb_projector(X_test, y_test, log_dir): 68 | """ 69 | TB的映射器 70 | :param X_test: 数据 71 | :param y_test: 标签, 数值型 72 | :param log_dir: 文件夹 73 | :return: 写入日志 74 | """ 75 | print "展示数据: %s" % str(X_test.shape) 76 | print "展示标签: %s" % str(y_test.shape) 77 | print "日志目录: %s" % str(log_dir) 78 | 79 | metadata = os.path.join(log_dir, 'metadata.tsv') 80 | 81 | images = tf.Variable(X_test) 82 | 83 | # 把标签写入metadata 84 | with open(metadata, 'w') as metadata_file: 85 | for row in y_test: 86 | metadata_file.write('%d\n' % row) 87 | 88 | with tf.Session() as sess: 89 | saver = tf.train.Saver([images]) # 把数据存储为矩阵 90 | 91 | sess.run(images.initializer) # 图像初始化 92 | saver.save(sess, os.path.join(log_dir, 'images.ckpt')) # 图像存储于images.ckpt 93 | 94 | config = projector.ProjectorConfig() # 配置 95 | # One can add multiple embeddings. 96 | embedding = config.embeddings.add() # 嵌入向量添加 97 | embedding.tensor_name = images.name # Tensor名称 98 | # Link this tensor to its metadata file (e.g. labels). 99 | embedding.metadata_path = metadata # Metadata的路径 100 | # Saves a config file that TensorBoard will read during startup. 101 | projector.visualize_embeddings(tf.summary.FileWriter(log_dir), config) # 可视化嵌入向量 102 | -------------------------------------------------------------------------------- /trainers/triplet_trainer.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | Copyright (c) 2018. All rights reserved. 4 | Created by C. L. Wang on 2018/4/18 5 | """ 6 | import os 7 | import random 8 | import warnings 9 | 10 | import numpy as np 11 | from keras.callbacks import TensorBoard, ModelCheckpoint, Callback 12 | from sklearn.exceptions import UndefinedMetricWarning 13 | from sklearn.metrics import precision_recall_fscore_support 14 | 15 | from bases.trainer_base import TrainerBase 16 | from root_dir import ROOT_DIR 17 | from utils.np_utils import prp_2_oh_array 18 | from utils.utils import mkdir_if_not_exist 19 | 20 | 21 | class TripletTrainer(TrainerBase): 22 | def __init__(self, model, data, config): 23 | super(TripletTrainer, self).__init__(model, data, config) 24 | self.callbacks = [] 25 | self.loss = [] 26 | self.acc = [] 27 | self.val_loss = [] 28 | self.val_acc = [] 29 | self.init_callbacks() 30 | 31 | def init_callbacks(self): 32 | train_dir = os.path.join(ROOT_DIR, self.config.tb_dir, "train") 33 | mkdir_if_not_exist(train_dir) 34 | self.callbacks.append( 35 | TensorBoard( 36 | log_dir=train_dir, 37 | write_images=True, 38 | write_graph=True, 39 | ) 40 | ) 41 | 42 | # self.callbacks.append(FPRMetric()) 43 | # self.callbacks.append(FPRMetricDetail()) 44 | 45 | def train(self): 46 | x_train = self.data[0][0] 47 | y_train = np.argmax(self.data[0][1], axis=1) 48 | x_test = self.data[1][0] 49 | y_test = np.argmax(self.data[1][1], axis=1) 50 | 51 | clz_size = len(np.unique(y_train)) 52 | print "[INFO] trainer - 类别数: %s" % clz_size 53 | digit_indices = [np.where(y_train == i)[0] for i in range(10)] 54 | tr_pairs = self.create_pairs(x_train, digit_indices, clz_size) 55 | 56 | digit_indices = [np.where(y_test == i)[0] for i in range(10)] 57 | te_pairs = self.create_pairs(x_test, digit_indices, clz_size) 58 | 59 | anc_ins = tr_pairs[:, 0] 60 | pos_ins = tr_pairs[:, 1] 61 | neg_ins = tr_pairs[:, 2] 62 | 63 | X = { 64 | 'anc_input': anc_ins, 65 | 'pos_input': pos_ins, 66 | 'neg_input': neg_ins 67 | } 68 | 69 | anc_ins_te = te_pairs[:, 0] 70 | pos_ins_te = te_pairs[:, 1] 71 | neg_ins_te = te_pairs[:, 2] 72 | 73 | X_te = { 74 | 'anc_input': anc_ins_te, 75 | 'pos_input': pos_ins_te, 76 | 'neg_input': neg_ins_te 77 | } 78 | 79 | self.model.fit( 80 | X, np.ones(len(anc_ins)), 81 | batch_size=32, 82 | epochs=2, 83 | validation_data=[X_te, np.ones(len(anc_ins_te))], 84 | verbose=1, 85 | callbacks=self.callbacks) 86 | 87 | self.model.save(os.path.join(self.config.cp_dir, "triplet_loss_model.h5")) # 存储模型 88 | 89 | y_pred = self.model.predict(X_te) # 验证模型 90 | self.show_acc_facets(y_pred, y_pred.shape[0] / clz_size, clz_size) 91 | 92 | @staticmethod 93 | def show_acc_facets(y_pred, n, clz_size): 94 | """ 95 | 展示模型的准确率 96 | :param y_pred: 测试结果数据组 97 | :param n: 数据长度 98 | :param clz_size: 类别数 99 | :return: 打印数据 100 | """ 101 | print "[INFO] trainer - n_clz: %s" % n 102 | for i in range(clz_size): 103 | print "[INFO] trainer - clz %s" % i 104 | final = y_pred[n * i:n * (i + 1), :] 105 | anchor, positive, negative = final[:, 0:128], final[:, 128:256], final[:, 256:] 106 | 107 | pos_dist = np.sum(np.square(anchor - positive), axis=-1, keepdims=True) 108 | neg_dist = np.sum(np.square(anchor - negative), axis=-1, keepdims=True) 109 | basic_loss = pos_dist - neg_dist 110 | r_count = basic_loss[np.where(basic_loss < 0)].shape[0] 111 | print "[INFO] trainer - distance - min: %s, max: %s, avg: %s" % ( 112 | np.min(basic_loss), np.max(basic_loss), np.average(basic_loss)) 113 | print "[INFO] acc: %s" % (float(r_count) / float(n)) 114 | print "" 115 | 116 | @staticmethod 117 | def create_pairs(x, digit_indices, num_classes): 118 | """ 119 | 创建正例和负例的Pairs 120 | :param x: 数据 121 | :param digit_indices: 不同类别的索引列表 122 | :param num_classes: 类别 123 | :return: Triplet Loss 的 Feed 数据 124 | """ 125 | 126 | pairs = [] 127 | n = min([len(digit_indices[d]) for d in range(num_classes)]) - 1 # 最小类别数 128 | for d in range(num_classes): 129 | for i in range(n): 130 | z1, z2 = digit_indices[d][i], digit_indices[d][i + 1] 131 | inc = random.randrange(1, num_classes) 132 | dn = (d + inc) % num_classes 133 | z3 = digit_indices[dn][i] 134 | pairs += [[x[z1], x[z2], x[z3]]] 135 | return np.array(pairs) 136 | 137 | 138 | class FPRMetric(Callback): 139 | """ 140 | 输出F, P, R 141 | """ 142 | 143 | def on_epoch_end(self, batch, logs=None): 144 | val_x = self.validation_data[0] 145 | val_y = self.validation_data[1] 146 | 147 | prd_y = prp_2_oh_array(np.asarray(self.model.predict(val_x))) 148 | 149 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning) 150 | precision, recall, f_score, _ = precision_recall_fscore_support( 151 | val_y, prd_y, average='macro') 152 | print " — val_f1: % 0.4f — val_pre: % 0.4f — val_rec % 0.4f" % (f_score, precision, recall) 153 | 154 | 155 | class FPRMetricDetail(Callback): 156 | """ 157 | 输出F, P, R 158 | """ 159 | 160 | def on_epoch_end(self, batch, logs=None): 161 | val_x = self.validation_data[0] 162 | val_y = self.validation_data[1] 163 | 164 | prd_y = prp_2_oh_array(np.asarray(self.model.predict(val_x))) 165 | 166 | warnings.filterwarnings("ignore", category=UndefinedMetricWarning) 167 | precision, recall, f_score, support = precision_recall_fscore_support(val_y, prd_y) 168 | 169 | for p, r, f, s in zip(precision, recall, f_score, support): 170 | print " — val_f1: % 0.4f — val_pre: % 0.4f — val_rec % 0.4f - ins %s" % (f, p, r, s) 171 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Triplet Loss 损失函数 2 | 3 | [Triplet Loss](https://arxiv.org/pdf/1503.03832.pdf)是深度学习中的一种损失函数,用于训练**差异性较小**的样本,如人脸等, Feed数据包括锚(Anchor)示例、正(Positive)示例、负(Negative)示例,通过优化锚示例与正示例的距离**小于**锚示例与负示例的距离,实现样本的相似性计算。 4 | 5 | 6 | 7 | 数据集:[MNIST](http://yann.lecun.com/exdb/mnist/) 8 | 9 | 框架:[DL-Project-Template](https://github.com/SpikeKing/DL-Project-Template) 10 | 11 | 目标:通过Triplet Loss训练模型,实现手写图像的相似性计算。 12 | 13 | [工程](https://github.com/SpikeKing/triplet-loss-mnist):https://github.com/SpikeKing/triplet-loss-mnist 14 | 15 | --- 16 | 17 | ## 模型 18 | 19 | Triplet Loss的核心是锚示例、正示例、负示例共享模型,通过模型,将锚示例与正示例聚类,远离负示例。 20 | 21 | **Triplet Loss Model**的结构如下: 22 | 23 | 24 | 25 | - 输入:三个输入,即锚示例、正示例、负示例,不同示例的**结构**相同; 26 | - 模型:一个共享模型,支持替换为**任意**网络结构; 27 | - 输出:一个输出,即三个模型输出的拼接。 28 | 29 | **Shared Model**选择常用的卷积模型,输出为全连接的128维数据: 30 | 31 | 32 | 33 | Triplet Loss **损失函数**的计算公式如下: 34 | 35 | 36 | 37 | --- 38 | 39 | ## 训练 40 | 41 | 模型参数: 42 | 43 | - batch_size:32 44 | - epochs:2 45 | 46 | 超参数: 47 | 48 | - 边界Margin的值设置为``1``。 49 | 50 | 训练命令: 51 | 52 | ```text 53 | python main_train.py -c configs/triplet_config.json 54 | ``` 55 | 56 | 训练日志: 57 | 58 | ```text 59 | Using TensorFlow backend. 60 | [INFO] 解析配置... 61 | [INFO] 加载数据... 62 | [INFO] X_train.shape: (60000, 28, 28, 1), y_train.shape: (60000, 10) 63 | [INFO] X_test.shape: (10000, 28, 28, 1), y_test.shape: (10000, 10) 64 | [INFO] 构造网络... 65 | [INFO] model - 锚shape: (?, 128) 66 | [INFO] model - 正shape: (?, 128) 67 | [INFO] model - 负shape: (?, 128) 68 | [INFO] model - triplet_loss shape: (?, 1) 69 | [INFO] 训练网络... 70 | [INFO] trainer - 类别数: 10 71 | Train on 54200 samples, validate on 8910 samples 72 | 2018-04-24 10:59:06.130952: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA 73 | Epoch 1/2 74 | 54200/54200 [==============================] - 84s 2ms/step - loss: 0.6083 - val_loss: 0.0515 75 | Epoch 2/2 76 | 54200/54200 [==============================] - 83s 2ms/step - loss: 0.0869 - val_loss: 0.0314 77 | ``` 78 | 79 | 算法收敛较好,Loss线性下降: 80 | 81 | 82 | 83 | TF Graph: 84 | 85 | 86 | 87 | --- 88 | 89 | ## 验证 90 | 91 | **算法效率**(TPS): 每秒48163次 (0.0207625 ms/t) 92 | 93 | 测试命令: 94 | 95 | ```text 96 | python main_test.py -c configs/triplet_config.json 97 | ``` 98 | 99 | 测试日志: 100 | 101 | ```text 102 | Using TensorFlow backend. 103 | [INFO] 解析配置... 104 | [INFO] 加载数据... 105 | [INFO] 预测数据... 106 | 展示数据: (10000, 784) 107 | 展示标签: (10000,) 108 | 日志目录: /Users/wang/workspace/triplet-loss-mnist/experiments/triplet_mnist/logs/default 109 | 2018-04-24 11:02:07.874682: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA 110 | [INFO] model - triplet_loss shape: (?, 1) 111 | __________________________________________________________________________________________________ 112 | Layer (type) Output Shape Param # Connected to 113 | ================================================================================================== 114 | anc_input (InputLayer) (None, 28, 28, 1) 0 115 | __________________________________________________________________________________________________ 116 | pos_input (InputLayer) (None, 28, 28, 1) 0 117 | __________________________________________________________________________________________________ 118 | neg_input (InputLayer) (None, 28, 28, 1) 0 119 | __________________________________________________________________________________________________ 120 | model_1 (Model) (None, 128) 112096 anc_input[0][0] 121 | pos_input[0][0] 122 | neg_input[0][0] 123 | __________________________________________________________________________________________________ 124 | concatenate_1 (Concatenate) (None, 384) 0 model_1[1][0] 125 | model_1[2][0] 126 | model_1[3][0] 127 | ================================================================================================== 128 | Total params: 112,096 129 | Trainable params: 112,096 130 | Non-trainable params: 0 131 | __________________________________________________________________________________________________ 132 | TPS: 272553.829381 (0.003669 ms) 133 | 验证结果结构: (10000, 128) 134 | 展示数据: (10000, 128) 135 | 展示标签: (10000,) 136 | 日志目录: /Users/wang/workspace/triplet-loss-mnist/experiments/triplet_mnist/logs/test 137 | [INFO] 预测完成... 138 | ``` 139 | 140 | MNIST验证集的效果: 141 | 142 | ``` bash 143 | [INFO] trainer - clz 0 144 | [INFO] trainer - distance - min: -15.4567, max: 1.98611, avg: -6.50481 145 | [INFO] acc: 0.996632996633 146 | 147 | [INFO] trainer - clz 1 148 | [INFO] trainer - distance - min: -13.09, max: 3.43779, avg: -6.66867 149 | [INFO] acc: 0.99214365881 150 | 151 | [INFO] trainer - clz 2 152 | [INFO] trainer - distance - min: -14.2524, max: 2.49437, avg: -5.60508 153 | [INFO] acc: 0.991021324355 154 | 155 | [INFO] trainer - clz 3 156 | [INFO] trainer - distance - min: -16.6555, max: 1.21776, avg: -6.32161 157 | [INFO] acc: 0.995510662177 158 | 159 | [INFO] trainer - clz 4 160 | [INFO] trainer - distance - min: -14.193, max: 1.65427, avg: -5.90896 161 | [INFO] acc: 0.991021324355 162 | 163 | [INFO] trainer - clz 5 164 | [INFO] trainer - distance - min: -14.1007, max: 2.01843, avg: -6.36086 165 | [INFO] acc: 0.994388327722 166 | 167 | [INFO] trainer - clz 6 168 | [INFO] trainer - distance - min: -16.8953, max: 2.84421, avg: -8.43978 169 | [INFO] acc: 0.995510662177 170 | 171 | [INFO] trainer - clz 7 172 | [INFO] trainer - distance - min: -16.6177, max: 3.49675, avg: -5.99822 173 | [INFO] acc: 0.989898989899 174 | 175 | [INFO] trainer - clz 8 176 | [INFO] trainer - distance - min: -14.937, max: 3.38141, avg: -5.4424 177 | [INFO] acc: 0.979797979798 178 | 179 | [INFO] trainer - clz 9 180 | [INFO] trainer - distance - min: -16.9519, max: 2.39112, avg: -5.93581 181 | [INFO] acc: 0.985409652076 182 | ``` 183 | 184 | 测试的MNIST分布: 185 | 186 | 187 | 188 | 输出的Triplet Loss MNIST分布: 189 | 190 | 191 | 192 | 本例仅仅使用2个Epoch,也没有特殊设置超参,实际效果仍有提升空间。 193 | 194 | --- 195 | 196 | By C. L. Wang 197 | --------------------------------------------------------------------------------