├── 1.preprocess.py ├── 2.relatives.py ├── 3.cluster.py ├── 4.validate.py ├── README.md ├── apriori.py ├── classical.csv ├── hier.py └── utils.py /1.preprocess.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | #from treat import * 3 | import re 4 | import pandas as pd 5 | import numpy as np 6 | import itertools 7 | import utils 8 | 9 | #1.1 初步清洗目标列的数据,并返回 10 | def preprocess(column_name, data): 11 | for i in range(length): 12 | item = data[column_name].loc[i] 13 | if column_name == '处方(处理)': 14 | item = re.sub(r'\S钱', '', item) 15 | item = re.sub(r'[\或\以及]', '', item) 16 | pre = re.sub(r'\(.*?\)|{.*?\}', '', item) 17 | pre = re.split(r'[\d\.\!\……\。\;\,\,]', pre) 18 | pre = [x for x in pre if x != ''] 19 | data[column_name].loc[i] = pre 20 | return data 21 | 22 | #1.2 判断每一条数据的长度,这里阈值设为5;过长的数据打印出来,由此设置manual_process(人工处理) 23 | def check(name, data): 24 | for i in range(length): 25 | item = data[name][i] 26 | for n in item: 27 | if len(n) > 5: 28 | print(n, i, item) 29 | print('-'*50) 30 | 31 | def manual_process(name, data): 32 | data[name][16][0] = '桃仁' 33 | data[name][16].append('红花') 34 | data[name][33][-1] = '补骨脂' 35 | data[name][51][0] = '川芎' 36 | data[name][51][1:] = ['阿胶', '甘草', '艾叶', '当归', '芍药', '干地黄'] 37 | return data 38 | 39 | #2.1 计算方剂所报含的所有药物各自的频率,以字典保存 40 | def count_dic(column_name, data_length, data): 41 | count_dic = {} 42 | for i in range(data_length): 43 | for per in data[column_name][i]: 44 | count_dic[per] = (count_dic[per] if per in count_dic else 0) + 1 45 | return count_dic 46 | 47 | #2.2 按value排列字典,这里指频率降序排列药物,并将药物名称和频率保存各自保存到list 48 | def dic_list(dic): 49 | list_name, list_frequecy = [], [] 50 | reversed_list = sorted(dic.items(),key = lambda x:x[1],reverse = True) 51 | for i in reversed_list: 52 | list_name.append(i[0]) 53 | list_frequecy.append(i[-1]) 54 | return list_name, list_frequecy 55 | 56 | # 3. 生成关于药物的one-hot,即每条药方根据药物有无 标 1or0 57 | def one_hot(list_name, input_data, column_name = '处方(处理)'): 58 | df = pd.DataFrame(np.zeros((length, len(list_name))), columns= list_name) 59 | for i in range(length): 60 | for name in input_data[column_name].loc[i]: 61 | df[name].loc[i] = 1 62 | return df 63 | 64 | # 4. 药物两两组合,计算组合频率,返回字典格式 65 | def combinations_dic(data_length, input_list_name, one_hot_data): 66 | combinations = list(itertools.combinations(range(len(input_list_name)),2)) 67 | combinations_fre = {} 68 | for i in range(data_length): 69 | for item in combinations: 70 | pre, suf = item 71 | if one_hot_data[input_list_name[pre]][i] == 1 and one_hot_data[input_list_name[suf]][i] == 1: 72 | combinations_fre[item] = (combinations_fre[item] if item in combinations_fre else 0) + 1 73 | return combinations_fre 74 | 75 | 76 | if __name__ == "__main__": 77 | data = pd.read_csv(open('classical.csv')) 78 | length = data.shape[0] 79 | data = preprocess('处方(处理)', data) 80 | check('处方(处理)', data) 81 | data = manual_process('处方(处理)', data) 82 | 83 | count_dic = count_dic('处方(处理)', length, data) 84 | list_name, list_frequency = dic_list(count_dic) #l1药物名:168,list_frequecy频数:474 85 | one_hot_df = one_hot(list_name, data) 86 | one_hot_df.to_csv('one_hot_df.csv', index = False, encoding = 'utf-8') 87 | combinations_dic_fre = combinations_dic(length, list_name, one_hot_df) 88 | combinations_list, combinations_frequency = dic_list(combinations_dic_fre) #l3药物组合index,l4频数 89 | list_fre = [i/sum(list_frequency) for i in list_frequency] 90 | combinations_fre = [i/sum(list_frequency) for i in combinations_frequency] 91 | 92 | 93 | utils.save_pickle('list_name.txt', list_name) 94 | utils.save_pickle('list_fre.txt', list_fre) 95 | utils.save_pickle('combinations_list.txt', combinations_list) 96 | utils.save_pickle('combinations_fre.txt', combinations_fre) 97 | 98 | # ============================================================================= 99 | # list_1 = [] 100 | # for item in data['处方(处理)']: 101 | # list_1.append(item) 102 | # utils.save_pickle('data.txt', list_1) 103 | # ============================================================================= 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /2.relatives.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pandas as pd 3 | from math import log 4 | import utils 5 | 6 | def comb_names(list_name): 7 | combinations_name = [] #组合药物名 8 | for i in combinations_list: 9 | pre, suf = i 10 | combinations_name.append([list_name[pre], list_name[suf]]) 11 | return combinations_name 12 | 13 | def calculate_correlation(combinations_list, combinations_fre, list_fre): 14 | correlation = [] #关联度系数 15 | for i in range(len(combinations_list)): 16 | flag_1, flag_2, flag_3 = 1, 1, 1 17 | pre, suf = combinations_list[i] 18 | H_pre = -list_fre[pre] * log(list_fre[pre]) - (1- list_fre[pre]) * log((1- list_fre[pre])) 19 | H_suf = -list_fre[suf] * log(list_fre[suf]) - (1- list_fre[suf]) * log((1- list_fre[suf])) 20 | param_1 = list_fre[pre] - combinations_fre[i] 21 | param_2 = list_fre[suf] - combinations_fre[i] 22 | param_3 = 1 + combinations_fre[i] - list_fre[pre] -list_fre[suf] 23 | if param_1 ==0: 24 | flag_1 = 0 25 | param_1 = 1 26 | if param_2 ==0: 27 | flag_2 = 0 28 | param_2 = 1 29 | if param_3 ==0: 30 | flag_3 = 0 31 | param_3 = 1 32 | H_pre_suf = -combinations_fre[i] * log(combinations_fre[i]) - flag_1 * param_1 * log(param_1) - flag_2 * param_2 * log(param_2) - flag_3 * param_3 * log(param_3) 33 | result = H_pre + H_suf - H_pre_suf 34 | #result = combinations_fre[i] * log(combinations_fre[i]/(list_fre[pre]*list_fre[suf])) 35 | correlation.append(result) 36 | return correlation 37 | 38 | def relatives(list_name, data, relatives_num): 39 | relatives_list = [] #药物亲友团 40 | length = data.shape[0] 41 | for item in list_name: 42 | list_ = [] 43 | for i in range(length): 44 | if item in data['药物'][i]: 45 | list_.append(data['药物'][i]) 46 | relatives_list.append(list_) 47 | def limited_relatives(input_list, relatives_num): 48 | result = [] 49 | for i in range(len(input_list)): 50 | if len(input_list[i]) >= relatives_num: 51 | result.append(input_list[i][:relatives_num]) 52 | else: 53 | result.append(input_list[i]) 54 | return result 55 | return limited_relatives(relatives_list, relatives_num) 56 | 57 | 58 | if __name__ == "__main__": 59 | list_name = utils.load_pickle('list_name.txt') 60 | list_fre = utils.load_pickle('list_fre.txt') 61 | combinations_list = utils.load_pickle('combinations_list.txt') 62 | combinations_fre = utils.load_pickle('combinations_fre.txt') 63 | correlation = calculate_correlation(combinations_list, combinations_fre, list_fre) 64 | combinations_name = comb_names(list_name) 65 | column_1 = pd.Series(combinations_name, name='药物') 66 | column_2 = pd.Series(correlation, name='关联度系数') 67 | data = pd.concat([column_1, column_2], axis=1) 68 | data = data.sort_values(by = '关联度系数', ascending=False) 69 | data.to_csv('rel2.csv', index = False, encoding = 'utf-8') 70 | relatives_list = relatives(list_name, data, 5) 71 | utils.save_pickle('relatives_list.txt', relatives_list) -------------------------------------------------------------------------------- /3.cluster.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import utils 4 | import pandas as pd 5 | import itertools 6 | 7 | 8 | #[[a,b], [a,c], [a,d]]变为[b, c, d] 9 | def duplicate_removal(relatives_list): 10 | result = [] 11 | for item in relatives_list: 12 | new2 = [] 13 | for n in item: 14 | for q in n: 15 | new2.append(q) 16 | guodu = list(set(new2)) 17 | guodu.sort(key = new2.index) 18 | if guodu != []: 19 | guodu.remove(list_name[relatives_list.index(item)]) 20 | result.append(guodu) 21 | return result 22 | 23 | #两两强相关的组合,这里作为判别其他组合的基础,即是strong_related(判别强相关)的基础 24 | def cluster_2_members(qyt): 25 | cluster_list = [] 26 | for i in range(len(qyt)): 27 | guodu = [] 28 | for item in qyt[i]: 29 | if list_name[i] in qyt[list_name.index(item)]: 30 | guodu.append([item]) 31 | else: 32 | guodu.append('') 33 | cluster_list.append([x for x in guodu if x != '']) 34 | return cluster_list 35 | 36 | def strong_related(index_1, index_2): 37 | if [list_name[index_1]] in cluster_2[index_2]: 38 | return True 39 | else: 40 | return False 41 | 42 | #聚类的基本体,cluster_num代表聚类的类别数 43 | def cluster(cluster, cluster_num): 44 | cluster_list2 = [] 45 | for i in range(len(cluster)): 46 | comb = list(itertools.combinations(range(len(cluster[i])), cluster_num-1)) 47 | cluster_list1 = [] 48 | for n in comb: 49 | guodu = [] 50 | guodu2 = [] 51 | for item in n: 52 | index = list_name.index(cluster[i][item][0]) 53 | guodu2.append(index) 54 | a = list(itertools.combinations(guodu2, 2)) 55 | for item1 in n: 56 | flag = 1 57 | for per in a: 58 | if strong_related(per[0], per[1]) == False: 59 | flag = 0 60 | if flag == 1: 61 | guodu.append(cluster[i][item1][0]) 62 | else: 63 | guodu.append('') 64 | cluster_list1.append([x for x in guodu if x != '']) 65 | cluster_list2.append([x for x in cluster_list1 if x != []]) 66 | return cluster_list2 67 | 68 | def cluster_main(input_cluster): 69 | flag = True 70 | cluster_num = 3 71 | while flag: 72 | cluster_ = cluster(input_cluster, cluster_num) 73 | for i in cluster_: 74 | if i != []: 75 | cluster_num += 1 76 | break 77 | else: 78 | flag = False 79 | return cluster(input_cluster, cluster_num-1) 80 | 81 | 82 | if __name__ == "__main__": 83 | relatives_list = utils.load_pickle('relatives_list.txt') 84 | list_name = utils.load_pickle('list_name.txt') 85 | list_qyt = duplicate_removal(relatives_list) 86 | cluster_2 = cluster_2_members(list_qyt) 87 | cluster_4 = cluster_main(cluster_2) 88 | 89 | column_1 = pd.Series(list_name, name='药物') 90 | column_2 = pd.Series(cluster_4, name='cluster') 91 | data = pd.concat([column_1, column_2], axis=1) 92 | data.to_csv('cluster2.csv', index = False, encoding = 'utf-8') 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | -------------------------------------------------------------------------------- /4.validate.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import pandas as pd 4 | import utils 5 | 6 | one_hot_df = pd.read_csv('one_hot_df.csv') 7 | data = pd.read_csv(open('classical.csv')) 8 | cluster = utils.load_pickle('cluster.txt') 9 | list_name = utils.load_pickle('list_name.txt') 10 | length = len(cluster) 11 | length_2 = one_hot_df.shape[0] 12 | 13 | item_list = [] 14 | l2 = [] 15 | 16 | index_list = [] 17 | for i in range(length): 18 | item = cluster[i] 19 | if item != []: 20 | item[0].insert(0, list_name[i]) 21 | item_list.append(sorted(item[0])) 22 | 23 | [l2.append(i) for i in item_list if i not in l2] 24 | 25 | ''' 26 | [['人参', '当归', '炙甘草', '白术'], 27 | ['仙茅', '巴戟天', '淫羊藿', '知母'], 28 | ['大怀熟地', '山茱萸', '枸杞', '菟丝子'], 29 | ['制何首乌', '牛膝', '酒浸当归', '酒浸枸杞子']] 30 | 31 | ''' 32 | for n in range(length_2): 33 | index = [] 34 | for item in l2: 35 | flag = 1 36 | for i in item: 37 | column = one_hot_df[i] 38 | if column[n] == 0: 39 | flag = 0 40 | break 41 | if flag == 1: 42 | index.append(n) 43 | index_list.append([item.copy(), index]) 44 | 45 | 46 | 47 | ''' 48 | [[['人参', '当归', '炙甘草', '白术'], [10]], 49 | [['人参', '当归', '炙甘草', '白术'], [14]], 50 | [['人参', '当归', '炙甘草', '白术'], [18]], 51 | [['人参', '当归', '炙甘草', '白术'], [24]], 52 | [['大怀熟地', '山茱萸', '枸杞', '菟丝子'], [26]], 53 | [['仙茅', '巴戟天', '淫羊藿', '知母'], [30]], 54 | [['制何首乌', '牛膝', '酒浸当归', '酒浸枸杞子'], [33]]] 55 | ''' 56 | for item in index_list: 57 | item.append(data['证型'][item[-1][0]]) 58 | #item.append(data['功用大类'][item[-1][0]]) 59 | 60 | 61 | # ============================================================================= 62 | # l2 = [['人参', '当归', '炙甘草', '白术'], 63 | # ['大怀熟地', '山茱萸', '枸杞', '菟丝子'], 64 | # ['仙茅', '巴戟天', '淫羊藿', '知母'], 65 | # ['制何首乌', '牛膝', '酒浸当归', '酒浸枸杞子']] 66 | # index_list[0][-1] = '血虚证' 67 | # set_point = 0 68 | # set_point_2 = 0 69 | # for n in l2[set_point:]: 70 | # i = 0 71 | # i_list = [] 72 | # max_count = 0 73 | # for item in index_list[set_point_2:]: 74 | # if item[0] == n: 75 | # i += 1 76 | # set_point_2 += 1 77 | # i_list.append(item[-1]) 78 | # print(item[0], n, set_point_2) 79 | # else: 80 | # break 81 | # set_point += 1 82 | # i_set = set(i_list) 83 | # for q in i_set: 84 | # if max_count < i_list.count(q): 85 | # max_count = i_list.count(q) 86 | # n.append([i, max_count]) 87 | # ============================================================================= 88 | for n in l2: 89 | i = 0 90 | i_list = [] 91 | max_count = 0 92 | for item in index_list: 93 | if item[0] == n: 94 | i += 1 95 | i_list.append(item[-1]) 96 | i_set = set(i_list) 97 | for q in i_set: 98 | if max_count < i_list.count(q): 99 | max_count = i_list.count(q) 100 | 101 | n.append([i, max_count]) 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 中医药自动组方 2 | 3 | objective:这是在爬虫基础之上,将爬取的数据进行清洗(预处理)、计算相关性、聚类和关联规则 4 | 5 | 环境:python3.5 + pandas 6 | 7 | 样本数据:classical.csv 8 | 9 | ------ 10 | 11 | ### 1. 数据预处理 12 | 13 | ``` 1.preprocess.py ``` 负责对样本数据进行处理,具体参见注释,最后输出包括药物的集合和频率等文件保存到本地 14 | 15 | ### 2. 计算相似度,构建亲友团 16 | 17 | ``` 2.relatives.py ``` 根据上述保存的药物频率,计算由复杂系统熵构造的相似度(本质为互信息),并由此构建亲友团 18 | 19 | (注:亲友团的概念为与某一药物相似度降序排位靠前的组合,例如针对药物a,降序排列相似度得出b,c,d,e,若亲友团个数设为2,则a的亲友团为b,c) 20 | 21 | ### 3. 聚类 22 | 23 | ``` 3.cluster.py ```互为亲友团的药物我们称之为强相关组合,由此构建最大的强相关组合 24 | 25 | ### 4. 计算敏感性 26 | 27 | ``` 4.validate.py ```计算敏感性作为验证模型的有效性的标准 28 | 29 | ### 其他 30 | 31 | ``` apriori.py ```和``` hier.py ```负责构建关联规则和层次聚类; 32 | ```utils.py```为工具类 33 | 34 | -------------------------------------------------------------------------------- /apriori.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | ''' 4 | Created on Mar 24, 2011 5 | Update on 2017-05-18 6 | Ch 11 code 7 | @author: Peter/片刻 8 | 《机器学习实战》更新地址:https://github.com/apachecn/MachineLearning 9 | ''' 10 | print(__doc__) 11 | from numpy import * 12 | import utils 13 | 14 | # 加载数据集 15 | def loadDataSet(): 16 | return utils.load_pickle('data.txt') 17 | 18 | # 创建集合 C1。即对 dataSet 进行去重,排序,放入 list 中,然后转换所有的元素为 frozenset 19 | def createC1(dataSet): 20 | """createC1(创建集合 C1) 21 | Args: 22 | dataSet 原始数据集 23 | Returns: 24 | frozenset 返回一个 frozenset 格式的 list 25 | """ 26 | 27 | C1 = [] 28 | for transaction in dataSet: 29 | for item in transaction: 30 | if not [item] in C1: 31 | # 遍历所有的元素,如果不在 C1 出现过,那么就 append 32 | C1.append([item]) 33 | # 对数组进行 `从小到大` 的排序 34 | # print 'sort 前=', C1 35 | C1.sort() 36 | # frozenset 表示冻结的 set 集合,元素无改变;可以把它当字典的 key 来使用 37 | # print 'sort 后=', C1 38 | # print 'frozenset=', map(frozenset, C1) 39 | return list(map(frozenset, C1)) 40 | 41 | # 计算候选数据集 CK 在数据集 D 中的支持度,并返回支持度大于最小支持度(minSupport)的数据 42 | def scanD(D, Ck, minSupport): 43 | """scanD(计算候选数据集 CK 在数据集 D 中的支持度,并返回支持度大于最小支持度 minSupport 的数据) 44 | Args: 45 | D 数据集 46 | Ck 候选项集列表 47 | minSupport 最小支持度 48 | Returns: 49 | retList 支持度大于 minSupport 的集合 50 | supportData 候选项集支持度数据 51 | """ 52 | 53 | # ssCnt 临时存放选数据集 Ck 的频率. 例如: a->10, b->5, c->8 54 | ssCnt = {} 55 | for tid in D: 56 | for can in Ck: 57 | # s.issubset(t) 测试是否 s 中的每一个元素都在 t 中 58 | if can.issubset(tid): 59 | if can not in ssCnt: 60 | ssCnt[can] = 1 61 | else: 62 | ssCnt[can] += 1 63 | numItems = float(len(D)) # 数据集 D 的数量 64 | retList = [] 65 | supportData = {} 66 | for key in ssCnt: 67 | # 支持度 = 候选项(key)出现的次数 / 所有数据集的数量 68 | support = ssCnt[key]/numItems 69 | if support >= minSupport: 70 | # 在 retList 的首位插入元素,只存储支持度满足频繁项集的值 71 | retList.insert(0, key) 72 | # 存储所有的候选项(key)和对应的支持度(support) 73 | supportData[key] = support 74 | return retList, supportData 75 | 76 | # 输入频繁项集列表 Lk 与返回的元素个数 k,然后输出所有可能的候选项集 Ck 77 | def aprioriGen(Lk, k): 78 | """aprioriGen(输入频繁项集列表 Lk 与返回的元素个数 k,然后输出候选项集 Ck。 79 | 例如: 以 {0},{1},{2} 为输入且 k = 2 则输出 {0,1}, {0,2}, {1,2}. 以 {0,1},{0,2},{1,2} 为输入且 k = 3 则输出 {0,1,2} 80 | 仅需要计算一次,不需要将所有的结果计算出来,然后进行去重操作 81 | 这是一个更高效的算法) 82 | Args: 83 | Lk 频繁项集列表 84 | k 返回的项集元素个数(若元素的前 k-2 相同,就进行合并) 85 | Returns: 86 | retList 元素两两合并的数据集 87 | """ 88 | 89 | retList = [] 90 | lenLk = len(Lk) 91 | for i in range(lenLk): 92 | for j in range(i+1, lenLk): 93 | L1 = list(Lk[i])[: k-2] 94 | L2 = list(Lk[j])[: k-2] 95 | # ============================================================================= 96 | # print('-----i=', i, k-2, Lk, Lk[i], list(Lk[i])[: k-2]) 97 | # print('-----j=', j, k-2, Lk, Lk[j], list(Lk[j])[: k-2]) 98 | # ============================================================================= 99 | L1.sort() 100 | L2.sort() 101 | # 第一次 L1,L2 为空,元素直接进行合并,返回元素两两合并的数据集 102 | # if first k-2 elements are equal 103 | #print('l1',L1,'l2',L2) 104 | if L1 == L2: 105 | # set union 106 | #print('union=', Lk[i] | Lk[j], Lk[i], Lk[j]) 107 | retList.append(Lk[i] | Lk[j]) 108 | #print(retList) 109 | return retList 110 | 111 | # 找出数据集 dataSet 中支持度 >= 最小支持度的候选项集以及它们的支持度。即我们的频繁项集。 112 | def apriori(dataSet, minSupport=0.5): 113 | """apriori(首先构建集合 C1,然后扫描数据集来判断这些只有一个元素的项集是否满足最小支持度的要求。那么满足最小支持度要求的项集构成集合 L1。然后 L1 中的元素相互组合成 C2,C2 再进一步过滤变成 L2,然后以此类推,知道 CN 的长度为 0 时结束,即可找出所有频繁项集的支持度。) 114 | Args: 115 | dataSet 原始数据集 116 | minSupport 支持度的阈值 117 | Returns: 118 | L 频繁项集的全集 119 | supportData 所有元素和支持度的全集 120 | """ 121 | # C1 即对 dataSet 进行去重,排序,放入 list 中,然后转换所有的元素为 frozenset 122 | C1 = createC1(dataSet) 123 | # print 'C1: ', C1 124 | # 对每一行进行 set 转换,然后存放到集合中 125 | D = list(map(set, dataSet)) 126 | # print 'D=', D 127 | # 计算候选数据集 C1 在数据集 D 中的支持度,并返回支持度大于 minSupport 的数据 128 | L1, supportData = scanD(D, C1, minSupport) 129 | # print "L1=", L1, "\n", "outcome: ", supportData 130 | 131 | # L 加了一层 list, L 一共 2 层 list 132 | L = [L1] 133 | k = 2 134 | # 判断 L 的第 k-2 项的数据长度是否 > 0。 135 | # 第一次执行时 L 为 [[frozenset([1]), frozenset([3]), frozenset([2]), frozenset([5])]]。L[k-2]=L[0]=[frozenset([1]), frozenset([3]), frozenset([2]), frozenset([5])],最后面 k += 1 136 | while (len(L[k-2]) > 0): 137 | #print('k=', k,'\n', L,'\n', L[k-2]) 138 | Ck = aprioriGen(L[k-2], k) # 例如: 以 {0},{1},{2} 为输入且 k = 2 则输出 {0,1}, {0,2}, {1,2}. 以 {0,1},{0,2},{1,2} 为输入且 k = 3 则输出 {0,1,2} 139 | # print 'Ck', Ck 140 | 141 | Lk, supK = scanD(D, Ck, minSupport) # 计算候选数据集 CK 在数据集 D 中的支持度,并返回支持度大于 minSupport 的数据 142 | # 保存所有候选项集的支持度,如果字典没有,就追加元素,如果有,就更新元素 143 | supportData.update(supK) 144 | if len(Lk) == 0: 145 | break 146 | # Lk 表示满足频繁子项的集合,L 元素在增加,例如: 147 | # l=[[set(1), set(2), set(3)]] 148 | # l=[[set(1), set(2), set(3)], [set(1, 2), set(2, 3)]] 149 | L.append(Lk) 150 | k += 1 151 | # print 'k=', k, len(L[k-2]) 152 | return L, supportData 153 | 154 | # 计算可信度(confidence) 155 | def calcConf(freqSet, H, supportData, brl, minConf=0.7): 156 | """calcConf(对两个元素的频繁项,计算可信度,例如: {1,2}/{1} 或者 {1,2}/{2} 看是否满足条件) 157 | Args: 158 | freqSet 频繁项集中的元素,例如: frozenset([1, 3]) 159 | H 频繁项集中的元素的集合,例如: [frozenset([1]), frozenset([3])] 160 | supportData 所有元素的支持度的字典 161 | brl 关联规则列表的空数组 162 | minConf 最小可信度 163 | Returns: 164 | prunedH 记录 可信度大于阈值的集合 165 | """ 166 | # 记录可信度大于最小可信度(minConf)的集合 167 | prunedH = [] 168 | for conseq in H: # 假设 freqSet = frozenset([1, 3]), H = [frozenset([1]), frozenset([3])],那么现在需要求出 frozenset([1]) -> frozenset([3]) 的可信度和 frozenset([3]) -> frozenset([1]) 的可信度 169 | 170 | # print 'confData=', freqSet, H, conseq, freqSet-conseq 171 | conf = supportData[freqSet]/supportData[freqSet-conseq] # 支持度定义: a -> b = support(a | b) / support(a). 假设 freqSet = frozenset([1, 3]), conseq = [frozenset([1])],那么 frozenset([1]) 至 frozenset([3]) 的可信度为 = support(a | b) / support(a) = supportData[freqSet]/supportData[freqSet-conseq] = supportData[frozenset([1, 3])] / supportData[frozenset([1])] 172 | if conf >= minConf: 173 | # 只要买了 freqSet-conseq 集合,一定会买 conseq 集合(freqSet-conseq 集合和 conseq集合 是全集) 174 | print(freqSet-conseq, '-->', conseq, 'conf:', conf) 175 | brl.append((freqSet-conseq, conseq, conf)) 176 | prunedH.append(conseq) 177 | return prunedH 178 | 179 | # 递归计算频繁项集的规则 180 | def rulesFromConseq(freqSet, H, supportData, brl, minConf=0.7): 181 | """rulesFromConseq 182 | Args: 183 | freqSet 频繁项集中的元素,例如: frozenset([2, 3, 5]) 184 | H 频繁项集中的元素的集合,例如: [frozenset([2]), frozenset([3]), frozenset([5])] 185 | supportData 所有元素的支持度的字典 186 | brl 关联规则列表的数组 187 | minConf 最小可信度 188 | """ 189 | # H[0] 是 freqSet 的元素组合的第一个元素,并且 H 中所有元素的长度都一样,长度由 aprioriGen(H, m+1) 这里的 m + 1 来控制 190 | # 该函数递归时,H[0] 的长度从 1 开始增长 1 2 3 ... 191 | # 假设 freqSet = frozenset([2, 3, 5]), H = [frozenset([2]), frozenset([3]), frozenset([5])] 192 | # 那么 m = len(H[0]) 的递归的值依次为 1 2 193 | # 在 m = 2 时, 跳出该递归。假设再递归一次,那么 H[0] = frozenset([2, 3, 5]),freqSet = frozenset([2, 3, 5]) ,没必要再计算 freqSet 与 H[0] 的关联规则了。 194 | m = len(H[0]) 195 | if (len(freqSet) > (m + 1)): 196 | # print 'freqSet******************', len(freqSet), m + 1, freqSet, H, H[0] 197 | # 生成 m+1 个长度的所有可能的 H 中的组合,假设 H = [frozenset([2]), frozenset([3]), frozenset([5])] 198 | # 第一次递归调用时生成 [frozenset([2, 3]), frozenset([2, 5]), frozenset([3, 5])] 199 | # 第二次 。。。没有第二次,递归条件判断时已经退出了 200 | Hmp1 = aprioriGen(H, m+1) 201 | # 返回可信度大于最小可信度的集合 202 | Hmp1 = calcConf(freqSet, Hmp1, supportData, brl, minConf) 203 | # ============================================================================= 204 | # print('Hmp1=', Hmp1) 205 | # print('len(Hmp1)=', len(Hmp1), 'len(freqSet)=', len(freqSet)) 206 | # ============================================================================= 207 | # 计算可信度后,还有数据大于最小可信度的话,那么继续递归调用,否则跳出递归 208 | if (len(Hmp1) > 1): 209 | # print '----------------------', Hmp1 210 | # print len(freqSet), len(Hmp1[0]) + 1 211 | rulesFromConseq(freqSet, Hmp1, supportData, brl, minConf) 212 | 213 | # 生成关联规则 214 | def generateRules(L, supportData, minConf=0.7): 215 | """generateRules 216 | Args: 217 | L 频繁项集列表 218 | supportData 频繁项集支持度的字典 219 | minConf 最小置信度 220 | Returns: 221 | bigRuleList 可信度规则列表(关于 (A->B+置信度) 3个字段的组合) 222 | """ 223 | bigRuleList = [] 224 | # 假设 L = [[frozenset([1]), frozenset([3]), frozenset([2]), frozenset([5])], [frozenset([1, 3]), frozenset([2, 5]), frozenset([2, 3]), frozenset([3, 5])], [frozenset([2, 3, 5])]] 225 | for i in range(1, len(L)): 226 | # 获取频繁项集中每个组合的所有元素 227 | for freqSet in L[i]: 228 | # 假设:freqSet= frozenset([1, 3]), H1=[frozenset([1]), frozenset([3])] 229 | # 组合总的元素并遍历子元素,并转化为 frozenset 集合,再存放到 list 列表中 230 | H1 = [frozenset([item]) for item in freqSet] 231 | # 2 个的组合,走 else, 2 个以上的组合,走 if 232 | if (i > 1): 233 | rulesFromConseq(freqSet, H1, supportData, bigRuleList, minConf) 234 | else: 235 | calcConf(freqSet, H1, supportData, bigRuleList, minConf) 236 | return bigRuleList 237 | 238 | 239 | def testApriori(): 240 | # 加载测试数据集 241 | dataSet = loadDataSet() 242 | print('dataSet: ', dataSet) 243 | 244 | # Apriori 算法生成频繁项集以及它们的支持度 245 | L1, supportData1 = apriori(dataSet, minSupport=0.7) 246 | print('L(0.7): ', L1) 247 | print('supportData(0.7): ', supportData1) 248 | 249 | print('->->->->->->->->->->->->->->->->->->->->->->->->->->->->') 250 | 251 | # Apriori 算法生成频繁项集以及它们的支持度 252 | L2, supportData2 = apriori(dataSet, minSupport=0.5) 253 | print('L(0.5): ', L2) 254 | print('supportData(0.5): ', supportData2) 255 | 256 | def testGenerateRules(): 257 | # 加载测试数据集 258 | dataSet = loadDataSet() 259 | #print('dataSet: ', dataSet) 260 | 261 | # Apriori 算法生成频繁项集以及它们的支持度 262 | L1, supportData1 = apriori(dataSet, minSupport=0.1) 263 | #print('L(0.1): ', L1) 264 | #print('supportData(0.1): ', supportData1) 265 | 266 | # 生成关联规则 267 | rules = generateRules(L1, supportData1, minConf=0.2) 268 | #print('rules: ', rules) 269 | 270 | def main(): 271 | # 测试 Apriori 算法 272 | #testApriori() 273 | 274 | # 生成关联规则 275 | testGenerateRules() 276 | 277 | # # 项目案例 278 | # # 构建美国国会投票记录的事务数据集 279 | # actionIdList, billTitleList = getActionIds() 280 | # # 测试前2个 281 | # # transDict, itemMeaning = getTransList(actionIdList[: 2], billTitleList[: 2]) 282 | # # transDict 表示 action_id的集合,transDict[key]这个就是action_id对应的选项,例如 [1, 2, 3] 283 | # transDict, itemMeaning = getTransList(actionIdList, billTitleList) 284 | # # 得到全集的数据 285 | # dataSet = [transDict[key] for key in transDict.keys()] 286 | # L, supportData = apriori(dataSet, minSupport=0.3) 287 | # rules = generateRules(L, supportData, minConf=0.95) 288 | # print rules 289 | 290 | # # 项目案例 291 | # # 发现毒蘑菇的相似特性 292 | # # 得到全集的数据 293 | # dataSet = [line.split() for line in open("input/11.Apriori/mushroom.dat").readlines()] 294 | # L, supportData = apriori(dataSet, minSupport=0.3) 295 | # # 2表示毒蘑菇,1表示可食用的蘑菇 296 | # # 找出关于2的频繁子项出来,就知道如果是毒蘑菇,那么出现频繁的也可能是毒蘑菇 297 | # for item in L[1]: 298 | # if item.intersection('2'): 299 | # print item 300 | 301 | # for item in L[2]: 302 | # if item.intersection('2'): 303 | # print item 304 | 305 | if __name__ == "__main__": 306 | main() 307 | -------------------------------------------------------------------------------- /classical.csv: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/gcaxuxi/cluster_2/72bf1c9ec63f2bdc2eaf0875ba7edf1f72bfc53e/classical.csv -------------------------------------------------------------------------------- /hier.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import math 4 | from math import log 5 | import utils 6 | 7 | #数据集 8 | 9 | 10 | list_name = utils.load_pickle('list_name.txt') 11 | list_fre = utils.load_pickle('list_fre.txt') 12 | combinations_fre = utils.load_pickle('combinations_fre.txt') 13 | combinations_list = utils.load_pickle('combinations_list.txt') 14 | data = list(range(len(list_name))) 15 | 16 | def calculate_correlation(pre, suf): 17 | flag_1, flag_2, flag_3 = 1, 1, 1 18 | if pre == suf or (pre,suf) not in combinations_list: 19 | result = 0 20 | else: 21 | i = combinations_list.index((pre, suf)) 22 | H_pre = -list_fre[pre] * log(list_fre[pre]) - (1- list_fre[pre]) * log((1- list_fre[pre])) 23 | H_suf = -list_fre[suf] * log(list_fre[suf]) - (1- list_fre[suf]) * log((1- list_fre[suf])) 24 | param_1 = list_fre[pre] - combinations_fre[i] 25 | param_2 = list_fre[suf] - combinations_fre[i] 26 | param_3 = 1 + combinations_fre[i] - list_fre[pre] -list_fre[suf] 27 | if param_1 ==0: 28 | flag_1 = 0 29 | param_1 = 1 30 | if param_2 ==0: 31 | flag_2 = 0 32 | param_2 = 1 33 | if param_3 ==0: 34 | flag_3 = 0 35 | param_3 = 1 36 | H_pre_suf = -combinations_fre[i] * log(combinations_fre[i]) - flag_1 * param_1 * log(param_1) - flag_2 * param_2 * log(param_2) - flag_3 * param_3 * log(param_3) 37 | result = (H_pre + H_suf - H_pre_suf)/((H_pre**0.5)*(H_suf**0.5)) 38 | return result 39 | 40 | #计算欧几里得距离,a,b分别为两个元组 41 | def dist(a, b): 42 | return math.sqrt(math.pow(a[0]-b[0], 2)+math.pow(a[1]-b[1], 2)) 43 | 44 | #dist_min 45 | def dist_min(Ci, Cj): 46 | return min(calculate_correlation(i, j) for i in Ci for j in Cj) 47 | #dist_max 48 | def dist_max(Ci, Cj): 49 | return max(calculate_correlation(i, j) for i in Ci for j in Cj) 50 | #dist_avg 51 | def dist_avg(Ci, Cj): 52 | return sum(calculate_correlation(i, j) for i in Ci for j in Cj)/(len(Ci)*len(Cj)) 53 | 54 | #找到距离最小的下标 55 | def find_Max(M): 56 | max = 0 57 | x = 0; y = 0 58 | for i in range(len(M)): 59 | for j in range(len(M[i])): 60 | if i != j and M[i][j] > max: 61 | max = M[i][j];x = i; y = j 62 | return (x, y, max) 63 | 64 | #算法模型: 65 | def AGNES(dataset, dist, k): 66 | #初始化C和M 67 | C = [];M = [] 68 | for i in dataset: 69 | Ci = [] 70 | Ci.append(i) 71 | C.append(Ci) 72 | for i in C: 73 | Mi = [] 74 | for j in C: 75 | Mi.append(dist(i, j)) 76 | M.append(Mi) 77 | #print('1', M) 78 | q = len(dataset) 79 | #合并更新 80 | while q > k: 81 | x, y, max = find_Max(M) 82 | C[x].extend(C[y]) 83 | C.remove(C[y]) 84 | M = [] 85 | for i in C: 86 | Mi = [] 87 | for j in C: 88 | Mi.append(dist(i, j)) 89 | M.append(Mi) 90 | q -= 1 91 | print(q) 92 | return C, M 93 | 94 | 95 | c, m = AGNES(data, dist_avg, 5) 96 | cluster = [] 97 | for item in c: 98 | clu = [] 99 | for i in item: 100 | clu.append(list_name[i]) 101 | cluster.append(clu) 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | import pickle 3 | 4 | def save_pickle(file_name, input_data): 5 | with open(file_name, 'wb') as f: 6 | pickle.dump(input_data, f) 7 | 8 | def load_pickle(file_name): 9 | with open(file_name, 'rb') as f: 10 | output_data = pickle.load(f) 11 | return output_data --------------------------------------------------------------------------------