├── .gitignore ├── motion_detector ├── README.md ├── __init__.py ├── layer_info.py ├── merged_dcl.py └── pic_drawing.py ├── project_utils.py ├── requirements.txt └── root_dir.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Python template 3 | # Byte-compiled / optimized / DLL files 4 | __pycache__/ 5 | *.py[cod] 6 | *$py.class 7 | 8 | # C extensions 9 | *.so 10 | 11 | # Distribution / packaging 12 | .Python 13 | build/ 14 | develop-eggs/ 15 | dist/ 16 | downloads/ 17 | eggs/ 18 | .eggs/ 19 | lib/ 20 | lib64/ 21 | parts/ 22 | sdist/ 23 | var/ 24 | wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .coverage 44 | .coverage.* 45 | .cache 46 | nosetests.xml 47 | coverage.xml 48 | *.cover 49 | .hypothesis/ 50 | 51 | # Translations 52 | *.mo 53 | *.pot 54 | 55 | # Django stuff: 56 | *.log 57 | .static_storage/ 58 | .media/ 59 | local_settings.py 60 | 61 | # Flask stuff: 62 | instance/ 63 | .webassets-cache 64 | 65 | # Scrapy stuff: 66 | .scrapy 67 | 68 | # Sphinx documentation 69 | docs/_build/ 70 | 71 | # PyBuilder 72 | target/ 73 | 74 | # Jupyter Notebook 75 | .ipynb_checkpoints 76 | 77 | # pyenv 78 | .python-version 79 | 80 | # celery beat schedule file 81 | celerybeat-schedule 82 | 83 | # SageMath parsed files 84 | *.sage.py 85 | 86 | # Environments 87 | .env 88 | .venv 89 | env/ 90 | venv/ 91 | ENV/ 92 | env.bak/ 93 | venv.bak/ 94 | 95 | # Spyder project settings 96 | .spyderproject 97 | .spyproject 98 | 99 | # Rope project settings 100 | .ropeproject 101 | 102 | # mkdocs documentation 103 | /site 104 | 105 | # mypy 106 | .mypy_cache/ 107 | 108 | .idea/inspectionProfiles/ 109 | .idea/vcs.xml 110 | .idea/ 111 | 112 | # 隐藏数据集 113 | data/ 114 | test/ 115 | -------------------------------------------------------------------------------- /motion_detector/README.md: -------------------------------------------------------------------------------- 1 | # 基于DeepConvLSTM的手机运动模式识别 2 | 3 | [UCI数据源](https://archive.ics.uci.edu/ml/datasets/human+activity+recognition+using+smartphones) 4 | 5 | - DL:[DeepConvLSTM](http://www.mdpi.com/1424-8220/16/1/115/html) 6 | - Keras:2.1.5 7 | - TensorFlow:1.4.0 8 | 9 | [参考文档](https://www.jianshu.com/p/ecbd6b4f54d2) 10 | -------------------------------------------------------------------------------- /motion_detector/__init__.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | 4 | created by C.L.Wang 5 | """ -------------------------------------------------------------------------------- /motion_detector/layer_info.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | 4 | created by C.L.Wang 5 | """ 6 | import os 7 | 8 | from keras import models 9 | from keras.models import load_model 10 | 11 | from motion_detector.pic_drawing import draw_dict 12 | from project_utils import read_file 13 | from root_dir import ROOT_DIR 14 | import numpy as np 15 | from motion_detector.merged_dcl import load_data 16 | 17 | 18 | def show_layer(): 19 | """ 20 | 绘制层次的信号图 21 | :return: 写入图片 22 | """ 23 | data_path = os.path.join(ROOT_DIR, "data", "UCI_HAR_Dataset") 24 | output_path = os.path.join(ROOT_DIR, "data", "UCI_HAR_Dataset_output") 25 | 26 | lm_path = os.path.join(ROOT_DIR, "data/UCI_HAR_Dataset/activity_labels.txt") 27 | lm_lines = read_file(lm_path) 28 | 29 | label_meaning_dict = dict() 30 | for lm in lm_lines: 31 | [label, meaning] = lm.split(" ") 32 | label_meaning_dict[label] = meaning.rstrip() 33 | 34 | label_path = os.path.join(data_path, "train", "y_train.txt") 35 | 36 | labels = np.loadtxt(label_path) 37 | label_index_dict = dict() 38 | 39 | for label in label_meaning_dict.keys(): 40 | data_index = np.where(labels == int(label))[0][0] 41 | meaning = label_meaning_dict[str(label)] 42 | label_index_dict[meaning] = data_index 43 | 44 | print "label_index_dict: %s" % label_index_dict 45 | 46 | model_path = os.path.join(output_path, "merged_dcl.h5") 47 | model = load_model(model_path) 48 | model.summary() # As a reminder. 49 | # 48是Merge层, 3是卷积层 50 | activation_model = models.Model(inputs=model.input, outputs=[model.layers[48].output]) 51 | X_trainS1, X_trainS2, X_trainS3, Y_train, X_valS1, X_valS2, X_valS3, Y_val = load_data(data_path) 52 | 53 | label_data_dict = dict() 54 | for label in label_index_dict.keys(): 55 | index = label_index_dict[label] 56 | activations = activation_model.predict( 57 | [np.expand_dims(X_trainS1[index], axis=0), np.expand_dims(X_trainS2[index], axis=0), 58 | np.expand_dims(X_trainS3[index], axis=0)]) 59 | 60 | label_data_dict[label] = activations[0] 61 | 62 | color_list = ['red', 'blue', 'green', 'orange', 'purple', 'black'] 63 | out_path = os.path.join(ROOT_DIR, output_path, "merged_layer.png") 64 | draw_dict(label_data_dict, color_list, out_path) 65 | 66 | 67 | if __name__ == '__main__': 68 | show_layer() 69 | -------------------------------------------------------------------------------- /motion_detector/merged_dcl.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | 3 | import os 4 | 5 | import sys 6 | from keras import Input, Model 7 | from keras.callbacks import Callback 8 | from keras.layers import Conv1D, BatchNormalization, MaxPooling1D, Dropout, LSTM, Dense, Activation, np, Concatenate 9 | from keras.utils import to_categorical, plot_model 10 | from sklearn.metrics import precision_recall_fscore_support 11 | 12 | p = os.path.dirname(os.path.dirname((os.path.abspath(__file__)))) 13 | if p not in sys.path: 14 | sys.path.append(p) 15 | 16 | from project_utils import create_folder_try 17 | from root_dir import ROOT_DIR 18 | 19 | 20 | def load_data(data_path): 21 | """ 22 | 加载本地的UCI的训练数据和验证数据 23 | :param data_path 数据集 24 | :return: 训练数据和验证数据 25 | """ 26 | train_path = os.path.join(data_path, "train") 27 | train_X_path = os.path.join(train_path, "Inertial Signals") 28 | 29 | X_trainS1_x = np.loadtxt(os.path.join(train_X_path, "body_acc_x_train.txt")) 30 | X_trainS1_y = np.loadtxt(os.path.join(train_X_path, "body_acc_y_train.txt")) 31 | X_trainS1_z = np.loadtxt(os.path.join(train_X_path, "body_acc_z_train.txt")) 32 | X_trainS1 = np.array([X_trainS1_x, X_trainS1_y, X_trainS1_z]) 33 | X_trainS1 = X_trainS1.transpose([1, 2, 0]) 34 | 35 | X_trainS2_x = np.loadtxt(os.path.join(train_X_path, "body_gyro_x_train.txt")) 36 | X_trainS2_y = np.loadtxt(os.path.join(train_X_path, "body_gyro_y_train.txt")) 37 | X_trainS2_z = np.loadtxt(os.path.join(train_X_path, "body_gyro_z_train.txt")) 38 | X_trainS2 = np.array([X_trainS2_x, X_trainS2_y, X_trainS2_z]) 39 | X_trainS2 = X_trainS2.transpose([1, 2, 0]) 40 | 41 | X_trainS3_x = np.loadtxt(os.path.join(train_X_path, "total_acc_x_train.txt")) 42 | X_trainS3_y = np.loadtxt(os.path.join(train_X_path, "total_acc_y_train.txt")) 43 | X_trainS3_z = np.loadtxt(os.path.join(train_X_path, "total_acc_z_train.txt")) 44 | X_trainS3 = np.array([X_trainS3_x, X_trainS3_y, X_trainS3_z]) 45 | X_trainS3 = X_trainS3.transpose([1, 2, 0]) 46 | 47 | Y_train = np.loadtxt(os.path.join(train_path, "y_train.txt")) 48 | Y_train = to_categorical(Y_train - 1.0) # 标签是从1开始 49 | 50 | print "训练数据: " 51 | print "传感器1: %s, 传感器1的X轴: %s" % (str(X_trainS1.shape), str(X_trainS1_x.shape)) 52 | print "传感器2: %s, 传感器2的X轴: %s" % (str(X_trainS2.shape), str(X_trainS2_x.shape)) 53 | print "传感器3: %s, 传感器3的X轴: %s" % (str(X_trainS3.shape), str(X_trainS3_x.shape)) 54 | print "传感器标签: %s" % str(Y_train.shape) 55 | print "" 56 | 57 | test_path = os.path.join(data_path, "test") 58 | test_X_path = os.path.join(test_path, "Inertial Signals") 59 | 60 | X_valS1_x = np.loadtxt(os.path.join(test_X_path, "body_acc_x_test.txt")) 61 | X_valS1_y = np.loadtxt(os.path.join(test_X_path, "body_acc_y_test.txt")) 62 | X_valS1_z = np.loadtxt(os.path.join(test_X_path, "body_acc_z_test.txt")) 63 | X_valS1 = np.array([X_valS1_x, X_valS1_y, X_valS1_z]) 64 | X_valS1 = X_valS1.transpose([1, 2, 0]) 65 | 66 | X_valS2_x = np.loadtxt(os.path.join(test_X_path, "body_gyro_x_test.txt")) 67 | X_valS2_y = np.loadtxt(os.path.join(test_X_path, "body_gyro_y_test.txt")) 68 | X_valS2_z = np.loadtxt(os.path.join(test_X_path, "body_gyro_z_test.txt")) 69 | X_valS2 = np.array([X_valS2_x, X_valS2_y, X_valS2_z]) 70 | X_valS2 = X_valS2.transpose([1, 2, 0]) 71 | 72 | X_valS3_x = np.loadtxt(os.path.join(test_X_path, "total_acc_x_test.txt")) 73 | X_valS3_y = np.loadtxt(os.path.join(test_X_path, "total_acc_y_test.txt")) 74 | X_valS3_z = np.loadtxt(os.path.join(test_X_path, "total_acc_z_test.txt")) 75 | X_valS3 = np.array([X_valS3_x, X_valS3_y, X_valS3_z]) 76 | X_valS3 = X_valS3.transpose([1, 2, 0]) 77 | 78 | Y_val = np.loadtxt(os.path.join(test_path, "y_test.txt")) 79 | Y_val = to_categorical(Y_val - 1.0) 80 | 81 | print "验证数据: " 82 | print "传感器1: %s, 传感器1的X轴: %s" % (str(X_valS1.shape), str(X_valS1.shape)) 83 | print "传感器2: %s, 传感器2的X轴: %s" % (str(X_valS2.shape), str(X_valS2.shape)) 84 | print "传感器3: %s, 传感器3的X轴: %s" % (str(X_valS3.shape), str(X_valS3.shape)) 85 | print "传感器标签: %s" % str(Y_val.shape) 86 | print "\n" 87 | 88 | return X_trainS1, X_trainS2, X_trainS3, Y_train, X_valS1, X_valS2, X_valS3, Y_val 89 | 90 | 91 | class Metrics(Callback): 92 | """ 93 | 输出F, P, R 94 | """ 95 | 96 | def on_train_begin(self, logs=None): 97 | pass 98 | 99 | def on_epoch_end(self, batch, logs=None): 100 | val_targ = self.validation_data[-3] 101 | 102 | val_value = [x for x in self.validation_data[0:-3]] 103 | y_pred = np.asarray(self.model.predict(val_value)) 104 | 105 | precision, recall, f_score, _ = precision_recall_fscore_support( 106 | val_targ, (y_pred > 0.5).astype(int), average='micro') 107 | print "— val_f1: % f — val_precision: % f — val_recall % f" % (f_score, precision, recall) 108 | 109 | 110 | def main(data_path, output_path): 111 | X_trainS1, X_trainS2, X_trainS3, Y_train, X_valS1, X_valS2, X_valS3, Y_val = load_data(data_path) 112 | 113 | epochs = 100 114 | batch_size = 256 115 | kernel_size = 3 116 | pool_size = 2 117 | dropout_rate = 0.15 118 | n_classes = 6 119 | 120 | f_act = 'relu' 121 | 122 | # 三个子模型的输入数据 123 | main_input1 = Input(shape=(128, 3), name='main_input1') 124 | main_input2 = Input(shape=(128, 3), name='main_input2') 125 | main_input3 = Input(shape=(128, 3), name='main_input3') 126 | 127 | def cnn_lstm_cell(main_input): 128 | """ 129 | 基于DeepConvLSTM算法, 创建子模型 130 | :param main_input: 输入数据 131 | :return: 子模型 132 | """ 133 | sub_model = Conv1D(512, kernel_size, input_shape=(128, 3), activation=f_act, padding='same')(main_input) 134 | sub_model = BatchNormalization()(sub_model) 135 | sub_model = MaxPooling1D(pool_size=pool_size)(sub_model) 136 | sub_model = Dropout(dropout_rate)(sub_model) 137 | sub_model = Conv1D(64, kernel_size, activation=f_act, padding='same')(sub_model) 138 | sub_model = BatchNormalization()(sub_model) 139 | sub_model = MaxPooling1D(pool_size=pool_size)(sub_model) 140 | sub_model = Dropout(dropout_rate)(sub_model) 141 | sub_model = Conv1D(32, kernel_size, activation=f_act, padding='same')(sub_model) 142 | sub_model = BatchNormalization()(sub_model) 143 | sub_model = MaxPooling1D(pool_size=pool_size)(sub_model) 144 | sub_model = LSTM(128, return_sequences=True)(sub_model) 145 | sub_model = LSTM(128, return_sequences=True)(sub_model) 146 | sub_model = LSTM(128)(sub_model) 147 | main_output = Dropout(dropout_rate)(sub_model) 148 | return main_output 149 | 150 | first_model = cnn_lstm_cell(main_input1) 151 | second_model = cnn_lstm_cell(main_input2) 152 | third_model = cnn_lstm_cell(main_input3) 153 | 154 | model = Concatenate()([first_model, second_model, third_model]) # 合并模型 155 | model = Dropout(0.4)(model) 156 | model = Dense(n_classes)(model) 157 | model = BatchNormalization()(model) 158 | output = Activation('softmax', name="softmax")(model) 159 | 160 | model = Model([main_input1, main_input2, main_input3], output) 161 | model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy']) 162 | 163 | graph_path = os.path.join(output_path, "merged_model.png") 164 | plot_model(model, to_file=graph_path, show_shapes=True) # 绘制模型图 165 | 166 | metrics = Metrics() # 度量FPR 167 | history = model.fit([X_trainS1, X_trainS2, X_trainS3], Y_train, 168 | batch_size=batch_size, 169 | validation_data=([X_valS1, X_valS2, X_valS3], Y_val), 170 | epochs=epochs, 171 | callbacks=[metrics]) # 增加FPR输出 172 | 173 | model_path = os.path.join(output_path, "merged_dcl.h5") 174 | model.save(model_path) # 存储模型 175 | print history.history 176 | 177 | 178 | if __name__ == '__main__': 179 | data = os.path.join(ROOT_DIR, "data", "UCI_HAR_Dataset") 180 | output = os.path.join(ROOT_DIR, "data", "UCI_HAR_Dataset_output") 181 | create_folder_try(output) 182 | main(data_path=data, output_path=output) 183 | -------------------------------------------------------------------------------- /motion_detector/pic_drawing.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | 绘制传感器数据 4 | 5 | created by C.L.Wang 6 | """ 7 | import matplotlib.pyplot as plt 8 | from cycler import cycler 9 | 10 | from project_utils import * 11 | from root_dir import ROOT_DIR 12 | 13 | 14 | def show_data_of_motions(): 15 | """ 16 | 显示不同运动的曲线 17 | :return: 曲线图 18 | """ 19 | lm_path = os.path.join(ROOT_DIR, "data/UCI_HAR_Dataset/activity_labels.txt") 20 | lm_lines = read_file(lm_path) 21 | 22 | label_meaning_dict = dict() 23 | for lm in lm_lines: 24 | [label, meaning] = lm.split(" ") 25 | label_meaning_dict[label] = meaning.rstrip() 26 | 27 | train_dir = os.path.join(ROOT_DIR, "data/UCI_HAR_Dataset/train") 28 | data_path = os.path.join(train_dir, "Inertial Signals/body_acc_x_train.txt") 29 | label_path = os.path.join(train_dir, "y_train.txt") 30 | 31 | data = np.loadtxt(data_path) 32 | labels = np.loadtxt(label_path) 33 | label_data_dict = dict() 34 | 35 | for label in label_meaning_dict.keys(): 36 | data_index = np.where(labels == int(label))[0][0] 37 | meaning = label_meaning_dict[str(label)] 38 | label_data_dict[meaning] = data[data_index] 39 | 40 | color_list = ['red', 'blue', 'green', 'orange', 'purple', 'black'] 41 | out_path = os.path.join(ROOT_DIR, "data/UCI_HAR_Dataset_output/activity_lines.png") 42 | draw_dict(label_data_dict, color_list, out_path) 43 | 44 | 45 | def show_data_of_sensors(): 46 | """ 47 | 显示不同传感器的曲线 48 | :return: 曲线图 49 | """ 50 | paths, names = listdir_files(os.path.join(ROOT_DIR, "data/UCI_HAR_Dataset/train/Inertial Signals")) 51 | label_data_dict = dict() 52 | for path, name in zip(paths, names): 53 | data = np.loadtxt(path)[27] 54 | label = name.rstrip('.txt') 55 | label_data_dict[label] = data 56 | color_list = ['red', 'blue', 'green', 'orange', 'purple', 'black', 'hotpink', 'brown', 'crimson'] 57 | out_path = os.path.join(ROOT_DIR, "data/UCI_HAR_Dataset_output/sensors_lines.png") 58 | draw_dict(label_data_dict, color_list, out_path) 59 | 60 | 61 | def draw_dict(label_data_dict, color_list, out_path): 62 | """ 63 | 绘制字典中的数据 64 | :param label_data_dict: 数据字典 65 | :param color_list: 颜色列表 66 | :param out_path: 输出文件 67 | :return: None 68 | """ 69 | label_data_list = sorted(label_data_dict.items()) # 字典排序 70 | 71 | plt.gca().set_prop_cycle(cycler('color', color_list)) # 图片颜色 72 | plt.gcf().set_size_inches(18.5, 10.5) # 图片尺寸 73 | 74 | label_list = [] 75 | for label, data in label_data_list: 76 | plt.plot(data) # 写入绘制数据 77 | label_list.append(label) 78 | 79 | plt.legend(label_list, loc='upper left') # 字典标签 80 | plt.grid() # 网格 81 | plt.savefig(out_path) # 写入图片 82 | plt.show() 83 | 84 | 85 | if __name__ == '__main__': 86 | show_data_of_motions() 87 | show_data_of_sensors() 88 | -------------------------------------------------------------------------------- /project_utils.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | 项目的工具类 4 | 5 | created by C.L.Wang 6 | """ 7 | 8 | from __future__ import absolute_import 9 | 10 | import collections 11 | import glob 12 | import json 13 | import operator 14 | import os 15 | import random 16 | import re 17 | import shutil 18 | import time 19 | from datetime import timedelta, datetime 20 | from itertools import izip 21 | 22 | import numpy as np 23 | from dateutil.relativedelta import relativedelta 24 | 25 | FORMAT_DATE = '%Y%m%d' 26 | FORMAT_DATE_2 = '%Y-%m-%d' 27 | FORMAT_DATE_3 = '%Y%m%d%H%M%S' 28 | 29 | 30 | def ensure_unicode(data): 31 | if not data: 32 | return u'' 33 | if isinstance(data, unicode): 34 | return data 35 | else: 36 | return data.decode('utf-8') 37 | 38 | 39 | def datetime_to_str(date, date_format=FORMAT_DATE): 40 | return date.strftime(date_format) 41 | 42 | 43 | def str_to_datetime(date_str, date_format=FORMAT_DATE): 44 | date = time.strptime(date_str, date_format) 45 | return datetime(*date[:6]) 46 | 47 | 48 | def get_next_half_year(): 49 | """ 50 | 当前时间的半年前 51 | :return: 半年时间 52 | """ 53 | n_days = datetime.now() - timedelta(days=178) 54 | return n_days.strftime('%Y-%m-%d') 55 | 56 | 57 | def get_person_age(born_date, is_log=False): 58 | """ 59 | 根据出生日期, 获取年龄 60 | :param born_date: 出生日期 61 | :param is_log: 日志显示 62 | :return: 异常返回-1, 其他正常 63 | """ 64 | try: 65 | if is_log: 66 | print "当前时间: %s, 出生日期: %s" % (datetime.now(), datetime.fromtimestamp(born_date / 1000.0)) 67 | time_data = relativedelta(datetime.now(), datetime.fromtimestamp(born_date / 1000.0)) 68 | years = float(time_data.years) 69 | months = float(time_data.months) 70 | age = (years * 12.0 + months) / 12.0 71 | if age < 0: 72 | print "获取年龄异常: %s" % age 73 | return -1.0 74 | else: 75 | return age 76 | except Exception as e: 77 | print "获取年龄异常: %s" % e 78 | return -1.0 79 | 80 | 81 | def timestr_2_timestamp(time_str): 82 | """ 83 | 时间字符串转换为毫秒 84 | :param time_str: 时间字符串, 2017-10-11 85 | :return: 毫秒, 如1443715200000 86 | """ 87 | return int(time.mktime(datetime.strptime(time_str, "%Y-%m-%d").timetuple()) * 1000) 88 | 89 | 90 | def create_folder_try(atp_out_dir, is_delete=False): 91 | """ 92 | 创建文件夹 93 | :param atp_out_dir: 文件夹 94 | :param is_delete: 是否删除 95 | :return: 96 | """ 97 | if is_delete: 98 | if os.path.exists(atp_out_dir): 99 | shutil.rmtree(atp_out_dir) 100 | print '文件夹 "%s" 存在,删除文件夹。' % atp_out_dir 101 | 102 | if not os.path.exists(atp_out_dir): 103 | os.makedirs(atp_out_dir) 104 | print '文件夹 "%s" 不存在,创建文件夹。' % atp_out_dir 105 | 106 | 107 | def create_file(file_name): 108 | """ 109 | 创建文件 110 | :param file_name: 文件名 111 | :return: None 112 | """ 113 | if os.path.exists(file_name): 114 | print "文件存在,删除文件:%s" % file_name 115 | os.remove(file_name) # 删除已有文件 116 | if not os.path.exists(file_name): 117 | print "文件不存在,创建文件:%s" % file_name 118 | open(file_name, 'a').close() 119 | 120 | 121 | def remove_punctuation(line): 122 | """ 123 | 去除所有半角全角符号,只留字母、数字、中文 124 | :param line: 125 | :return: 126 | """ 127 | rule = re.compile(ur"[^a-zA-Z0-9\u4e00-\u9fa5]") 128 | line = rule.sub('', line) 129 | return line 130 | 131 | 132 | def check_punctuation(word): 133 | pattern = re.compile(ur"[^a-zA-Z0-9\u4e00-\u9fa5]") 134 | if pattern.search(word): 135 | return True 136 | else: 137 | return False 138 | 139 | 140 | def clean_text(text): 141 | """ 142 | text = "hello world nice ok done \n\n\n hhade\t\rjdla" 143 | result = hello world nice ok done hhade jdla 144 | 将多个空格换成一个 145 | :param text: 146 | :return: 147 | """ 148 | if not text: 149 | return '' 150 | return re.sub(r"\s+", " ", text) 151 | 152 | 153 | def merge_files(folder, merge_file): 154 | """ 155 | 将多个文件合并为一个文件 156 | :param folder: 文件夹 157 | :param merge_file: 合并后的文件 158 | :return: 159 | """ 160 | paths, _ = listdir_files(folder) 161 | with open(merge_file, 'w') as outfile: 162 | for file_path in paths: 163 | with open(file_path) as infile: 164 | for line in infile: 165 | outfile.write(line) 166 | 167 | 168 | def random_pick(some_list, probabilities): 169 | """ 170 | 根据概率随机获取元素 171 | :param some_list: 元素列表 172 | :param probabilities: 概率列表 173 | :return: 当前元素 174 | """ 175 | x = random.uniform(0, 1) 176 | cumulative_probability = 0.0 177 | item = some_list[0] 178 | for item, item_probability in zip(some_list, probabilities): 179 | cumulative_probability += item_probability 180 | if x < cumulative_probability: 181 | break 182 | return item 183 | 184 | 185 | def intersection_of_lists(l1, l2): 186 | """ 187 | 两个list的交集 188 | :param l1: 189 | :param l2: 190 | :return: 191 | """ 192 | return list(set(l1).intersection(set(l2))) 193 | 194 | 195 | def safe_div(x, y): 196 | """ 197 | 安全除法 198 | :param x: 分子 199 | :param y: 分母 200 | :return: 除法 201 | """ 202 | x = float(x) 203 | y = float(y) 204 | if y == 0.0: 205 | return 0.0 206 | return x / y 207 | 208 | 209 | def calculate_percent(x, y): 210 | """ 211 | 计算百分比 212 | :param x: 分子 213 | :param y: 分母 214 | :return: 百分比 215 | """ 216 | x = float(x) 217 | y = float(y) 218 | return safe_div(x, y) * 100 219 | 220 | 221 | def invert_dict(d): 222 | """ 223 | 当字典的元素不重复时, 反转字典 224 | :param d: 字典 225 | :return: 反转后的字典 226 | """ 227 | return dict((v, k) for k, v in d.iteritems()) 228 | 229 | 230 | def init_num_dict(): 231 | """ 232 | 初始化值是int的字典 233 | :return: 234 | """ 235 | return collections.defaultdict(int) 236 | 237 | 238 | def sort_dict_by_value(dict_, reverse=True): 239 | """ 240 | 按照values排序字典 241 | :param dict_: 待排序字典 242 | :param reverse: 默认从大到小 243 | :return: 排序后的字典 244 | """ 245 | return sorted(dict_.items(), key=operator.itemgetter(1), reverse=reverse) 246 | 247 | 248 | def get_current_time_str(): 249 | """ 250 | 输入当天的日期格式, 20170718_1137 251 | :return: 20170718_1137 252 | """ 253 | return datetime.now().strftime('%Y%m%d%H%M%S') 254 | 255 | 256 | def get_current_time_for_show(): 257 | """ 258 | 输入当天的日期格式, 20170718_1137 259 | :return: 20170718_1137 260 | """ 261 | return datetime.now().strftime('%Y-%m-%d %H:%M:%S') 262 | 263 | 264 | def get_current_day_str(): 265 | """ 266 | 输入当天的日期格式, 20170718 267 | :return: 20170718 268 | """ 269 | return datetime.now().strftime('%Y%m%d') 270 | 271 | 272 | def remove_line_of_file(ex_line, file_name): 273 | ex_line = ex_line.replace('\n', '') 274 | lines = read_file(file_name) 275 | 276 | out_file = open(file_name, "w") 277 | for line in lines: 278 | line = line.replace('\n', '') # 确认编码格式 279 | if line != ex_line: 280 | out_file.write(line + '\n') 281 | out_file.close() 282 | 283 | 284 | def map_to_ordered_list(data_dict, reverse=True): 285 | """ 286 | 将字段根据Key的值转换为有序列表 287 | :param data_dict: 字典 288 | :param reverse: 默认从大到小 289 | :return: 有序列表 290 | """ 291 | return sorted(data_dict.items(), key=operator.itemgetter(1), reverse=reverse) 292 | 293 | 294 | def map_to_index_list(data_list, all_list): 295 | """ 296 | 转换为one-hot形式 297 | :param data_list: 298 | :param all_list: 299 | :return: 300 | """ 301 | index_dict = {l.strip(): i for i, l in enumerate(all_list)} # 字典 302 | index = index_dict[data_list.strip()] 303 | index_list = np.zeros(len(all_list), np.float32) 304 | index_list[index] = 1 305 | return index_list 306 | 307 | 308 | def map_to_index(data_list, all_list): 309 | """ 310 | 转换为one-hot形式 311 | :param data_list: 312 | :param all_list: 313 | :return: 314 | """ 315 | index_dict = {l.strip(): i for i, l in enumerate(all_list)} # 字典 316 | index = index_dict[data_list.strip()] 317 | return index 318 | 319 | 320 | def n_lines_of_file(file_name): 321 | """ 322 | 获取文件行数 323 | :param file_name: 文件名 324 | :return: 数量 325 | """ 326 | return sum(1 for _ in open(file_name)) 327 | 328 | 329 | def remove_file(file_name): 330 | """ 331 | 删除文件 332 | :param file_name: 文件名 333 | :return: 删除文件 334 | """ 335 | if os.path.exists(file_name): 336 | os.remove(file_name) 337 | 338 | 339 | def find_sub_in_str(string, sub_str): 340 | """ 341 | 子字符串的起始位置 342 | :param string: 字符串 343 | :param sub_str: 子字符串 344 | :return: 当前字符串 345 | """ 346 | return [m.start() for m in re.finditer(sub_str, string)] 347 | 348 | 349 | def list_has_sub_str(string_list, sub_str): 350 | """ 351 | 字符串是否在子字符串中 352 | :param string_list: 字符串列表 353 | :param sub_str: 子字符串列表 354 | :return: 是否在其中 355 | """ 356 | for string in string_list: 357 | if sub_str in string: 358 | return True 359 | return False 360 | 361 | 362 | def remove_last_char(str_value, num): 363 | """ 364 | 删除最后的字符串 365 | :param str_value: 字符串 366 | :param num: 删除位置 367 | :return: 新的字符串 368 | """ 369 | str_list = list(str_value) 370 | return "".join(str_list[:(-1 * num)]) 371 | 372 | 373 | def read_file(data_file, mode='more'): 374 | """ 375 | 读文件, 原文件和数据文件 376 | :return: 单行或数组 377 | """ 378 | try: 379 | with open(data_file, 'r') as f: 380 | if mode == 'one': 381 | output = f.read() 382 | return output 383 | elif mode == 'more': 384 | output = f.readlines() 385 | return map(str.strip, output) 386 | else: 387 | return list() 388 | except IOError: 389 | return list() 390 | 391 | 392 | def find_word_position(original, word): 393 | """ 394 | 查询字符串的位置 395 | :param original: 原始字符串 396 | :param word: 单词 397 | :return: [起始位置, 终止位置] 398 | """ 399 | u_original = original.decode('utf-8') 400 | u_word = word.decode('utf-8') 401 | start_indexes = find_sub_in_str(u_original, u_word) 402 | end_indexes = [x + len(u_word) - 1 for x in start_indexes] 403 | return zip(start_indexes, end_indexes) 404 | 405 | 406 | def write_line(file_name, line): 407 | """ 408 | 将行数据写入文件 409 | :param file_name: 文件名 410 | :param line: 行数据 411 | :return: None 412 | """ 413 | if file_name == "": 414 | return 415 | with open(file_name, "a+") as fs: 416 | if type(line) is (tuple or list): 417 | fs.write("%s\n" % ", ".join(line)) 418 | else: 419 | fs.write("%s\n" % line) 420 | 421 | 422 | def show_set(data_set): 423 | """ 424 | 显示集合数据 425 | :param data_set: 数据集 426 | :return: None 427 | """ 428 | data_list = list(data_set) 429 | show_string(data_list) 430 | 431 | 432 | def show_string(obj): 433 | """ 434 | 用于显示UTF-8字符串, 尤其是含有中文的. 435 | :param obj: 输入对象, 可以是列表或字典 436 | :return: None 437 | """ 438 | print list_2_utf8(obj) 439 | 440 | 441 | def list_2_utf8(obj): 442 | """ 443 | 用于显示list汉字 444 | :param obj: 445 | :return: 446 | """ 447 | return json.dumps(obj, encoding="UTF-8", ensure_ascii=False) 448 | 449 | 450 | def grouped_list(iterable, n): 451 | """ 452 | "s -> (s0,s1,s2,...sn-1), (sn,sn+1,sn+2,...s2n-1), (s2n,s2n+1,s2n+2,...s3n-1), ..." 453 | 例子: 454 | for x, y in grouped(l, 2): 455 | print "%d + %d = %d" % (x, y, x + y) 456 | 457 | :param iterable: 迭代器 458 | :param n: 间隔 459 | :return: 组合 460 | """ 461 | return izip(*[iter(iterable)] * n) 462 | 463 | 464 | def listdir_no_hidden(root_dir): 465 | """ 466 | 显示顶层文件夹 467 | :param root_dir: 根目录 468 | :return: 文件夹列表 469 | """ 470 | return glob.glob(os.path.join(root_dir, '*')) 471 | 472 | 473 | def listdir_files(root_dir, ext=None): 474 | """ 475 | 列出文件夹中的文件 476 | :param root_dir: 根目录 477 | :param ext: 类型 478 | :return: [文件路径(相对路径), 文件夹名称, 文件名称] 479 | """ 480 | names_list = [] 481 | paths_list = [] 482 | for parent, _, fileNames in os.walk(root_dir): 483 | for name in fileNames: 484 | if name.startswith('.'): 485 | continue 486 | if ext: 487 | if name.endswith(tuple(ext)): 488 | names_list.append(name) 489 | paths_list.append(os.path.join(parent, name)) 490 | else: 491 | names_list.append(name) 492 | paths_list.append(os.path.join(parent, name)) 493 | return paths_list, names_list 494 | 495 | 496 | def time_elapsed(start, end): 497 | """ 498 | 输出时间 499 | :param start: 开始 500 | :param end: 结束 501 | :return: 502 | """ 503 | hours, rem = divmod(end - start, 3600) 504 | minutes, seconds = divmod(rem, 60) 505 | return "{:0>2}:{:0>2}:{:05.2f}".format(int(hours), int(minutes), seconds) 506 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | backports.functools-lru-cache==1.5 2 | backports.weakref==1.0.post1 3 | bleach==1.5.0 4 | cycler==0.10.0 5 | decorator==4.2.1 6 | enum34==1.1.6 7 | funcsigs==1.0.2 8 | futures==3.2.0 9 | graphviz==0.8.2 10 | h5py==2.7.1 11 | html5lib==0.9999999 12 | Keras==2.1.5 13 | keras-vis==0.4.1 14 | kiwisolver==1.0.1 15 | Markdown==2.6.11 16 | matplotlib==2.2.2 17 | mock==2.0.0 18 | networkx==2.1 19 | numpy==1.14.2 20 | pbr==4.0.1 21 | Pillow==5.1.0 22 | protobuf==3.5.2.post1 23 | pydot==1.2.4 24 | pyparsing==2.2.0 25 | python-dateutil==2.7.2 26 | pytz==2018.3 27 | PyWavelets==0.5.2 28 | PyYAML==3.12 29 | scikit-image==0.13.1 30 | scikit-learn==0.19.1 31 | scipy==1.0.1 32 | six==1.11.0 33 | sklearn==0.0 34 | subprocess32==3.2.7 35 | tensorflow==1.4.0 36 | tensorflow-tensorboard==0.4.0 37 | Werkzeug==0.14.1 38 | -------------------------------------------------------------------------------- /root_dir.py: -------------------------------------------------------------------------------- 1 | # -- coding: utf-8 -- 2 | """ 3 | 4 | created by C.L.Wang 5 | """ 6 | import os 7 | 8 | ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) 9 | --------------------------------------------------------------------------------