├── .idea ├── dbn_pytorch.iml ├── encodings.xml ├── misc.xml ├── modules.xml ├── other.xml ├── vcs.xml └── workspace.xml ├── DBN.py ├── RBM.py ├── README.md ├── __pycache__ ├── DBN.cpython-36.pyc └── RBM.cpython-36.pyc ├── describledata ├── corrected ├── kddcup.data_10_percent_corrected ├── ~$说明.docx └── 说明.docx ├── exercise.ipynb ├── null ├── process.py └── smote.py /.idea/dbn_pytorch.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 12 | -------------------------------------------------------------------------------- /.idea/encodings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | 7 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/other.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 6 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /.idea/workspace.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 13 | 14 | 15 | 16 | 17 | 27 | 28 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 189 | 190 | 191 | 192 | 193 | 209 | 210 | 211 | 212 | 213 | 214 | 215 | 216 | 1554963002975 217 | 228 | 229 | 230 | 231 | 233 | 234 | 235 | 236 | 237 | 238 | 239 | 240 | 241 | 242 | 243 | 244 | 245 | 246 | 247 | 248 | 249 | 250 | 251 | 252 | 253 | 254 | 255 | 256 | 257 | 258 | 259 | 260 | 261 | 262 | 263 | 264 | 265 | 266 | 267 | 268 | 269 | 270 | 271 | 272 | 273 | 274 | 275 | 276 | 277 | 278 | 279 | 280 | 281 | 282 | 283 | 284 | 286 | 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | 297 | 298 | 299 | 300 | 301 | 302 | 303 | 304 | 305 | 306 | 307 | 308 | 309 | 310 | 311 | 312 | 313 | 314 | 315 | 316 | 317 | 318 | 319 | 320 | 321 | 322 | 323 | 324 | 325 | 326 | 327 | 328 | 329 | 330 | 331 | 332 | 333 | 334 | 335 | 336 | 337 | 338 | 339 | 340 | 341 | 342 | 343 | 344 | 345 | 346 | 347 | 348 | 349 | 350 | 351 | 352 | 353 | 354 | 355 | 356 | 357 | 358 | 359 | 360 | -------------------------------------------------------------------------------- /DBN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from RBM import RBM 6 | import time 7 | class DBN(nn.Module): 8 | 9 | def __init__(self, 10 | visible_units=41, # 可视层节点 根据你的特征维度 ,如果有4个feature那就是4 如果有26个那就是26 11 | hidden_units=[11, 6, 11], # 隐藏层节点 12 | k=5, # Gibbs采样步数 13 | learning_rate=1e-3, # 学习率 14 | momentum_coefficient=0.9, # 动量系数 15 | weight_decay=1e-4, # 权重衰减 16 | use_gpu=False, 17 | _activation='sigmoid'): 18 | super(DBN, self).__init__() 19 | device = torch.device("cuda:0" if use_gpu else "cpu") 20 | self.n_layers = len(hidden_units) # 隐含层数 21 | self.rbm_layers = [] # rbm 22 | self.rbm_nodes = [] 23 | # 构建不同的RBM层 24 | for i in range(self.n_layers): 25 | 26 | if i == 0: 27 | input_size = visible_units 28 | else: 29 | input_size = hidden_units[i - 1] 30 | rbm = RBM(visible_units=input_size, 31 | hidden_units=hidden_units[i], 32 | k=k, 33 | learning_rate=learning_rate, 34 | momentum_coefficient=momentum_coefficient, 35 | weight_decay=weight_decay, 36 | use_gpu=use_gpu, 37 | _activation=_activation).to(device) 38 | 39 | self.rbm_layers.append(rbm) 40 | self.W_rec = [nn.Parameter(self.rbm_layers[i].weight.data.clone()) for i in range(self.n_layers - 1)] 41 | self.W_gen = [nn.Parameter(self.rbm_layers[i].weight.data) for i in range(self.n_layers - 1)] 42 | self.bias_rec = [nn.Parameter(self.rbm_layers[i].c.data.clone()) for i in range(self.n_layers - 1)] 43 | self.bias_gen = [nn.Parameter(self.rbm_layers[i].b.data) for i in range(self.n_layers - 1)] 44 | self.W_mem = nn.Parameter(self.rbm_layers[-1].weight.data) 45 | self.v_bias_mem = nn.Parameter(self.rbm_layers[-1].b.data) 46 | self.h_bias_mem = nn.Parameter(self.rbm_layers[-1].c.data) 47 | for i in range(self.n_layers-1): 48 | self.register_parameter('W_rec%i'%i, self.W_rec[i]) 49 | self.register_parameter('W_gen%i'%i, self.W_gen[i]) 50 | self.register_parameter('bias_rec%i'%i, self.bias_rec[i]) 51 | self.register_parameter('bias_gen%i'%i, self.bias_gen[i]) 52 | 53 | self.BPNN=nn.Sequential( #用作分类和反向微调参数 54 | torch.nn.Linear(11, 11), 55 | torch.nn.ReLU(), 56 | torch.nn.Dropout(0.5), 57 | torch.nn.Linear(11,5), 58 | ) 59 | def forward(self , input_data): 60 | ''' 61 | 前馈 62 | ''' 63 | v = input_data 64 | for i in range(len(self.rbm_layers)): 65 | v = v.view((v.shape[0] , -1)).type(torch.FloatTensor)#flatten 66 | p_v,v = self.rbm_layers[i].forward(v) 67 | # print('p_v:', p_v.shape,p_v) 68 | # print('v:',v.shape,v) 69 | out=self.BPNN(p_v) 70 | # print('out',out.shape,out) 71 | # print(self.BPNN(p_v)) 72 | return out 73 | 74 | def train_static(self, train_data,train_labels,num_epochs,batch_size): 75 | ''' 76 | 逐层贪婪训练RBM,固定上一层 77 | ''' 78 | 79 | tmp = train_data 80 | 81 | for i in range(len(self.rbm_layers)): 82 | print("-"*20) 83 | print("Training the {} st rbm layer".format(i+1)) 84 | 85 | tensor_x = tmp.type(torch.FloatTensor) 86 | tensor_y = train_labels.type(torch.FloatTensor) 87 | _dataset = torch.utils.data.TensorDataset(tensor_x,tensor_y) 88 | _dataloader = torch.utils.data.DataLoader(_dataset) 89 | 90 | self.rbm_layers[i].trains(_dataloader,num_epochs,batch_size) 91 | print(type(_dataloader)) 92 | # print(train_data.shape) 93 | v = tmp.view((tmp.shape[0] , -1)).type(torch.FloatTensor) 94 | v,_ = self.rbm_layers[i].forward(v) 95 | tmp = v 96 | # print(v.shape) 97 | return 98 | 99 | def train_ith(self, train_data,num_epochs,batch_size,ith_layer,rbm_layers): 100 | ''' 101 | 只训练某一层,可用作调优 102 | ''' 103 | if(ith_layer>len(rbm_layers)): 104 | return 105 | 106 | v = train_data 107 | for ith in range(ith_layer): 108 | v,out_ = self.rbm_layers[ith].forward(v) 109 | 110 | 111 | self.rbm_layers[ith_layer].trains(v, num_epochs,batch_size) 112 | return 113 | 114 | def trainBP(self,trainloader): 115 | optimizer = torch.optim.SGD(self.BPNN.parameters(), lr=0.005, momentum=0.7) 116 | loss_func = torch.nn.CrossEntropyLoss() 117 | for epoch in range(5): 118 | for step,(x,y) in enumerate(trainloader): 119 | bx = Variable(x) 120 | by = Variable(y) 121 | out=self.forward(bx)[1] 122 | # print(out) 123 | loss=loss_func(out,by) 124 | optimizer.zero_grad() 125 | loss.backward() 126 | optimizer.step() 127 | if step % 10 == 0: 128 | print('Epoch: ', epoch, 'step:', step, '| train loss: %.4f' % loss.data.numpy()) 129 | 130 | 131 | 132 | 133 | -------------------------------------------------------------------------------- /RBM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torchvision 3 | 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import math 7 | import numpy as np 8 | BATCH_SIZE = 200 9 | class RBM(nn.Module): 10 | def __init__(self,visible_units=26, 11 | hidden_units = 9, 12 | k=2, 13 | learning_rate=1e-3, 14 | momentum_coefficient=0.5, 15 | weight_decay = 1e-4, 16 | use_gpu = False, 17 | _activation='sigmoid'): 18 | super(RBM,self).__init__() 19 | # 这些都是可调参数,随便设置 20 | self.visible_units = visible_units 21 | self.hidden_units = hidden_units 22 | self.k = k 23 | self.learning_rate = learning_rate 24 | self.momentum_coefficient = momentum_coefficient 25 | self.weight_decay = weight_decay 26 | self.use_gpu = use_gpu# 如果你有最好设置为有 27 | self._activation = _activation 28 | self.weight = torch.randn(self.visible_units, self.hidden_units) / math.sqrt(self.visible_units) # 初始化 29 | self.c = torch.randn(self.hidden_units) / math.sqrt(self.hidden_units) 30 | self.b = torch.randn(self.visible_units) / math.sqrt(self.visible_units) 31 | 32 | self.W_momentum = torch.zeros(self.visible_units, self.hidden_units) 33 | self.b_momentum = torch.zeros(self.visible_units) 34 | self.c_momentum = torch.zeros(self.hidden_units) 35 | # 设置激活函数 36 | def activation(self,X): 37 | if self._activation=='sigmoid': 38 | return nn.functional.sigmoid(X) 39 | elif self._activation=='tanh': 40 | return nn.functional.tanh(X) 41 | elif self._activation=='relu': 42 | return nn.functional.relu(X) 43 | else: 44 | raise ValueError("Invalid Activation Function") 45 | 46 | def to_hidden(self, X): 47 | ''' 48 | 根据可视层生成隐藏层 49 | 通过采样进行 50 | X 为可视层概率分布 51 | :param X: torch tensor shape = (n_samples , n_features) 52 | :return - hidden - 新的隐藏层 (概率) 53 | sample_h - 吉布斯样本 (1 or 0) 54 | ''' 55 | # print('hinput:',X) 56 | hidden = torch.matmul(X, self.weight) 57 | hidden = torch.add(hidden, self.c) # W.x + c 58 | # print('mm:',hidden) 59 | hidden = self.activation(hidden) 60 | 61 | sample_h = self.sampling(hidden) 62 | # print('h:',hidden,'sam_h:',sample_h) 63 | 64 | return hidden, sample_h 65 | 66 | 67 | def to_visible(self,X): 68 | ''' 69 | 根据隐藏层重构可视层 70 | 也通过采样进行 71 | X 为隐藏层概率分布 72 | :returns - X_dash - 新的重构层(概率) 73 | sample_X_dash - 新的样本(吉布斯采样) 74 | 75 | ''' 76 | # 计算隐含层激活,然后转换为概率 77 | # print('vinput:',X) 78 | X_dash = torch.matmul(X ,self.weight.transpose( 0 , 1) ) 79 | X_dash = torch.add(X_dash , self.b) #W.T*x+b 80 | # print('mm:',X_dash) 81 | X_dash = self.activation(X_dash) 82 | 83 | sample_X_dash = self.sampling(X_dash) 84 | # print('v:',X_dash, 'sam_v:', sample_X_dash) 85 | 86 | return X_dash,sample_X_dash 87 | 88 | def sampling(self,s): 89 | ''' 90 | 通过Bernoulli函数进行吉布斯采样 91 | ''' 92 | s = torch.distributions.Bernoulli(s) 93 | return s.sample() 94 | def reconstruction_error(self , data): 95 | ''' 96 | 通过损失函数计算重构误差 97 | ''' 98 | return self.contrastive_divergence(data, False) 99 | 100 | 101 | def contrastive_divergence(self, input_data ,training = True): 102 | ''' 103 | 对比散列算法 104 | ''' 105 | # positive phase 106 | positive_hidden_probabilities,positive_hidden_act = self.to_hidden(input_data) 107 | 108 | # 计算 W 109 | positive_associations = torch.matmul(input_data.t() , positive_hidden_act) 110 | 111 | 112 | 113 | # negetive phase 114 | hidden_activations = positive_hidden_act 115 | for i in range(self.k): #采样步数 116 | visible_p , _ = self.to_visible(hidden_activations) 117 | hidden_probabilities,hidden_activations = self.to_hidden(visible_p) 118 | 119 | negative_visible_probabilities = visible_p 120 | negative_hidden_probabilities = hidden_probabilities 121 | 122 | # 计算 W 123 | negative_associations = torch.matmul(negative_visible_probabilities.t() , negative_hidden_probabilities) 124 | 125 | 126 | # 更新参数 127 | if(training): 128 | self.W_momentum *= self.momentum_coefficient 129 | self.W_momentum += (positive_associations - negative_associations) 130 | 131 | self.b_momentum *= self.momentum_coefficient 132 | self.b_momentum += torch.sum(input_data - negative_visible_probabilities, dim=0) 133 | 134 | self.c_momentum *= self.momentum_coefficient 135 | self.c_momentum += torch.sum(positive_hidden_probabilities - negative_hidden_probabilities, dim=0) 136 | 137 | batch_size = input_data.size(0) 138 | 139 | self.weight += self.W_momentum * self.learning_rate / BATCH_SIZE 140 | self.b += self.b_momentum * self.learning_rate / BATCH_SIZE 141 | self.c += self.c_momentum * self.learning_rate / BATCH_SIZE 142 | 143 | self.weight -= self.weight * self.weight_decay # L2 weight decay 144 | 145 | # 计算重构误差 146 | error = torch.mean(torch.sum((input_data - negative_visible_probabilities)**2 , 1)) 147 | # print('i:',input_data,'o:',negative_hidden_probabilities) 148 | 149 | return error 150 | def forward(self,input_data): 151 | return self.to_hidden(input_data) 152 | 153 | def step(self,input_data): 154 | ''' 155 | 包括前馈和梯度下降,用作训练 156 | ''' 157 | # print('w:',self.weight);print('b:',self.b);print('c:',self.c) 158 | return self.contrastive_divergence(input_data , True) 159 | 160 | 161 | def trains(self,train_data,num_epochs = 50,batch_size= 20): 162 | 163 | BATCH_SIZE = batch_size 164 | 165 | if(isinstance(train_data ,torch.utils.data.DataLoader)): 166 | train_loader = train_data 167 | else: 168 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size) 169 | 170 | 171 | for epochs in range(num_epochs): 172 | epoch_err = 0.0 173 | 174 | for batch,_ in train_loader: 175 | # batch = batch.view(len(batch) , self.visible_units) 176 | 177 | if(self.use_gpu): 178 | batch = batch.cuda() 179 | batch_err = self.step(batch) 180 | 181 | epoch_err += batch_err 182 | 183 | 184 | print("Epoch Error(epoch:%d) : %.4f" % (epochs , epoch_err)) 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # dbn_pytorch 2 | 用Pytorch 实现dbn 如果有别的问题,请联系我 3 | 4 | 要求的包 pytorch, numpy ,tqdm 5 | 6 | 7 | 关于rbm的原理可以看[这里](https://blog.csdn.net/itplus/article/details/19168937) 8 | 9 | dbn只不过是将rbm进行固定住而已,然后推往下一层 10 | 其中 有几点比较重要的他没有仔细提到的就是关于MCMC中的Metropolis–Hastings算法与吉布斯采样 11 | 12 | 这几点着实也困扰了我好长时间 可以通过下面俩篇博客进行学习 13 | 14 | 15 | [白马博客](https://blog.csdn.net/baimafujinji/article/details/53946367) 16 | 17 | [刘建平博客](https://www.cnblogs.com/pinard/p/6638955.html) 18 | 19 | 关于数据集的描述[查看](https://blog.csdn.net/com_stu_zhang/article/details/6987632) 20 | 21 | 22 | 关于代码 的运行步骤请看exercise.ipynb,里面有详细的代码 23 | 另外data.npz.npy 、adddata、testdata.npz.npy 分别是经过预处理的训练数据 24 | 加强后smote生成了5000个点的数据 testdata.npz.npy 是测试数据 25 | 26 | 数据增强 -> smote..py 27 | 28 | 29 | 在这里我把 最外层的softmax层设置成了5个单元,当进行2分类的时候请修改 30 | 31 | ## Tips 32 | 33 | 另外: 虽然用了gibbs采样代码还是非常慢,请合理设置epoch! 34 | 35 | 刚开始我没用gibbs而是直接进行训练,跑了我整整一天,sad!!!! 36 | 37 | 后来我又用了svd,train了一发,发现同样不行 sad*2!!!!! 38 | 39 | 效果烂的快哭了,请不要和我一样心态失衡! 40 | 41 | 对了,如果exercise.ipynb 你对dbn的参数进行修改,然后运行ipynb的时候,记得重启下。 42 | 43 | 我也不知道为什么ipynb对缓存没有消除,就酱! 44 | 45 | -------------------------------------------------------------------------------- /__pycache__/DBN.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/__pycache__/DBN.cpython-36.pyc -------------------------------------------------------------------------------- /__pycache__/RBM.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/__pycache__/RBM.cpython-36.pyc -------------------------------------------------------------------------------- /describledata/~$说明.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/describledata/~$说明.docx -------------------------------------------------------------------------------- /describledata/说明.docx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/describledata/说明.docx -------------------------------------------------------------------------------- /exercise.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "collapsed": true 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "import pandas as pd\n", 13 | "import codecs\n", 14 | "import matplotlib\n", 15 | "import matplotlib.pyplot as plt\n", 16 | "import random\n", 17 | "import random\n", 18 | "import codecs\n", 19 | "# this input_path is just example\n", 20 | "#inputpath = \"C:/Users/Administrator/Desktop/Smote/describledata/kddcup.data_10_percent_corrected\"\n", 21 | "Nor_list = ['normal']\n", 22 | "Dos_list = ['back', 'land', 'neptune' , 'pod', 'smurf', 'teardrop', 'mailbomb', 'processtable', 'udpstorm', 'apache2', 'worm']\n", 23 | "R2L_list =['guess_passwd', 'ftp_write', 'imap', 'phf', 'multihop', 'warezmaster', 'warezclient', 'xlock','xsnoop', 'snmpguess', 'snmpgetattack', 'httptunnel','spy','named','snmpguess']\n", 24 | "Probe_list = ['satan','ipsweep', 'nmap', 'portsweep', 'mscan', 'saint']\n", 25 | "U2R_list =['buffer_overflow', 'loadmodule', 'rootkit', 'perl', 'sqlattack', 'xterm', 'ps','httptunnel']\n", 26 | "continue_col_index = [0,4,5,7,8,9,10,12,15,16,17,18,19,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]\n", 27 | "decreate_col_index =[1,2,3,6,11,13,14,20,21]" 28 | ] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": {}, 34 | "outputs": [], 35 | "source": [ 36 | "def gen_lable1(seg):\n", 37 | " if seg in Nor_list:\n", 38 | " return 0\n", 39 | " elif seg in Dos_list:\n", 40 | " return 1\n", 41 | " elif seg in R2L_list:\n", 42 | " return 2\n", 43 | " elif seg in Probe_list:\n", 44 | " return 3\n", 45 | " elif seg in U2R_list: \n", 46 | " return 4\n", 47 | " \n", 48 | " \n", 49 | "def gen_label2(seg):\n", 50 | " if seg in Nor_list:\n", 51 | " return 0\n", 52 | " else:\n", 53 | " return 1\n", 54 | "def replace(columnlist,data):\n", 55 | " #离散值代替\n", 56 | " #listt = []\n", 57 | " for i in columnlist:\n", 58 | " listt = []\n", 59 | " for line,seg in enumerate(data[:,i]):\n", 60 | " if seg in listt:\n", 61 | " data[line][i] = listt.index(seg)\n", 62 | " else:\n", 63 | " listt.append(seg)\n", 64 | " data[line][i] = listt.index(seg)\n", 65 | "\n", 66 | "def processdata(file_path):\n", 67 | " tmp = None\n", 68 | " with codecs.open(file_path,'r') as f:\n", 69 | " content = f.readlines()\n", 70 | " datas =[]\n", 71 | " for line in content:\n", 72 | " line = line.strip()\n", 73 | " line = line[:-1].split(',')\n", 74 | " datas.append(line)\n", 75 | " datas = np.array(datas)\n", 76 | " replace(decreate_col_index,datas)\n", 77 | " new_datas =[]\n", 78 | " for index,col in enumerate(datas):\n", 79 | " new_datas.append([float(k) for k in col[:-1]])\n", 80 | " if random.random()<0.00005:\n", 81 | " print(col[-1])\n", 82 | " new_datas[index].append(gen_lable1(col[-1]))\n", 83 | " new_datas[index].append(gen_label2(col[-1]))\n", 84 | " tmp = new_datas\n", 85 | " datas = np.array(tmp).astype('float32')\n", 86 | " \n", 87 | " for j in continue_col_index:\n", 88 | " meanVal=np.mean(datas[:,j])\n", 89 | " stdVal=np.std(datas[:,j])\n", 90 | " datas[:,j]=(datas[:,j]-meanVal)/stdVal\n", 91 | " return datas\n" 92 | ] 93 | }, 94 | { 95 | "cell_type": "code", 96 | "execution_count": 3, 97 | "metadata": { 98 | "collapsed": true 99 | }, 100 | "outputs": [ 101 | { 102 | "name": "stdout", 103 | "output_type": "stream", 104 | "text": [ 105 | "normal\n" 106 | ] 107 | }, 108 | { 109 | "name": "stdout", 110 | "output_type": "stream", 111 | "text": [ 112 | "normal\n" 113 | ] 114 | }, 115 | { 116 | "name": "stdout", 117 | "output_type": "stream", 118 | "text": [ 119 | "normal\nsmurf\n" 120 | ] 121 | }, 122 | { 123 | "name": "stdout", 124 | "output_type": "stream", 125 | "text": [ 126 | "neptune\n" 127 | ] 128 | }, 129 | { 130 | "name": "stdout", 131 | "output_type": "stream", 132 | "text": [ 133 | "neptune\n" 134 | ] 135 | }, 136 | { 137 | "name": "stdout", 138 | "output_type": "stream", 139 | "text": [ 140 | "smurf\nsmurf\nsmurf\n" 141 | ] 142 | }, 143 | { 144 | "name": "stdout", 145 | "output_type": "stream", 146 | "text": [ 147 | "smurf\n" 148 | ] 149 | }, 150 | { 151 | "name": "stdout", 152 | "output_type": "stream", 153 | "text": [ 154 | "smurf\n" 155 | ] 156 | }, 157 | { 158 | "name": "stdout", 159 | "output_type": "stream", 160 | "text": [ 161 | "smurf\nsmurf\n" 162 | ] 163 | }, 164 | { 165 | "name": "stdout", 166 | "output_type": "stream", 167 | "text": [ 168 | "smurf\n" 169 | ] 170 | }, 171 | { 172 | "name": "stdout", 173 | "output_type": "stream", 174 | "text": [ 175 | "normal\n" 176 | ] 177 | }, 178 | { 179 | "name": "stdout", 180 | "output_type": "stream", 181 | "text": [ 182 | "neptune\n" 183 | ] 184 | }, 185 | { 186 | "name": "stdout", 187 | "output_type": "stream", 188 | "text": [ 189 | "neptune\n" 190 | ] 191 | }, 192 | { 193 | "name": "stdout", 194 | "output_type": "stream", 195 | "text": [ 196 | "smurf\n" 197 | ] 198 | }, 199 | { 200 | "name": "stdout", 201 | "output_type": "stream", 202 | "text": [ 203 | "smurf\nsmurf\nnormal\n" 204 | ] 205 | }, 206 | { 207 | "name": "stderr", 208 | "output_type": "stream", 209 | "text": [ 210 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\ipykernel_launcher.py:55: RuntimeWarning: invalid value encountered in true_divide\n" 211 | ] 212 | } 213 | ], 214 | "source": [ 215 | "data = processdata('./describledata/kddcup.data_10_percent_corrected')" 216 | ] 217 | }, 218 | { 219 | "cell_type": "code", 220 | "execution_count": 4, 221 | "metadata": {}, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "smurf\nsmurf\n" 228 | ] 229 | }, 230 | { 231 | "name": "stdout", 232 | "output_type": "stream", 233 | "text": [ 234 | "normal\n" 235 | ] 236 | }, 237 | { 238 | "name": "stdout", 239 | "output_type": "stream", 240 | "text": [ 241 | "normal\n" 242 | ] 243 | }, 244 | { 245 | "name": "stdout", 246 | "output_type": "stream", 247 | "text": [ 248 | "neptune\nneptune\n" 249 | ] 250 | }, 251 | { 252 | "name": "stdout", 253 | "output_type": "stream", 254 | "text": [ 255 | "smurf\n" 256 | ] 257 | }, 258 | { 259 | "name": "stdout", 260 | "output_type": "stream", 261 | "text": [ 262 | "smurf\n" 263 | ] 264 | }, 265 | { 266 | "name": "stdout", 267 | "output_type": "stream", 268 | "text": [ 269 | "smurf\n" 270 | ] 271 | }, 272 | { 273 | "name": "stdout", 274 | "output_type": "stream", 275 | "text": [ 276 | "neptune\nneptune\n" 277 | ] 278 | }, 279 | { 280 | "name": "stderr", 281 | "output_type": "stream", 282 | "text": [ 283 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\ipykernel_launcher.py:55: RuntimeWarning: invalid value encountered in true_divide\n" 284 | ] 285 | } 286 | ], 287 | "source": [ 288 | "#\n", 289 | "label_dict ={}\n", 290 | "for i in data[:,-2]:\n", 291 | " if i in label_dict.keys():\n", 292 | " continue\n", 293 | " else:\n", 294 | " label_dict[i] = len(label_dict)\n", 295 | "test_data = processdata('./describledata/corrected')" 296 | ] 297 | }, 298 | { 299 | "cell_type": "code", 300 | "execution_count": 5, 301 | "metadata": {}, 302 | "outputs": [], 303 | "source": [ 304 | "def fill_nan(data):\n", 305 | " #m,n = data.shape\n", 306 | " choice_list =[]\n", 307 | " matrix = np.isnan(data)\n", 308 | " for m in range(len(data)):\n", 309 | " for n in range(len(data[0])):\n", 310 | " if matrix[m,n] ==True:\n", 311 | " data[m][n] = 0" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 6, 317 | "metadata": { 318 | "collapsed": true 319 | }, 320 | "outputs": [], 321 | "source": [ 322 | "fill_nan(data)\n", 323 | "fill_nan(test_data)" 324 | ] 325 | }, 326 | { 327 | "cell_type": "code", 328 | "execution_count": 10, 329 | "metadata": {}, 330 | "outputs": [], 331 | "source": [ 332 | "import pandas as pd\n", 333 | "import numpy as np\n", 334 | "import random\n", 335 | "from sklearn.neighbors import NearestNeighbors\n", 336 | "import math\n", 337 | "from random import randint\n", 338 | "import matplotlib.pyplot as plt\n", 339 | "from sklearn.decomposition import TruncatedSVD\n", 340 | "class Smote():\n", 341 | " def __init__(self,distance,range1,range2):\n", 342 | " self.synthetic_arr = []\n", 343 | " self.newindex = 0\n", 344 | " self.distance_measure = distance\n", 345 | " self.range1 =range1\n", 346 | " self.range2 = range2\n", 347 | "\n", 348 | " def Populate(self, N, i, indices, min_samples, k):\n", 349 | " \"\"\"\n", 350 | " 此代码主要作用是生成增强数组\n", 351 | "\n", 352 | " Returns:返回增强后的数组\n", 353 | " \"\"\"\n", 354 | "\n", 355 | " choice_list =[]\n", 356 | " def choice(data):\n", 357 | " p = []\n", 358 | " wc = {}\n", 359 | " for num in data:\n", 360 | " # print(num)\n", 361 | " wc[num] = wc.setdefault(num, 0) + 1\n", 362 | " for key in wc.keys():\n", 363 | " p.append(wc[key] / len(data))\n", 364 | " # print(p)\n", 365 | " keylist = np.array([key for key in wc.keys()])\n", 366 | " # print(wc[0])\n", 367 | " return keylist,p\n", 368 | " for index in self.range1:\n", 369 | " choice_list.append(choice(min_samples[:,index]))\n", 370 | "\n", 371 | " while N != 0:\n", 372 | " arr = np.zeros(len(min_samples[0]))\n", 373 | " arr[-2] = min_samples[i][-2]\n", 374 | " arr[-1] = min_samples[i][-1]\n", 375 | " nn = randint(0, k - 2)\n", 376 | " # 统计离散型变量\n", 377 | " for rowindex,index in enumerate(self.range1):\n", 378 | " arr[index] = np.random.choice(choice_list[rowindex][0],size=1,p=choice_list[rowindex][1])\n", 379 | " #for attr in features2:\n", 380 | " for attr in self.range2:\n", 381 | " min_samples[i][attr] = float(min_samples[i][attr])\n", 382 | " min_samples[indices[nn]][attr] = float(min_samples[indices[nn]][attr])\n", 383 | " try:\n", 384 | " diff = float(min_samples[indices[nn]][attr]) - float(min_samples[i][attr])\n", 385 | " except:\n", 386 | " print('这是第%d列'%attr,min_samples[indices[nn]][attr],min_samples[i][attr])\n", 387 | " gap = random.uniform(0, 1)\n", 388 | "\n", 389 | " arr[attr] = float(min_samples[i][attr]) + gap * diff\n", 390 | " #print(arr)\n", 391 | " self.synthetic_arr.append(arr)\n", 392 | " self.newindex = self.newindex + 1\n", 393 | " N = N - 1\n", 394 | "\n", 395 | " def k_neighbors(self, euclid_distance, k):\n", 396 | " nearest_idx_npy = np.empty([euclid_distance.shape[0], euclid_distance.shape[0]], dtype=np.int64)\n", 397 | "\n", 398 | " for i in range(len(euclid_distance)):\n", 399 | " idx = np.argsort(euclid_distance[i])\n", 400 | " nearest_idx_npy[i] = idx\n", 401 | " idx = 0\n", 402 | "\n", 403 | " return nearest_idx_npy[:, 1:k]\n", 404 | "\n", 405 | " def find_k(self, X, k):\n", 406 | "\n", 407 | " \"\"\"\n", 408 | " Finds k nearest neighbors using euclidian distance\n", 409 | "\n", 410 | " Returns: The k nearest neighbor\n", 411 | " \"\"\"\n", 412 | "\n", 413 | " euclid_distance = np.empty([X.shape[0], X.shape[0]], dtype=np.float32)\n", 414 | "\n", 415 | " for i in range(len(X)):\n", 416 | " dist_arr = []\n", 417 | " for j in range(len(X)):\n", 418 | " dist_arr.append(math.sqrt(sum((X[j] - X[i]) ** 2)))\n", 419 | " dist_arr = np.asarray(dist_arr, dtype=np.float32)\n", 420 | " euclid_distance[i] = dist_arr\n", 421 | "\n", 422 | " return self.k_neighbors(euclid_distance, k)\n", 423 | "\n", 424 | " def generate_synthetic_points(self, min_samples, N, k):\n", 425 | "\n", 426 | " \"\"\"\n", 427 | "\n", 428 | " Parameters\n", 429 | " ----------\n", 430 | " min_samples : 要增强的数据\n", 431 | " N :要额外生成的负样本的数目\n", 432 | " k : int. Number of nearest neighbours.\n", 433 | " Returns\n", 434 | " -------\n", 435 | " S : Synthetic samples. array,\n", 436 | " shape = [(N/100) * n_minority_samples, n_features].\n", 437 | " \"\"\"\n", 438 | "\n", 439 | " if N < 1:\n", 440 | " raise ValueError(\"Value of N cannot be less than 100%\")\n", 441 | "\n", 442 | " if self.distance_measure not in ('euclidian', 'ball_tree'):\n", 443 | " raise ValueError(\"Invalid Distance Measure.You can use only Euclidian or ball_tree\")\n", 444 | "\n", 445 | " if k > min_samples.shape[0]:\n", 446 | " raise ValueError(\"Size of k cannot exceed the number of samples.\")\n", 447 | "\n", 448 | " T = min_samples.shape[0]\n", 449 | "\n", 450 | " if self.distance_measure == 'euclidian':\n", 451 | " indices = self.find_k(min_samples, k)\n", 452 | "\n", 453 | " elif self.distance_measure == 'ball_tree':\n", 454 | " nb = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(min_samples)\n", 455 | " distance, indices = nb.kneighbors(min_samples)\n", 456 | " indices = indices[:, 1:]\n", 457 | "\n", 458 | " for i in range(indices.shape[0]):\n", 459 | " self.Populate(N, i, indices[i], min_samples, k)\n", 460 | "\n", 461 | " return np.asarray(self.synthetic_arr)" 462 | ] 463 | }, 464 | { 465 | "cell_type": "code", 466 | "execution_count": 14, 467 | "metadata": { 468 | "collapsed": true 469 | }, 470 | "outputs": [], 471 | "source": [ 472 | "rowindex = []\n", 473 | "for i, line in enumerate(data):\n", 474 | " if line[-2] ==4.0:\n", 475 | " rowindex.append(i)\n", 476 | "range1 = [1,2,3,6,11,13,14,20,21]\n", 477 | "range2 = [0,4,5,7,8,9,10,12,15,16,17,18,19,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]\n", 478 | "minsamples = data[rowindex]\n", 479 | "smote = Smote(distance='ball_tree',range1=range1,range2=range2)\n", 480 | "smotedata = smote.generate_synthetic_points(min_samples=minsamples,N=100,k=9)\n" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": 15, 486 | "metadata": { 487 | "collapsed": true 488 | }, 489 | "outputs": [ 490 | { 491 | "data": { 492 | "text/plain": [ 493 | "(5200, 43)" 494 | ] 495 | }, 496 | "execution_count": 15, 497 | "metadata": {}, 498 | "output_type": "execute_result" 499 | } 500 | ], 501 | "source": [ 502 | "smotedata.shape" 503 | ] 504 | }, 505 | { 506 | "cell_type": "code", 507 | "execution_count": 16, 508 | "metadata": {}, 509 | "outputs": [], 510 | "source": [ 511 | "adddata = np.concatenate((smotedata,data),axis=0)" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 17, 517 | "metadata": {}, 518 | "outputs": [ 519 | { 520 | "data": { 521 | "text/plain": [ 522 | "(499221, 43)" 523 | ] 524 | }, 525 | "execution_count": 17, 526 | "metadata": {}, 527 | "output_type": "execute_result" 528 | } 529 | ], 530 | "source": [ 531 | "adddata.shape" 532 | ] 533 | }, 534 | { 535 | "cell_type": "code", 536 | "execution_count": 18, 537 | "metadata": {}, 538 | "outputs": [], 539 | "source": [ 540 | "fill_nan(test_data)" 541 | ] 542 | }, 543 | { 544 | "cell_type": "code", 545 | "execution_count": 2, 546 | "metadata": {}, 547 | "outputs": [], 548 | "source": [ 549 | "# 保存数据\n", 550 | "\"\"\"\n", 551 | "np.save('data.npz',data)\n", 552 | "np.save('adddata.npz',adddata)\n", 553 | "np.save('testdata.npz',test_data)\n", 554 | "\"\"\"\n", 555 | "import numpy as np\n", 556 | "import pandas as pd\n", 557 | "import codecs\n", 558 | "import matplotlib\n", 559 | "import matplotlib.pyplot as plt\n", 560 | "import random\n", 561 | "import random\n", 562 | "import codecs\n", 563 | "data = np.load('data.npz.npy')\n", 564 | "adddata =np.load('adddata.npz.npy')\n", 565 | "test_data =np.load('testdata.npz.npy')\n" 566 | ] 567 | }, 568 | { 569 | "cell_type": "code", 570 | "execution_count": 3, 571 | "metadata": {}, 572 | "outputs": [], 573 | "source": [ 574 | "def gen_train_test(dat): #分割训练集和测试集\n", 575 | " index = int(len(dat)*0.8)\n", 576 | " np.random.shuffle(dat)\n", 577 | " traindat=torch.from_numpy(dat[:index,:-2]).float()\n", 578 | " trainlabel=torch.from_numpy(dat[:index,-2]).long()\n", 579 | " validdat = torch.from_numpy(dat[index:,:-2]).float()\n", 580 | " validlabel = torch.from_numpy(dat[index:,-2]).long()\n", 581 | " # print(dat[0])\n", 582 | " return traindat,trainlabel,validdat,validlabel" 583 | ] 584 | }, 585 | { 586 | "cell_type": "code", 587 | "execution_count": 4, 588 | "metadata": {}, 589 | "outputs": [], 590 | "source": [ 591 | "import tqdm\n", 592 | "import torch\n", 593 | "from DBN import DBN\n", 594 | "import torch.utils.data as Data\n", 595 | "def train_batch(traind,trainl,SIZE=500,SHUFFLE=True): #分批处理\n", 596 | " trainset=Data.TensorDataset(traind,trainl)\n", 597 | " trainloader=Data.DataLoader(\n", 598 | " dataset=trainset,\n", 599 | " batch_size=SIZE,\n", 600 | " shuffle=SHUFFLE)\n", 601 | " return trainloader\n" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": 5, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [ 610 | "traindat,trainlabel, validdat, validlabel = gen_train_test(data)" 611 | ] 612 | }, 613 | { 614 | "cell_type": "code", 615 | "execution_count": 6, 616 | "metadata": {}, 617 | "outputs": [], 618 | "source": [ 619 | "trainloader = train_batch(traindat,trainlabel)\n", 620 | "from torch.autograd.variable import Variable" 621 | ] 622 | }, 623 | { 624 | "cell_type": "code", 625 | "execution_count": 7, 626 | "metadata": {}, 627 | "outputs": [], 628 | "source": [ 629 | "device = torch.device('cuda:0')\n" 630 | ] 631 | }, 632 | { 633 | "cell_type": "code", 634 | "execution_count": 9, 635 | "metadata": { 636 | "collapsed": true 637 | }, 638 | "outputs": [ 639 | { 640 | "name": "stdout", 641 | "output_type": "stream", 642 | "text": [ 643 | "--------------------" 644 | ] 645 | }, 646 | { 647 | "name": "stdout", 648 | "output_type": "stream", 649 | "text": [ 650 | "\n" 651 | ] 652 | }, 653 | { 654 | "name": "stdout", 655 | "output_type": "stream", 656 | "text": [ 657 | "Training the 1 st rbm layer" 658 | ] 659 | }, 660 | { 661 | "name": "stdout", 662 | "output_type": "stream", 663 | "text": [ 664 | "\n" 665 | ] 666 | }, 667 | { 668 | "name": "stderr", 669 | "output_type": "stream", 670 | "text": [ 671 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\torch\\nn\\functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n" 672 | ] 673 | }, 674 | { 675 | "name": "stdout", 676 | "output_type": "stream", 677 | "text": [ 678 | "Epoch Error(epoch:0) : 45406564.0000" 679 | ] 680 | }, 681 | { 682 | "name": "stdout", 683 | "output_type": "stream", 684 | "text": [ 685 | "\n" 686 | ] 687 | } 688 | ], 689 | "source": [ 690 | "import time\n", 691 | "\n", 692 | "start_time = time.time()\n", 693 | "dbn=DBN(visible_units=len(traindat[0]))\n", 694 | "dbn.to(device)\n", 695 | "dbn.train()\n", 696 | "dbn.train_static(train_data=traindat,train_labels=trainlabel,num_epochs=3,batch_size=20)\n", 697 | "\n", 698 | "optimizer = torch.optim.SGD(dbn.parameters(), lr=0.001, momentum=0.9)\n", 699 | "loss_func = torch.nn.CrossEntropyLoss()\n", 700 | "\n", 701 | "dbn.trainBP(trainloader)\n", 702 | "\n", 703 | "for epoch in range(5):\n", 704 | " for step,(x,y) in tqdm(enumerate(trainloader)):\n", 705 | " # print(x.data.numpy(),y.data.numpy())\n", 706 | "\n", 707 | " b_x=Variable(x)\n", 708 | " b_y=Variable(y)\n", 709 | "\n", 710 | " output=dbn(b_x)\n", 711 | " # print(output)\n", 712 | " # print(prediction);print(output);print(b_y)\n", 713 | "\n", 714 | " loss=loss_func(output,b_y)\n", 715 | " optimizer.zero_grad()\n", 716 | " loss.backward()\n", 717 | " optimizer.step()\n", 718 | "\n", 719 | " if step%10==0:\n", 720 | " print('Epoch: ', epoch, 'step:',step,'| train loss: %.4f' % loss.data.numpy())\n", 721 | "duration=time.time()-start_time\n", 722 | "\n", 723 | "dbn.eval()\n", 724 | "test_x = Variable(validdat);test_y = Variable(validlabel)\n", 725 | "test_out = dbn(test_x)\n", 726 | "# print(test_out)\n", 727 | "test_pred = torch.max(test_out, 1)[1]\n", 728 | "pre_val = test_pred.data.squeeze().numpy()\n", 729 | "y_val = test_y.data.squeeze().numpy()\n", 730 | "print('prediciton:',pre_val);print('true value:',y_val)\n", 731 | "accuracy = float((pre_val == y_val).astype(int).sum()) / float(test_y.size(0))\n", 732 | "print('test accuracy: %.2f' % accuracy,'duration:%.4f' % duration)\n" 733 | ] 734 | }, 735 | { 736 | "cell_type": "code", 737 | "execution_count": 9, 738 | "metadata": { 739 | "collapsed": true 740 | }, 741 | "outputs": [ 742 | { 743 | "name": "stdout", 744 | "output_type": "stream", 745 | "text": [ 746 | "" 747 | ] 748 | }, 749 | { 750 | "name": "stdout", 751 | "output_type": "stream", 752 | "text": [ 753 | " " 754 | ] 755 | }, 756 | { 757 | "name": "stdout", 758 | "output_type": "stream", 759 | "text": [ 760 | "" 761 | ] 762 | }, 763 | { 764 | "name": "stdout", 765 | "output_type": "stream", 766 | "text": [ 767 | " " 768 | ] 769 | }, 770 | { 771 | "name": "stdout", 772 | "output_type": "stream", 773 | "text": [ 774 | "" 775 | ] 776 | }, 777 | { 778 | "name": "stdout", 779 | "output_type": "stream", 780 | "text": [ 781 | " " 782 | ] 783 | }, 784 | { 785 | "name": "stdout", 786 | "output_type": "stream", 787 | "text": [ 788 | "" 789 | ] 790 | }, 791 | { 792 | "name": "stdout", 793 | "output_type": "stream", 794 | "text": [ 795 | " " 796 | ] 797 | }, 798 | { 799 | "name": "stdout", 800 | "output_type": "stream", 801 | "text": [ 802 | "" 803 | ] 804 | }, 805 | { 806 | "name": "stdout", 807 | "output_type": "stream", 808 | "text": [ 809 | "\n" 810 | ] 811 | }, 812 | { 813 | "name": "stdout", 814 | "output_type": "stream", 815 | "text": [ 816 | "--------------------" 817 | ] 818 | }, 819 | { 820 | "name": "stdout", 821 | "output_type": "stream", 822 | "text": [ 823 | "\n" 824 | ] 825 | }, 826 | { 827 | "name": "stdout", 828 | "output_type": "stream", 829 | "text": [ 830 | "Training the 1 st rbm layer" 831 | ] 832 | }, 833 | { 834 | "name": "stdout", 835 | "output_type": "stream", 836 | "text": [ 837 | "\n" 838 | ] 839 | }, 840 | { 841 | "name": "stderr", 842 | "output_type": "stream", 843 | "text": [ 844 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\torch\\nn\\functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n" 845 | ] 846 | }, 847 | { 848 | "ename": "KeyboardInterrupt", 849 | "evalue": "", 850 | "traceback": [ 851 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", 852 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", 853 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_and_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraindat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrainlabel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvaliddat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidlabel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", 854 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_and_test\u001b[0;34m(traind, trainl, testdat, testlabel, loader)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_static\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtraind\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrainl\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.9\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 855 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\DBN.py\u001b[0m in \u001b[0;36mtrain_static\u001b[0;34m(self, train_data, train_labels, num_epochs, batch_size)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0m_dataloader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrbm_layers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrains\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_dataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# print(train_data.shape)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 856 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mtrains\u001b[0;34m(self, train_data, num_epochs, batch_size)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_gpu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mbatch_err\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mepoch_err\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbatch_err\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 857 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, input_data)\u001b[0m\n\u001b[1;32m 156\u001b[0m '''\n\u001b[1;32m 157\u001b[0m \u001b[0;31m# print('w:',self.weight);print('b:',self.b);print('c:',self.c)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrastive_divergence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 858 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mcontrastive_divergence\u001b[0;34m(self, input_data, training)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0mhidden_activations\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpositive_hidden_act\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m#采样步数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0mvisible_p\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_visible\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_activations\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0mhidden_probabilities\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mhidden_activations\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvisible_p\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", 859 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mto_visible\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;31m# 计算隐含层激活,然后转换为概率\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;31m# print('vinput:',X)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mX_dash\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m \u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0mX_dash\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_dash\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#W.T*x+b\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;31m# print('mm:',X_dash)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", 860 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: " 861 | ], 862 | "output_type": "error" 863 | } 864 | ], 865 | "source": [ 866 | "" 867 | ] 868 | }, 869 | { 870 | "cell_type": "code", 871 | "execution_count": 14, 872 | "metadata": {}, 873 | "outputs": [], 874 | "source": [] 875 | } 876 | ], 877 | "metadata": { 878 | "kernelspec": { 879 | "display_name": "Python 2", 880 | "language": "python", 881 | "name": "python2" 882 | }, 883 | "language_info": { 884 | "codemirror_mode": { 885 | "name": "ipython", 886 | "version": 2 887 | }, 888 | "file_extension": ".py", 889 | "mimetype": "text/x-python", 890 | "name": "python", 891 | "nbconvert_exporter": "python", 892 | "pygments_lexer": "ipython2", 893 | "version": "2.7.6" 894 | } 895 | }, 896 | "nbformat": 4, 897 | "nbformat_minor": 0 898 | } 899 | -------------------------------------------------------------------------------- /null: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/null -------------------------------------------------------------------------------- /process.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import codecs 4 | import matplotlib 5 | import matplotlib.pyplot as plt 6 | import random 7 | import codecs 8 | # this input_path is just example 9 | #inputpath = "C:/Users/Administrator/Desktop/Smote/describledata/kddcup.data_10_percent_corrected" 10 | Nor_list = ['normal'] 11 | Dos_list = ['back', 'land', 'neptune' , 'pod', 'smurf', 'teardrop', 'mailbomb', 'processtable', 'udpstorm', 'apache2', 'worm'] 12 | R2L_list =['guess_passwd', 'ftp_write', 'imap', 'phf', 'multihop', 'warezmaster', 'warezclient', 'xlock','xsnoop', 'snmpguess', 'snmpgetattack', 'httptunnel','spy','named','snmpguess'] 13 | Probe_list = ['satan','ipsweep', 'nmap', 'portsweep', 'mscan', 'saint'] 14 | U2R_list =['buffer_overflow', 'loadmodule', 'rootkit', 'perl', 'sqlattack', 'xterm', 'ps','httptunnel'] 15 | continue_col_index = [0,4,5,7,8,9,10,12,15,16,17,18,19,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40] 16 | decreate_col_index =[1,2,3,6,11,13,14,20,21] 17 | 18 | 19 | def gen_lable1(seg): 20 | if seg in Nor_list: 21 | return 0 22 | elif seg in Dos_list: 23 | return 1 24 | elif seg in R2L_list: 25 | return 2 26 | elif seg in Probe_list: 27 | return 3 28 | elif seg in U2R_list: 29 | return 4 30 | 31 | 32 | def gen_label2(seg): 33 | if seg in Nor_list: 34 | return 0 35 | else: 36 | return 1 37 | 38 | 39 | def replace(columnlist, data): 40 | # 离散值代替 41 | # listt = [] 42 | for i in columnlist: 43 | listt = [] 44 | for line, seg in enumerate(data[:, i]): 45 | if seg in listt: 46 | data[line][i] = listt.index(seg) 47 | else: 48 | listt.append(seg) 49 | data[line][i] = listt.index(seg) 50 | 51 | 52 | def processdata(file_path): 53 | tmp = None 54 | with codecs.open(file_path, 'r') as f: 55 | content = f.readlines() 56 | datas = [] 57 | for line in content: 58 | line = line.strip() 59 | line = line[:-1].split(',') 60 | datas.append(line) 61 | datas = np.array(datas) 62 | replace(decreate_col_index, datas) 63 | new_datas = [] 64 | for index, col in enumerate(datas): 65 | new_datas.append([float(k) for k in col[:-1]]) 66 | if random.random() < 0.00005: 67 | print(col[-1]) 68 | new_datas[index].append(gen_lable1(col[-1])) 69 | new_datas[index].append(gen_lable1(col[-1])) 70 | tmp = new_datas 71 | datas = np.array(tmp).astype('float32') 72 | 73 | for j in continue_col_index: 74 | meanVal = np.mean(datas[:, j]) 75 | stdVal = np.std(datas[:, j]) 76 | datas[:, j] = (datas[:, j] - meanVal) / stdVal 77 | return datas 78 | def gen_train_test(dat): #分割训练集和测试集 79 | index = int(len(dat)*0.8) 80 | np.random.shuffle(dat) 81 | traindat=torch.from_numpy(dat[:index,:-2]).float() 82 | trainlabel=torch.from_numpy(dat[:index,-2]).long() 83 | validdat = torch.from_numpy(dat[index:,:-2]).float() 84 | validlabel = torch.from_numpy(dat[index:,-2]).long() 85 | # print(dat[0]) 86 | return traindat,trainlabel,validdat,validlabel 87 | 88 | def train_batch(traind,trainl,SIZE=200,SHUFFLE=True): #分批处理 89 | trainset=Data.TensorDataset(traind,trainl) 90 | trainloader=Data.DataLoader( 91 | dataset=trainset, 92 | batch_size=SIZE, 93 | shuffle=SHUFFLE) 94 | return trainloader 95 | 96 | def train_and_test(traind,trainl,testdat,testlabel,loader): 97 | print(type(traind),type(trainl),type(traind),type(trainl),type(loader)) 98 | start_time = time.time() 99 | dbn=DBN(visible_units=len(traind[0])) 100 | 101 | dbn.train() 102 | dbn.train_static(train_data=traind,train_labels=trainl,num_epochs=0,batch_size=20) 103 | 104 | optimizer = torch.optim.SGD(dbn.parameters(), lr=0.001, momentum=0.9) 105 | loss_func = torch.nn.CrossEntropyLoss() 106 | train_loader = loader 107 | dbn.trainBP(train_loader) 108 | 109 | for epoch in range(0): 110 | for step,(x,y) in enumerate(train_loader): 111 | # print(x.data.numpy(),y.data.numpy()) 112 | 113 | b_x=Variable(x) 114 | b_y=Variable(y) 115 | 116 | output=dbn(b_x) 117 | # print(output) 118 | # print(prediction);print(output);print(b_y) 119 | 120 | loss=loss_func(output,b_y) 121 | optimizer.zero_grad() 122 | loss.backward() 123 | optimizer.step() 124 | 125 | if step%10==0: 126 | print('Epoch: ', epoch, 'step:',step,'| train loss: %.4f' % loss.data.numpy()) 127 | duration=time.time()-start_time 128 | 129 | dbn.eval() 130 | test_x = Variable(testdat);test_y = Variable(testlabel) 131 | test_out = dbn(test_x) 132 | # print(test_out) 133 | test_pred = torch.max(test_out, 1)[1] 134 | pre_val = test_pred.data.squeeze().numpy() 135 | y_val = test_y.data.squeeze().numpy() 136 | print('prediciton:',pre_val);print('true value:',y_val) 137 | accuracy = float((pre_val == y_val).astype(int).sum()) / float(test_y.size(0)) 138 | print('test accuracy: %.2f' % accuracy,'duration:%.4f' % duration) 139 | return accuracy, duration 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | -------------------------------------------------------------------------------- /smote.py: -------------------------------------------------------------------------------- 1 | import pandas as pd 2 | import numpy as np 3 | import random 4 | from sklearn.neighbors import NearestNeighbors 5 | import math 6 | from random import randint 7 | import matplotlib.pyplot as plt 8 | from sklearn.decomposition import TruncatedSVD 9 | class Smote(): 10 | def __init__(self,distance,range1,range2): 11 | self.synthetic_arr = [] 12 | self.newindex = 0 13 | self.distance_measure = distance 14 | self.range1 =range1 15 | self.range2 = range2 16 | 17 | def Populate(self, N, i, indices, min_samples, k): 18 | """ 19 | 此代码主要作用是生成增强数组 20 | 21 | Returns:返回增强后的数组 22 | """ 23 | 24 | choice_list =[] 25 | def choice(data): 26 | p = [] 27 | wc = {} 28 | for num in data: 29 | # print(num) 30 | wc[num] = wc.setdefault(num, 0) + 1 31 | for key in wc.keys(): 32 | p.append(wc[key] / len(data)) 33 | # print(p) 34 | keylist = np.array([key for key in wc.keys()]) 35 | # print(wc[0]) 36 | return keylist,p 37 | for index in self.range1: 38 | choice_list.append(choice(min_samples[:,index])) 39 | 40 | while N != 0: 41 | arr = np.zeros(min_samples[0]) 42 | arr[-2] = min_samples[i][-2] 43 | arr[-1] = min_samples[i][-1] 44 | nn = randint(0, k - 2) 45 | # 统计离散型变量 46 | for index in self.range1: 47 | arr[index] = np.random.choice(choice_list[index][0],size=1,p=choice_list[index][1]) 48 | #for attr in features2: 49 | for attr in self.range2: 50 | min_samples[i][attr] = float(min_samples[i][attr]) 51 | min_samples[indices[nn]][attr] = float(min_samples[indices[nn]][attr]) 52 | try: 53 | diff = float(min_samples[indices[nn]][attr]) - float(min_samples[i][attr]) 54 | except: 55 | print('这是第%d列'%attr,min_samples[indices[nn]][attr],min_samples[i][attr]) 56 | gap = random.uniform(0, 1) 57 | 58 | arr[attr] = float(min_samples[i][attr]) + gap * diff 59 | #print(arr) 60 | self.synthetic_arr.append(arr) 61 | self.newindex = self.newindex + 1 62 | N = N - 1 63 | 64 | def k_neighbors(self, euclid_distance, k): 65 | nearest_idx_npy = np.empty([euclid_distance.shape[0], euclid_distance.shape[0]], dtype=np.int64) 66 | 67 | for i in range(len(euclid_distance)): 68 | idx = np.argsort(euclid_distance[i]) 69 | nearest_idx_npy[i] = idx 70 | idx = 0 71 | 72 | return nearest_idx_npy[:, 1:k] 73 | 74 | def find_k(self, X, k): 75 | 76 | """ 77 | Finds k nearest neighbors using euclidian distance 78 | 79 | Returns: The k nearest neighbor 80 | """ 81 | 82 | euclid_distance = np.empty([X.shape[0], X.shape[0]], dtype=np.float32) 83 | 84 | for i in range(len(X)): 85 | dist_arr = [] 86 | for j in range(len(X)): 87 | dist_arr.append(math.sqrt(sum((X[j] - X[i]) ** 2))) 88 | dist_arr = np.asarray(dist_arr, dtype=np.float32) 89 | euclid_distance[i] = dist_arr 90 | 91 | return self.k_neighbors(euclid_distance, k) 92 | 93 | def generate_synthetic_points(self, min_samples, N, k): 94 | 95 | """ 96 | 97 | Parameters 98 | ---------- 99 | min_samples : 要增强的数据 100 | N :要额外生成的负样本的数目 101 | k : int. Number of nearest neighbours. 102 | Returns 103 | ------- 104 | S : Synthetic samples. array, 105 | shape = [(N/100) * n_minority_samples, n_features]. 106 | """ 107 | 108 | if N < 1: 109 | raise ValueError("Value of N cannot be less than 100%") 110 | 111 | if self.distance_measure not in ('euclidian', 'ball_tree'): 112 | raise ValueError("Invalid Distance Measure.You can use only Euclidian or ball_tree") 113 | 114 | if k > min_samples.shape[0]: 115 | raise ValueError("Size of k cannot exceed the number of samples.") 116 | 117 | T = min_samples.shape[0] 118 | 119 | if self.distance_measure == 'euclidian': 120 | indices = self.find_k(min_samples, k) 121 | 122 | elif self.distance_measure == 'ball_tree': 123 | nb = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(min_samples) 124 | distance, indices = nb.kneighbors(min_samples) 125 | indices = indices[:, 1:] 126 | 127 | for i in range(indices.shape[0]): 128 | self.Populate(N, i, indices[i], min_samples, k) 129 | 130 | return np.asarray(self.synthetic_arr) 131 | --------------------------------------------------------------------------------