├── .gitignore
├── .idea
├── dataSources.local.xml
├── dataSources.xml
├── dataSources
│ ├── 0e2e93c2-ec73-47f1-a3f3-868befde3f72.xml
│ └── eb023117-f844-4ce6-9dbd-5a03c863f17a.xml
├── encodings.xml
├── inspectionProfiles
│ └── Project_Default.xml
├── misc.xml
├── modules.xml
├── other.xml
├── transformer-master.iml
├── vcs.xml
└── workspace.xml
├── LICENSE
├── README.md
├── __pycache__
├── beam_search.cpython-36.pyc
├── data_load.cpython-36.pyc
├── hparams.cpython-36.pyc
├── model.cpython-36.pyc
├── modules.cpython-36.pyc
└── utils.cpython-36.pyc
├── beam_search.py
├── data_load.py
├── fig
├── structure.jpg
├── transformer-loss.png
└── transformer-pointer gererator-loss.png
├── hparams.py
├── model.py
├── modules.py
├── pred.py
├── requirements.txt
├── train.py
├── utils.py
└── vocab
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/.idea/dataSources.local.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 | #@
7 | `
8 |
9 |
10 | master_key
11 | juexiaotest
12 |
13 |
14 |
15 | #@
16 | `
17 |
18 |
19 | master_key
20 | juexiaotimered
21 |
22 |
23 |
--------------------------------------------------------------------------------
/.idea/dataSources.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | mysql.8
6 | true
7 | com.mysql.cj.jdbc.Driver
8 | jdbc:mysql://47.102.145.57:3306/jx
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 | mysql.8
20 | true
21 | com.mysql.cj.jdbc.Driver
22 | jdbc:mysql://juexiaotime.rwlb.rds.aliyuncs.com:3306/jx
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
--------------------------------------------------------------------------------
/.idea/dataSources/0e2e93c2-ec73-47f1-a3f3-868befde3f72.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 5.6.16
6 | InnoDB
7 | InnoDB
8 | lower/lower
9 |
10 |
11 | utf8_general_ci
12 |
13 |
14 | 1
15 | utf8_general_ci
16 |
17 |
18 | utf8_general_ci
19 |
20 |
21 | utf8_general_ci
22 |
23 |
24 | armscii8
25 | 0
26 |
27 |
28 | armscii8
29 | 1
30 |
31 |
32 | ascii
33 | 0
34 |
35 |
36 | ascii
37 | 1
38 |
39 |
40 | big5
41 | 0
42 |
43 |
44 | big5
45 | 1
46 |
47 |
48 | binary
49 | 1
50 |
51 |
52 | cp1250
53 | 0
54 |
55 |
56 | cp1250
57 | 0
58 |
59 |
60 | cp1250
61 | 0
62 |
63 |
64 | cp1250
65 | 1
66 |
67 |
68 | cp1250
69 | 0
70 |
71 |
72 | cp1251
73 | 0
74 |
75 |
76 | cp1251
77 | 0
78 |
79 |
80 | cp1251
81 | 1
82 |
83 |
84 | cp1251
85 | 0
86 |
87 |
88 | cp1251
89 | 0
90 |
91 |
92 | cp1256
93 | 0
94 |
95 |
96 | cp1256
97 | 1
98 |
99 |
100 | cp1257
101 | 0
102 |
103 |
104 | cp1257
105 | 1
106 |
107 |
108 | cp1257
109 | 0
110 |
111 |
112 | cp850
113 | 0
114 |
115 |
116 | cp850
117 | 1
118 |
119 |
120 | cp852
121 | 0
122 |
123 |
124 | cp852
125 | 1
126 |
127 |
128 | cp866
129 | 0
130 |
131 |
132 | cp866
133 | 1
134 |
135 |
136 | cp932
137 | 0
138 |
139 |
140 | cp932
141 | 1
142 |
143 |
144 | dec8
145 | 0
146 |
147 |
148 | dec8
149 | 1
150 |
151 |
152 | eucjpms
153 | 0
154 |
155 |
156 | eucjpms
157 | 1
158 |
159 |
160 | euckr
161 | 0
162 |
163 |
164 | euckr
165 | 1
166 |
167 |
168 | gb2312
169 | 0
170 |
171 |
172 | gb2312
173 | 1
174 |
175 |
176 | gbk
177 | 0
178 |
179 |
180 | gbk
181 | 1
182 |
183 |
184 | geostd8
185 | 0
186 |
187 |
188 | geostd8
189 | 1
190 |
191 |
192 | greek
193 | 0
194 |
195 |
196 | greek
197 | 1
198 |
199 |
200 | hebrew
201 | 0
202 |
203 |
204 | hebrew
205 | 1
206 |
207 |
208 | hp8
209 | 0
210 |
211 |
212 | hp8
213 | 1
214 |
215 |
216 | keybcs2
217 | 0
218 |
219 |
220 | keybcs2
221 | 1
222 |
223 |
224 | koi8r
225 | 0
226 |
227 |
228 | koi8r
229 | 1
230 |
231 |
232 | koi8u
233 | 0
234 |
235 |
236 | koi8u
237 | 1
238 |
239 |
240 | latin1
241 | 0
242 |
243 |
244 | latin1
245 | 0
246 |
247 |
248 | latin1
249 | 0
250 |
251 |
252 | latin1
253 | 0
254 |
255 |
256 | latin1
257 | 0
258 |
259 |
260 | latin1
261 | 0
262 |
263 |
264 | latin1
265 | 0
266 |
267 |
268 | latin1
269 | 1
270 |
271 |
272 | latin2
273 | 0
274 |
275 |
276 | latin2
277 | 0
278 |
279 |
280 | latin2
281 | 0
282 |
283 |
284 | latin2
285 | 1
286 |
287 |
288 | latin2
289 | 0
290 |
291 |
292 | latin5
293 | 0
294 |
295 |
296 | latin5
297 | 1
298 |
299 |
300 | latin7
301 | 0
302 |
303 |
304 | latin7
305 | 0
306 |
307 |
308 | latin7
309 | 1
310 |
311 |
312 | latin7
313 | 0
314 |
315 |
316 | macce
317 | 0
318 |
319 |
320 | macce
321 | 1
322 |
323 |
324 | macroman
325 | 0
326 |
327 |
328 | macroman
329 | 1
330 |
331 |
332 | sjis
333 | 0
334 |
335 |
336 | sjis
337 | 1
338 |
339 |
340 | swe7
341 | 0
342 |
343 |
344 | swe7
345 | 1
346 |
347 |
348 | tis620
349 | 0
350 |
351 |
352 | tis620
353 | 1
354 |
355 |
356 | ucs2
357 | 0
358 |
359 |
360 | ucs2
361 | 0
362 |
363 |
364 | ucs2
365 | 0
366 |
367 |
368 | ucs2
369 | 0
370 |
371 |
372 | ucs2
373 | 0
374 |
375 |
376 | ucs2
377 | 0
378 |
379 |
380 | ucs2
381 | 1
382 |
383 |
384 | ucs2
385 | 0
386 |
387 |
388 | ucs2
389 | 0
390 |
391 |
392 | ucs2
393 | 0
394 |
395 |
396 | ucs2
397 | 0
398 |
399 |
400 | ucs2
401 | 0
402 |
403 |
404 | ucs2
405 | 0
406 |
407 |
408 | ucs2
409 | 0
410 |
411 |
412 | ucs2
413 | 0
414 |
415 |
416 | ucs2
417 | 0
418 |
419 |
420 | ucs2
421 | 0
422 |
423 |
424 | ucs2
425 | 0
426 |
427 |
428 | ucs2
429 | 0
430 |
431 |
432 | ucs2
433 | 0
434 |
435 |
436 | ucs2
437 | 0
438 |
439 |
440 | ucs2
441 | 0
442 |
443 |
444 | ucs2
445 | 0
446 |
447 |
448 | ucs2
449 | 0
450 |
451 |
452 | ucs2
453 | 0
454 |
455 |
456 | ucs2
457 | 0
458 |
459 |
460 | ucs2
461 | 0
462 |
463 |
464 | ujis
465 | 0
466 |
467 |
468 | ujis
469 | 1
470 |
471 |
472 | utf16
473 | 0
474 |
475 |
476 | utf16
477 | 0
478 |
479 |
480 | utf16
481 | 0
482 |
483 |
484 | utf16
485 | 0
486 |
487 |
488 | utf16
489 | 0
490 |
491 |
492 | utf16
493 | 0
494 |
495 |
496 | utf16
497 | 1
498 |
499 |
500 | utf16
501 | 0
502 |
503 |
504 | utf16
505 | 0
506 |
507 |
508 | utf16
509 | 0
510 |
511 |
512 | utf16
513 | 0
514 |
515 |
516 | utf16
517 | 0
518 |
519 |
520 | utf16
521 | 0
522 |
523 |
524 | utf16
525 | 0
526 |
527 |
528 | utf16
529 | 0
530 |
531 |
532 | utf16
533 | 0
534 |
535 |
536 | utf16
537 | 0
538 |
539 |
540 | utf16
541 | 0
542 |
543 |
544 | utf16
545 | 0
546 |
547 |
548 | utf16
549 | 0
550 |
551 |
552 | utf16
553 | 0
554 |
555 |
556 | utf16
557 | 0
558 |
559 |
560 | utf16
561 | 0
562 |
563 |
564 | utf16
565 | 0
566 |
567 |
568 | utf16
569 | 0
570 |
571 |
572 | utf16
573 | 0
574 |
575 |
576 | utf16le
577 | 0
578 |
579 |
580 | utf16le
581 | 1
582 |
583 |
584 | utf32
585 | 0
586 |
587 |
588 | utf32
589 | 0
590 |
591 |
592 | utf32
593 | 0
594 |
595 |
596 | utf32
597 | 0
598 |
599 |
600 | utf32
601 | 0
602 |
603 |
604 | utf32
605 | 0
606 |
607 |
608 | utf32
609 | 1
610 |
611 |
612 | utf32
613 | 0
614 |
615 |
616 | utf32
617 | 0
618 |
619 |
620 | utf32
621 | 0
622 |
623 |
624 | utf32
625 | 0
626 |
627 |
628 | utf32
629 | 0
630 |
631 |
632 | utf32
633 | 0
634 |
635 |
636 | utf32
637 | 0
638 |
639 |
640 | utf32
641 | 0
642 |
643 |
644 | utf32
645 | 0
646 |
647 |
648 | utf32
649 | 0
650 |
651 |
652 | utf32
653 | 0
654 |
655 |
656 | utf32
657 | 0
658 |
659 |
660 | utf32
661 | 0
662 |
663 |
664 | utf32
665 | 0
666 |
667 |
668 | utf32
669 | 0
670 |
671 |
672 | utf32
673 | 0
674 |
675 |
676 | utf32
677 | 0
678 |
679 |
680 | utf32
681 | 0
682 |
683 |
684 | utf32
685 | 0
686 |
687 |
688 | utf8
689 | 0
690 |
691 |
692 | utf8
693 | 0
694 |
695 |
696 | utf8
697 | 0
698 |
699 |
700 | utf8
701 | 0
702 |
703 |
704 | utf8
705 | 0
706 |
707 |
708 | utf8
709 | 0
710 |
711 |
712 | utf8
713 | 1
714 |
715 |
716 | utf8
717 | 0
718 |
719 |
720 | utf8
721 | 0
722 |
723 |
724 | utf8
725 | 0
726 |
727 |
728 | utf8
729 | 0
730 |
731 |
732 | utf8
733 | 0
734 |
735 |
736 | utf8
737 | 0
738 |
739 |
740 | utf8
741 | 0
742 |
743 |
744 | utf8
745 | 0
746 |
747 |
748 | utf8
749 | 0
750 |
751 |
752 | utf8
753 | 0
754 |
755 |
756 | utf8
757 | 0
758 |
759 |
760 | utf8
761 | 0
762 |
763 |
764 | utf8
765 | 0
766 |
767 |
768 | utf8
769 | 0
770 |
771 |
772 | utf8
773 | 0
774 |
775 |
776 | utf8
777 | 0
778 |
779 |
780 | utf8
781 | 0
782 |
783 |
784 | utf8
785 | 0
786 |
787 |
788 | utf8
789 | 0
790 |
791 |
792 | utf8
793 | 0
794 |
795 |
796 | utf8mb4
797 | 0
798 |
799 |
800 | utf8mb4
801 | 0
802 |
803 |
804 | utf8mb4
805 | 0
806 |
807 |
808 | utf8mb4
809 | 0
810 |
811 |
812 | utf8mb4
813 | 0
814 |
815 |
816 | utf8mb4
817 | 0
818 |
819 |
820 | utf8mb4
821 | 1
822 |
823 |
824 | utf8mb4
825 | 0
826 |
827 |
828 | utf8mb4
829 | 0
830 |
831 |
832 | utf8mb4
833 | 0
834 |
835 |
836 | utf8mb4
837 | 0
838 |
839 |
840 | utf8mb4
841 | 0
842 |
843 |
844 | utf8mb4
845 | 0
846 |
847 |
848 | utf8mb4
849 | 0
850 |
851 |
852 | utf8mb4
853 | 0
854 |
855 |
856 | utf8mb4
857 | 0
858 |
859 |
860 | utf8mb4
861 | 0
862 |
863 |
864 | utf8mb4
865 | 0
866 |
867 |
868 | utf8mb4
869 | 0
870 |
871 |
872 | utf8mb4
873 | 0
874 |
875 |
876 | utf8mb4
877 | 0
878 |
879 |
880 | utf8mb4
881 | 0
882 |
883 |
884 | utf8mb4
885 | 0
886 |
887 |
888 | utf8mb4
889 | 0
890 |
891 |
892 | utf8mb4
893 | 0
894 |
895 |
896 | utf8mb4
897 | 0
898 |
899 |
900 |
--------------------------------------------------------------------------------
/.idea/dataSources/eb023117-f844-4ce6-9dbd-5a03c863f17a.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | 5.6.43
6 | InnoDB
7 | InnoDB
8 | exact
9 |
10 |
11 | utf8_bin
12 |
13 |
14 | utf8_general_ci
15 |
16 |
17 | utf8_general_ci
18 |
19 |
20 | utf8_bin
21 |
22 |
23 | 1
24 | utf8_general_ci
25 |
26 |
27 | utf8_general_ci
28 |
29 |
30 | utf8_general_ci
31 |
32 |
33 | armscii8
34 | 0
35 |
36 |
37 | armscii8
38 | 1
39 |
40 |
41 | ascii
42 | 0
43 |
44 |
45 | ascii
46 | 1
47 |
48 |
49 | big5
50 | 0
51 |
52 |
53 | big5
54 | 1
55 |
56 |
57 | binary
58 | 1
59 |
60 |
61 | cp1250
62 | 0
63 |
64 |
65 | cp1250
66 | 0
67 |
68 |
69 | cp1250
70 | 0
71 |
72 |
73 | cp1250
74 | 1
75 |
76 |
77 | cp1250
78 | 0
79 |
80 |
81 | cp1251
82 | 0
83 |
84 |
85 | cp1251
86 | 0
87 |
88 |
89 | cp1251
90 | 1
91 |
92 |
93 | cp1251
94 | 0
95 |
96 |
97 | cp1251
98 | 0
99 |
100 |
101 | cp1256
102 | 0
103 |
104 |
105 | cp1256
106 | 1
107 |
108 |
109 | cp1257
110 | 0
111 |
112 |
113 | cp1257
114 | 1
115 |
116 |
117 | cp1257
118 | 0
119 |
120 |
121 | cp850
122 | 0
123 |
124 |
125 | cp850
126 | 1
127 |
128 |
129 | cp852
130 | 0
131 |
132 |
133 | cp852
134 | 1
135 |
136 |
137 | cp866
138 | 0
139 |
140 |
141 | cp866
142 | 1
143 |
144 |
145 | cp932
146 | 0
147 |
148 |
149 | cp932
150 | 1
151 |
152 |
153 | dec8
154 | 0
155 |
156 |
157 | dec8
158 | 1
159 |
160 |
161 | eucjpms
162 | 0
163 |
164 |
165 | eucjpms
166 | 1
167 |
168 |
169 | euckr
170 | 0
171 |
172 |
173 | euckr
174 | 1
175 |
176 |
177 | gb2312
178 | 0
179 |
180 |
181 | gb2312
182 | 1
183 |
184 |
185 | gbk
186 | 0
187 |
188 |
189 | gbk
190 | 1
191 |
192 |
193 | geostd8
194 | 0
195 |
196 |
197 | geostd8
198 | 1
199 |
200 |
201 | greek
202 | 0
203 |
204 |
205 | greek
206 | 1
207 |
208 |
209 | hebrew
210 | 0
211 |
212 |
213 | hebrew
214 | 1
215 |
216 |
217 | hp8
218 | 0
219 |
220 |
221 | hp8
222 | 1
223 |
224 |
225 | keybcs2
226 | 0
227 |
228 |
229 | keybcs2
230 | 1
231 |
232 |
233 | koi8r
234 | 0
235 |
236 |
237 | koi8r
238 | 1
239 |
240 |
241 | koi8u
242 | 0
243 |
244 |
245 | koi8u
246 | 1
247 |
248 |
249 | latin1
250 | 0
251 |
252 |
253 | latin1
254 | 0
255 |
256 |
257 | latin1
258 | 0
259 |
260 |
261 | latin1
262 | 0
263 |
264 |
265 | latin1
266 | 0
267 |
268 |
269 | latin1
270 | 0
271 |
272 |
273 | latin1
274 | 0
275 |
276 |
277 | latin1
278 | 1
279 |
280 |
281 | latin2
282 | 0
283 |
284 |
285 | latin2
286 | 0
287 |
288 |
289 | latin2
290 | 0
291 |
292 |
293 | latin2
294 | 1
295 |
296 |
297 | latin2
298 | 0
299 |
300 |
301 | latin5
302 | 0
303 |
304 |
305 | latin5
306 | 1
307 |
308 |
309 | latin7
310 | 0
311 |
312 |
313 | latin7
314 | 0
315 |
316 |
317 | latin7
318 | 1
319 |
320 |
321 | latin7
322 | 0
323 |
324 |
325 | macce
326 | 0
327 |
328 |
329 | macce
330 | 1
331 |
332 |
333 | macroman
334 | 0
335 |
336 |
337 | macroman
338 | 1
339 |
340 |
341 | sjis
342 | 0
343 |
344 |
345 | sjis
346 | 1
347 |
348 |
349 | swe7
350 | 0
351 |
352 |
353 | swe7
354 | 1
355 |
356 |
357 | tis620
358 | 0
359 |
360 |
361 | tis620
362 | 1
363 |
364 |
365 | ucs2
366 | 0
367 |
368 |
369 | ucs2
370 | 0
371 |
372 |
373 | ucs2
374 | 0
375 |
376 |
377 | ucs2
378 | 0
379 |
380 |
381 | ucs2
382 | 0
383 |
384 |
385 | ucs2
386 | 0
387 |
388 |
389 | ucs2
390 | 1
391 |
392 |
393 | ucs2
394 | 0
395 |
396 |
397 | ucs2
398 | 0
399 |
400 |
401 | ucs2
402 | 0
403 |
404 |
405 | ucs2
406 | 0
407 |
408 |
409 | ucs2
410 | 0
411 |
412 |
413 | ucs2
414 | 0
415 |
416 |
417 | ucs2
418 | 0
419 |
420 |
421 | ucs2
422 | 0
423 |
424 |
425 | ucs2
426 | 0
427 |
428 |
429 | ucs2
430 | 0
431 |
432 |
433 | ucs2
434 | 0
435 |
436 |
437 | ucs2
438 | 0
439 |
440 |
441 | ucs2
442 | 0
443 |
444 |
445 | ucs2
446 | 0
447 |
448 |
449 | ucs2
450 | 0
451 |
452 |
453 | ucs2
454 | 0
455 |
456 |
457 | ucs2
458 | 0
459 |
460 |
461 | ucs2
462 | 0
463 |
464 |
465 | ucs2
466 | 0
467 |
468 |
469 | ucs2
470 | 0
471 |
472 |
473 | ujis
474 | 0
475 |
476 |
477 | ujis
478 | 1
479 |
480 |
481 | utf16
482 | 0
483 |
484 |
485 | utf16
486 | 0
487 |
488 |
489 | utf16
490 | 0
491 |
492 |
493 | utf16
494 | 0
495 |
496 |
497 | utf16
498 | 0
499 |
500 |
501 | utf16
502 | 0
503 |
504 |
505 | utf16
506 | 1
507 |
508 |
509 | utf16
510 | 0
511 |
512 |
513 | utf16
514 | 0
515 |
516 |
517 | utf16
518 | 0
519 |
520 |
521 | utf16
522 | 0
523 |
524 |
525 | utf16
526 | 0
527 |
528 |
529 | utf16
530 | 0
531 |
532 |
533 | utf16
534 | 0
535 |
536 |
537 | utf16
538 | 0
539 |
540 |
541 | utf16
542 | 0
543 |
544 |
545 | utf16
546 | 0
547 |
548 |
549 | utf16
550 | 0
551 |
552 |
553 | utf16
554 | 0
555 |
556 |
557 | utf16
558 | 0
559 |
560 |
561 | utf16
562 | 0
563 |
564 |
565 | utf16
566 | 0
567 |
568 |
569 | utf16
570 | 0
571 |
572 |
573 | utf16
574 | 0
575 |
576 |
577 | utf16
578 | 0
579 |
580 |
581 | utf16
582 | 0
583 |
584 |
585 | utf16le
586 | 0
587 |
588 |
589 | utf16le
590 | 1
591 |
592 |
593 | utf32
594 | 0
595 |
596 |
597 | utf32
598 | 0
599 |
600 |
601 | utf32
602 | 0
603 |
604 |
605 | utf32
606 | 0
607 |
608 |
609 | utf32
610 | 0
611 |
612 |
613 | utf32
614 | 0
615 |
616 |
617 | utf32
618 | 1
619 |
620 |
621 | utf32
622 | 0
623 |
624 |
625 | utf32
626 | 0
627 |
628 |
629 | utf32
630 | 0
631 |
632 |
633 | utf32
634 | 0
635 |
636 |
637 | utf32
638 | 0
639 |
640 |
641 | utf32
642 | 0
643 |
644 |
645 | utf32
646 | 0
647 |
648 |
649 | utf32
650 | 0
651 |
652 |
653 | utf32
654 | 0
655 |
656 |
657 | utf32
658 | 0
659 |
660 |
661 | utf32
662 | 0
663 |
664 |
665 | utf32
666 | 0
667 |
668 |
669 | utf32
670 | 0
671 |
672 |
673 | utf32
674 | 0
675 |
676 |
677 | utf32
678 | 0
679 |
680 |
681 | utf32
682 | 0
683 |
684 |
685 | utf32
686 | 0
687 |
688 |
689 | utf32
690 | 0
691 |
692 |
693 | utf32
694 | 0
695 |
696 |
697 | utf8
698 | 0
699 |
700 |
701 | utf8
702 | 0
703 |
704 |
705 | utf8
706 | 0
707 |
708 |
709 | utf8
710 | 0
711 |
712 |
713 | utf8
714 | 0
715 |
716 |
717 | utf8
718 | 0
719 |
720 |
721 | utf8
722 | 1
723 |
724 |
725 | utf8
726 | 0
727 |
728 |
729 | utf8
730 | 0
731 |
732 |
733 | utf8
734 | 0
735 |
736 |
737 | utf8
738 | 0
739 |
740 |
741 | utf8
742 | 0
743 |
744 |
745 | utf8
746 | 0
747 |
748 |
749 | utf8
750 | 0
751 |
752 |
753 | utf8
754 | 0
755 |
756 |
757 | utf8
758 | 0
759 |
760 |
761 | utf8
762 | 0
763 |
764 |
765 | utf8
766 | 0
767 |
768 |
769 | utf8
770 | 0
771 |
772 |
773 | utf8
774 | 0
775 |
776 |
777 | utf8
778 | 0
779 |
780 |
781 | utf8
782 | 0
783 |
784 |
785 | utf8
786 | 0
787 |
788 |
789 | utf8
790 | 0
791 |
792 |
793 | utf8
794 | 0
795 |
796 |
797 | utf8
798 | 0
799 |
800 |
801 | utf8
802 | 0
803 |
804 |
805 | utf8mb4
806 | 0
807 |
808 |
809 | utf8mb4
810 | 0
811 |
812 |
813 | utf8mb4
814 | 0
815 |
816 |
817 | utf8mb4
818 | 0
819 |
820 |
821 | utf8mb4
822 | 0
823 |
824 |
825 | utf8mb4
826 | 0
827 |
828 |
829 | utf8mb4
830 | 1
831 |
832 |
833 | utf8mb4
834 | 0
835 |
836 |
837 | utf8mb4
838 | 0
839 |
840 |
841 | utf8mb4
842 | 0
843 |
844 |
845 | utf8mb4
846 | 0
847 |
848 |
849 | utf8mb4
850 | 0
851 |
852 |
853 | utf8mb4
854 | 0
855 |
856 |
857 | utf8mb4
858 | 0
859 |
860 |
861 | utf8mb4
862 | 0
863 |
864 |
865 | utf8mb4
866 | 0
867 |
868 |
869 | utf8mb4
870 | 0
871 |
872 |
873 | utf8mb4
874 | 0
875 |
876 |
877 | utf8mb4
878 | 0
879 |
880 |
881 | utf8mb4
882 | 0
883 |
884 |
885 | utf8mb4
886 | 0
887 |
888 |
889 | utf8mb4
890 | 0
891 |
892 |
893 | utf8mb4
894 | 0
895 |
896 |
897 | utf8mb4
898 | 0
899 |
900 |
901 | utf8mb4
902 | 0
903 |
904 |
905 | utf8mb4
906 | 0
907 |
908 |
909 |
--------------------------------------------------------------------------------
/.idea/encodings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/Project_Default.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
14 |
15 |
16 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/other.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
--------------------------------------------------------------------------------
/.idea/transformer-master.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
153 |
154 |
155 |
156 | ”“
157 | ga
158 | elems
159 | bleu score
160 | self.tf.
161 | evaldir
162 | (
163 | oovs
164 | memory
165 | 苏泊尔被曝锰含量超标近4倍可致帕金森病
166 | LSCTS
167 | ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1]
168 | gpu_list
169 | gpu_nums
170 | hp.eval3
171 |
172 | num_eval_batches
173 | tf
174 | _gs
175 | i
176 | eval_rouge
177 | num_train_samples
178 | iter
179 | validation
180 | iterator
181 | 隋炀帝墓135件文物首次展出杨广2颗牙齿亮相
182 | 中医药为何逐渐被边缘化
183 | self.vocab_file
184 | output_shapes
185 | _encode
186 |
187 |
188 | tf.
189 | LCSTS
190 | hp.test
191 | hp.eval_rouge
192 | iterator
193 | val
194 | iter
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 | 1552362774445
496 |
497 |
498 | 1552362774445
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 | 1557365218479
534 |
535 |
536 |
537 | 1557365218479
538 |
539 |
540 | 1557972785135
541 |
542 |
543 |
544 | 1557972785135
545 |
546 |
547 | 1558085745605
548 |
549 |
550 |
551 | 1558085745605
552 |
553 |
554 | 1558337002489
555 |
556 |
557 |
558 | 1558337002489
559 |
560 |
561 | 1558348841756
562 |
563 |
564 |
565 | 1558348841756
566 |
567 |
568 | 1558401210881
569 |
570 |
571 |
572 | 1558401210881
573 |
574 |
575 | 1558408517378
576 |
577 |
578 |
579 | 1558408517378
580 |
581 |
582 | 1558409466718
583 |
584 |
585 |
586 | 1558409466718
587 |
588 |
589 | 1558409636201
590 |
591 |
592 |
593 | 1558409636201
594 |
595 |
596 | 1558409968858
597 |
598 |
599 |
600 | 1558409968858
601 |
602 |
603 | 1558410453262
604 |
605 |
606 |
607 | 1558410453262
608 |
609 |
610 | 1558410517294
611 |
612 |
613 |
614 | 1558410517294
615 |
616 |
617 | 1558422819896
618 |
619 |
620 |
621 | 1558422819896
622 |
623 |
624 | 1558493250965
625 |
626 |
627 |
628 | 1558493250965
629 |
630 |
631 | 1558493509793
632 |
633 |
634 |
635 | 1558493509793
636 |
637 |
638 | 1558692358913
639 |
640 |
641 |
642 | 1558692358913
643 |
644 |
645 | 1558692586145
646 |
647 |
648 |
649 | 1558692586145
650 |
651 |
652 | 1558768389765
653 |
654 |
655 |
656 | 1558768389765
657 |
658 |
659 | 1558768991425
660 |
661 |
662 |
663 | 1558768991425
664 |
665 |
666 | 1558777069599
667 |
668 |
669 |
670 | 1558777069599
671 |
672 |
673 | 1558778049864
674 |
675 |
676 |
677 | 1558778049864
678 |
679 |
680 | 1558780427260
681 |
682 |
683 |
684 | 1558780427260
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 | file://$PROJECT_DIR$/utils.py
761 | 143
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 |
829 |
830 |
831 |
832 |
833 |
834 |
835 |
836 |
837 |
838 |
839 |
840 |
841 |
842 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 |
858 |
859 |
860 |
861 |
862 |
863 |
864 |
865 |
866 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 |
886 |
887 |
888 |
889 |
890 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
899 |
900 |
901 |
902 |
903 |
904 |
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 |
916 |
917 |
918 |
919 |
920 |
921 |
922 |
923 |
924 |
925 |
926 |
927 |
928 |
929 |
930 |
931 |
932 |
933 |
934 |
935 |
936 |
937 |
938 |
939 |
940 |
941 |
942 |
943 |
944 |
945 |
946 |
947 |
948 |
949 |
950 |
951 |
952 |
953 |
954 |
955 |
956 |
957 |
958 |
959 |
960 |
961 |
962 |
963 |
964 |
965 |
966 |
967 |
968 |
969 |
970 |
971 |
972 |
973 |
974 |
975 |
976 |
977 |
978 |
979 |
980 |
981 |
982 |
983 |
984 |
985 |
986 |
987 |
988 |
989 |
990 |
991 |
992 |
993 |
994 |
995 |
996 |
997 |
998 |
999 |
1000 |
1001 |
1002 |
1003 |
1004 |
1005 |
1006 |
1007 |
1008 |
1009 |
1010 |
1011 |
1012 |
1013 |
1014 |
1015 |
1016 |
1017 |
1018 |
1019 |
1020 |
1021 |
1022 |
1023 |
1024 |
1025 |
1026 |
1027 |
1028 |
1029 |
1030 |
1031 |
1032 |
1033 |
1034 |
1035 |
1036 |
1037 |
1038 |
1039 |
1040 |
1041 |
1042 |
1043 |
1044 |
1045 |
1046 |
1047 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Cally
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # A Abstractive Summarization Implementation with Transformer and Pointer-generator
2 | when I wanted to get summary by neural network, I tried many ways to generate abstract summary, but the result was not good.
3 | when I heared 2018 byte cup, I found some information about it, and the champion's solution attracted me, but I found some websites,
4 | like github gitlab, I didn't find the official code, so I decided to implement it.
5 |
6 | ## Requirements
7 | * python==3.x (Let's move on to python 3 if you still use python 2)
8 | * tensorflow==1.12.0
9 | * tqdm>=4.28.1
10 | * jieba>=0.3x
11 | * sumeval>=0.2.0
12 |
13 | ## Model Structure
14 | ### Based
15 | My model is based on [Attention Is All You Need](https://arxiv.org/abs/1706.03762) and [Get To The Point: Summarization with Pointer-Generator Networks](https://arxiv.org/abs/1704.04368)
16 | ### Change
17 | * The pointer-generator model has two mechanisms, which are **copy mechanism** and **coverage mechanism**, I found some materials,
18 | they show the Coverage mechanism doesn't suit short summary, so I didn't use this mechanism, just use the first one.
19 | * Pointer generator model has a inadequacy, which can let the loss got nan, I tried some times and wanted to fix it,
20 | but the result was I can't, I think the reason was when calculate final logists, it will
21 | extend vocab length to oov and vocab length, it will get more zeroes. so I delete the mechanism of extend final logists, just use their mechanism of
22 | deocode from article and vocab. there is more [detail](https://github.com/abisee/pointer-generator/issues/4) about it,
23 | in this model, I just use word than vocab, this idea is from bert.
24 | ### Structure
25 |
26 |
27 | ## Training
28 | * STEP 1. [download](https://pan.baidu.com/s/1szq0Wa60AS5ISpM_SNPcbA) the dataset, pwd is ayn6, the dataset is LCSTS by pre processed, so you will see very different dataset structure with LCSTS in the file
29 | each line is abstract and article, they split by **","**, if you worry the amount of the dataset is different between my and LCSTS, don't
30 | worry, the amout of the dataset is same as LCSTS.
31 | * STEP 2. Run the following command.
32 | ```
33 | python train.py
34 | ```
35 | Check `hparams.py` to see which parameters are possible. For example,
36 | ```
37 | python train.py --logdir myLog --batch_size 32 --train myTrain --eval myEval
38 | ```
39 | My code also improve multi gpu to train this model, if you have more than one gpu, just run like this
40 | ```
41 | python train.py --logdir myLog --batch_size 32 --train myTrain --eval myEval --gpu_nums=myGPUNums
42 | ```
43 |
44 | | name | type | detail |
45 | |--------------------|------|-------------|
46 | vocab_size | int | vocab size
47 | train | str | train dataset dir
48 | eval | str| eval dataset dir
49 | test | str| data for calculate rouge score
50 | vocab | str| vocabulary file path
51 | batch_size | int| train batch size
52 | eval_batch_size | int| eval batch size
53 | lr | float| learning rate
54 | warmup_steps | int| warmup steps by learing rate
55 | logdir | str| log directory
56 | num_epochs | int| the number of train epoch
57 | evaldir | str| evaluation dir
58 | d_model | int| hidden dimension of encoder/decoder
59 | d_ff | int| hidden dimension of feedforward layer
60 | num_blocks | int| number of encoder/decoder blocks
61 | num_heads | int| number of attention heads
62 | maxlen1 | int| maximum length of a source sequence
63 | maxlen2 | int| maximum length of a target sequence
64 | dropout_rate | float| dropout rate
65 | beam_size | int| beam size for decode
66 | gpu_nums | int| gpu amount, which can allow how many gpu to train this model, default 1
67 | ### Note
68 | Don't change the hyper-parameters of transformer util you have good solution, it will let the loss can't go down! if you have good solution, I hope you can tell me.
69 |
70 | ## Evaluation
71 | ### Loss
72 | * Transformer-Pointer generator
73 |
74 | * Transformer
75 |
76 | As you see, transformer-pointer generator model can let the loss go down very quickly!
77 |
78 | ## If you like it, and think it useful for you, hope you can star.
--------------------------------------------------------------------------------
/__pycache__/beam_search.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/__pycache__/beam_search.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/data_load.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/__pycache__/data_load.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/hparams.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/__pycache__/hparams.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/__pycache__/model.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/modules.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/__pycache__/modules.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/utils.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/__pycache__/utils.cpython-36.pyc
--------------------------------------------------------------------------------
/beam_search.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 |
9 | import tensorflow as tf
10 | # add self to decode memory
11 | class Hypothesis:
12 | """
13 | Defines a hypothesis during beam search.
14 | """
15 | def __init__(self, tokens, log_prob, sents, normalize_by_length=True):
16 | """
17 | :param tokens: a list, which are ids in vocab
18 | :param log_prob: log probability, add by beam search
19 | :param sents: already decode words,
20 | :param normalize_by_length: sort hypothesis by prob / len, if not, just by prob
21 | """
22 | self.tokens = tokens
23 | self.log_prob = log_prob
24 | self.normalize_by_length = normalize_by_length
25 | self.sents = sents
26 |
27 | def extend(self, token, log_prob, word):
28 | """
29 | Extend the hypothesis with result from latest step.
30 | :param token: latest token from decoding
31 | :param log_prob: log prob of the latest decoded tokens.
32 | :param word: word piece by transformer decode
33 | :return: new Hypothesis with the results from latest step.
34 | """
35 |
36 | return Hypothesis(self.tokens + [token], self.log_prob + log_prob, self.sents + word)
37 |
38 | @property
39 | def latest_token(self):
40 | return self.tokens[-1]
41 |
42 | def __str__(self):
43 | return ''.join(list(self.sents))
44 |
45 | class BeamSearch:
46 | def __init__(self, model, beam_size, start_token, end_token, id2token, max_steps, input_x, input_y, logits,
47 | normalize_by_length=False):
48 | """
49 | :param model: transformer model
50 | :param beam_size: beam size
51 | :param start_token: start token
52 | :param end_token: end token
53 | :param id2token: id to token dict
54 | :param max_steps: max steps in decode
55 | :param input_x: input x
56 | :param input_y: input y
57 | :param logits: logits by decode
58 | :param normalize_by_length: sort hypothesis by prob / len, if not, just by prob
59 | """
60 | # basic params
61 | self.model = model
62 | self.beam_size = beam_size
63 | self.start_token = start_token
64 | self.end_token = end_token
65 | self.max_steps = max_steps
66 | self.id2token = id2token
67 |
68 | # placeholder
69 | self.input_x = input_x
70 | self.input_y = input_y
71 |
72 | self.top_k_ = tf.nn.top_k(tf.nn.softmax(logits), k=self.beam_size * 2)
73 |
74 | # This length normalization is only effective for the final results.
75 | self.normalize_by_length = normalize_by_length
76 |
77 | def search(self, sess, input_x, memory):
78 | """
79 | use beam search for decoding
80 | :param sess: tensorflow session
81 | :param input_x: article by list, and convert to id by vocab
82 | :param memory: transformer encode result
83 | :return: hyps: list of Hypothesis, the best hypotheses found by beam search,
84 | ordered by score
85 | """
86 | # create a list, which each element is Hypothesis
87 | hyps = [Hypothesis([self.start_token], 0.0, '')] * self.beam_size
88 |
89 | results = []
90 | steps = 0
91 | while steps < self.max_steps and len(results) < self.beam_size:
92 | top_k = sess.run([self.top_k_], feed_dict={self.model.memory: [memory] * self.beam_size,
93 | self.input_x: [input_x] * self.beam_size,
94 | self.input_y: [h.tokens for h in hyps]})
95 | # print(time.time() - start)
96 | indices = [list(indice[-1]) for indice in top_k[0][1]]
97 | probs = [list(prob[-1]) for prob in top_k[0][0]]
98 |
99 | all_hyps = []
100 |
101 | num_orig_hyps = 1 if steps == 0 else len(hyps)
102 | for i in range(num_orig_hyps):
103 | h = hyps[i]
104 | for j in range(self.beam_size*2):
105 | new_h = h.extend(indices[i][j], probs[i][j], self.id2token[indices[i][j]])
106 | all_hyps.append(new_h)
107 |
108 | # Filter and collect any hypotheses that have the end token
109 | hyps = []
110 | for h in self.best_hyps(all_hyps):
111 | if h.latest_token == self.end_token:
112 | # Pull the hypothesis off the beam if the end token is reached.
113 | results.append(h)
114 | else:
115 | # Otherwise continue to the extend the hypothesis.
116 | hyps.append(h)
117 | if len(hyps) == self.beam_size or len(results) == self.beam_size:
118 | break
119 |
120 | steps += 1
121 |
122 | if steps == self.max_steps:
123 | results.extend(hyps)
124 |
125 | return self.best_hyps(results)
126 |
127 | def best_hyps(self, hyps):
128 | """
129 | Sort the hyps based on log probs and length.
130 | :param hyps: A list of hypothesis
131 | :return: A list of sorted hypothesis in reverse log_prob order.
132 | """
133 | # This length normalization is only effective for the final results.
134 | if self.normalize_by_length:
135 | return sorted(hyps, key=lambda h: h.log_prob / len(h.tokens), reverse=True)
136 | else:
137 | return sorted(hyps, key=lambda h: h.log_prob, reverse=True)
--------------------------------------------------------------------------------
/data_load.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 |
9 | import tensorflow as tf
10 | from utils import calc_num_batches
11 |
12 | def _load_vocab(vocab_fpath):
13 | '''Loads vocabulary file and returns idx<->token maps
14 | vocab_fpath: string. vocabulary file path.
15 | Note that these are reserved
16 | 0: , 1: , 2: , 3:
17 |
18 | Returns
19 | two dictionaries.
20 | '''
21 | vocab = []
22 | with open(vocab_fpath, 'r', encoding='utf-8') as f:
23 | for line in f:
24 | vocab.append(line.replace('\n', ''))
25 | token2idx = {token: idx for idx, token in enumerate(vocab)}
26 | idx2token = {idx: token for idx, token in enumerate(vocab)}
27 |
28 | return token2idx, idx2token
29 |
30 | def load_stop(vocab_path):
31 | """
32 | load stop word
33 | :param vocab_path: stop word path
34 | :return: stop word list
35 | """
36 | stop_words = []
37 | with open(vocab_path, 'r', encoding='utf-8') as f:
38 | for line in f:
39 | stop_words.append(line.replace('\n', ''))
40 |
41 | return sorted(stop_words, key=lambda i: len(i), reverse=True)
42 |
43 | def _load_data(fpaths, maxlen1, maxlen2):
44 | '''Loads source and target data and filters out too lengthy samples.
45 | fpath1: source file path. string.
46 | fpath2: target file path. string.
47 | maxlen1: source sent maximum length. scalar.
48 | maxlen2: target sent maximum length. scalar.
49 |
50 | Returns
51 | sents1: list of source sents
52 | sents2: list of target sents
53 | '''
54 | sents1, sents2 = [], []
55 | for fpath in fpaths.split('|'):
56 | with open(fpath, 'r', encoding='utf-8') as f:
57 | for line in f:
58 | splits = line.split(',')
59 | if len(splits) != 2: continue
60 | sen1 = splits[1].replace('\n', '').strip()
61 | sen2 = splits[0].strip()
62 | if len(list(sen1)) + 1 > maxlen1-2: continue
63 | if len(list(sen2)) + 1 > maxlen2-1: continue
64 |
65 | sents1.append(sen1.encode('utf-8'))
66 | sents2.append(sen2.encode('utf-8'))
67 |
68 | return sents1[:400000], sents2[:400000]
69 |
70 | def _encode(inp, token2idx, maxlen, type):
71 | '''Converts string to number. Used for `generator_fn`.
72 | inp: 1d byte array.
73 | type: "x" (source side) or "y" (target side)
74 | dict: token2idx dictionary
75 |
76 | Returns
77 | list of numbers
78 | '''
79 | inp = inp.decode('utf-8')
80 | if type == 'x':
81 | tokens = [''] + list(inp) + ['']
82 | while len(tokens) < maxlen:
83 | tokens.append('')
84 | return [token2idx.get(token, token2idx['']) for token in tokens]
85 |
86 | else:
87 | inputs = [''] + list(inp)
88 | target = list(inp) + ['']
89 | while len(target) < maxlen:
90 | inputs.append('')
91 | target.append('')
92 | return [token2idx.get(token, token2idx['']) for token in inputs], [token2idx.get(token, token2idx['']) for token in target]
93 |
94 | def _generator_fn(sents1, sents2, vocab_fpath, maxlen1, maxlen2):
95 | '''Generates training / evaluation data
96 | sents1: list of source sents
97 | sents2: list of target sents
98 | vocab_fpath: string. vocabulary file path.
99 |
100 | yields
101 | xs: tuple of
102 | x: list of source token ids in a sent
103 | x_seqlen: int. sequence length of x
104 | sent1: str. raw source (=input) sentence
105 | labels: tuple of
106 | decoder_input: decoder_input: list of encoded decoder inputs
107 | y: list of target token ids in a sent
108 | y_seqlen: int. sequence length of y
109 | sent2: str. target sentence
110 | '''
111 | token2idx, _ = _load_vocab(vocab_fpath)
112 | for sent1, sent2 in zip(sents1, sents2):
113 | x = _encode(sent1, token2idx, maxlen1, "x")
114 |
115 | inputs, targets = _encode(sent2, token2idx, maxlen2, "y")
116 |
117 | yield (x, sent1.decode('utf-8')), (inputs, targets, sent2.decode('utf-8'))
118 |
119 | def _input_fn(sents1, sents2, vocab_fpath, batch_size, gpu_nums, maxlen1, maxlen2, shuffle=False):
120 | '''Batchify data
121 | sents1: list of source sents
122 | sents2: list of target sents
123 | vocab_fpath: string. vocabulary file path.
124 | batch_size: scalar
125 | shuffle: boolean
126 |
127 | Returns
128 | xs: tuple of
129 | x: int32 tensor. (N, T1)
130 | x_seqlens: int32 tensor. (N,)
131 | sents1: str tensor. (N,)
132 | ys: tuple of
133 | decoder_input: int32 tensor. (N, T2)
134 | y: int32 tensor. (N, T2)
135 | y_seqlen: int32 tensor. (N, )
136 | sents2: str tensor. (N,)
137 | '''
138 | shapes = (([maxlen1], ()),
139 | ([maxlen2], [maxlen2], ()))
140 | types = ((tf.int32, tf.string),
141 | (tf.int32, tf.int32, tf.string))
142 |
143 | dataset = tf.data.Dataset.from_generator(
144 | _generator_fn,
145 | output_shapes=shapes,
146 | output_types=types,
147 | args=(sents1, sents2, vocab_fpath, maxlen1, maxlen2)) # <- arguments for generator_fn. converted to np string arrays
148 |
149 | if shuffle: # for training
150 | dataset = dataset.shuffle(128*batch_size*gpu_nums)
151 |
152 | dataset = dataset.repeat() # iterate forever
153 | dataset = dataset.batch(batch_size*gpu_nums)
154 |
155 | return dataset
156 |
157 | def get_batch(fpath, maxlen1, maxlen2, vocab_fpath, batch_size, gpu_nums, shuffle=False):
158 | '''Gets training / evaluation mini-batches
159 | fpath: source file path. string.
160 | maxlen1: source sent maximum length. scalar.
161 | maxlen2: target sent maximum length. scalar.
162 | vocab_fpath: string. vocabulary file path.
163 | batch_size: scalar
164 | shuffle: boolean
165 |
166 | Returns
167 | batches
168 | num_batches: number of mini-batches
169 | num_samples
170 | '''
171 | sents1, sents2 = _load_data(fpath, maxlen1, maxlen2)
172 | batches = _input_fn(sents1, sents2, vocab_fpath, batch_size, gpu_nums, maxlen1, maxlen2, shuffle=shuffle)
173 | num_batches = calc_num_batches(len(sents1), batch_size*gpu_nums)
174 | return batches, num_batches, len(sents1)
175 |
--------------------------------------------------------------------------------
/fig/structure.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/fig/structure.jpg
--------------------------------------------------------------------------------
/fig/transformer-loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/fig/transformer-loss.png
--------------------------------------------------------------------------------
/fig/transformer-pointer gererator-loss.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiongma/transformer-pointer-generator/6547cd40aa7aaa3802081b14b4d0490493cd794d/fig/transformer-pointer gererator-loss.png
--------------------------------------------------------------------------------
/hparams.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 | import argparse
9 |
10 | class Hparams:
11 | parser = argparse.ArgumentParser()
12 |
13 | # prepro
14 | parser.add_argument('--vocab_size', default=10598, type=int)
15 |
16 | # train
17 | ## files
18 | parser.add_argument('--train', default='data/test.csv',
19 | help="data for train")
20 |
21 | parser.add_argument('--eval', default='data/test.csv',
22 | help="data for evaluation")
23 | parser.add_argument('--eval_rouge', default='data/test_summary.csv',
24 | help="data for calculate rouge score")
25 |
26 | ## vocabulary
27 | parser.add_argument('--vocab', default='vocab',
28 | help="vocabulary file path")
29 |
30 | parser.add_argument('--stop_vocab', default='stop_vocab',
31 | help="stop vocabulary file path")
32 |
33 | # training scheme
34 | parser.add_argument('--batch_size', default=32, type=int)
35 | parser.add_argument('--eval_batch_size', default=32, type=int)
36 |
37 | parser.add_argument('--lr', default=0.0005, type=float, help="learning rate")
38 | parser.add_argument('--warmup_steps', default=4000, type=int)
39 | parser.add_argument('--logdir', default="log/2", help="log directory")
40 | parser.add_argument('--num_epochs', default=5, type=int)
41 | parser.add_argument('--evaldir', default="eval/1", help="evaluation dir")
42 |
43 | # model
44 | parser.add_argument('--d_model', default=512, type=int,
45 | help="hidden dimension of encoder/decoder")
46 | parser.add_argument('--d_ff', default=2048, type=int,
47 | help="hidden dimension of feedforward layer")
48 | parser.add_argument('--num_blocks', default=6, type=int,
49 | help="number of encoder/decoder blocks")
50 | parser.add_argument('--num_heads', default=8, type=int,
51 | help="number of attention heads")
52 | parser.add_argument('--maxlen1', default=150, type=int,
53 | help="maximum length of a source sequence")
54 | parser.add_argument('--maxlen2', default=25, type=int,
55 | help="maximum length of a target sequence")
56 | parser.add_argument('--dropout_rate', default=0.1, type=float)
57 | parser.add_argument('--beam_size', default=4, type=int,
58 | help="beam size")
59 | parser.add_argument('--gpu_nums', default=1, type=int,
60 | help="gpu amount, which can allow how many gpu to train this model")
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 | import logging
9 |
10 | import tensorflow as tf
11 | from tqdm import tqdm
12 |
13 | from data_load import _load_vocab
14 | from modules import get_token_embeddings, ff, positional_encoding, multihead_attention, noam_scheme
15 | from utils import convert_idx_to_token_tensor, split_input
16 |
17 | logging.basicConfig(level=logging.INFO)
18 |
19 | class Transformer:
20 | def __init__(self, hp):
21 | self.hp = hp
22 | self.token2idx, self.idx2token = _load_vocab(hp.vocab)
23 | self.embeddings = get_token_embeddings(self.hp.vocab_size, self.hp.d_model, zero_pad=True)
24 |
25 | def encode(self, xs, training=True):
26 | '''
27 | Returns
28 | memory: encoder outputs. (N, T1, d_model)
29 | '''
30 | with tf.variable_scope("encoder", reuse=tf.AUTO_REUSE):
31 | self.x, sents1 = xs
32 |
33 | # embedding
34 | enc = tf.nn.embedding_lookup(self.embeddings, self.x) # (N, T1, d_model)
35 | enc *= self.hp.d_model**0.5 # scale
36 |
37 | enc += positional_encoding(enc, self.hp.maxlen1)
38 | enc = tf.layers.dropout(enc, self.hp.dropout_rate, training=training)
39 | ## Blocks
40 | for i in range(self.hp.num_blocks):
41 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
42 | # self-attention
43 | enc, _ = multihead_attention(queries=enc,
44 | keys=enc,
45 | values=enc,
46 | num_heads=self.hp.num_heads,
47 | dropout_rate=self.hp.dropout_rate,
48 | training=training,
49 | causality=False)
50 | # feed forward
51 | enc = ff(enc, num_units=[self.hp.d_ff, self.hp.d_model])
52 | self.enc_output = enc
53 | return self.enc_output, sents1
54 |
55 | def decode(self, xs, ys, memory, training=True):
56 | '''
57 | memory: encoder outputs. (N, T1, d_model)
58 |
59 | Returns
60 | logits: (N, T2, V). float32.
61 | y: (N, T2). int32
62 | sents2: (N,). string.
63 | '''
64 | self.memory = memory
65 | with tf.variable_scope("decoder", reuse=tf.AUTO_REUSE):
66 | self.decoder_inputs, y, sents2 = ys
67 | x, _, = xs
68 |
69 | # embedding
70 | dec = tf.nn.embedding_lookup(self.embeddings, self.decoder_inputs) # (N, T2, d_model)
71 | dec *= self.hp.d_model ** 0.5 # scale
72 |
73 | dec += positional_encoding(dec, self.hp.maxlen2)
74 |
75 | before_dec = dec
76 |
77 | dec = tf.layers.dropout(dec, self.hp.dropout_rate, training=training)
78 |
79 | attn_dists = []
80 | # Blocks
81 | for i in range(self.hp.num_blocks):
82 | with tf.variable_scope("num_blocks_{}".format(i), reuse=tf.AUTO_REUSE):
83 | # Masked self-attention (Note that causality is True at this time)
84 | dec, _ = multihead_attention(queries=dec,
85 | keys=dec,
86 | values=dec,
87 | num_heads=self.hp.num_heads,
88 | dropout_rate=self.hp.dropout_rate,
89 | training=training,
90 | causality=True,
91 | scope="self_attention")
92 | # Vanilla attention
93 | dec, attn_dist = multihead_attention(queries=dec,
94 | keys=self.memory,
95 | values=self.memory,
96 | num_heads=self.hp.num_heads,
97 | dropout_rate=self.hp.dropout_rate,
98 | training=training,
99 | causality=False,
100 | scope="vanilla_attention")
101 | attn_dists.append(attn_dist)
102 | ### Feed Forward
103 | dec = ff(dec, num_units=[self.hp.d_ff, self.hp.d_model])
104 |
105 | # Final linear projection (embedding weights are shared)
106 | weights = tf.transpose(self.embeddings) # (d_model, vocab_size)
107 | logits = tf.einsum('ntd,dk->ntk', dec, weights) # (N, T2, vocab_size)
108 |
109 | with tf.variable_scope("gen", reuse=tf.AUTO_REUSE):
110 | gens = tf.layers.dense(tf.concat([before_dec, dec, attn_dists[-1]], axis=-1), units=1, activation=tf.sigmoid,
111 | trainable=training, use_bias=False)
112 |
113 | logits = tf.nn.softmax(logits)
114 |
115 | # final distribution
116 | self.logits = self._calc_final_dist(x, gens, logits, attn_dists[-1])
117 |
118 | return self.logits, y, sents2
119 |
120 | def _calc_final_dist(self, x, gens, vocab_dists, attn_dists):
121 | """Calculate the final distribution, for the pointer-generator model
122 |
123 | Args:
124 | x: encoder input which contain oov number
125 | gens: the generation, choose vocab from article or vocab
126 | vocab_dists: The vocabulary distributions. List length max_dec_steps of (batch_size, vsize) arrays.
127 | The words are in the order they appear in the vocabulary file.
128 | attn_dists: The attention distributions. List length max_dec_steps of (batch_size, attn_len) arrays
129 |
130 | Returns:
131 | final_dists: The final distributions. List length max_dec_steps of (batch_size, extended_vsize) arrays.
132 | """
133 | with tf.variable_scope('final_distribution', reuse=tf.AUTO_REUSE):
134 | # Multiply vocab dists by p_gen and attention dists by (1-p_gen)
135 | vocab_dists = gens * vocab_dists
136 | attn_dists = (1-gens) * attn_dists
137 |
138 | batch_size = tf.shape(attn_dists)[0]
139 | dec_t = tf.shape(attn_dists)[1]
140 | attn_len = tf.shape(attn_dists)[2]
141 |
142 | dec = tf.range(0, limit=dec_t) # [dec]
143 | dec = tf.expand_dims(dec, axis=-1) # [dec, 1]
144 | dec = tf.tile(dec, [1, attn_len]) # [dec, atten_len]
145 | dec = tf.expand_dims(dec, axis=0) # [1, dec, atten_len]
146 | dec = tf.tile(dec, [batch_size, 1, 1]) # [batch_size, dec, atten_len]
147 |
148 | x = tf.expand_dims(x, axis=1) # [batch_size, 1, atten_len]
149 | x = tf.tile(x, [1, dec_t, 1]) # [batch_size, dec, atten_len]
150 | x = tf.stack([dec, x], axis=3)
151 |
152 | attn_dists_projected = tf.map_fn(fn=lambda y: tf.scatter_nd(y[0], y[1], [dec_t, self.hp.vocab_size]),
153 | elems=(x, attn_dists), dtype=tf.float32)
154 |
155 | final_dists = attn_dists_projected + vocab_dists
156 |
157 | return final_dists
158 |
159 | def _calc_loss(self, targets, final_dists):
160 | """
161 | calculate loss
162 | :param targets: reference
163 | :param final_dists: transformer decoder output add by pointer generator
164 | :return: loss
165 | """
166 | with tf.name_scope('loss'):
167 | dec = tf.shape(targets)[1]
168 | batch_nums = tf.shape(targets)[0]
169 | dec = tf.range(0, limit=dec)
170 | dec = tf.expand_dims(dec, axis=0)
171 | dec = tf.tile(dec, [batch_nums, 1])
172 | indices = tf.stack([dec, targets], axis=2) # [batch_size, dec, 2]
173 |
174 | loss = tf.map_fn(fn=lambda x: tf.gather_nd(x[1], x[0]), elems=(indices, final_dists), dtype=tf.float32)
175 | loss = tf.log(0.9) - tf.log(loss)
176 |
177 | nonpadding = tf.to_float(tf.not_equal(targets, self.token2idx[""])) # 0:
178 | loss = tf.reduce_sum(loss * nonpadding) / (tf.reduce_sum(nonpadding) + 1e-7)
179 |
180 | return loss
181 |
182 | def train(self, xs, ys):
183 | """
184 | train model
185 | :param xs: dataset xs
186 | :param ys: dataset ys
187 | :return: loss
188 | train op
189 | global step
190 | tensorflow summary
191 | """
192 | tower_grads = []
193 | global_step = tf.train.get_or_create_global_step()
194 | global_step_ = global_step * self.hp.gpu_nums
195 | lr = noam_scheme(self.hp.d_model, global_step_, self.hp.warmup_steps)
196 | optimizer = tf.train.AdamOptimizer(lr)
197 | losses = []
198 | xs, ys = split_input(xs, ys, self.hp.gpu_nums)
199 | with tf.variable_scope(tf.get_variable_scope()):
200 | for no in range(self.hp.gpu_nums):
201 | with tf.device("/gpu:%d" % no):
202 | with tf.name_scope("tower_%d" % no):
203 | memory, sents1 = self.encode(xs[no])
204 | logits, y, sents2 = self.decode(xs[no], ys[no], memory)
205 | tf.get_variable_scope().reuse_variables()
206 |
207 | loss = self._calc_loss(y, logits)
208 | losses.append(loss)
209 | grads = optimizer.compute_gradients(loss)
210 | tower_grads.append(grads)
211 |
212 | with tf.device("/cpu:0"):
213 | grads = self.average_gradients(tower_grads)
214 | train_op = optimizer.apply_gradients(grads, global_step=global_step)
215 | loss = sum(losses) / len(losses)
216 | tf.summary.scalar('lr', lr)
217 | tf.summary.scalar("train_loss", loss)
218 | summaries = tf.summary.merge_all()
219 |
220 | return loss, train_op, global_step_, summaries
221 |
222 | def average_gradients(self, tower_grads):
223 | """
224 | average gradients of all gpu gradients
225 | :param tower_grads: list, each element is a gradient of gpu
226 | :return: be averaged gradient
227 | """
228 | average_grads = []
229 | for grad_and_vars in zip(*tower_grads):
230 | grads = []
231 | for g, _ in grad_and_vars:
232 | expend_g = tf.expand_dims(g, 0)
233 | grads.append(expend_g)
234 | grad = tf.concat(grads, 0)
235 | grad = tf.reduce_mean(grad, 0)
236 | v = grad_and_vars[0][1]
237 | grad_and_var = (grad, v)
238 | average_grads.append(grad_and_var)
239 |
240 | return average_grads
241 |
242 | def eval(self, xs, ys):
243 | '''Predicts autoregressively
244 | At inference, input ys is ignored.
245 | Returns
246 | y_hat: (N, T2)
247 | tensorflow summary
248 | '''
249 | # decoder_inputs sentences
250 | decoder_inputs, y, sents2 = ys
251 |
252 | # decoder_inputs shape: [batch_size, 1] [[], [], [], []]
253 | decoder_inputs = tf.ones((tf.shape(xs[0])[0], 1), tf.int32) * self.token2idx[""]
254 | ys = (decoder_inputs, y, sents2)
255 |
256 | memory, sents1 = self.encode(xs, False)
257 |
258 | y_hat = None
259 | logging.info("Inference graph is being built. Please be patient.")
260 | for _ in tqdm(range(self.hp.maxlen2)):
261 | logits, y, sents2 = self.decode(xs, ys, memory, False)
262 | y_hat = tf.to_int32(tf.argmax(logits, axis=-1))
263 |
264 | if tf.reduce_sum(y_hat, 1) == self.token2idx[""]: break
265 |
266 | _decoder_inputs = tf.concat((decoder_inputs, y_hat), 1)
267 | ys = (_decoder_inputs, y, sents2)
268 |
269 | # monitor a random sample
270 | n = tf.random_uniform((), 0, tf.shape(y_hat)[0]-1, tf.int32)
271 | sent1 = sents1[n]
272 | pred = convert_idx_to_token_tensor(y_hat[n], self.idx2token)
273 | sent2 = sents2[n]
274 |
275 | tf.summary.text("sent1", sent1)
276 | tf.summary.text("pred", pred)
277 | tf.summary.text("sent2", sent2)
278 | summaries = tf.summary.merge_all()
279 |
280 | return y_hat, summaries
--------------------------------------------------------------------------------
/modules.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 |
9 | import numpy as np
10 | import tensorflow as tf
11 |
12 | def ln(inputs, epsilon = 1e-8, scope="ln"):
13 | '''Applies layer normalization. See https://arxiv.org/abs/1607.06450.
14 | inputs: A tensor with 2 or more dimensions, where the first dimension has `batch_size`.
15 | epsilon: A floating number. A very small number for preventing ZeroDivision Error.
16 | scope: Optional scope for `variable_scope`.
17 |
18 | Returns:
19 | A tensor with the same shape and data dtype as `inputs`.
20 | '''
21 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
22 | inputs_shape = inputs.get_shape()
23 | params_shape = inputs_shape[-1:]
24 |
25 | mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
26 | beta= tf.get_variable("beta", params_shape, initializer=tf.zeros_initializer())
27 | gamma = tf.get_variable("gamma", params_shape, initializer=tf.ones_initializer())
28 | normalized = (inputs - mean) / ( (variance + epsilon) ** (.5) )
29 | outputs = gamma * normalized + beta
30 |
31 | return outputs
32 |
33 | def get_token_embeddings(vocab_size, num_units, zero_pad=True):
34 | '''Constructs token embedding matrix.
35 | Note that the column of index 0's are set to zeros.
36 | vocab_size: scalar. V.
37 | num_units: embedding dimensionalty. E.
38 | zero_pad: Boolean. If True, all the values of the first row (id = 0) should be constant zero
39 | To apply query/key masks easily, zero pad is turned on.
40 |
41 | Returns
42 | weight variable: (V, E)
43 | '''
44 | with tf.variable_scope("shared_weight_matrix", reuse=tf.AUTO_REUSE):
45 | embeddings = tf.get_variable('weight_mat',
46 | dtype=tf.float32,
47 | shape=(vocab_size, num_units),
48 | initializer=tf.contrib.layers.xavier_initializer())
49 | if zero_pad:
50 | embeddings = tf.concat((tf.zeros(shape=[1, num_units]),
51 | embeddings[1:, :]), 0)
52 | return embeddings
53 |
54 | def scaled_dot_product_attention(Q, K, V,
55 | num_heads,
56 | causality=False, dropout_rate=0.,
57 | training=True,
58 | scope="scaled_dot_product_attention"):
59 | '''See 3.2.1.
60 | Q: Packed queries. 3d tensor. [N, T_q, d_k].
61 | K: Packed keys. 3d tensor. [N, T_k, d_k].
62 | V: Packed values. 3d tensor. [N, T_k, d_v].
63 | causality: If True, applies masking for future blinding
64 | dropout_rate: A floating point number of [0, 1].
65 | training: boolean for controlling droput
66 | scope: Optional scope for `variable_scope`.
67 | '''
68 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
69 | d_k = Q.get_shape().as_list()[-1]
70 |
71 | # dot product
72 | outputs = tf.matmul(Q, tf.transpose(K, [0, 2, 1])) # (N, T_q, T_k)
73 |
74 | # scale
75 | outputs /= d_k ** 0.5
76 |
77 | # key masking, delete key 0
78 | outputs = mask(outputs, Q, K, type="key")
79 |
80 | # causality or future blinding masking
81 | if causality:
82 | outputs = mask(outputs, type="future")
83 |
84 | # softmax
85 | attn_dists = tf.nn.softmax(tf.reduce_sum(tf.split(outputs, num_heads, axis=0), axis=0))
86 | outputs = tf.nn.softmax(outputs)
87 | attention = tf.transpose(outputs, [0, 2, 1])
88 | tf.summary.image("attention", tf.expand_dims(attention[:1], -1))
89 |
90 | # query masking, delete query
91 | outputs = mask(outputs, Q, K, type="query")
92 |
93 | # dropout
94 | outputs = tf.layers.dropout(outputs, rate=dropout_rate, training=training)
95 |
96 | # weighted sum (context vectors)
97 | outputs = tf.matmul(outputs, V) # (N, T_q, d_v)
98 |
99 | return outputs, attn_dists
100 |
101 | def mask(inputs, queries=None, keys=None, type=None):
102 | """Masks paddings on keys or queries to inputs
103 | inputs: 3d tensor. (N, T_q, T_k)
104 | queries: 3d tensor. (N, T_q, d)
105 | keys: 3d tensor. (N, T_k, d)
106 |
107 | e.g.,
108 | >> queries = tf.constant([[[1.],
109 | [2.],
110 | [0.]]], tf.float32) # (1, 3, 1)
111 | >> keys = tf.constant([[[4.],
112 | [0.]]], tf.float32) # (1, 2, 1)
113 | >> inputs = tf.constant([[[4., 0.],
114 | [8., 0.],
115 | [0., 0.]]], tf.float32)
116 | >> mask(inputs, queries, keys, "key")
117 | array([[[ 4.0000000e+00, -4.2949673e+09],
118 | [ 8.0000000e+00, -4.2949673e+09],
119 | [ 0.0000000e+00, -4.2949673e+09]]], dtype=float32)
120 | >> inputs = tf.constant([[[1., 0.],
121 | [1., 0.],
122 | [1., 0.]]], tf.float32)
123 | >> mask(inputs, queries, keys, "query")
124 | array([[[1., 0.],
125 | [1., 0.],
126 | [0., 0.]]], dtype=float32)
127 | """
128 | padding_num = -2 ** 32 + 1
129 | if type in ("k", "key", "keys"):
130 | # Generate masks
131 | masks = tf.sign(tf.reduce_sum(tf.abs(keys), axis=-1)) # (N, T_k)
132 | masks = tf.expand_dims(masks, 1) # (N, 1, T_k)
133 | masks = tf.tile(masks, [1, tf.shape(queries)[1], 1]) # (N, T_q, T_k)
134 |
135 | # Apply masks to inputs
136 | paddings = tf.ones_like(inputs) * padding_num
137 |
138 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs) # (N, T_q, T_k)
139 | elif type in ("q", "query", "queries"):
140 | # Generate masks
141 | masks = tf.sign(tf.reduce_sum(tf.abs(queries), axis=-1)) # (N, T_q)
142 | masks = tf.expand_dims(masks, -1) # (N, T_q, 1)
143 | masks = tf.tile(masks, [1, 1, tf.shape(keys)[1]]) # (N, T_q, T_k)
144 |
145 | # Apply masks to inputs
146 | outputs = inputs*masks
147 | elif type in ("f", "future", "right"):
148 | diag_vals = tf.ones_like(inputs[0, :, :]) # (T_q, T_k)
149 | tril = tf.linalg.LinearOperatorLowerTriangular(diag_vals).to_dense() # (T_q, T_k)
150 | masks = tf.tile(tf.expand_dims(tril, 0), [tf.shape(inputs)[0], 1, 1]) # (N, T_q, T_k)
151 |
152 | paddings = tf.ones_like(masks) * padding_num
153 | outputs = tf.where(tf.equal(masks, 0), paddings, inputs)
154 | else:
155 | print("Check if you entered type correctly!")
156 |
157 |
158 | return outputs
159 |
160 | def multihead_attention(queries, keys, values,
161 | num_heads=8,
162 | dropout_rate=0,
163 | training=True,
164 | causality=False,
165 | scope="multihead_attention"):
166 | '''Applies multihead attention. See 3.2.2
167 | queries: A 3d tensor with shape of [N, T_q, d_model].
168 | keys: A 3d tensor with shape of [N, T_k, d_model].
169 | values: A 3d tensor with shape of [N, T_k, d_model].
170 | num_heads: An int. Number of heads.
171 | dropout_rate: A floating point number.
172 | training: Boolean. Controller of mechanism for dropout.
173 | causality: Boolean. If true, units that reference the future are masked.
174 | scope: Optional scope for `variable_scope`.
175 |
176 | Returns
177 | A 3d tensor with shape of (N, T_q, C)
178 | '''
179 | d_model = queries.get_shape().as_list()[-1]
180 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
181 | # Linear projections
182 | Q = tf.layers.dense(queries, d_model, use_bias=False) # (N, T_q, d_model)
183 | K = tf.layers.dense(keys, d_model, use_bias=False) # (N, T_k, d_model)
184 | V = tf.layers.dense(values, d_model, use_bias=False) # (N, T_k, d_model)
185 |
186 | # Split and concat
187 | Q_ = tf.concat(tf.split(Q, num_heads, axis=2), axis=0) # (h*N, T_q, d_model/h)
188 | K_ = tf.concat(tf.split(K, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
189 | V_ = tf.concat(tf.split(V, num_heads, axis=2), axis=0) # (h*N, T_k, d_model/h)
190 |
191 | # Attention
192 | outputs, attn_dists = scaled_dot_product_attention(Q_, K_, V_, num_heads, causality, dropout_rate, training)
193 |
194 | # Restore shape
195 | outputs = tf.concat(tf.split(outputs, num_heads, axis=0), axis=2) # (N, T_q, d_model)
196 | # Residual connection
197 | outputs = queries + outputs
198 |
199 | # Normalize
200 | outputs = ln(outputs)
201 |
202 | return outputs, attn_dists
203 |
204 | def ff(inputs, num_units, scope="positionwise_feedforward"):
205 | '''position-wise feed forward net. See 3.3
206 |
207 | inputs: A 3d tensor with shape of [N, T, C].
208 | num_units: A list of two integers.
209 | scope: Optional scope for `variable_scope`.
210 |
211 | Returns:
212 | A 3d tensor with the same shape and dtype as inputs
213 | '''
214 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
215 | # Inner layer
216 | outputs = tf.layers.dense(inputs, num_units[0], activation=tf.nn.relu)
217 |
218 | # Outer layer
219 | outputs = tf.layers.dense(outputs, num_units[1])
220 |
221 | # Residual connection
222 | outputs += inputs
223 |
224 | # Normalize
225 | outputs = ln(outputs)
226 |
227 | return outputs
228 |
229 | def label_smoothing(inputs, epsilon=0.1):
230 | '''Applies label smoothing. See 5.4 and https://arxiv.org/abs/1512.00567.
231 | inputs: 3d tensor. [N, T, V], where V is the number of vocabulary.
232 | epsilon: Smoothing rate.
233 |
234 | For example,
235 |
236 | ```
237 | import tensorflow as tf
238 | inputs = tf.convert_to_tensor([[[0, 0, 1],
239 | [0, 1, 0],
240 | [1, 0, 0]],
241 |
242 | [[1, 0, 0],
243 | [1, 0, 0],
244 | [0, 1, 0]]], tf.float32)
245 |
246 | outputs = label_smoothing(inputs)
247 |
248 | with tf.Session() as sess:
249 | print(sess.run([outputs]))
250 |
251 | >>
252 | [array([[[ 0.03333334, 0.03333334, 0.93333334],
253 | [ 0.03333334, 0.93333334, 0.03333334],
254 | [ 0.93333334, 0.03333334, 0.03333334]],
255 |
256 | [[ 0.93333334, 0.03333334, 0.03333334],
257 | [ 0.93333334, 0.03333334, 0.03333334],
258 | [ 0.03333334, 0.93333334, 0.03333334]]], dtype=float32)]
259 | ```
260 | '''
261 | V = tf.cast(tf.shape(inputs)[-1], tf.float32) # number of channels
262 | return ((1-epsilon) * inputs) + (epsilon / V)
263 |
264 | def positional_encoding(inputs,
265 | maxlen,
266 | masking=True,
267 | scope="positional_encoding"):
268 | '''Sinusoidal Positional_Encoding. See 3.5
269 | inputs: 3d tensor. (N, T, E)
270 | maxlen: scalar. Must be >= T
271 | masking: Boolean. If True, padding positions are set to zeros.
272 | scope: Optional scope for `variable_scope`.
273 |
274 | returns
275 | 3d tensor that has the same shape as inputs.
276 | '''
277 |
278 | E = inputs.get_shape().as_list()[-1] # static
279 | N, T = tf.shape(inputs)[0], tf.shape(inputs)[1] # dynamic
280 | with tf.variable_scope(scope, reuse=tf.AUTO_REUSE):
281 | # position indices
282 | position_ind = tf.tile(tf.expand_dims(tf.range(T), 0), [N, 1]) # (N, T)
283 |
284 | # First part of the PE function: sin and cos argument
285 | position_enc = np.array([
286 | [pos / np.power(10000, (i-i%2)/E) for i in range(E)]
287 | for pos in range(maxlen)])
288 |
289 | # Second part, apply the cosine to even columns and sin to odds.
290 | position_enc[:, 0::2] = np.sin(position_enc[:, 0::2]) # dim 2i
291 | position_enc[:, 1::2] = np.cos(position_enc[:, 1::2]) # dim 2i+1
292 | position_enc = tf.convert_to_tensor(position_enc, tf.float32) # (maxlen, E)
293 |
294 | # lookup
295 | outputs = tf.nn.embedding_lookup(position_enc, position_ind)
296 |
297 | # masks
298 | if masking:
299 | outputs = tf.where(tf.equal(inputs, 0), inputs, outputs)
300 |
301 | return tf.to_float(outputs)
302 |
303 | def noam_scheme(d_model, global_step, warmup_steps=4000.):
304 | '''Noam scheme learning rate decay
305 | d_model: encoder and decoder embedding
306 | global_step: scalar.
307 | warmup_steps: scalar. During warmup_steps, learning rate increases
308 | until it reaches init_lr.
309 | '''
310 | step = tf.cast(global_step + 1, dtype=tf.float32)
311 | return d_model ** -0.5 * tf.minimum(step * warmup_steps ** -1.5, step ** -0.5)
--------------------------------------------------------------------------------
/pred.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 | import os
9 |
10 | from beam_search import BeamSearch
11 | from data_load import _load_vocab
12 | from hparams import Hparams
13 | from model import Transformer
14 |
15 | def import_tf(device_id=-1, verbose=False):
16 | """
17 | import tensorflow, set tensorflow graph load device, set tensorflow log level, return tensorflow instance
18 | :param device_id: GPU id
19 | :param verbose: tensorflow logging level
20 | :return: tensorflow instance
21 | """
22 | # set visible gpu, -1 is cpu
23 | os.environ['CUDA_VISIBLE_DEVICES'] = '-1' if device_id < 0 else str(device_id)
24 | os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' if verbose else '3'
25 | import tensorflow as tf
26 | tf.logging.set_verbosity(tf.logging.DEBUG if verbose else tf.logging.ERROR)
27 | return tf
28 |
29 | class Prediction:
30 | def __init__(self, args):
31 | """
32 | :param model_dir: model dir path
33 | :param vocab_file: vocab file path
34 | """
35 | self.tf = import_tf(0)
36 |
37 | self.args = args
38 | self.model_dir = args.logdir
39 | self.vocab_file = args.vocab
40 | self.token2idx, self.idx2token = _load_vocab(args.vocab)
41 |
42 | hparams = Hparams()
43 | parser = hparams.parser
44 | self.hp = parser.parse_args()
45 |
46 | self.model = Transformer(self.hp)
47 |
48 | self._add_placeholder()
49 | self._init_graph()
50 |
51 | def _init_graph(self):
52 | """
53 | init graph
54 | """
55 | self.ys = (self.input_y, None, None)
56 | self.xs = (self.input_x, None)
57 | self.memory = self.model.encode(self.xs, False)[0]
58 | self.logits = self.model.decode(self.xs, self.ys, self.memory, False)[0]
59 |
60 | ckpt = self.tf.train.get_checkpoint_state(self.model_dir).all_model_checkpoint_paths[-1]
61 |
62 | graph = self.logits.graph
63 | sess_config = self.tf.ConfigProto(allow_soft_placement=True)
64 | sess_config.gpu_options.allow_growth = True
65 |
66 | saver = self.tf.train.Saver()
67 | self.sess = self.tf.Session(config=sess_config, graph=graph)
68 |
69 | self.sess.run(self.tf.global_variables_initializer())
70 | self.tf.reset_default_graph()
71 | saver.restore(self.sess, ckpt)
72 |
73 | self.bs = BeamSearch(self.model,
74 | self.hp.beam_size,
75 | list(self.idx2token.keys())[2],
76 | list(self.idx2token.keys())[3],
77 | self.idx2token,
78 | self.hp.maxlen2,
79 | self.input_x,
80 | self.input_y,
81 | self.logits)
82 |
83 | def predict(self, content):
84 | """
85 | abstract prediction by beam search
86 | :param content: article content
87 | :return: prediction result
88 | """
89 | input_x = list(content)
90 | while len(input_x) < self.args.maxlen1: input_x.append('')
91 | input_x = input_x[:self.args.maxlen1]
92 |
93 | input_x = [self.token2idx.get(s, self.token2idx['']) for s in input_x]
94 |
95 | memory = self.sess.run(self.memory, feed_dict={self.input_x: [input_x]})
96 |
97 | return self.bs.search(self.sess, input_x, memory[0])
98 |
99 | def _add_placeholder(self):
100 | """
101 | add tensorflow placeholder
102 | """
103 | self.input_x = self.tf.placeholder(dtype=self.tf.int32, shape=[None, self.args.maxlen1], name='input_x')
104 | self.input_y = self.tf.placeholder(dtype=self.tf.int32, shape=[None, None], name='input_y')
105 |
106 | if __name__ == '__main__':
107 | hparams = Hparams()
108 | parser = hparams.parser
109 | hp = parser.parse_args()
110 | preds = Prediction(hp)
111 | content = '2014年,51信用卡管家跟宜信等P2P公司合作,推出线上信贷产品“瞬时贷”,其是一种纯在线操作的信贷模式。51信用卡管家创始人孙海涛说,51目前每天放贷1000万,预计2015年,自营产品加上瞬>时贷,放贷额度将远超'
112 | result = preds.predict(content)
113 | for res in result:
114 | print(res)
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tensorflow==1.15.4
2 | tqdm>=4.28.1
3 | jieba>=0.3x
4 | sumeval>=0.2.0
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 |
9 | import logging
10 | import os
11 |
12 | from sumeval.metrics.rouge import RougeCalculator
13 | from tqdm import tqdm
14 |
15 | from beam_search import BeamSearch
16 | from data_load import get_batch, _load_vocab
17 | from hparams import Hparams
18 | from model import Transformer
19 | from utils import save_hparams, save_variable_specs, get_hypotheses, calc_rouge, import_tf
20 |
21 | logging.basicConfig(level=logging.INFO)
22 |
23 | rouge = RougeCalculator(stopwords=True, lang="zh")
24 |
25 | logging.info("# hparams")
26 | hparams = Hparams()
27 | parser = hparams.parser
28 | hp = parser.parse_args()
29 |
30 | # import tensorflow
31 | gpu_list = [str(i) for i in list(range(hp.gpu_nums))]
32 | tf = import_tf(gpu_list)
33 |
34 | save_hparams(hp, hp.logdir)
35 |
36 | logging.info("# Prepare train/eval batches")
37 | train_batches, num_train_batches, num_train_samples = get_batch(hp.train,
38 | hp.maxlen1,
39 | hp.maxlen2,
40 | hp.vocab,
41 | hp.batch_size,
42 | hp.gpu_nums,
43 | shuffle=True)
44 |
45 | eval_batches, num_eval_batches, num_eval_samples = get_batch(hp.eval,
46 | hp.maxlen1,
47 | hp.maxlen2,
48 | hp.vocab,
49 | hp.eval_batch_size,
50 | hp.gpu_nums,
51 | shuffle=False)
52 |
53 | handle = tf.placeholder(tf.string, shape=[])
54 | iter = tf.data.Iterator.from_string_handle(
55 | handle, train_batches.output_types, train_batches.output_shapes)
56 |
57 | # create a iter of the correct shape and type
58 | xs, ys = iter.get_next()
59 |
60 | logging.info('# init data')
61 | training_iter = train_batches.make_one_shot_iterator()
62 | val_iter = eval_batches.make_initializable_iterator()
63 |
64 | logging.info("# Load model")
65 | m = Transformer(hp)
66 |
67 | # get op
68 | loss, train_op, global_step, train_summaries = m.train(xs, ys)
69 | y_hat, eval_summaries = m.eval(xs, ys)
70 |
71 | token2idx, idx2token = _load_vocab(hp.vocab)
72 |
73 | bs = BeamSearch(m, hp.beam_size, list(idx2token.keys())[2], list(idx2token.keys())[3], idx2token, hp.maxlen2, m.x,
74 | m.decoder_inputs, m.logits)
75 |
76 | logging.info("# Session")
77 | saver = tf.train.Saver(max_to_keep=hp.num_epochs)
78 | with tf.Session(config=tf.ConfigProto(allow_soft_placement=True)) as sess:
79 | ckpt = tf.train.latest_checkpoint(hp.logdir)
80 | if ckpt is None:
81 | logging.info("Initializing from scratch")
82 | sess.run(tf.global_variables_initializer())
83 | save_variable_specs(os.path.join(hp.logdir, "specs"))
84 | else:
85 | saver.restore(sess, ckpt)
86 |
87 | summary_writer = tf.summary.FileWriter(hp.logdir, sess.graph)
88 |
89 | # Iterator.string_handle() get a tensor that can be got value to feed handle placeholder
90 | training_handle = sess.run(training_iter.string_handle())
91 | val_handle = sess.run(val_iter.string_handle())
92 |
93 | total_steps = hp.num_epochs * num_train_batches
94 | _gs = sess.run(global_step)
95 | for i in tqdm(range(_gs, total_steps+1)):
96 | _, _gs, _summary = sess.run([train_op, global_step, train_summaries], feed_dict={handle: training_handle})
97 | summary_writer.add_summary(_summary, _gs)
98 | if _gs % (hp.gpu_nums * 5000) == 0 and _gs != 0:
99 | logging.info("steps {} is done".format(_gs))
100 |
101 | logging.info("# test evaluation")
102 | sess.run(val_iter.initializer) # initial val dataset
103 | _eval_summaries = sess.run(eval_summaries, feed_dict={handle: val_handle})
104 | summary_writer.add_summary(_eval_summaries, _gs)
105 |
106 | logging.info("# beam search")
107 | hypotheses, all_targets = get_hypotheses(num_eval_batches, num_eval_samples, sess, m, bs, [xs[0], ys[2]],
108 | handle, val_handle)
109 |
110 | logging.info("# calc rouge score ")
111 | if not os.path.exists(hp.evaldir): os.makedirs(hp.evaldir)
112 | rouge_l = calc_rouge(rouge, all_targets, hypotheses, _gs, hp.evaldir)
113 |
114 | model_output = "trans_pointer%02dL%.2f" % (_gs, rouge_l)
115 |
116 | logging.info('# write hypotheses')
117 | with open(os.path.join(hp.evaldir, model_output), 'w', encoding='utf-8') as f:
118 | for target, hypothes in zip(all_targets, hypotheses):
119 | f.write('{}-{} \n'.format(target, ' '.join(hypothes)))
120 |
121 | logging.info("# save models")
122 |
123 | ckpt_name = os.path.join(hp.logdir, model_output)
124 | saver.save(sess, ckpt_name, global_step=_gs)
125 | logging.info("after training of {} steps, {} has been saved.".format(_gs, ckpt_name))
126 |
127 | logging.info("# fall back to train mode")
128 | summary_writer.close()
129 |
130 | logging.info("Done")
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #/usr/bin/python3
3 | '''
4 | date: 2019/5/21
5 | mail: cally.maxiong@gmail.com
6 | page: http://www.cnblogs.com/callyblog/
7 | '''
8 | import json
9 | import logging
10 | import os
11 | from tqdm import tqdm
12 |
13 | logging.basicConfig(level=logging.INFO)
14 |
15 | def calc_num_batches(total_num, batch_size):
16 | '''Calculates the number of batches.
17 | total_num: total sample number
18 | batch_size
19 |
20 | Returns
21 | number of batches, allowing for remainders.'''
22 | return total_num // batch_size + int(total_num % batch_size != 0)
23 |
24 | def convert_idx_to_token_tensor(inputs, idx2token):
25 | '''Converts int32 tensor to string tensor.
26 | inputs: 1d int32 tensor. indices.
27 | idx2token: dictionary
28 |
29 | Returns
30 | 1d string tensor.
31 | '''
32 | import tensorflow as tf
33 | def my_func(inputs):
34 | return " ".join(idx2token[elem] for elem in inputs)
35 |
36 | return tf.py_func(my_func, [inputs], tf.string)
37 |
38 | def postprocess(hypotheses):
39 | '''Processes translation outputs.
40 | hypotheses: list of encoded predictions
41 | idx2token: dictionary
42 |
43 | Returns
44 | processed hypotheses
45 | '''
46 | _hypotheses = []
47 | for h in hypotheses:
48 | h = str(h)
49 | h = h.replace('', '')
50 | h = h.replace('', '')
51 | h = h.replace('', '')
52 | _hypotheses.append(h)
53 |
54 | return _hypotheses
55 |
56 | def save_hparams(hparams, path):
57 | '''Saves hparams to path
58 | hparams: argsparse object.
59 | path: output directory.
60 |
61 | Writes
62 | hparams as literal dictionary to path.
63 | '''
64 | if not os.path.exists(path): os.makedirs(path)
65 | hp = json.dumps(vars(hparams))
66 | with open(os.path.join(path, "hparams"), 'w') as fout:
67 | fout.write(hp)
68 |
69 | def load_hparams(parser, path):
70 | '''Loads hparams and overrides parser
71 | parser: argsparse parser
72 | path: directory or file where hparams are saved
73 | '''
74 | if not os.path.isdir(path):
75 | path = os.path.dirname(path)
76 | d = open(os.path.join(path, "hparams"), 'r').read()
77 | flag2val = json.loads(d)
78 | for f, v in flag2val.items():
79 | parser.f = v
80 |
81 | def save_variable_specs(fpath):
82 | '''Saves information about variables such as
83 | their name, shape, and total parameter number
84 | fpath: string. output file path
85 |
86 | Writes
87 | a text file named fpath.
88 | '''
89 | import tensorflow as tf
90 | def _get_size(shp):
91 | '''Gets size of tensor shape
92 | shp: TensorShape
93 |
94 | Returns
95 | size
96 | '''
97 | size = 1
98 | for d in range(len(shp)):
99 | size *=shp[d]
100 | return size
101 |
102 | params, num_params = [], 0
103 | for v in tf.global_variables():
104 | params.append("{}==={}".format(v.name, v.shape))
105 | num_params += _get_size(v.shape)
106 | print("num_params: ", num_params)
107 | with open(fpath, 'w') as fout:
108 | fout.write("num_params: {}\n".format(num_params))
109 | fout.write("\n".join(params))
110 | logging.info("Variables info has been saved.")
111 |
112 | def get_hypotheses(num_batches, num_samples, sess, model, beam_search, tensor, handle_placehoder, handle):
113 | '''Gets hypotheses.
114 | num_batches: scalar.
115 | num_samples: scalar.
116 | sess: tensorflow sess object
117 | tensor: target tensor to fetch
118 | dict: idx2token dictionary
119 |
120 | Returns
121 | hypotheses: list of sents
122 | '''
123 | hypotheses, all_targets = [], []
124 | for _ in tqdm(range(num_batches)):
125 | articles, targets = sess.run(tensor, feed_dict={handle_placehoder: handle})
126 | memories = sess.run(model.enc_output, feed_dict={model.x: articles})
127 | for article, memory in zip(articles, memories):
128 | summary = beam_search.search(sess, article, memory)
129 | summary = postprocess(summary)
130 | hypotheses.append(summary)
131 | all_targets.extend([target.decode('utf-8') for target in targets])
132 |
133 | return hypotheses[:num_samples], all_targets[:num_samples]
134 |
135 | def calc_rouge(rouge, references, models, global_step, logdir):
136 | """
137 | calculate rouge score
138 | :param references: reference sentences
139 | :param models: model sentences
140 | :param global_step: global step
141 | :param logdir: log dir
142 | :return: rouge score
143 | """
144 | # delete symbol
145 | references = [reference.replace('', '') for reference in references]
146 |
147 | # calculate rouge score
148 | rouge1_scores = [_rouge(rouge, model, reference, type='rouge1') for model, reference in zip(models, references)]
149 | rouge2_scores = [_rouge(rouge, model, reference, type='rouge2') for model, reference in zip(models, references)]
150 | rougel_scores = [_rouge(rouge, model, reference, type='rougel') for model, reference in zip(models, references)]
151 |
152 | # get rouge score
153 | rouge1_score = sum(rouge1_scores) / len(rouge1_scores)
154 | rouge2_score = sum(rouge2_scores) / len(rouge2_scores)
155 | rougel_score = sum(rougel_scores) / len(rouge2_scores)
156 |
157 | # write result
158 | with open(os.path.join(logdir, 'rouge'), 'a', encoding='utf-8') as f:
159 | f.write('global step: {}, ROUGE 1: {}, ROUGE 2: {}, ROUGE L: {}\n'.format(str(global_step), str(rouge1_score),
160 | str(rouge2_score), str(rougel_score)))
161 | return rouge1_score
162 |
163 | def _rouge(rouge, model, reference, type='rouge1'):
164 | """
165 | calculate rouge socore
166 | :param rouge: sumeval instance
167 | :param model: model prediction, list
168 | :param reference: reference
169 | :param type: rouge1, rouge2, rougel
170 | :return: rouge 1 score
171 | """
172 | scores = None
173 | if type == 'rouge1':
174 | scores = [rouge.rouge_n(summary=m, references=reference, n=1) for m in model]
175 |
176 | if type == 'rouge2':
177 | scores = [rouge.rouge_n(summary=m, references=reference, n=2) for m in model]
178 |
179 | if type == 'rougel':
180 | scores = [rouge.rouge_l(summary=m, references=reference) for m in model]
181 |
182 | return max(scores)
183 |
184 | def import_tf(gpu_list):
185 | """
186 | import tensorflow, set tensorflow graph load device
187 | :param gpu_list: GPU list
188 | :return: tensorflow instance
189 | """
190 | import tensorflow as tf
191 | os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(gpu_list)
192 |
193 | return tf
194 |
195 | def split_input(xs, ys, gpu_nums):
196 | """
197 | split input
198 | :param xs: articles
199 | :param ys: summaries
200 | :param gpu_nums: gpu numbers
201 | :return: split input by gpu numbers
202 | """
203 | import tensorflow as tf
204 | xs = [tf.split(x, num_or_size_splits=gpu_nums, axis=0) for x in xs]
205 | ys = [tf.split(y, num_or_size_splits=gpu_nums, axis=0) for y in ys]
206 |
207 | return [(xs[0][i], xs[1][i]) for i in range(gpu_nums)], [(ys[0][i], ys[1][i], ys[2][i]) for i in range(gpu_nums)]
--------------------------------------------------------------------------------