├── .gitignore
├── .idea
├── libraries
│ └── R_User_Library.xml
├── misc.xml
├── modules.xml
├── stack_autoencoder.iml
├── vcs.xml
└── workspace.xml
├── SAE.py
├── __pycache__
└── SAE.cpython-36.pyc
├── result
└── trian.py
/.gitignore:
--------------------------------------------------------------------------------
1 | /data/
--------------------------------------------------------------------------------
/.idea/libraries/R_User_Library.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/stack_autoencoder.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 | 1522118011126
159 |
160 |
161 | 1522118011126
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
--------------------------------------------------------------------------------
/SAE.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | from torch import nn
3 | import torch.nn.functional as F
4 |
5 | class AutoEncoder(nn.Module):
6 |
7 | def __init__(self, inputDim, hiddenDim):
8 | super().__init__()
9 | self.inputDim = inputDim
10 | self.hiddenDim = hiddenDim
11 | self.encoder = nn.Linear(inputDim, hiddenDim, bias=True)
12 | self.decoder = nn.Linear(hiddenDim, inputDim, bias=True)
13 | self.act = F.sigmoid
14 |
15 | def forward(self, x, rep=False):
16 |
17 | hidden = self.encoder(x)
18 | hidden = self.act(hidden)
19 | if rep == False:
20 | out = self.decoder(hidden)
21 | #out = self.act(out)
22 | return out
23 | else:
24 | return hidden
25 |
26 |
27 |
28 | class SAE(nn.Module):
29 |
30 | def __init__(self, encoderList):
31 |
32 | super().__init__()
33 |
34 | self.encoderList = encoderList
35 | self.en1 = encoderList[0]
36 | self.en2 = encoderList[1]
37 | #self.en3 = encoderList[2]
38 |
39 | self.fc = nn.Linear(64, 10, bias=True)
40 |
41 | def forward(self, x):
42 |
43 | out = x
44 | out = self.en1(out, rep=True)
45 | out = self.en2(out, rep=True)
46 | #out = self.en3(out, rep=True)
47 | out = self.fc(out)
48 | out = F.log_softmax(out)
49 |
50 | return out
51 |
52 | class MLP(nn.Module):
53 |
54 | def __init__(self):
55 |
56 | super().__init__()
57 |
58 | self.fc1 = nn.Linear(784, 392, bias=True)
59 | self.fc2 = nn.Linear(392, 196, bias=True)
60 | self.fc3 = nn.Linear(196, 98, bias=True)
61 | self.classify = nn.Linear(98, 10, bias=True)
62 | self.act = F.sigmoid
63 |
64 | def forward(self, x):
65 |
66 | out = self.act(self.fc1(x))
67 | out = self.act(self.fc2(out))
68 | #out = self.act(self.fc3(out))
69 | out = self.classify(out)
70 | out = F.log_softmax(out)
71 |
72 | return out
73 |
74 |
75 |
76 |
77 |
78 |
--------------------------------------------------------------------------------
/__pycache__/SAE.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhangxu0307/stack-autoencoder/4ea33f37ebed21a227ac551a999ded2796dd8f2c/__pycache__/SAE.cpython-36.pyc
--------------------------------------------------------------------------------
/result:
--------------------------------------------------------------------------------
1 | ANN 4-layers sigmoid 15epoch 0.957
2 | SAE 4-layers sigmoid 15epoch without pre-trained AE 0.958
3 | SAE 4-layers sigmoid 15epoch pre-trained AE 10epoch 0.951
4 | SAE 4-layers sigmoid 15epoch pre-trained AE 15epoch 0.955
5 | SAE 4-layers sigmoid 15epoch pre-trained AE 20epoch 0.957
6 |
7 | ANN 3-layers sigmoid 15epoch 0.962
--------------------------------------------------------------------------------
/trian.py:
--------------------------------------------------------------------------------
1 | import torch as th
2 | import torchvision
3 | from torch.autograd import Variable
4 | from torch import nn
5 | from torch import optim
6 | from torchvision import datasets
7 | import torchvision.transforms as transforms
8 | from SAE import *
9 |
10 |
11 | def loadMNIST(batchSize):
12 |
13 | root = "./data/"
14 |
15 | trans= transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
16 |
17 | train_set = datasets.MNIST(root=root, train=True, transform=trans, download=True)
18 | test_set = datasets.MNIST(root=root, train=False, transform=trans)
19 |
20 | train_loader = th.utils.data.DataLoader(dataset=train_set, batch_size=batchSize, shuffle=True)
21 | test_loader = th.utils.data.DataLoader(dataset=test_set, batch_size=batchSize, shuffle=False)
22 |
23 | print ('==>>> total trainning batch number: {}'.format(len(train_loader)))
24 | print ('==>>> total testing batch number: {}'.format(len(test_loader)))
25 |
26 | return train_loader, test_loader
27 |
28 |
29 | def trainAE(encoderList, trainLayer, batchSize, epoch, useCuda = False):
30 |
31 | if useCuda:
32 | for i in range(len(encoderList)):
33 | encoderList[i].cuda()
34 |
35 | optimizer = optim.SGD(encoderList[trainLayer].parameters(), lr=0.1)
36 | ceriation = nn.L1Loss()
37 | trainLoader, testLoader = loadMNIST(batchSize=batchSize)
38 |
39 | for i in range(epoch):
40 |
41 | sum_loss = 0
42 |
43 | if trainLayer != 0: # 单独处理第0层,因为第一个编码器之前没有前驱的编码器了
44 | for i in range(trainLayer): # 冻结要训练前面的所有参数
45 | for param in encoderList[i].parameters():
46 | param.requires_grad = False
47 |
48 | for batch_idx, (x, target) in enumerate(trainLoader):
49 | optimizer.zero_grad()
50 | if useCuda:
51 | x, target = x.cuda(), target.cuda()
52 | x, target = Variable(x), Variable(target)
53 | x = x.view(-1, 784)
54 | # 产生需要训练层的输入数据
55 | out = x
56 | if trainLayer != 0:
57 | for i in range(trainLayer):
58 | out = encoderList[i](out, rep=True)
59 |
60 | # 训练指定的自编码器
61 | pred = encoderList[trainLayer](out, rep=False)
62 |
63 | loss = ceriation(pred, out)
64 | sum_loss += loss.data[0]
65 | loss.backward()
66 | optimizer.step()
67 |
68 | if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(trainLoader):
69 | print('==>>> train layer:{}, epoch: {}, batch index: {}, train loss: {:.6f}'
70 | .format(trainLayer, i, batch_idx + 1, sum_loss/batch_idx))
71 |
72 |
73 | def trainClassifier(model, batchSize, epoch, useCuda = False):
74 |
75 | if useCuda:
76 | model = model.cuda()
77 |
78 | # 解锁参数
79 | for param in model.parameters():
80 | param.requires_grad = True
81 |
82 | optimizer = optim.SGD(model.parameters(), lr=0.1)
83 | ceriation = nn.NLLLoss()
84 | trainLoader, testLoader = loadMNIST(batchSize=batchSize)
85 |
86 | for i in range(epoch):
87 |
88 | # trainning
89 | sum_loss = 0
90 |
91 | for batch_idx, (x, target) in enumerate(trainLoader):
92 | optimizer.zero_grad()
93 | if useCuda:
94 | x, target = x.cuda(), target.cuda()
95 | x, target = Variable(x), Variable(target)
96 | x = x.view(-1, 784)
97 |
98 | out = model(x)
99 |
100 | loss = ceriation(out, target)
101 | sum_loss += loss.data[0]
102 | loss.backward()
103 | optimizer.step()
104 |
105 | if (batch_idx + 1) % 10 == 0 or (batch_idx + 1) == len(trainLoader):
106 | print('==>>> epoch: {}, batch index: {}, train loss: {:.6f}'.format( i, batch_idx + 1, sum_loss/batch_idx))
107 |
108 | # testing
109 | correct_cnt, sum_loss = 0, 0
110 | total_cnt = 0
111 | for batch_idx, (x, target) in enumerate(testLoader):
112 |
113 | x, target = Variable(x, volatile=True), Variable(target, volatile=True)
114 | if useCuda:
115 | x, target = x.cuda(), target.cuda()
116 | x = x.view(-1, 784)
117 |
118 | out = model(x)
119 | loss = ceriation(out, target)
120 | _, pred_label = th.max(out.data, 1)
121 | total_cnt += x.data.size()[0]
122 | correct_cnt += (pred_label == target.data).sum()
123 |
124 | # smooth average
125 | if (batch_idx + 1) % 100 == 0 or (batch_idx + 1) == len(testLoader):
126 | print('==>>> epoch: {}, batch index: {}, test loss: {:.6f}, acc: {:.3f}'.format(
127 | i, batch_idx + 1, sum_loss/batch_idx, correct_cnt * 1.0 / total_cnt))
128 |
129 | if __name__ == '__main__':
130 |
131 | batchSize = 128
132 | AEepoch = 20
133 | epoch = 10
134 |
135 | encoder1 = AutoEncoder(784, 256)
136 | encoder2 = AutoEncoder(256, 64)
137 | #encoder3 = AutoEncoder(196, 98)
138 |
139 | encoderList = [encoder1, encoder2,]
140 |
141 | trainAE(encoderList, 0, batchSize, AEepoch, useCuda=True)
142 | trainAE(encoderList, 1, batchSize, AEepoch, useCuda=True)
143 | #trainAE(encoderList, 2, batchSize, AEepoch, useCuda=True)
144 |
145 | model = SAE(encoderList)
146 | #model = MLP()
147 | trainClassifier(model, batchSize, epoch, useCuda=True)
148 |
--------------------------------------------------------------------------------