├── README.md
├── wechat_test.py
└── wechat_utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # wechat_callback
2 | 可移步知乎专栏查看详情https://zhuanlan.zhihu.com/p/25670072
3 | ## Requirement
4 | itchat
5 | keras
6 | numpy
7 | scipy
8 | _thread
9 | matplotlib
10 | ## Functions and keywords
11 | ### Functions 功能
12 | Send training information to wechat every epoch(auto)
13 |
14 | 每个epoch自动发送训练信息
15 |
16 |
17 | Send figures to wechat every epoch(auto)
18 |
19 | 每个epoch自动发送图表
20 |
21 |
22 | Get figures manualy
23 |
24 | 主动获取图表
25 |
26 |
27 | Shut down/cancel computer
28 |
29 | 关机/取消关机
30 |
31 |
32 | Specify a stop epoch
33 |
34 | 指定训练停止epoch数
35 |
36 |
37 | Stop now manualy
38 |
39 | 立刻停止训练(当前epoch结束后)
40 |
41 | New:Get gpu status
42 | 新增:获取GPU状态
新增:查询进度
43 | ### Keywords/commands 关键词和命令
44 | stop_training_cmdlist=['Stop now',"That's enough",u'停止训练',u'放弃治疗']
45 |
46 | The keywords of stop training,if any of them is in the msg you sent,the command would be accepted
47 |
48 | 停止训练的关键词列表,发送的消息中包含任意一项都可触发命令
49 |
50 |
51 | shut_down_cmdlist=[u'关机','Shut down','Shut down the computer',u'别浪费电了',u'洗洗睡吧']
52 |
53 | The keywords of shutting down,similair to stop_training_cmdlist
54 |
55 | 关机关键词列表,和stop_training_cmdlist类似
56 |
57 |
58 | cancel_cmdlist=[u'取消','cancel','aaaa']
59 |
60 | The keywords of cancel shutting down,similair to stop_training_cmdlist
61 |
62 | 取消关机关键词列表,和stop_training_cmdlist类似
63 |
64 |
65 | get_fig_cmdlist=[u'获取图表','Show me the figure']
66 |
67 | The keywords of getting figure,similair to stop_training_cmdlist
68 |
69 | 获取图表关键词列表,和stop_training_cmdlist类似
70 |
71 |
72 | specify stop epoch:
73 |
74 | keywords:'Stop at + epoch'
75 |
76 | 指定训练停止轮数
77 |
78 | gpu_cmdlist=['GPU','gpu',u'显卡']
79 | type_list=['MEMORY', 'UTILIZATION', 'ECC', 'TEMPERATURE', 'POWER', 'CLOCK', 'COMPUTE', 'PIDS', 'PERFORMANCE', 'SUPPORTED_CLOCKS,PAGE_RETIREMENT', 'ACCOUNTING']
显卡关键词
80 | 以及可查询状态列表
prog_cmdlist=[u'进度','Progress']
查询进度,预告停止时间
Get progress,preview stop time
81 | ## Examples
82 | specify stop epoch
83 |
84 | 指定训练停止轮数
85 |
86 | Example:send:'Stop at:8' from your phone,and then training will be stopped after epoch8
87 | 例如:手机发送“Stop at:8”,训练将在epoch8完成后停止
88 | Stop training after current epoch finished
89 | 当前epoch完成后停止训练
90 | example:send:'Stop now' or send:'停止训练' from your phone,and then training will be stopped after current epoch
91 | 例如:手机发送“停止训练”或者“Stop now”,训练将会在当前epoch完成后被停止
92 | Shutting down the computer after specified sec,specify waiting seconds and saved model filename by {sec} and [name](without .h5)
93 | 在指定秒数后关机,用{sec}和[name]指定等待时间和保存文件名,文件名不包括.h5
94 | example:send:'Shut down now [test]{120}' from phone,the computer will be shut down after 120s,and save the model as test.h5
95 | or send:'Shut down now{120},don't save',then the model won't be saved.
96 | Cancel shutting down the computer
97 | example:send:'取消关机' or 'cancel' from phone
98 | Get figure of train infomation,specify metrics and level you want to show by[metrics]and{level},defualt are both 'all'
99 | example:send:'Show me the figure [loss]{batches}' from phone,you will recive a jpg image of losses in batches
100 | send:'Show me the figure',you will recive two jpg images of all metrics in batches and epochs
101 | 获取图表,通过[metrics]和{level}指定参数,如果没有指定则皆默认为’all'
102 | 例如,手机发送"获取图表[loss]{batches}",会收到一个jpg格式的loss随batches变化的图片
103 | 手机发送"获取图表",则会得到两张图片,分别是所有指标随batch和epoch的变化
获取gpu状态
发送'gpu[MEMORY]'或者'GPU[MEMORY TEMPERATURE]'或者'显卡[MEMORY]'
104 |
--------------------------------------------------------------------------------
/wechat_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Thu Mar 9 12:18:18 2017
4 |
5 | @author: Quantum Liu
6 | """
7 | import numpy as np
8 | from keras.models import Sequential
9 | from keras.layers import Dense, Activation
10 | import wechat_utils #will login automaticly
11 | #wechat_utils.sendmessage()isthe callback class
12 | #wechat_utils.sendmessage()是keras的回调类,fit时传入callbacklist
13 |
14 | nb_sample=64*10000
15 | batch_size=16
16 | dim=784
17 |
18 | model = Sequential()
19 | model.add(Dense(1024, input_dim=784))
20 | model.add(Activation('relu'))
21 | for i in range(9):
22 | model.add(Dense(2048))
23 | model.add(Activation('sigmoid'))
24 | model.add(Dense(1,activation='sigmoid'))
25 |
26 | x=np.random.rand(nb_sample,dim)
27 | y=np.random.randint(2,size=(nb_sample,1))
28 |
29 | train_x=x[:390*64]
30 | train_y=y[:390*64]
31 |
32 | val_x=x[-10*64:]
33 | val_y=y[-10*64:]
34 |
35 | model.compile(optimizer='RMSprop',loss='binary_crossentropy',metrics=['acc','hinge'])
36 | #==============================================================================
37 | # Train
38 | #==============================================================================
39 | model.fit(x=train_x,y=train_y,batch_size=batch_size,nb_epoch=60,validation_data=(val_x,val_y),callbacks=[wechat_utils.sendmessage(savelog=True,fexten='TEST')])
--------------------------------------------------------------------------------
/wechat_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | Created on Tue Mar 7 12:44:30 2017
4 |
5 | @author: Quantum Liu
6 | """
7 | from keras import __version__ as kv
8 | kv=int(kv[0])
9 | import platform
10 | pv=int(platform.python_version()[0])
11 | import numpy as np
12 | import scipy.io as sio
13 | import itchat
14 | from keras.callbacks import Callback
15 | import time
16 | import matplotlib
17 | matplotlib.use('Agg') #
18 | import matplotlib.pyplot as plt
19 | from math import ceil
20 | from itchat.content import TEXT
21 | if pv>2:
22 | import _thread as th
23 | else:
24 | import thread as th
25 | import os
26 | from os import system
27 | import re
28 | import traceback
29 | import platform
30 | from requests.exceptions import ConnectionError
31 | #==============================================================================
32 | #==============================================================================
33 | # A log in function call it at first
34 | #函数,需要首先调用
35 | #==============================================================================
36 | def login():
37 | if 'Windows' in platform.system():
38 | itchat.auto_login(enableCmdQR=1,hotReload=True)#
39 | else:
40 | itchat.auto_login(enableCmdQR=2,hotReload=True)#
41 | itchat.dump_login_status()#dump
42 | #==============================================================================
43 | #
44 | #==============================================================================
45 | def send_text(text):
46 | #send text msgs to 'filehelper'
47 | #给文件助手发送文本信息
48 | try:
49 | itchat.send_msg(msg=text,toUserName='filehelper')
50 | return
51 | except (ConnectionError,NotImplementedError,KeyError):
52 | traceback.print_exc()
53 | print('\nConection error,failed to send the message!\n')
54 | return
55 | else:
56 | return
57 | def send_img(filename):
58 | #send text imgs to 'filehelper'
59 | #给文件助手发送
60 | try:
61 | itchat.send_image(filename,toUserName='filehelper')
62 | return
63 | except (ConnectionError,NotImplementedError,KeyError):
64 | traceback.print_exc()
65 | print('\nConection error,failed to send the figure!\n')
66 | return
67 | else:
68 | return
69 | #==============================================================================
70 | #
71 | #==============================================================================
72 | class sendmessage(Callback):
73 | #A subclss of keras.callbacks.Callback class
74 | #keras.callbacks.Callback class的子类
75 | def __init__(self,savelog=True,fexten=''):
76 | self.fexten=(fexten if fexten else '')#the name of log and figure files
77 | self.savelog=bool(savelog)#save log or not
78 |
79 | def t_send(self,msg,toUserName='filehelper'):
80 | try:
81 | itchat.send_msg(msg=msg,toUserName=toUserName)
82 | return
83 | except (ConnectionError,NotImplementedError,KeyError):
84 | traceback.print_exc()
85 | print('\nConection error,failed to send the message!\n')
86 | return
87 | else:
88 | return
89 | def t_send_img(self,filename,toUserName='filehelper'):
90 | try:
91 | itchat.send_image(filename,toUserName=toUserName)
92 | return
93 | except (ConnectionError,NotImplementedError,KeyError):
94 | traceback.print_exc()
95 | print('\nConection error,failed to send the figure!\n')
96 | return
97 | else:
98 | return
99 |
100 |
101 | def shutdown(self,sec,save=True,filepath='temp.h5'):
102 | #Function used to shut down the computer
103 | #sec:waitting time to shut down the computer,sencond
104 | #save:wether saving the model
105 | #filepath:the filepath for saving the model
106 | #关机函数
107 | #sec:关机等待秒数
108 | #save:是否保存模型
109 | #filepath:保存模型的文件名
110 | if save:
111 | self.model.save(filepath, overwrite=True)
112 | self.t_send('Command accepted,the model has already been saved,shutting down the computer....', toUserName='filehelper')
113 | else:
114 | self.t_send('Command accepted,shutting down the computer....', toUserName='filehelper')
115 | if 'Windows' in platform.system():
116 | th.start_new_thread(system, ('shutdown -s -t %d' %sec,))
117 | else:
118 | m=(int(sec/60) if int(sec/60) else 1)
119 | th.start_new_thread(system, ('shutdown -h -t %d' %m,))
120 |
121 | #==============================================================================
122 | #
123 | #==============================================================================
124 | def cancel(self):
125 | #Cancel function to cancel shutting down the computer
126 | #取消关机函数
127 | self.t_send('Command accepted,cancel shutting down the computer....', toUserName='filehelper')
128 | if 'Windows' in platform.system():
129 | th.start_new_thread(system, ('shutdown -a',))
130 | else:
131 | th.start_new_thread(system, ('shutdown -c',))
132 | #==============================================================================
133 | #
134 | #==============================================================================
135 | def GetMiddleStr(self,content,startStr,endStr):
136 | #get the string between two specified strings
137 | #从指定的字符串之间截取字符串
138 | try:
139 | startIndex = content.index(startStr)
140 | if startIndex>=0:
141 | startIndex += len(startStr)
142 | endIndex = content.index(endStr)
143 | return content[startIndex:endIndex]
144 | except:
145 | return ''
146 | #==============================================================================
147 | #
148 | #==============================================================================
149 | def validateTitle(self,title):
150 | #transform a string to a validate filename
151 | #将字符串转化为合法文件名
152 | rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/\:*?"<>|'
153 | new_title = re.sub(rstr, "", title).replace(' ','')
154 | return new_title
155 | #==============================================================================
156 | #
157 | #==============================================================================
158 | def prog(self):#Show progress
159 | nb_batches_total=(self.params['nb_epoch'] if not kv-1 else self.params['epochs'])*self.params['nb_sample']/self.params['batch_size']
160 | nb_batches_epoch=self.params['nb_sample']/self.params['batch_size']
161 | prog_total=(self.t_batches/nb_batches_total if nb_batches_total else 0)+0.01
162 | prog_epoch=(self.c_batches/nb_batches_epoch if nb_batches_epoch else 0)+0.01
163 | if self.t_epochs:
164 | now=time.time()
165 | t_mean=float(sum(self.t_epochs)) / len(self.t_epochs)
166 | eta_t=(now-self.train_start)*((1/prog_total)-1)
167 | eta_e=t_mean*(1-prog_epoch)
168 | t_end=time.asctime(time.localtime(now+eta_t))
169 | e_end=time.asctime(time.localtime(now+eta_e))
170 | m='\nTotal:\nProg:'+str(prog_total*100.)[:5]+'%\nEpoch:'+str(self.epoch[-1])+'/'+str(self.stopped_epoch)+'\nETA:'+str(eta_t)[:8]+'sec\nTrain will be finished at '+t_end+'\nCurrent epoch:\nPROG:'+str(prog_epoch*100.)[:5]+'%\nETA:'+str(eta_e)[:8]+'sec\nCurrent epoch will be finished at '+e_end
171 | self.t_send(msg=m)
172 | print(m)
173 | else:
174 | now=time.time()
175 | eta_t=(now-self.train_start)*((1/prog_total)-1)
176 | eta_e=(now-self.train_start)*((1/prog_epoch)-1)
177 | t_end=time.asctime(time.localtime(now+eta_t))
178 | e_end=time.asctime(time.localtime(now+eta_e))
179 | m='\nTotal:\nProg:'+str(prog_total*100.)[:5]+'%\nEpoch:'+str(len(self.epoch))+'/'+str(self.stopped_epoch)+'\nETA:'+str(eta_t)[:8]+'sec\nTrain will be finished at '+t_end+'\nCurrent epoch:\nPROG:'+str(prog_epoch*100.)[:5]+'%\nETA:'+str(eta_e)[:8]+'sec\nCurrent epoch will be finished at '+e_end
180 | self.t_send(msg=m)
181 | print(m)
182 |
183 | #==============================================================================
184 | #
185 | #==============================================================================
186 | def get_fig(self,level='all',metrics=['all']):
187 | #Get figure of train infomation
188 | #level:show the information of which level
189 | #metrics:metrics want to show,only show available ones
190 | #获取训练状态图表
191 | #level:显示batch级别函数epoch级别
192 | #metrics:希望获得的指标,只显示存在的指标,若指定了不存在的指标将不会被显示
193 | color_list='rgbyck'*10
194 | def batches(color_list='rgbyck'*10,metrics=['all']):
195 | if 'all' in metrics:
196 | m_available=list(self.logs_batches.keys())
197 | else:
198 | m_available=([val for val in list(self.logs_batches.keys()) if val in metrics]if[val for val in list(self.logs_batches.keys()) if val in metrics]else list(self.logs_batches.keys()))
199 | nb_rows_batches=int(ceil(len(m_available)*1.0/2))
200 | fig_batches=plt.figure('all_subs_batches')
201 | for i,k in enumerate(m_available):
202 | p=plt.subplot(nb_rows_batches,2,i+1)
203 | data=self.logs_batches[k]
204 | p.plot(range(len(data)),data,color_list[i]+'-',label=k)
205 | p.set_title(k+' in batches',fontsize=14)
206 | p.set_xlabel('batch',fontsize=10)
207 | p.set_ylabel(k,fontsize=10)
208 | #p.legend()
209 | filename=(self.fexten if self.fexten else self.validateTitle(self.localtime))+'_batches.jpg'
210 | plt.tight_layout()
211 | plt.savefig(filename)
212 | plt.close('all')
213 | #==============================================================================
214 | # try:
215 | # itchat.send_image(filename,toUserName='filehelper')
216 | # except (socket.gaierror,ConnectionError,NotImplementedError,TypeError,KeyError):
217 | # traceback.print_exc()
218 | # print('\nConection error!\n')
219 | # return
220 | #==============================================================================
221 | self.t_send_img(filename,toUserName='filehelper')
222 | time.sleep(.5)
223 | self.t_send('Sending batches figure',toUserName='filehelper')
224 | return
225 | #==============================================================================
226 | #
227 | #==============================================================================
228 | def epochs(color_list='rgbyck'*10,metrics=['all']):
229 | if 'all' in metrics:
230 | m_available=list(self.logs_epochs.keys())
231 | else:
232 | m_available=([val for val in list(self.logs_epochs.keys()) if val in metrics]if[val for val in list(self.logs_epochs.keys()) if val in metrics]else list(self.logs_epochs.keys()))
233 | nb_rows_epochs=int(ceil(len(m_available)*1.0/2))
234 | fig_epochs=plt.figure('all_subs_epochs')
235 | for i,k in enumerate(m_available):
236 | p=plt.subplot(nb_rows_epochs,2,i+1)
237 | data=self.logs_epochs[k]
238 | p.plot(range(len(data)),data,color_list[i]+'-',label=k)
239 | p.set_title(k+' in epochs',fontsize=14)
240 | p.set_xlabel('epoch',fontsize=10)
241 | p.set_ylabel(k,fontsize=10)
242 | filename=(self.fexten if self.fexten else self.validateTitle(self.localtime))+'_epochs.jpg'
243 | plt.tight_layout()
244 | plt.savefig(filename)
245 | plt.close('all')
246 | #==============================================================================
247 | # try:
248 | # itchat.send_image(filename,toUserName='filehelper')
249 | # except (socket.gaierror,ConnectionError,NotImplementedError,TypeError,KeyError):
250 | # traceback.print_exc()
251 | # print('\nConection error!\n')
252 | # return
253 | #==============================================================================
254 | self.t_send_img(filename,toUserName='filehelper')
255 | time.sleep(.5)
256 | self.t_send('Sending epochs figure',toUserName='filehelper')
257 | return
258 | #==============================================================================
259 | #
260 | #==============================================================================
261 | try:
262 | if not self.epoch and (level in ['all','epochs']):
263 | level='batches'
264 | if level=='all':
265 | batches(metrics=metrics)
266 | epochs(metrics=metrics)
267 | th.exit()
268 | return
269 | elif level=='epochs':
270 | epochs(metrics=metrics)
271 | th.exit()
272 | return
273 | elif level=='batches':
274 | batches(metrics=metrics)
275 | th.exit()
276 | return
277 | else:
278 | batches(metrics=metrics)
279 | epochs(metrics=metrics)
280 | th.exit()
281 | return
282 | except Exception:
283 | traceback.print_exc()
284 | self.t_send('Failed to send figure',toUserName='filehelper')
285 | th.exit()
286 | return
287 | #==============================================================================
288 | #
289 | #==============================================================================
290 | def gpu_status(self,av_type_list):
291 | for t in av_type_list:
292 | cmd='nvidia-smi -q --display='+t
293 | #print('\nCMD:',cmd,'\n')
294 | r=os.popen(cmd)
295 | info=r.readlines()
296 | r.close()
297 | content = " ".join(info)
298 | #print('\ncontent:',content,'\n')
299 | index=content.find('Attached GPUs')
300 | s=content[index:].replace(' ','').rstrip('\n')
301 | self.t_send(s, toUserName='filehelper')
302 | time.sleep(.5)
303 | #th.exit()
304 | #==============================================================================
305 | #
306 | #==============================================================================
307 | def on_train_begin(self, logs={}):
308 | self.epoch=[]
309 | self.t_epochs=[]
310 | self.t_batches=0
311 | self.logs_batches={}
312 | self.logs_epochs={}
313 | self.train_start=time.time()
314 | self.localtime = time.asctime( time.localtime(self.train_start) )
315 | self.mesg = 'Train started at: '+self.localtime
316 | self.t_send(self.mesg, toUserName='filehelper')
317 | self.stopped_epoch = (self.params['epochs'] if kv-1 else self.params['nb_epoch'])
318 | @itchat.msg_register(TEXT)
319 | #==============================================================================
320 | # registe methods to reply msgs,similar to main()
321 | # 注册消息响应方法,相当于主函数
322 | #==============================================================================
323 | def manualstop(msg):
324 | text=msg['Text']
325 | stop_training_cmdlist=['Stop now',"That's enough",u'停止训练',u'放弃治疗']
326 | #The keywords of stop training,if any of them is in the msg you sent,the command would be accepted
327 | #停止训练的关键词列表,发送的消息中包含任意一项都可触发命令
328 | shut_down_cmdlist=[u'关机','Shut down','Shut down the computer',u'别浪费电了',u'洗洗睡吧']
329 | #The keywords of shutting down,similair to stop_training_cmdlist
330 | #关机关键词列表,和stop_training_cmdlist类似
331 | cancel_cmdlist=[u'取消','cancel','aaaa']
332 | #The keywords of cancel shutting down,similair to stop_training_cmdlist
333 | #取消关机关键词列表,和stop_training_cmdlist类似
334 | get_fig_cmdlist=[u'获取图表','Show me the figure']
335 | #The keywords of getting figure,similair to stop_training_cmdlist
336 | #获取图表关键词列表,和stop_training_cmdlist类似
337 | gpu_cmdlist=['GPU','gpu',u'显卡']
338 | type_list=['MEMORY', 'UTILIZATION', 'ECC', 'TEMPERATURE', 'POWER', 'CLOCK', 'COMPUTE', 'PIDS', 'PERFORMANCE', 'SUPPORTED_CLOCKS,PAGE_RETIREMENT', 'ACCOUNTING']
339 | prog_cmdlist=[u'进度','Progress']
340 | if msg['ToUserName']=='filehelper':
341 | print('\n',text,'\n')
342 | if 'Stop at' in text:
343 | # Specify stop epoch,training will be stop after that epoch
344 | #指定停止轮数,训练在指定epoch完成后会停止
345 | #Example:send:'Stop at:8' from your phone,and then training will be stopped after epoch8
346 | #例如:手机发送“Stop at:8”,训练将在epoch8完成后停止
347 | self.stopped_epoch = int(re.findall(r"\d+\.?\d*",text)[0])
348 | if kv-1:
349 | self.params['epochs']=self.stopped_epoch
350 | else:
351 | self.params['nb_epoch']=self.stopped_epoch
352 | self.t_send('Command accepted,training will be stopped at epoch'+str(self.stopped_epoch), toUserName='filehelper')
353 | #==============================================================================
354 | #
355 | #==============================================================================
356 | if any((k in text) for k in stop_training_cmdlist) :
357 | #Stop training after current epoch finished
358 | #当前epoch完成后停止训练
359 | #example:send:'Stop now' or send:'停止训练' from your phone,and then training will be stopped after current epoch
360 | #例如:手机发送“停止训练”或者“Stop now”,训练将会在当前epoch完成后被停止
361 | self.model.stop_training = True
362 | self.t_send('Command accepted,stop training now at epoch'+str(self.epoch[-1]+1), toUserName='filehelper')
363 | #==============================================================================
364 | #
365 | #==============================================================================
366 | if any((k in text) for k in shut_down_cmdlist):
367 | #Shutting down the computer after specified sec,specify waiting seconds and saved model filename by {sec} and [name](without .h5)
368 | #在指定秒数后关机,用{sec}和[name]指定等待时间和保存文件名,文件名不包括.h5
369 | #example:send:'Shut down now [test]{120}' from phone,the computer will be shut down after 120s,and save the model as test.h5
370 | #or send:'Shut down now{120},don't save',then the model won't be saved.
371 | if any((k in text) for k in [u'不保存模型',"don't save"]):
372 | save=False
373 | else:
374 | save=True
375 | filepath=(self.GetMiddleStr(text,'[',']')+'.h5' if self.GetMiddleStr(text,'[',']') else (self.fexten if self.fexten else self.validateTitle(self.localtime))+'.h5')
376 | print('\n',filepath,'\n')
377 | sec=int((self.GetMiddleStr(text,'{','}') if self.GetMiddleStr(text,'{','}')>'30' else 120))
378 | self.shutdown(sec,save=save,filepath=filepath)
379 | #==============================================================================
380 | #
381 | #==============================================================================
382 | if any((k in text) for k in cancel_cmdlist):
383 | #Cancel shutting down the computer
384 | self.cancel()
385 | #==============================================================================
386 | #
387 | #==============================================================================
388 | if any((k in text) for k in get_fig_cmdlist):
389 | #Get figure of train infomation,specify metrics and level you want to show by[metrics]and{level},defualt are both 'all'
390 | #example:send:'Show me the figure [loss]{batches}' from phone,you will recive a jpg image of losses in batches
391 | #send:'Show me the figure',you will recive two jpg images of all metrics in batches and epochs
392 | #获取图表,通过[metrics]和{level}指定参数,如果没有指定则皆默认为’all'
393 | #例如,手机发送"获取图表[loss]{batches}",会收到一个jpg格式的loss随batches变化的图片
394 | #手机发送"获取图表",则会得到两张图片,分别是所有指标随batch和epoch的变化
395 | metrics=(self.GetMiddleStr(text,'[',']').split() if self.GetMiddleStr(text,'[',']').split() else ['all'])
396 | level=(self.GetMiddleStr(text,'{','}') if self.GetMiddleStr(text,'{','}') else 'all' )
397 | if level in ['all','epochs','batches']:
398 | th.start_new_thread(self.get_fig,(level,metrics))
399 | else:
400 | print("\nGot no level,using default 'all'\n")
401 | self.t_send("Got no level,using default 'all'", toUserName='filehelper')
402 | th.start_new_thread(self.get_fig,())
403 | if any((k in text) for k in gpu_cmdlist):
404 | sp_type_lsit=(self.GetMiddleStr(text,'[',']').split() if self.GetMiddleStr(text,'[',']').split() else ['MEMORY'])
405 | av_type_list=[val for val in sp_type_lsit if val in type_list]
406 | self.gpu_status(av_type_list,)
407 | if any((k in text) for k in prog_cmdlist):
408 | try:
409 | self.prog()
410 | except:
411 | traceback.print_exc()
412 | th.start_new_thread(itchat.run, ())
413 | #==============================================================================
414 | #
415 | #==============================================================================
416 | def on_batch_end(self, batch, logs=None):
417 | logs = logs or {}
418 | for k in self.params['metrics']:
419 | if k in logs:
420 | self.logs_batches.setdefault(k, []).append(logs[k])
421 | self.c_batches+=1
422 | self.t_batches+=1
423 | #==============================================================================
424 | #
425 | #==============================================================================
426 | def on_epoch_begin(self, epoch, logs=None):
427 | self.t_s=time.time()
428 | self.epoch.append(epoch)
429 | self.c_batches=0
430 | self.t_send('Epoch'+str(epoch+1)+'/'+str(self.stopped_epoch)+' started', toUserName='filehelper')
431 | self.mesg = ('Epoch:'+str(epoch+1)+' ')
432 | #==============================================================================
433 | #
434 | #==============================================================================
435 | def on_epoch_end(self, epoch, logs=None):
436 | for k in self.params['metrics']:
437 | if k in logs:
438 | self.mesg+=(k+': '+str(logs[k])[:5]+' ')
439 | self.logs_epochs.setdefault(k, []).append(logs[k])
440 | #==============================================================================
441 | # except:
442 | # itchat.auto_login(hotReload=True,enableCmdQR=True)
443 | # itchat.dump_login_status()
444 | # self.t_send(self.mesg, toUserName='filehelper')
445 | #==============================================================================
446 | if epoch+1>=self.stopped_epoch:
447 | self.model.stop_training = True
448 | logs = logs or {}
449 | self.epoch.append(epoch)
450 | self.t_epochs.append(time.time()-self.t_s)
451 | if self.savelog:
452 | sio.savemat((self.fexten if self.fexten else self.validateTitle(self.localtime))+'_logs_batches'+'.mat',{'log':np.array(self.logs_batches)})
453 | sio.savemat((self.fexten if self.fexten else self.validateTitle(self.localtime))+'_logs_batches'+'.mat',{'log':np.array(self.logs_epochs)})
454 | th.start_new_thread(self.get_fig,())
455 | #==============================================================================
456 | # try:
457 | # itchat.send(self.mesg, toUserName='filehelper')
458 | # except:
459 | # traceback.print_exc()
460 | # return
461 | #==============================================================================
462 | self.t_send(self.mesg, toUserName='filehelper')
463 | return
464 | #==============================================================================
465 | #
466 | #==============================================================================
467 | def on_train_end(self, logs=None):
468 | self.t_send('Train stopped at epoch'+str(self.epoch[-1]+1), toUserName='filehelper')
469 |
470 |
--------------------------------------------------------------------------------