├── .idea
├── dbn_pytorch.iml
├── encodings.xml
├── misc.xml
├── modules.xml
├── other.xml
├── vcs.xml
└── workspace.xml
├── DBN.py
├── RBM.py
├── README.md
├── __pycache__
├── DBN.cpython-36.pyc
└── RBM.cpython-36.pyc
├── describledata
├── corrected
├── kddcup.data_10_percent_corrected
├── ~$说明.docx
└── 说明.docx
├── exercise.ipynb
├── null
├── process.py
└── smote.py
/.idea/dbn_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 | 1554963002975
217 |
218 |
219 | 1554963002975
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
--------------------------------------------------------------------------------
/DBN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.autograd import Variable
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from RBM import RBM
6 | import time
7 | class DBN(nn.Module):
8 |
9 | def __init__(self,
10 | visible_units=41, # 可视层节点 根据你的特征维度 ,如果有4个feature那就是4 如果有26个那就是26
11 | hidden_units=[11, 6, 11], # 隐藏层节点
12 | k=5, # Gibbs采样步数
13 | learning_rate=1e-3, # 学习率
14 | momentum_coefficient=0.9, # 动量系数
15 | weight_decay=1e-4, # 权重衰减
16 | use_gpu=False,
17 | _activation='sigmoid'):
18 | super(DBN, self).__init__()
19 | device = torch.device("cuda:0" if use_gpu else "cpu")
20 | self.n_layers = len(hidden_units) # 隐含层数
21 | self.rbm_layers = [] # rbm
22 | self.rbm_nodes = []
23 | # 构建不同的RBM层
24 | for i in range(self.n_layers):
25 |
26 | if i == 0:
27 | input_size = visible_units
28 | else:
29 | input_size = hidden_units[i - 1]
30 | rbm = RBM(visible_units=input_size,
31 | hidden_units=hidden_units[i],
32 | k=k,
33 | learning_rate=learning_rate,
34 | momentum_coefficient=momentum_coefficient,
35 | weight_decay=weight_decay,
36 | use_gpu=use_gpu,
37 | _activation=_activation).to(device)
38 |
39 | self.rbm_layers.append(rbm)
40 | self.W_rec = [nn.Parameter(self.rbm_layers[i].weight.data.clone()) for i in range(self.n_layers - 1)]
41 | self.W_gen = [nn.Parameter(self.rbm_layers[i].weight.data) for i in range(self.n_layers - 1)]
42 | self.bias_rec = [nn.Parameter(self.rbm_layers[i].c.data.clone()) for i in range(self.n_layers - 1)]
43 | self.bias_gen = [nn.Parameter(self.rbm_layers[i].b.data) for i in range(self.n_layers - 1)]
44 | self.W_mem = nn.Parameter(self.rbm_layers[-1].weight.data)
45 | self.v_bias_mem = nn.Parameter(self.rbm_layers[-1].b.data)
46 | self.h_bias_mem = nn.Parameter(self.rbm_layers[-1].c.data)
47 | for i in range(self.n_layers-1):
48 | self.register_parameter('W_rec%i'%i, self.W_rec[i])
49 | self.register_parameter('W_gen%i'%i, self.W_gen[i])
50 | self.register_parameter('bias_rec%i'%i, self.bias_rec[i])
51 | self.register_parameter('bias_gen%i'%i, self.bias_gen[i])
52 |
53 | self.BPNN=nn.Sequential( #用作分类和反向微调参数
54 | torch.nn.Linear(11, 11),
55 | torch.nn.ReLU(),
56 | torch.nn.Dropout(0.5),
57 | torch.nn.Linear(11,5),
58 | )
59 | def forward(self , input_data):
60 | '''
61 | 前馈
62 | '''
63 | v = input_data
64 | for i in range(len(self.rbm_layers)):
65 | v = v.view((v.shape[0] , -1)).type(torch.FloatTensor)#flatten
66 | p_v,v = self.rbm_layers[i].forward(v)
67 | # print('p_v:', p_v.shape,p_v)
68 | # print('v:',v.shape,v)
69 | out=self.BPNN(p_v)
70 | # print('out',out.shape,out)
71 | # print(self.BPNN(p_v))
72 | return out
73 |
74 | def train_static(self, train_data,train_labels,num_epochs,batch_size):
75 | '''
76 | 逐层贪婪训练RBM,固定上一层
77 | '''
78 |
79 | tmp = train_data
80 |
81 | for i in range(len(self.rbm_layers)):
82 | print("-"*20)
83 | print("Training the {} st rbm layer".format(i+1))
84 |
85 | tensor_x = tmp.type(torch.FloatTensor)
86 | tensor_y = train_labels.type(torch.FloatTensor)
87 | _dataset = torch.utils.data.TensorDataset(tensor_x,tensor_y)
88 | _dataloader = torch.utils.data.DataLoader(_dataset)
89 |
90 | self.rbm_layers[i].trains(_dataloader,num_epochs,batch_size)
91 | print(type(_dataloader))
92 | # print(train_data.shape)
93 | v = tmp.view((tmp.shape[0] , -1)).type(torch.FloatTensor)
94 | v,_ = self.rbm_layers[i].forward(v)
95 | tmp = v
96 | # print(v.shape)
97 | return
98 |
99 | def train_ith(self, train_data,num_epochs,batch_size,ith_layer,rbm_layers):
100 | '''
101 | 只训练某一层,可用作调优
102 | '''
103 | if(ith_layer>len(rbm_layers)):
104 | return
105 |
106 | v = train_data
107 | for ith in range(ith_layer):
108 | v,out_ = self.rbm_layers[ith].forward(v)
109 |
110 |
111 | self.rbm_layers[ith_layer].trains(v, num_epochs,batch_size)
112 | return
113 |
114 | def trainBP(self,trainloader):
115 | optimizer = torch.optim.SGD(self.BPNN.parameters(), lr=0.005, momentum=0.7)
116 | loss_func = torch.nn.CrossEntropyLoss()
117 | for epoch in range(5):
118 | for step,(x,y) in enumerate(trainloader):
119 | bx = Variable(x)
120 | by = Variable(y)
121 | out=self.forward(bx)[1]
122 | # print(out)
123 | loss=loss_func(out,by)
124 | optimizer.zero_grad()
125 | loss.backward()
126 | optimizer.step()
127 | if step % 10 == 0:
128 | print('Epoch: ', epoch, 'step:', step, '| train loss: %.4f' % loss.data.numpy())
129 |
130 |
131 |
132 |
133 |
--------------------------------------------------------------------------------
/RBM.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torchvision
3 |
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import math
7 | import numpy as np
8 | BATCH_SIZE = 200
9 | class RBM(nn.Module):
10 | def __init__(self,visible_units=26,
11 | hidden_units = 9,
12 | k=2,
13 | learning_rate=1e-3,
14 | momentum_coefficient=0.5,
15 | weight_decay = 1e-4,
16 | use_gpu = False,
17 | _activation='sigmoid'):
18 | super(RBM,self).__init__()
19 | # 这些都是可调参数,随便设置
20 | self.visible_units = visible_units
21 | self.hidden_units = hidden_units
22 | self.k = k
23 | self.learning_rate = learning_rate
24 | self.momentum_coefficient = momentum_coefficient
25 | self.weight_decay = weight_decay
26 | self.use_gpu = use_gpu# 如果你有最好设置为有
27 | self._activation = _activation
28 | self.weight = torch.randn(self.visible_units, self.hidden_units) / math.sqrt(self.visible_units) # 初始化
29 | self.c = torch.randn(self.hidden_units) / math.sqrt(self.hidden_units)
30 | self.b = torch.randn(self.visible_units) / math.sqrt(self.visible_units)
31 |
32 | self.W_momentum = torch.zeros(self.visible_units, self.hidden_units)
33 | self.b_momentum = torch.zeros(self.visible_units)
34 | self.c_momentum = torch.zeros(self.hidden_units)
35 | # 设置激活函数
36 | def activation(self,X):
37 | if self._activation=='sigmoid':
38 | return nn.functional.sigmoid(X)
39 | elif self._activation=='tanh':
40 | return nn.functional.tanh(X)
41 | elif self._activation=='relu':
42 | return nn.functional.relu(X)
43 | else:
44 | raise ValueError("Invalid Activation Function")
45 |
46 | def to_hidden(self, X):
47 | '''
48 | 根据可视层生成隐藏层
49 | 通过采样进行
50 | X 为可视层概率分布
51 | :param X: torch tensor shape = (n_samples , n_features)
52 | :return - hidden - 新的隐藏层 (概率)
53 | sample_h - 吉布斯样本 (1 or 0)
54 | '''
55 | # print('hinput:',X)
56 | hidden = torch.matmul(X, self.weight)
57 | hidden = torch.add(hidden, self.c) # W.x + c
58 | # print('mm:',hidden)
59 | hidden = self.activation(hidden)
60 |
61 | sample_h = self.sampling(hidden)
62 | # print('h:',hidden,'sam_h:',sample_h)
63 |
64 | return hidden, sample_h
65 |
66 |
67 | def to_visible(self,X):
68 | '''
69 | 根据隐藏层重构可视层
70 | 也通过采样进行
71 | X 为隐藏层概率分布
72 | :returns - X_dash - 新的重构层(概率)
73 | sample_X_dash - 新的样本(吉布斯采样)
74 |
75 | '''
76 | # 计算隐含层激活,然后转换为概率
77 | # print('vinput:',X)
78 | X_dash = torch.matmul(X ,self.weight.transpose( 0 , 1) )
79 | X_dash = torch.add(X_dash , self.b) #W.T*x+b
80 | # print('mm:',X_dash)
81 | X_dash = self.activation(X_dash)
82 |
83 | sample_X_dash = self.sampling(X_dash)
84 | # print('v:',X_dash, 'sam_v:', sample_X_dash)
85 |
86 | return X_dash,sample_X_dash
87 |
88 | def sampling(self,s):
89 | '''
90 | 通过Bernoulli函数进行吉布斯采样
91 | '''
92 | s = torch.distributions.Bernoulli(s)
93 | return s.sample()
94 | def reconstruction_error(self , data):
95 | '''
96 | 通过损失函数计算重构误差
97 | '''
98 | return self.contrastive_divergence(data, False)
99 |
100 |
101 | def contrastive_divergence(self, input_data ,training = True):
102 | '''
103 | 对比散列算法
104 | '''
105 | # positive phase
106 | positive_hidden_probabilities,positive_hidden_act = self.to_hidden(input_data)
107 |
108 | # 计算 W
109 | positive_associations = torch.matmul(input_data.t() , positive_hidden_act)
110 |
111 |
112 |
113 | # negetive phase
114 | hidden_activations = positive_hidden_act
115 | for i in range(self.k): #采样步数
116 | visible_p , _ = self.to_visible(hidden_activations)
117 | hidden_probabilities,hidden_activations = self.to_hidden(visible_p)
118 |
119 | negative_visible_probabilities = visible_p
120 | negative_hidden_probabilities = hidden_probabilities
121 |
122 | # 计算 W
123 | negative_associations = torch.matmul(negative_visible_probabilities.t() , negative_hidden_probabilities)
124 |
125 |
126 | # 更新参数
127 | if(training):
128 | self.W_momentum *= self.momentum_coefficient
129 | self.W_momentum += (positive_associations - negative_associations)
130 |
131 | self.b_momentum *= self.momentum_coefficient
132 | self.b_momentum += torch.sum(input_data - negative_visible_probabilities, dim=0)
133 |
134 | self.c_momentum *= self.momentum_coefficient
135 | self.c_momentum += torch.sum(positive_hidden_probabilities - negative_hidden_probabilities, dim=0)
136 |
137 | batch_size = input_data.size(0)
138 |
139 | self.weight += self.W_momentum * self.learning_rate / BATCH_SIZE
140 | self.b += self.b_momentum * self.learning_rate / BATCH_SIZE
141 | self.c += self.c_momentum * self.learning_rate / BATCH_SIZE
142 |
143 | self.weight -= self.weight * self.weight_decay # L2 weight decay
144 |
145 | # 计算重构误差
146 | error = torch.mean(torch.sum((input_data - negative_visible_probabilities)**2 , 1))
147 | # print('i:',input_data,'o:',negative_hidden_probabilities)
148 |
149 | return error
150 | def forward(self,input_data):
151 | return self.to_hidden(input_data)
152 |
153 | def step(self,input_data):
154 | '''
155 | 包括前馈和梯度下降,用作训练
156 | '''
157 | # print('w:',self.weight);print('b:',self.b);print('c:',self.c)
158 | return self.contrastive_divergence(input_data , True)
159 |
160 |
161 | def trains(self,train_data,num_epochs = 50,batch_size= 20):
162 |
163 | BATCH_SIZE = batch_size
164 |
165 | if(isinstance(train_data ,torch.utils.data.DataLoader)):
166 | train_loader = train_data
167 | else:
168 | train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size)
169 |
170 |
171 | for epochs in range(num_epochs):
172 | epoch_err = 0.0
173 |
174 | for batch,_ in train_loader:
175 | # batch = batch.view(len(batch) , self.visible_units)
176 |
177 | if(self.use_gpu):
178 | batch = batch.cuda()
179 | batch_err = self.step(batch)
180 |
181 | epoch_err += batch_err
182 |
183 |
184 | print("Epoch Error(epoch:%d) : %.4f" % (epochs , epoch_err))
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # dbn_pytorch
2 | 用Pytorch 实现dbn 如果有别的问题,请联系我
3 |
4 | 要求的包 pytorch, numpy ,tqdm
5 |
6 |
7 | 关于rbm的原理可以看[这里](https://blog.csdn.net/itplus/article/details/19168937)
8 |
9 | dbn只不过是将rbm进行固定住而已,然后推往下一层
10 | 其中 有几点比较重要的他没有仔细提到的就是关于MCMC中的Metropolis–Hastings算法与吉布斯采样
11 |
12 | 这几点着实也困扰了我好长时间 可以通过下面俩篇博客进行学习
13 |
14 |
15 | [白马博客](https://blog.csdn.net/baimafujinji/article/details/53946367)
16 |
17 | [刘建平博客](https://www.cnblogs.com/pinard/p/6638955.html)
18 |
19 | 关于数据集的描述[查看](https://blog.csdn.net/com_stu_zhang/article/details/6987632)
20 |
21 |
22 | 关于代码 的运行步骤请看exercise.ipynb,里面有详细的代码
23 | 另外data.npz.npy 、adddata、testdata.npz.npy 分别是经过预处理的训练数据
24 | 加强后smote生成了5000个点的数据 testdata.npz.npy 是测试数据
25 |
26 | 数据增强 -> smote..py
27 |
28 |
29 | 在这里我把 最外层的softmax层设置成了5个单元,当进行2分类的时候请修改
30 |
31 | ## Tips
32 |
33 | 另外: 虽然用了gibbs采样代码还是非常慢,请合理设置epoch!
34 |
35 | 刚开始我没用gibbs而是直接进行训练,跑了我整整一天,sad!!!!
36 |
37 | 后来我又用了svd,train了一发,发现同样不行 sad*2!!!!!
38 |
39 | 效果烂的快哭了,请不要和我一样心态失衡!
40 |
41 | 对了,如果exercise.ipynb 你对dbn的参数进行修改,然后运行ipynb的时候,记得重启下。
42 |
43 | 我也不知道为什么ipynb对缓存没有消除,就酱!
44 |
45 |
--------------------------------------------------------------------------------
/__pycache__/DBN.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/__pycache__/DBN.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/RBM.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/__pycache__/RBM.cpython-36.pyc
--------------------------------------------------------------------------------
/describledata/~$说明.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/describledata/~$说明.docx
--------------------------------------------------------------------------------
/describledata/说明.docx:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/describledata/说明.docx
--------------------------------------------------------------------------------
/exercise.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "import numpy as np\n",
12 | "import pandas as pd\n",
13 | "import codecs\n",
14 | "import matplotlib\n",
15 | "import matplotlib.pyplot as plt\n",
16 | "import random\n",
17 | "import random\n",
18 | "import codecs\n",
19 | "# this input_path is just example\n",
20 | "#inputpath = \"C:/Users/Administrator/Desktop/Smote/describledata/kddcup.data_10_percent_corrected\"\n",
21 | "Nor_list = ['normal']\n",
22 | "Dos_list = ['back', 'land', 'neptune' , 'pod', 'smurf', 'teardrop', 'mailbomb', 'processtable', 'udpstorm', 'apache2', 'worm']\n",
23 | "R2L_list =['guess_passwd', 'ftp_write', 'imap', 'phf', 'multihop', 'warezmaster', 'warezclient', 'xlock','xsnoop', 'snmpguess', 'snmpgetattack', 'httptunnel','spy','named','snmpguess']\n",
24 | "Probe_list = ['satan','ipsweep', 'nmap', 'portsweep', 'mscan', 'saint']\n",
25 | "U2R_list =['buffer_overflow', 'loadmodule', 'rootkit', 'perl', 'sqlattack', 'xterm', 'ps','httptunnel']\n",
26 | "continue_col_index = [0,4,5,7,8,9,10,12,15,16,17,18,19,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]\n",
27 | "decreate_col_index =[1,2,3,6,11,13,14,20,21]"
28 | ]
29 | },
30 | {
31 | "cell_type": "code",
32 | "execution_count": 2,
33 | "metadata": {},
34 | "outputs": [],
35 | "source": [
36 | "def gen_lable1(seg):\n",
37 | " if seg in Nor_list:\n",
38 | " return 0\n",
39 | " elif seg in Dos_list:\n",
40 | " return 1\n",
41 | " elif seg in R2L_list:\n",
42 | " return 2\n",
43 | " elif seg in Probe_list:\n",
44 | " return 3\n",
45 | " elif seg in U2R_list: \n",
46 | " return 4\n",
47 | " \n",
48 | " \n",
49 | "def gen_label2(seg):\n",
50 | " if seg in Nor_list:\n",
51 | " return 0\n",
52 | " else:\n",
53 | " return 1\n",
54 | "def replace(columnlist,data):\n",
55 | " #离散值代替\n",
56 | " #listt = []\n",
57 | " for i in columnlist:\n",
58 | " listt = []\n",
59 | " for line,seg in enumerate(data[:,i]):\n",
60 | " if seg in listt:\n",
61 | " data[line][i] = listt.index(seg)\n",
62 | " else:\n",
63 | " listt.append(seg)\n",
64 | " data[line][i] = listt.index(seg)\n",
65 | "\n",
66 | "def processdata(file_path):\n",
67 | " tmp = None\n",
68 | " with codecs.open(file_path,'r') as f:\n",
69 | " content = f.readlines()\n",
70 | " datas =[]\n",
71 | " for line in content:\n",
72 | " line = line.strip()\n",
73 | " line = line[:-1].split(',')\n",
74 | " datas.append(line)\n",
75 | " datas = np.array(datas)\n",
76 | " replace(decreate_col_index,datas)\n",
77 | " new_datas =[]\n",
78 | " for index,col in enumerate(datas):\n",
79 | " new_datas.append([float(k) for k in col[:-1]])\n",
80 | " if random.random()<0.00005:\n",
81 | " print(col[-1])\n",
82 | " new_datas[index].append(gen_lable1(col[-1]))\n",
83 | " new_datas[index].append(gen_label2(col[-1]))\n",
84 | " tmp = new_datas\n",
85 | " datas = np.array(tmp).astype('float32')\n",
86 | " \n",
87 | " for j in continue_col_index:\n",
88 | " meanVal=np.mean(datas[:,j])\n",
89 | " stdVal=np.std(datas[:,j])\n",
90 | " datas[:,j]=(datas[:,j]-meanVal)/stdVal\n",
91 | " return datas\n"
92 | ]
93 | },
94 | {
95 | "cell_type": "code",
96 | "execution_count": 3,
97 | "metadata": {
98 | "collapsed": true
99 | },
100 | "outputs": [
101 | {
102 | "name": "stdout",
103 | "output_type": "stream",
104 | "text": [
105 | "normal\n"
106 | ]
107 | },
108 | {
109 | "name": "stdout",
110 | "output_type": "stream",
111 | "text": [
112 | "normal\n"
113 | ]
114 | },
115 | {
116 | "name": "stdout",
117 | "output_type": "stream",
118 | "text": [
119 | "normal\nsmurf\n"
120 | ]
121 | },
122 | {
123 | "name": "stdout",
124 | "output_type": "stream",
125 | "text": [
126 | "neptune\n"
127 | ]
128 | },
129 | {
130 | "name": "stdout",
131 | "output_type": "stream",
132 | "text": [
133 | "neptune\n"
134 | ]
135 | },
136 | {
137 | "name": "stdout",
138 | "output_type": "stream",
139 | "text": [
140 | "smurf\nsmurf\nsmurf\n"
141 | ]
142 | },
143 | {
144 | "name": "stdout",
145 | "output_type": "stream",
146 | "text": [
147 | "smurf\n"
148 | ]
149 | },
150 | {
151 | "name": "stdout",
152 | "output_type": "stream",
153 | "text": [
154 | "smurf\n"
155 | ]
156 | },
157 | {
158 | "name": "stdout",
159 | "output_type": "stream",
160 | "text": [
161 | "smurf\nsmurf\n"
162 | ]
163 | },
164 | {
165 | "name": "stdout",
166 | "output_type": "stream",
167 | "text": [
168 | "smurf\n"
169 | ]
170 | },
171 | {
172 | "name": "stdout",
173 | "output_type": "stream",
174 | "text": [
175 | "normal\n"
176 | ]
177 | },
178 | {
179 | "name": "stdout",
180 | "output_type": "stream",
181 | "text": [
182 | "neptune\n"
183 | ]
184 | },
185 | {
186 | "name": "stdout",
187 | "output_type": "stream",
188 | "text": [
189 | "neptune\n"
190 | ]
191 | },
192 | {
193 | "name": "stdout",
194 | "output_type": "stream",
195 | "text": [
196 | "smurf\n"
197 | ]
198 | },
199 | {
200 | "name": "stdout",
201 | "output_type": "stream",
202 | "text": [
203 | "smurf\nsmurf\nnormal\n"
204 | ]
205 | },
206 | {
207 | "name": "stderr",
208 | "output_type": "stream",
209 | "text": [
210 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\ipykernel_launcher.py:55: RuntimeWarning: invalid value encountered in true_divide\n"
211 | ]
212 | }
213 | ],
214 | "source": [
215 | "data = processdata('./describledata/kddcup.data_10_percent_corrected')"
216 | ]
217 | },
218 | {
219 | "cell_type": "code",
220 | "execution_count": 4,
221 | "metadata": {},
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "smurf\nsmurf\n"
228 | ]
229 | },
230 | {
231 | "name": "stdout",
232 | "output_type": "stream",
233 | "text": [
234 | "normal\n"
235 | ]
236 | },
237 | {
238 | "name": "stdout",
239 | "output_type": "stream",
240 | "text": [
241 | "normal\n"
242 | ]
243 | },
244 | {
245 | "name": "stdout",
246 | "output_type": "stream",
247 | "text": [
248 | "neptune\nneptune\n"
249 | ]
250 | },
251 | {
252 | "name": "stdout",
253 | "output_type": "stream",
254 | "text": [
255 | "smurf\n"
256 | ]
257 | },
258 | {
259 | "name": "stdout",
260 | "output_type": "stream",
261 | "text": [
262 | "smurf\n"
263 | ]
264 | },
265 | {
266 | "name": "stdout",
267 | "output_type": "stream",
268 | "text": [
269 | "smurf\n"
270 | ]
271 | },
272 | {
273 | "name": "stdout",
274 | "output_type": "stream",
275 | "text": [
276 | "neptune\nneptune\n"
277 | ]
278 | },
279 | {
280 | "name": "stderr",
281 | "output_type": "stream",
282 | "text": [
283 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\ipykernel_launcher.py:55: RuntimeWarning: invalid value encountered in true_divide\n"
284 | ]
285 | }
286 | ],
287 | "source": [
288 | "#\n",
289 | "label_dict ={}\n",
290 | "for i in data[:,-2]:\n",
291 | " if i in label_dict.keys():\n",
292 | " continue\n",
293 | " else:\n",
294 | " label_dict[i] = len(label_dict)\n",
295 | "test_data = processdata('./describledata/corrected')"
296 | ]
297 | },
298 | {
299 | "cell_type": "code",
300 | "execution_count": 5,
301 | "metadata": {},
302 | "outputs": [],
303 | "source": [
304 | "def fill_nan(data):\n",
305 | " #m,n = data.shape\n",
306 | " choice_list =[]\n",
307 | " matrix = np.isnan(data)\n",
308 | " for m in range(len(data)):\n",
309 | " for n in range(len(data[0])):\n",
310 | " if matrix[m,n] ==True:\n",
311 | " data[m][n] = 0"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 6,
317 | "metadata": {
318 | "collapsed": true
319 | },
320 | "outputs": [],
321 | "source": [
322 | "fill_nan(data)\n",
323 | "fill_nan(test_data)"
324 | ]
325 | },
326 | {
327 | "cell_type": "code",
328 | "execution_count": 10,
329 | "metadata": {},
330 | "outputs": [],
331 | "source": [
332 | "import pandas as pd\n",
333 | "import numpy as np\n",
334 | "import random\n",
335 | "from sklearn.neighbors import NearestNeighbors\n",
336 | "import math\n",
337 | "from random import randint\n",
338 | "import matplotlib.pyplot as plt\n",
339 | "from sklearn.decomposition import TruncatedSVD\n",
340 | "class Smote():\n",
341 | " def __init__(self,distance,range1,range2):\n",
342 | " self.synthetic_arr = []\n",
343 | " self.newindex = 0\n",
344 | " self.distance_measure = distance\n",
345 | " self.range1 =range1\n",
346 | " self.range2 = range2\n",
347 | "\n",
348 | " def Populate(self, N, i, indices, min_samples, k):\n",
349 | " \"\"\"\n",
350 | " 此代码主要作用是生成增强数组\n",
351 | "\n",
352 | " Returns:返回增强后的数组\n",
353 | " \"\"\"\n",
354 | "\n",
355 | " choice_list =[]\n",
356 | " def choice(data):\n",
357 | " p = []\n",
358 | " wc = {}\n",
359 | " for num in data:\n",
360 | " # print(num)\n",
361 | " wc[num] = wc.setdefault(num, 0) + 1\n",
362 | " for key in wc.keys():\n",
363 | " p.append(wc[key] / len(data))\n",
364 | " # print(p)\n",
365 | " keylist = np.array([key for key in wc.keys()])\n",
366 | " # print(wc[0])\n",
367 | " return keylist,p\n",
368 | " for index in self.range1:\n",
369 | " choice_list.append(choice(min_samples[:,index]))\n",
370 | "\n",
371 | " while N != 0:\n",
372 | " arr = np.zeros(len(min_samples[0]))\n",
373 | " arr[-2] = min_samples[i][-2]\n",
374 | " arr[-1] = min_samples[i][-1]\n",
375 | " nn = randint(0, k - 2)\n",
376 | " # 统计离散型变量\n",
377 | " for rowindex,index in enumerate(self.range1):\n",
378 | " arr[index] = np.random.choice(choice_list[rowindex][0],size=1,p=choice_list[rowindex][1])\n",
379 | " #for attr in features2:\n",
380 | " for attr in self.range2:\n",
381 | " min_samples[i][attr] = float(min_samples[i][attr])\n",
382 | " min_samples[indices[nn]][attr] = float(min_samples[indices[nn]][attr])\n",
383 | " try:\n",
384 | " diff = float(min_samples[indices[nn]][attr]) - float(min_samples[i][attr])\n",
385 | " except:\n",
386 | " print('这是第%d列'%attr,min_samples[indices[nn]][attr],min_samples[i][attr])\n",
387 | " gap = random.uniform(0, 1)\n",
388 | "\n",
389 | " arr[attr] = float(min_samples[i][attr]) + gap * diff\n",
390 | " #print(arr)\n",
391 | " self.synthetic_arr.append(arr)\n",
392 | " self.newindex = self.newindex + 1\n",
393 | " N = N - 1\n",
394 | "\n",
395 | " def k_neighbors(self, euclid_distance, k):\n",
396 | " nearest_idx_npy = np.empty([euclid_distance.shape[0], euclid_distance.shape[0]], dtype=np.int64)\n",
397 | "\n",
398 | " for i in range(len(euclid_distance)):\n",
399 | " idx = np.argsort(euclid_distance[i])\n",
400 | " nearest_idx_npy[i] = idx\n",
401 | " idx = 0\n",
402 | "\n",
403 | " return nearest_idx_npy[:, 1:k]\n",
404 | "\n",
405 | " def find_k(self, X, k):\n",
406 | "\n",
407 | " \"\"\"\n",
408 | " Finds k nearest neighbors using euclidian distance\n",
409 | "\n",
410 | " Returns: The k nearest neighbor\n",
411 | " \"\"\"\n",
412 | "\n",
413 | " euclid_distance = np.empty([X.shape[0], X.shape[0]], dtype=np.float32)\n",
414 | "\n",
415 | " for i in range(len(X)):\n",
416 | " dist_arr = []\n",
417 | " for j in range(len(X)):\n",
418 | " dist_arr.append(math.sqrt(sum((X[j] - X[i]) ** 2)))\n",
419 | " dist_arr = np.asarray(dist_arr, dtype=np.float32)\n",
420 | " euclid_distance[i] = dist_arr\n",
421 | "\n",
422 | " return self.k_neighbors(euclid_distance, k)\n",
423 | "\n",
424 | " def generate_synthetic_points(self, min_samples, N, k):\n",
425 | "\n",
426 | " \"\"\"\n",
427 | "\n",
428 | " Parameters\n",
429 | " ----------\n",
430 | " min_samples : 要增强的数据\n",
431 | " N :要额外生成的负样本的数目\n",
432 | " k : int. Number of nearest neighbours.\n",
433 | " Returns\n",
434 | " -------\n",
435 | " S : Synthetic samples. array,\n",
436 | " shape = [(N/100) * n_minority_samples, n_features].\n",
437 | " \"\"\"\n",
438 | "\n",
439 | " if N < 1:\n",
440 | " raise ValueError(\"Value of N cannot be less than 100%\")\n",
441 | "\n",
442 | " if self.distance_measure not in ('euclidian', 'ball_tree'):\n",
443 | " raise ValueError(\"Invalid Distance Measure.You can use only Euclidian or ball_tree\")\n",
444 | "\n",
445 | " if k > min_samples.shape[0]:\n",
446 | " raise ValueError(\"Size of k cannot exceed the number of samples.\")\n",
447 | "\n",
448 | " T = min_samples.shape[0]\n",
449 | "\n",
450 | " if self.distance_measure == 'euclidian':\n",
451 | " indices = self.find_k(min_samples, k)\n",
452 | "\n",
453 | " elif self.distance_measure == 'ball_tree':\n",
454 | " nb = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(min_samples)\n",
455 | " distance, indices = nb.kneighbors(min_samples)\n",
456 | " indices = indices[:, 1:]\n",
457 | "\n",
458 | " for i in range(indices.shape[0]):\n",
459 | " self.Populate(N, i, indices[i], min_samples, k)\n",
460 | "\n",
461 | " return np.asarray(self.synthetic_arr)"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 14,
467 | "metadata": {
468 | "collapsed": true
469 | },
470 | "outputs": [],
471 | "source": [
472 | "rowindex = []\n",
473 | "for i, line in enumerate(data):\n",
474 | " if line[-2] ==4.0:\n",
475 | " rowindex.append(i)\n",
476 | "range1 = [1,2,3,6,11,13,14,20,21]\n",
477 | "range2 = [0,4,5,7,8,9,10,12,15,16,17,18,19,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]\n",
478 | "minsamples = data[rowindex]\n",
479 | "smote = Smote(distance='ball_tree',range1=range1,range2=range2)\n",
480 | "smotedata = smote.generate_synthetic_points(min_samples=minsamples,N=100,k=9)\n"
481 | ]
482 | },
483 | {
484 | "cell_type": "code",
485 | "execution_count": 15,
486 | "metadata": {
487 | "collapsed": true
488 | },
489 | "outputs": [
490 | {
491 | "data": {
492 | "text/plain": [
493 | "(5200, 43)"
494 | ]
495 | },
496 | "execution_count": 15,
497 | "metadata": {},
498 | "output_type": "execute_result"
499 | }
500 | ],
501 | "source": [
502 | "smotedata.shape"
503 | ]
504 | },
505 | {
506 | "cell_type": "code",
507 | "execution_count": 16,
508 | "metadata": {},
509 | "outputs": [],
510 | "source": [
511 | "adddata = np.concatenate((smotedata,data),axis=0)"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 17,
517 | "metadata": {},
518 | "outputs": [
519 | {
520 | "data": {
521 | "text/plain": [
522 | "(499221, 43)"
523 | ]
524 | },
525 | "execution_count": 17,
526 | "metadata": {},
527 | "output_type": "execute_result"
528 | }
529 | ],
530 | "source": [
531 | "adddata.shape"
532 | ]
533 | },
534 | {
535 | "cell_type": "code",
536 | "execution_count": 18,
537 | "metadata": {},
538 | "outputs": [],
539 | "source": [
540 | "fill_nan(test_data)"
541 | ]
542 | },
543 | {
544 | "cell_type": "code",
545 | "execution_count": 2,
546 | "metadata": {},
547 | "outputs": [],
548 | "source": [
549 | "# 保存数据\n",
550 | "\"\"\"\n",
551 | "np.save('data.npz',data)\n",
552 | "np.save('adddata.npz',adddata)\n",
553 | "np.save('testdata.npz',test_data)\n",
554 | "\"\"\"\n",
555 | "import numpy as np\n",
556 | "import pandas as pd\n",
557 | "import codecs\n",
558 | "import matplotlib\n",
559 | "import matplotlib.pyplot as plt\n",
560 | "import random\n",
561 | "import random\n",
562 | "import codecs\n",
563 | "data = np.load('data.npz.npy')\n",
564 | "adddata =np.load('adddata.npz.npy')\n",
565 | "test_data =np.load('testdata.npz.npy')\n"
566 | ]
567 | },
568 | {
569 | "cell_type": "code",
570 | "execution_count": 3,
571 | "metadata": {},
572 | "outputs": [],
573 | "source": [
574 | "def gen_train_test(dat): #分割训练集和测试集\n",
575 | " index = int(len(dat)*0.8)\n",
576 | " np.random.shuffle(dat)\n",
577 | " traindat=torch.from_numpy(dat[:index,:-2]).float()\n",
578 | " trainlabel=torch.from_numpy(dat[:index,-2]).long()\n",
579 | " validdat = torch.from_numpy(dat[index:,:-2]).float()\n",
580 | " validlabel = torch.from_numpy(dat[index:,-2]).long()\n",
581 | " # print(dat[0])\n",
582 | " return traindat,trainlabel,validdat,validlabel"
583 | ]
584 | },
585 | {
586 | "cell_type": "code",
587 | "execution_count": 4,
588 | "metadata": {},
589 | "outputs": [],
590 | "source": [
591 | "import tqdm\n",
592 | "import torch\n",
593 | "from DBN import DBN\n",
594 | "import torch.utils.data as Data\n",
595 | "def train_batch(traind,trainl,SIZE=500,SHUFFLE=True): #分批处理\n",
596 | " trainset=Data.TensorDataset(traind,trainl)\n",
597 | " trainloader=Data.DataLoader(\n",
598 | " dataset=trainset,\n",
599 | " batch_size=SIZE,\n",
600 | " shuffle=SHUFFLE)\n",
601 | " return trainloader\n"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": 5,
607 | "metadata": {},
608 | "outputs": [],
609 | "source": [
610 | "traindat,trainlabel, validdat, validlabel = gen_train_test(data)"
611 | ]
612 | },
613 | {
614 | "cell_type": "code",
615 | "execution_count": 6,
616 | "metadata": {},
617 | "outputs": [],
618 | "source": [
619 | "trainloader = train_batch(traindat,trainlabel)\n",
620 | "from torch.autograd.variable import Variable"
621 | ]
622 | },
623 | {
624 | "cell_type": "code",
625 | "execution_count": 7,
626 | "metadata": {},
627 | "outputs": [],
628 | "source": [
629 | "device = torch.device('cuda:0')\n"
630 | ]
631 | },
632 | {
633 | "cell_type": "code",
634 | "execution_count": 9,
635 | "metadata": {
636 | "collapsed": true
637 | },
638 | "outputs": [
639 | {
640 | "name": "stdout",
641 | "output_type": "stream",
642 | "text": [
643 | "--------------------"
644 | ]
645 | },
646 | {
647 | "name": "stdout",
648 | "output_type": "stream",
649 | "text": [
650 | "\n"
651 | ]
652 | },
653 | {
654 | "name": "stdout",
655 | "output_type": "stream",
656 | "text": [
657 | "Training the 1 st rbm layer"
658 | ]
659 | },
660 | {
661 | "name": "stdout",
662 | "output_type": "stream",
663 | "text": [
664 | "\n"
665 | ]
666 | },
667 | {
668 | "name": "stderr",
669 | "output_type": "stream",
670 | "text": [
671 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\torch\\nn\\functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
672 | ]
673 | },
674 | {
675 | "name": "stdout",
676 | "output_type": "stream",
677 | "text": [
678 | "Epoch Error(epoch:0) : 45406564.0000"
679 | ]
680 | },
681 | {
682 | "name": "stdout",
683 | "output_type": "stream",
684 | "text": [
685 | "\n"
686 | ]
687 | }
688 | ],
689 | "source": [
690 | "import time\n",
691 | "\n",
692 | "start_time = time.time()\n",
693 | "dbn=DBN(visible_units=len(traindat[0]))\n",
694 | "dbn.to(device)\n",
695 | "dbn.train()\n",
696 | "dbn.train_static(train_data=traindat,train_labels=trainlabel,num_epochs=3,batch_size=20)\n",
697 | "\n",
698 | "optimizer = torch.optim.SGD(dbn.parameters(), lr=0.001, momentum=0.9)\n",
699 | "loss_func = torch.nn.CrossEntropyLoss()\n",
700 | "\n",
701 | "dbn.trainBP(trainloader)\n",
702 | "\n",
703 | "for epoch in range(5):\n",
704 | " for step,(x,y) in tqdm(enumerate(trainloader)):\n",
705 | " # print(x.data.numpy(),y.data.numpy())\n",
706 | "\n",
707 | " b_x=Variable(x)\n",
708 | " b_y=Variable(y)\n",
709 | "\n",
710 | " output=dbn(b_x)\n",
711 | " # print(output)\n",
712 | " # print(prediction);print(output);print(b_y)\n",
713 | "\n",
714 | " loss=loss_func(output,b_y)\n",
715 | " optimizer.zero_grad()\n",
716 | " loss.backward()\n",
717 | " optimizer.step()\n",
718 | "\n",
719 | " if step%10==0:\n",
720 | " print('Epoch: ', epoch, 'step:',step,'| train loss: %.4f' % loss.data.numpy())\n",
721 | "duration=time.time()-start_time\n",
722 | "\n",
723 | "dbn.eval()\n",
724 | "test_x = Variable(validdat);test_y = Variable(validlabel)\n",
725 | "test_out = dbn(test_x)\n",
726 | "# print(test_out)\n",
727 | "test_pred = torch.max(test_out, 1)[1]\n",
728 | "pre_val = test_pred.data.squeeze().numpy()\n",
729 | "y_val = test_y.data.squeeze().numpy()\n",
730 | "print('prediciton:',pre_val);print('true value:',y_val)\n",
731 | "accuracy = float((pre_val == y_val).astype(int).sum()) / float(test_y.size(0))\n",
732 | "print('test accuracy: %.2f' % accuracy,'duration:%.4f' % duration)\n"
733 | ]
734 | },
735 | {
736 | "cell_type": "code",
737 | "execution_count": 9,
738 | "metadata": {
739 | "collapsed": true
740 | },
741 | "outputs": [
742 | {
743 | "name": "stdout",
744 | "output_type": "stream",
745 | "text": [
746 | ""
747 | ]
748 | },
749 | {
750 | "name": "stdout",
751 | "output_type": "stream",
752 | "text": [
753 | " "
754 | ]
755 | },
756 | {
757 | "name": "stdout",
758 | "output_type": "stream",
759 | "text": [
760 | ""
761 | ]
762 | },
763 | {
764 | "name": "stdout",
765 | "output_type": "stream",
766 | "text": [
767 | " "
768 | ]
769 | },
770 | {
771 | "name": "stdout",
772 | "output_type": "stream",
773 | "text": [
774 | ""
775 | ]
776 | },
777 | {
778 | "name": "stdout",
779 | "output_type": "stream",
780 | "text": [
781 | " "
782 | ]
783 | },
784 | {
785 | "name": "stdout",
786 | "output_type": "stream",
787 | "text": [
788 | ""
789 | ]
790 | },
791 | {
792 | "name": "stdout",
793 | "output_type": "stream",
794 | "text": [
795 | " "
796 | ]
797 | },
798 | {
799 | "name": "stdout",
800 | "output_type": "stream",
801 | "text": [
802 | ""
803 | ]
804 | },
805 | {
806 | "name": "stdout",
807 | "output_type": "stream",
808 | "text": [
809 | "\n"
810 | ]
811 | },
812 | {
813 | "name": "stdout",
814 | "output_type": "stream",
815 | "text": [
816 | "--------------------"
817 | ]
818 | },
819 | {
820 | "name": "stdout",
821 | "output_type": "stream",
822 | "text": [
823 | "\n"
824 | ]
825 | },
826 | {
827 | "name": "stdout",
828 | "output_type": "stream",
829 | "text": [
830 | "Training the 1 st rbm layer"
831 | ]
832 | },
833 | {
834 | "name": "stdout",
835 | "output_type": "stream",
836 | "text": [
837 | "\n"
838 | ]
839 | },
840 | {
841 | "name": "stderr",
842 | "output_type": "stream",
843 | "text": [
844 | "C:\\Users\\Administrator\\AppData\\Local\\Programs\\Python\\Python36\\lib\\site-packages\\torch\\nn\\functional.py:1332: UserWarning: nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\n warnings.warn(\"nn.functional.sigmoid is deprecated. Use torch.sigmoid instead.\")\n"
845 | ]
846 | },
847 | {
848 | "ename": "KeyboardInterrupt",
849 | "evalue": "",
850 | "traceback": [
851 | "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
852 | "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
853 | "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0ma\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mb\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtrain_and_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtraindat\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrainlabel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvaliddat\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalidlabel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
854 | "\u001b[0;32m\u001b[0m in \u001b[0;36mtrain_and_test\u001b[0;34m(traind, trainl, testdat, testlabel, loader)\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdevice\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 7\u001b[0;31m \u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain_static\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_data\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtraind\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtrain_labels\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtrainl\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 8\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 9\u001b[0m \u001b[0moptimizer\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mSGD\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdbn\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparameters\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.001\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmomentum\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.9\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
855 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\DBN.py\u001b[0m in \u001b[0;36mtrain_static\u001b[0;34m(self, train_data, train_labels, num_epochs, batch_size)\u001b[0m\n\u001b[1;32m 88\u001b[0m \u001b[0m_dataloader\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mutils\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataLoader\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_dataset\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 89\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 90\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrbm_layers\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrains\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_dataloader\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mnum_epochs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 91\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_dataloader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 92\u001b[0m \u001b[0;31m# print(train_data.shape)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
856 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mtrains\u001b[0;34m(self, train_data, num_epochs, batch_size)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0;32mif\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0muse_gpu\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mbatch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcuda\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mbatch_err\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mepoch_err\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mbatch_err\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
857 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, input_data)\u001b[0m\n\u001b[1;32m 156\u001b[0m '''\n\u001b[1;32m 157\u001b[0m \u001b[0;31m# print('w:',self.weight);print('b:',self.b);print('c:',self.c)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 158\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcontrastive_divergence\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_data\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 159\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
858 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mcontrastive_divergence\u001b[0;34m(self, input_data, training)\u001b[0m\n\u001b[1;32m 114\u001b[0m \u001b[0mhidden_activations\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mpositive_hidden_act\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 115\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mk\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;31m#采样步数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 116\u001b[0;31m \u001b[0mvisible_p\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0m_\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_visible\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_activations\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 117\u001b[0m \u001b[0mhidden_probabilities\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mhidden_activations\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mto_hidden\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvisible_p\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 118\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
859 | "\u001b[0;32mC:\\Users\\Administrator\\Documents\\GitHub\\dbn_pytorch\\RBM.py\u001b[0m in \u001b[0;36mto_visible\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;31m# 计算隐含层激活,然后转换为概率\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 77\u001b[0m \u001b[0;31m# print('vinput:',X)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 78\u001b[0;31m \u001b[0mX_dash\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmatmul\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m \u001b[0;34m,\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtranspose\u001b[0m\u001b[0;34m(\u001b[0m \u001b[0;36m0\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 79\u001b[0m \u001b[0mX_dash\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_dash\u001b[0m \u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mb\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#W.T*x+b\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 80\u001b[0m \u001b[0;31m# print('mm:',X_dash)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
860 | "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
861 | ],
862 | "output_type": "error"
863 | }
864 | ],
865 | "source": [
866 | ""
867 | ]
868 | },
869 | {
870 | "cell_type": "code",
871 | "execution_count": 14,
872 | "metadata": {},
873 | "outputs": [],
874 | "source": []
875 | }
876 | ],
877 | "metadata": {
878 | "kernelspec": {
879 | "display_name": "Python 2",
880 | "language": "python",
881 | "name": "python2"
882 | },
883 | "language_info": {
884 | "codemirror_mode": {
885 | "name": "ipython",
886 | "version": 2
887 | },
888 | "file_extension": ".py",
889 | "mimetype": "text/x-python",
890 | "name": "python",
891 | "nbconvert_exporter": "python",
892 | "pygments_lexer": "ipython2",
893 | "version": "2.7.6"
894 | }
895 | },
896 | "nbformat": 4,
897 | "nbformat_minor": 0
898 | }
899 |
--------------------------------------------------------------------------------
/null:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Jxinyee/dbn_pytorch/8657c5b1048a1b73f8b2d400b000c09f2b38dfb5/null
--------------------------------------------------------------------------------
/process.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import codecs
4 | import matplotlib
5 | import matplotlib.pyplot as plt
6 | import random
7 | import codecs
8 | # this input_path is just example
9 | #inputpath = "C:/Users/Administrator/Desktop/Smote/describledata/kddcup.data_10_percent_corrected"
10 | Nor_list = ['normal']
11 | Dos_list = ['back', 'land', 'neptune' , 'pod', 'smurf', 'teardrop', 'mailbomb', 'processtable', 'udpstorm', 'apache2', 'worm']
12 | R2L_list =['guess_passwd', 'ftp_write', 'imap', 'phf', 'multihop', 'warezmaster', 'warezclient', 'xlock','xsnoop', 'snmpguess', 'snmpgetattack', 'httptunnel','spy','named','snmpguess']
13 | Probe_list = ['satan','ipsweep', 'nmap', 'portsweep', 'mscan', 'saint']
14 | U2R_list =['buffer_overflow', 'loadmodule', 'rootkit', 'perl', 'sqlattack', 'xterm', 'ps','httptunnel']
15 | continue_col_index = [0,4,5,7,8,9,10,12,15,16,17,18,19,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]
16 | decreate_col_index =[1,2,3,6,11,13,14,20,21]
17 |
18 |
19 | def gen_lable1(seg):
20 | if seg in Nor_list:
21 | return 0
22 | elif seg in Dos_list:
23 | return 1
24 | elif seg in R2L_list:
25 | return 2
26 | elif seg in Probe_list:
27 | return 3
28 | elif seg in U2R_list:
29 | return 4
30 |
31 |
32 | def gen_label2(seg):
33 | if seg in Nor_list:
34 | return 0
35 | else:
36 | return 1
37 |
38 |
39 | def replace(columnlist, data):
40 | # 离散值代替
41 | # listt = []
42 | for i in columnlist:
43 | listt = []
44 | for line, seg in enumerate(data[:, i]):
45 | if seg in listt:
46 | data[line][i] = listt.index(seg)
47 | else:
48 | listt.append(seg)
49 | data[line][i] = listt.index(seg)
50 |
51 |
52 | def processdata(file_path):
53 | tmp = None
54 | with codecs.open(file_path, 'r') as f:
55 | content = f.readlines()
56 | datas = []
57 | for line in content:
58 | line = line.strip()
59 | line = line[:-1].split(',')
60 | datas.append(line)
61 | datas = np.array(datas)
62 | replace(decreate_col_index, datas)
63 | new_datas = []
64 | for index, col in enumerate(datas):
65 | new_datas.append([float(k) for k in col[:-1]])
66 | if random.random() < 0.00005:
67 | print(col[-1])
68 | new_datas[index].append(gen_lable1(col[-1]))
69 | new_datas[index].append(gen_lable1(col[-1]))
70 | tmp = new_datas
71 | datas = np.array(tmp).astype('float32')
72 |
73 | for j in continue_col_index:
74 | meanVal = np.mean(datas[:, j])
75 | stdVal = np.std(datas[:, j])
76 | datas[:, j] = (datas[:, j] - meanVal) / stdVal
77 | return datas
78 | def gen_train_test(dat): #分割训练集和测试集
79 | index = int(len(dat)*0.8)
80 | np.random.shuffle(dat)
81 | traindat=torch.from_numpy(dat[:index,:-2]).float()
82 | trainlabel=torch.from_numpy(dat[:index,-2]).long()
83 | validdat = torch.from_numpy(dat[index:,:-2]).float()
84 | validlabel = torch.from_numpy(dat[index:,-2]).long()
85 | # print(dat[0])
86 | return traindat,trainlabel,validdat,validlabel
87 |
88 | def train_batch(traind,trainl,SIZE=200,SHUFFLE=True): #分批处理
89 | trainset=Data.TensorDataset(traind,trainl)
90 | trainloader=Data.DataLoader(
91 | dataset=trainset,
92 | batch_size=SIZE,
93 | shuffle=SHUFFLE)
94 | return trainloader
95 |
96 | def train_and_test(traind,trainl,testdat,testlabel,loader):
97 | print(type(traind),type(trainl),type(traind),type(trainl),type(loader))
98 | start_time = time.time()
99 | dbn=DBN(visible_units=len(traind[0]))
100 |
101 | dbn.train()
102 | dbn.train_static(train_data=traind,train_labels=trainl,num_epochs=0,batch_size=20)
103 |
104 | optimizer = torch.optim.SGD(dbn.parameters(), lr=0.001, momentum=0.9)
105 | loss_func = torch.nn.CrossEntropyLoss()
106 | train_loader = loader
107 | dbn.trainBP(train_loader)
108 |
109 | for epoch in range(0):
110 | for step,(x,y) in enumerate(train_loader):
111 | # print(x.data.numpy(),y.data.numpy())
112 |
113 | b_x=Variable(x)
114 | b_y=Variable(y)
115 |
116 | output=dbn(b_x)
117 | # print(output)
118 | # print(prediction);print(output);print(b_y)
119 |
120 | loss=loss_func(output,b_y)
121 | optimizer.zero_grad()
122 | loss.backward()
123 | optimizer.step()
124 |
125 | if step%10==0:
126 | print('Epoch: ', epoch, 'step:',step,'| train loss: %.4f' % loss.data.numpy())
127 | duration=time.time()-start_time
128 |
129 | dbn.eval()
130 | test_x = Variable(testdat);test_y = Variable(testlabel)
131 | test_out = dbn(test_x)
132 | # print(test_out)
133 | test_pred = torch.max(test_out, 1)[1]
134 | pre_val = test_pred.data.squeeze().numpy()
135 | y_val = test_y.data.squeeze().numpy()
136 | print('prediciton:',pre_val);print('true value:',y_val)
137 | accuracy = float((pre_val == y_val).astype(int).sum()) / float(test_y.size(0))
138 | print('test accuracy: %.2f' % accuracy,'duration:%.4f' % duration)
139 | return accuracy, duration
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
--------------------------------------------------------------------------------
/smote.py:
--------------------------------------------------------------------------------
1 | import pandas as pd
2 | import numpy as np
3 | import random
4 | from sklearn.neighbors import NearestNeighbors
5 | import math
6 | from random import randint
7 | import matplotlib.pyplot as plt
8 | from sklearn.decomposition import TruncatedSVD
9 | class Smote():
10 | def __init__(self,distance,range1,range2):
11 | self.synthetic_arr = []
12 | self.newindex = 0
13 | self.distance_measure = distance
14 | self.range1 =range1
15 | self.range2 = range2
16 |
17 | def Populate(self, N, i, indices, min_samples, k):
18 | """
19 | 此代码主要作用是生成增强数组
20 |
21 | Returns:返回增强后的数组
22 | """
23 |
24 | choice_list =[]
25 | def choice(data):
26 | p = []
27 | wc = {}
28 | for num in data:
29 | # print(num)
30 | wc[num] = wc.setdefault(num, 0) + 1
31 | for key in wc.keys():
32 | p.append(wc[key] / len(data))
33 | # print(p)
34 | keylist = np.array([key for key in wc.keys()])
35 | # print(wc[0])
36 | return keylist,p
37 | for index in self.range1:
38 | choice_list.append(choice(min_samples[:,index]))
39 |
40 | while N != 0:
41 | arr = np.zeros(min_samples[0])
42 | arr[-2] = min_samples[i][-2]
43 | arr[-1] = min_samples[i][-1]
44 | nn = randint(0, k - 2)
45 | # 统计离散型变量
46 | for index in self.range1:
47 | arr[index] = np.random.choice(choice_list[index][0],size=1,p=choice_list[index][1])
48 | #for attr in features2:
49 | for attr in self.range2:
50 | min_samples[i][attr] = float(min_samples[i][attr])
51 | min_samples[indices[nn]][attr] = float(min_samples[indices[nn]][attr])
52 | try:
53 | diff = float(min_samples[indices[nn]][attr]) - float(min_samples[i][attr])
54 | except:
55 | print('这是第%d列'%attr,min_samples[indices[nn]][attr],min_samples[i][attr])
56 | gap = random.uniform(0, 1)
57 |
58 | arr[attr] = float(min_samples[i][attr]) + gap * diff
59 | #print(arr)
60 | self.synthetic_arr.append(arr)
61 | self.newindex = self.newindex + 1
62 | N = N - 1
63 |
64 | def k_neighbors(self, euclid_distance, k):
65 | nearest_idx_npy = np.empty([euclid_distance.shape[0], euclid_distance.shape[0]], dtype=np.int64)
66 |
67 | for i in range(len(euclid_distance)):
68 | idx = np.argsort(euclid_distance[i])
69 | nearest_idx_npy[i] = idx
70 | idx = 0
71 |
72 | return nearest_idx_npy[:, 1:k]
73 |
74 | def find_k(self, X, k):
75 |
76 | """
77 | Finds k nearest neighbors using euclidian distance
78 |
79 | Returns: The k nearest neighbor
80 | """
81 |
82 | euclid_distance = np.empty([X.shape[0], X.shape[0]], dtype=np.float32)
83 |
84 | for i in range(len(X)):
85 | dist_arr = []
86 | for j in range(len(X)):
87 | dist_arr.append(math.sqrt(sum((X[j] - X[i]) ** 2)))
88 | dist_arr = np.asarray(dist_arr, dtype=np.float32)
89 | euclid_distance[i] = dist_arr
90 |
91 | return self.k_neighbors(euclid_distance, k)
92 |
93 | def generate_synthetic_points(self, min_samples, N, k):
94 |
95 | """
96 |
97 | Parameters
98 | ----------
99 | min_samples : 要增强的数据
100 | N :要额外生成的负样本的数目
101 | k : int. Number of nearest neighbours.
102 | Returns
103 | -------
104 | S : Synthetic samples. array,
105 | shape = [(N/100) * n_minority_samples, n_features].
106 | """
107 |
108 | if N < 1:
109 | raise ValueError("Value of N cannot be less than 100%")
110 |
111 | if self.distance_measure not in ('euclidian', 'ball_tree'):
112 | raise ValueError("Invalid Distance Measure.You can use only Euclidian or ball_tree")
113 |
114 | if k > min_samples.shape[0]:
115 | raise ValueError("Size of k cannot exceed the number of samples.")
116 |
117 | T = min_samples.shape[0]
118 |
119 | if self.distance_measure == 'euclidian':
120 | indices = self.find_k(min_samples, k)
121 |
122 | elif self.distance_measure == 'ball_tree':
123 | nb = NearestNeighbors(n_neighbors=k, algorithm='ball_tree').fit(min_samples)
124 | distance, indices = nb.kneighbors(min_samples)
125 | indices = indices[:, 1:]
126 |
127 | for i in range(indices.shape[0]):
128 | self.Populate(N, i, indices[i], min_samples, k)
129 |
130 | return np.asarray(self.synthetic_arr)
131 |
--------------------------------------------------------------------------------