├── .gitignore
├── .idea
├── DeepQoE_pytorch.iml
├── dbnavigator.xml
├── misc.xml
├── modules.xml
├── vcs.xml
└── workspace.xml
├── DeepQoE
├── __init__.py
├── config.py
├── data_loader.py
├── glove_dict.py
└── nets.py
├── README.md
├── cal_corre.py
├── data
└── __init__.py
├── demo.py
├── log
└── __init__.py
├── models
├── C3D
│ └── __init__.py
├── GloVe
│ └── __init__.py
├── LIVE_NFLX
│ └── __init__.py
├── __init__.py
└── protect
│ └── __init__.py
├── scripts
├── __init__.py
├── download_glove_6B.sh
├── generate_pretrained_models.py
└── min_max.py
├── test.py
├── train_test.py
├── train_test_MOS.py
├── train_test_classifiers.py
├── train_test_streaming.py
└── train_test_video.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 |
49 | # Translations
50 | *.mo
51 | *.pot
52 |
53 | # Django stuff:
54 | *.log
55 | .static_storage/
56 | .media/
57 | local_settings.py
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
106 | *.pkl
107 | *.pt
108 | *.zip
109 | *.txt
110 | *.mat
111 | scripts/preprocess_data.py
112 |
--------------------------------------------------------------------------------
/.idea/DeepQoE_pytorch.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/.idea/dbnavigator.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 |
439 |
440 |
441 |
442 |
443 |
444 |
445 |
446 |
447 |
448 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/workspace.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
35 |
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 |
44 |
45 |
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 |
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 |
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
123 |
124 |
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 |
135 |
136 |
137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 |
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 |
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 |
174 |
175 |
176 |
177 |
178 |
179 |
180 |
181 |
182 |
183 |
184 |
185 |
186 |
187 |
188 |
189 |
190 |
191 |
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
202 |
203 |
204 |
205 |
206 |
207 |
208 |
209 |
210 |
211 |
212 |
213 |
214 |
215 |
216 |
217 |
218 |
219 |
220 |
221 |
222 |
223 |
224 |
225 |
226 |
227 |
228 |
229 |
230 |
231 |
232 |
233 |
234 |
235 |
236 |
237 |
238 |
239 |
240 |
241 |
242 |
243 |
244 |
245 |
246 |
247 |
248 |
249 |
250 |
251 |
252 |
253 |
254 |
255 |
256 |
257 |
258 |
259 |
260 |
261 |
262 |
263 |
264 |
265 |
266 |
267 |
268 |
269 |
270 |
271 |
272 |
273 |
274 |
275 |
276 |
277 |
278 |
279 |
280 |
281 |
282 |
283 |
284 |
285 |
286 |
287 |
288 |
289 |
290 |
291 |
292 |
293 |
294 |
295 |
296 |
297 |
298 |
299 |
300 |
301 |
302 |
303 |
304 |
305 |
306 |
307 |
308 |
309 |
310 |
311 |
312 |
313 |
314 |
315 |
316 |
317 |
318 |
319 |
320 |
321 |
322 |
323 |
324 |
325 |
326 |
327 |
328 |
329 |
330 |
331 |
332 |
333 |
334 |
335 |
336 |
337 |
338 |
339 |
340 |
341 |
342 |
343 |
344 |
345 |
346 |
347 |
348 |
349 |
350 |
351 |
352 |
353 |
354 |
355 |
356 |
357 |
358 |
359 |
360 |
361 |
362 |
363 |
364 |
365 |
366 |
367 |
368 |
369 |
370 |
371 |
372 |
373 |
374 |
375 |
376 |
377 |
378 |
379 |
380 |
381 |
382 |
383 |
384 |
385 |
386 |
387 |
388 |
389 |
390 |
391 |
392 |
393 |
394 |
395 |
396 |
397 |
398 |
399 |
400 |
401 |
402 |
403 |
404 |
405 |
406 |
407 |
408 |
409 |
410 |
411 |
412 |
413 |
414 |
415 |
416 |
417 |
418 |
419 |
420 |
421 |
422 |
423 |
424 |
425 |
426 |
427 |
428 |
429 |
430 |
431 |
432 |
433 |
434 |
435 |
436 |
437 |
438 | 1521966246895
439 |
440 |
441 | 1521966246895
442 |
443 |
444 |
445 |
446 |
447 |
448 |
449 |
450 |
451 |
452 |
453 |
454 |
455 |
456 |
457 |
458 |
459 |
460 |
461 |
462 |
463 |
464 |
465 |
466 |
467 |
468 |
469 |
470 |
471 |
472 |
473 |
474 |
475 |
476 |
477 |
478 |
479 |
480 |
481 |
482 |
483 |
484 |
485 |
486 |
487 |
488 |
489 |
490 |
491 |
492 |
493 |
494 |
495 |
496 |
497 |
498 |
499 |
500 |
501 |
502 |
503 |
504 |
505 |
506 |
507 |
508 |
509 |
510 |
511 |
512 |
513 |
514 |
515 |
516 |
517 |
518 |
519 |
520 |
521 |
522 |
523 |
524 |
525 |
526 |
527 |
528 |
529 |
530 |
531 |
532 |
533 |
534 |
535 |
536 |
537 |
538 |
539 |
540 |
541 |
542 |
543 |
544 |
545 |
546 |
547 |
548 |
549 |
550 |
551 |
552 |
553 |
554 |
555 |
556 |
557 |
558 |
559 |
560 |
561 |
562 |
563 |
564 |
565 |
566 |
567 |
568 |
569 |
570 |
571 |
572 |
573 |
574 |
575 |
576 |
577 |
578 |
579 |
580 |
581 |
582 |
583 |
584 |
585 |
586 |
587 |
588 |
589 |
590 |
591 |
592 |
593 |
594 |
595 |
596 |
597 |
598 |
599 |
600 |
601 |
602 |
603 |
604 |
605 |
606 |
607 |
608 |
609 |
610 |
611 |
612 |
613 |
614 |
615 |
616 |
617 |
618 |
619 |
620 |
621 |
622 |
623 |
624 |
625 |
626 |
627 |
628 |
629 |
630 |
631 |
632 |
633 |
634 |
635 |
636 |
637 |
638 |
639 |
640 |
641 |
642 |
643 |
644 |
645 |
646 |
647 |
648 |
649 |
650 |
651 |
652 |
653 |
654 |
655 |
656 |
657 |
658 |
659 |
660 |
661 |
662 |
663 |
664 |
665 |
666 |
667 |
668 |
669 |
670 |
671 |
672 |
673 |
674 |
675 |
676 |
677 |
678 |
679 |
680 |
681 |
682 |
683 |
684 |
685 |
686 |
687 |
688 |
689 |
690 |
691 |
692 |
693 |
694 |
695 |
696 |
697 |
698 |
699 |
700 |
701 |
702 |
703 |
704 |
705 |
706 |
707 |
708 |
709 |
710 |
711 |
712 |
713 |
714 |
715 |
716 |
717 |
718 |
719 |
720 |
721 |
722 |
723 |
724 |
725 |
726 |
727 |
728 |
729 |
730 |
731 |
732 |
733 |
734 |
735 |
736 |
737 |
738 |
739 |
740 |
741 |
742 |
743 |
744 |
745 |
746 |
747 |
748 |
749 |
750 |
751 |
752 |
753 |
754 |
755 |
756 |
757 |
758 |
759 |
760 |
761 |
762 |
763 |
764 |
765 |
766 |
767 |
768 |
769 |
770 |
771 |
772 |
773 |
774 |
775 |
776 |
777 |
778 |
779 |
780 |
781 |
782 |
783 |
784 |
785 |
786 |
787 |
788 |
789 |
790 |
791 |
792 |
793 |
794 |
795 |
796 |
797 |
798 |
799 |
800 |
801 |
802 |
803 |
804 |
805 |
806 |
807 |
808 |
809 |
810 |
811 |
812 |
813 |
814 |
815 |
816 |
817 |
818 |
819 |
820 |
821 |
822 |
823 |
824 |
825 |
826 |
827 |
828 |
829 |
830 |
831 |
832 |
833 |
834 |
835 |
836 |
837 |
838 |
839 |
840 |
841 |
842 |
843 |
844 |
845 |
846 |
847 |
848 |
849 |
850 |
851 |
852 |
853 |
854 |
855 |
856 |
857 |
858 |
859 |
860 |
861 |
862 |
863 |
864 |
865 |
866 |
867 |
868 |
869 |
870 |
871 |
872 |
873 |
874 |
875 |
876 |
877 |
878 |
879 |
880 |
881 |
882 |
883 |
884 |
885 |
886 |
887 |
888 |
889 |
890 |
891 |
892 |
893 |
894 |
895 |
896 |
897 |
898 |
899 |
900 |
901 |
902 |
903 |
904 |
905 |
906 |
907 |
908 |
909 |
910 |
911 |
912 |
913 |
914 |
915 |
916 |
--------------------------------------------------------------------------------
/DeepQoE/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 28/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py
--------------------------------------------------------------------------------
/DeepQoE/config.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 28/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : config.py
7 | import argparse
8 | from easydict import EasyDict as edict
9 |
10 | from sklearn import svm
11 | from sklearn.tree import DecisionTreeClassifier
12 | from sklearn.ensemble import RandomForestClassifier, AdaBoostClassifier
13 | from sklearn.naive_bayes import GaussianNB
14 |
15 |
16 | __C = edict()
17 | cfg = __C
18 |
19 | __C.EMBEDDING_DATA = 'data/train_test_glove.6B.50d.pkl'
20 | __C.MOS_DATA = 'data/MOS_train_data.pkl'
21 | __C.JND_DATA = 'data/videos_c3d_features_labels.pkl'
22 | __C.NFLX_DATA = 'data/LIVE_NFLX_PublicData_VideoATLAS_Release'
23 |
24 | __C.TRAIN_RATIO = 0.9
25 |
26 | __C.DIMENSION = 60
27 | __C.LOG = 'bit'
28 |
29 |
30 | __C.MODEL_SAVE_TEXT = 'models/GloVe/QoE_score.pt'
31 | __C.MODEL_SAVE_MOS = 'models/GloVe/QoE_MOS_sport.pt'
32 | __C.MODEL_SAVE_VIDEO = 'models/C3D/QoE_JND.pt'
33 |
34 | __C.CLASSIFIER_NAME = ['SVM', 'Decision Tree', 'Random Forest', 'AdaBoost', 'Naive Bayes']
35 | __C.CLASSIFIER = [svm.SVC(decision_function_shape='ovo'),
36 | DecisionTreeClassifier(max_depth=5),
37 | RandomForestClassifier(max_depth=5, n_estimators=10, max_features=3),
38 | AdaBoostClassifier(),
39 | GaussianNB()]
40 |
41 |
42 | __C.POOLING_TYPE = 'mean'
43 | __C.QUALITY_MODEL = 'STRRED'
44 | __C.PREPROC = False
45 | __C.FEATURE_NAMES = ["VQA", "R$_1$", "R$_2$", "M", "I"]
46 | __C.DB_PATH = 'data/LIVE_NFLX_PublicData_VideoATLAS_Release/'
47 | __C.TRAIN_TEST_NFLX = 'TrainingMatrix_LIVENetflix_1000_trials'
48 |
49 | def parse_arguments(argv):
50 | parser = argparse.ArgumentParser(description='DeepQoE training')
51 | parser.add_argument('--model_number', type=int, choices=[0, 1, 2, 3], default=0)
52 | parser.add_argument('--batch_size', type=int, default=64,
53 | help='input batch size for training (default: 64)')
54 | parser.add_argument('--epochs', type=int, default=500,
55 | help='number of epochs to train (default: 10)')
56 | parser.add_argument('--lr', type=float, default=0.01,
57 | help='learning rate (default: 0.01)')
58 | parser.add_argument('--momentum', type=float, default=0.5,
59 | help='SGD momentum (default: 0.5)')
60 | parser.add_argument('--use_gpu', action='store_true', default=True,
61 | help='use GPU')
62 | parser.add_argument('--gpu_id', type=int, default=1,
63 | help='selected a gpu')
64 | parser.add_argument('--classifier', type=int, default=1,
65 | help='selected a classifier')
66 | return parser.parse_args(argv)
67 |
--------------------------------------------------------------------------------
/DeepQoE/data_loader.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 7/9/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : data_loader.py
7 |
8 | import torch
9 | import numpy as np
10 | from torch.utils.data import Dataset
11 | import pickle
12 |
13 |
14 | class QoETextDataset(Dataset):
15 | def __init__(self, x, y):
16 | self.x, self.y = x, y
17 |
18 | def __len__(self):
19 | return len(self.x)
20 |
21 | def __getitem__(self, idx):
22 | glove_index = torch.from_numpy(self.x[idx, 0:50].astype(np.float)).float()
23 | res = self.x[idx, [-4]].astype(np.int)
24 | bitrate = torch.from_numpy(self.x[idx, [-3]].astype(np.float)).float()
25 | gender = self.x[idx, [-2]].astype(np.int)
26 | age = torch.from_numpy(self.x[idx, [-1]].astype(np.float)).float()
27 |
28 | label = self.y[idx]
29 | sample = {'glove': glove_index, 'res': res, 'bitrate': bitrate,
30 | 'gender': gender, 'age': age, 'label': label}
31 |
32 | return sample
33 |
34 |
35 | class QoENFLXDataset(Dataset):
36 | def __init__(self, x, y):
37 | self.x, self.y = x, y
38 |
39 | def __len__(self):
40 | return len(self.x)
41 |
42 | def __getitem__(self, idx):
43 | VQA = torch.from_numpy(self.x[idx, [0]].astype(np.float)).float()
44 | R1 = torch.from_numpy(self.x[idx, [1]].astype(np.float)).float()
45 | R2 = self.x[idx, [2]].astype(np.int)
46 | M = torch.from_numpy(self.x[idx, [3]].astype(np.float)).float()
47 | I = torch.from_numpy(self.x[idx, [4]].astype(np.float)).float()
48 | label = self.y[idx]
49 | sample = {'VQA': VQA, 'R1': R1, 'R2': R2, 'Mem': M, 'Impair': I, 'label': label}
50 |
51 | return sample
52 |
53 |
54 |
55 |
56 |
57 | class QoEMOSDataset(Dataset):
58 | def __init__(self, x, y):
59 | self.x, self.y = x, y
60 |
61 | def __len__(self):
62 | return len(self.x)
63 |
64 | def __getitem__(self, idx):
65 | glove_index = torch.from_numpy(self.x[idx, 0:50].astype(np.float)).float()
66 | res = self.x[idx, [-2]].astype(np.int)
67 | bitrate = torch.from_numpy(self.x[idx, [-1]].astype(np.float)).float()
68 | label = self.y[idx]
69 | sample = {'glove': glove_index, 'res': res, 'bitrate': bitrate, 'label': label}
70 |
71 | return sample
72 |
73 | class QoEVideoDataset(Dataset):
74 | def __init__(self, x, y):
75 | self.x, self.y = x, y
76 |
77 | def __len__(self):
78 | return len(self.x)
79 |
80 | def __getitem__(self, idx):
81 | video = torch.from_numpy(self.x[idx, 0:-1].astype(np.float)).float()
82 | res = self.x[idx, [-1]].astype(np.int)
83 | label = self.y[idx]
84 | sample = {'video': video, 'res': res, 'label': label}
85 |
86 | return sample
87 |
88 |
--------------------------------------------------------------------------------
/DeepQoE/glove_dict.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 28/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : glove_dict.py
7 |
8 | from __future__ import division
9 | from __future__ import print_function
10 |
11 | import os
12 | from torchtext import vocab
13 | from collections import Counter
14 |
15 | from config import cfg
16 |
17 |
18 | def get_dicts_base(words=['hello', 'world'], glove_path='../models/GloVe/glove.6B.50d.txt'):
19 | if not os.path.isfile(glove_path):
20 | print("Please use script to download pre-trained glove models...")
21 | return False
22 | else:
23 | vectors = {}
24 | dicts = {}
25 | with open(glove_path, 'r') as f:
26 | for line in f:
27 | vals = line.rstrip().split(' ')
28 | vectors[vals[0]] = [float(x) for x in vals[1:]]
29 | for i in words:
30 | dicts[i] = vectors[i]
31 | return dicts
32 |
33 |
34 | def get_dicts(words=['hello', 'world'], glove='glove.6B.50d'):
35 | c = Counter(words)
36 | v = vocab.Vocab(c, vectors=glove)
37 | dicts = {}
38 | for i in words:
39 | dicts[i] = v.vectors.numpy()[v.stoi[i]]
40 | return dicts
41 |
42 |
43 | if __name__ == '__main__':
44 | get_dicts()
45 |
--------------------------------------------------------------------------------
/DeepQoE/nets.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 29/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : nets.py
7 |
8 | import torch
9 | import torch.nn as nn
10 | import torch.nn.functional as F
11 | from DeepQoE.config import cfg
12 |
13 | class HybridNN(nn.Module):
14 | def __init__(self):
15 | super(HybridNN, self).__init__()
16 | self.layer1_glove = nn.Linear(50, 5)
17 | self.layer1_res = nn.Embedding(3, 8)
18 | self.layer1_bit = nn.Linear(1, cfg.DIMENSION)
19 | self.layer1_gender = nn.Embedding(2, 1)
20 | self.layer1_age = nn.Linear(1, 1)
21 |
22 | # self.bn_0 = nn.BatchNorm2d(16)
23 | # self.conv1 = nn.Conv1d(1, 1, 1, stride=1)
24 |
25 | self.fc1 = nn.Linear(15+cfg.DIMENSION, 512)
26 |
27 | # self.bn_1 = nn.BatchNorm2d(512)
28 |
29 | self.fc2 = nn.Linear(512, 256)
30 | # self.bn_2 = nn.BatchNorm2d(256)
31 |
32 | self.fc3 = nn.Linear(256, 5)
33 |
34 | def forward(self, x1, x2, x3, x4, x5):
35 | x_res = self.layer1_res(x2).view(-1, 8)
36 | x_gender = self.layer1_gender(x4).view(-1, 1)
37 | h = torch.cat((self.layer1_glove(x1), x_res,
38 | self.layer1_bit(x3), x_gender,
39 | self.layer1_age(x5)), 1)
40 |
41 | # h = torch.stack([h], dim=1)
42 | # h = self.conv1(h)
43 | # h = torch.squeeze(h)
44 | # h = F.tanh(h)
45 |
46 | h = F.tanh(self.fc1(h))
47 | h = F.dropout(h, p=0.5, training=self.training)
48 | fc2 = F.tanh(self.fc2(h))
49 | h = F.dropout(fc2, p=0.5, training=self.training)
50 | h = F.log_softmax(self.fc3(h))
51 | return h, fc2
52 |
53 |
54 | class C3DHybridNN(nn.Module):
55 | def __init__(self):
56 | super(C3DHybridNN, self).__init__()
57 | self.layer1_content = nn.Linear(4096, 4088)
58 | self.layer1_res = nn.Embedding(4, 8)
59 |
60 | self.fc1 = nn.Linear(4096, 1024)
61 | self.fc2 = nn.Linear(1024, 1024)
62 | self.fc3 = nn.Linear(1024, 512)
63 | self.fc4 = nn.Linear(512, 4)
64 |
65 | def forward(self, x1, x2):
66 | x_res = self.layer1_res(x2).view(-1, 8)
67 | h = torch.cat((self.layer1_content(x1), x_res), 1)
68 | h = F.tanh(h)
69 | h = F.tanh(self.fc1(h))
70 | h = F.dropout(h, p=0.5, training=self.training)
71 | h = F.tanh(self.fc2(h))
72 | h = F.dropout(h, p=0.5, training=self.training)
73 | fc3 = F.tanh(self.fc3(h))
74 | h = F.dropout(fc3, p=0.5, training=self.training)
75 | h = F.log_softmax(self.fc4(h))
76 | return h, fc3
77 |
78 |
79 | class HybridRR(nn.Module):
80 | def __init__(self):
81 | super(HybridRR, self).__init__()
82 | self.layer1_glove = nn.Linear(50, 5)
83 | self.layer1_res = nn.Embedding(3, 8)
84 | self.layer1_bit = nn.Linear(1, 1)
85 |
86 | self.fc1 = nn.Linear(14, 512)
87 | self.fc2 = nn.Linear(512, 256)
88 | self.fc3 = nn.Linear(256, 1)
89 |
90 | def forward(self, x1, x2, x3):
91 | x_res = self.layer1_res(x2).view(-1, 8)
92 | h = torch.cat((self.layer1_glove(x1), x_res,
93 | self.layer1_bit(x3)), 1)
94 | h = F.relu(self.fc1(h))
95 | h = F.dropout(h, p=0.5, training=self.training)
96 | fc2 = F.relu(self.fc2(h))
97 | h = F.dropout(fc2, p=0.5, training=self.training)
98 | h = self.fc3(h)
99 | return h, fc2
100 |
101 |
102 | class HybridStreaming(nn.Module):
103 | def __init__(self):
104 | super(HybridStreaming, self).__init__()
105 | self.layer1_VQA = nn.Linear(1, 20)
106 | self.layer1_R1 = nn.Linear(1, 5)
107 | self.layer1_R2 = nn.Embedding(3, 5)
108 | self.layer1_M = nn.Linear(1, 10)
109 | self.layer1_I = nn.Linear(1, 10)
110 |
111 | self.fc1 = nn.Linear(50, 512)
112 | self.fc2 = nn.Linear(512, 256)
113 | self.fc3 = nn.Linear(256, 1)
114 |
115 | def forward(self, x1, x2, x3, x4, x5):
116 | x_R2 = self.layer1_R2(x3).view(-1, 5)
117 | h = torch.cat((self.layer1_VQA(x1), self.layer1_R1(x2), x_R2,
118 | self.layer1_M(x4), self.layer1_I(x5)), 1)
119 | h = F.relu(self.fc1(h))
120 | h = F.dropout(h, p=0.5, training=self.training)
121 | fc2 = F.relu(self.fc2(h))
122 | h = F.dropout(fc2, p=0.5, training=self.training)
123 | h = self.fc3(h)
124 | return h, fc2
125 |
126 |
127 |
128 |
129 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### Please go to [project website](https://huaizhengzhang.github.io//DeepQoE/) to find more information.
2 |
--------------------------------------------------------------------------------
/cal_corre.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 19/2/18 8:43 PM
4 | # @Author : Huaizheng Zhang
5 | # @Site : zhanghuaizheng.info
6 | # @File : cal_corre.py
7 |
8 | from __future__ import print_function
9 | import torch
10 | import numpy as np
11 | from scripts.generate_pretrained_models import show_corre
12 | from scipy.stats.stats import pearsonr
13 | from scipy.stats import kendalltau
14 | from scipy.stats import spearmanr
15 | import torch.utils.data as data_utils
16 | from DeepQoE.nets import *
17 | from DeepQoE.config import cfg
18 | from DeepQoE.data_loader import QoEMOSDataset
19 | import pickle
20 |
21 | with open(cfg.MOS_DATA, 'rb') as f:
22 | data = pickle.load(f)
23 | x = data[:, 0:data.shape[1] - 1]
24 | y = np.array(data[:, -1], np.float)
25 |
26 | model = HybridRR()
27 | model.cuda()
28 | model.load_state_dict(torch.load(cfg.MODEL_SAVE_MOS))
29 |
30 | test_data = QoEMOSDataset(x, y)
31 | test_loader = data_utils.DataLoader(test_data, batch_size=72, shuffle=False)
32 | model.eval()
33 | for sample_batched in test_loader:
34 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
35 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
36 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
37 | target = torch.autograd.Variable(sample_batched['label'].cuda())
38 | output, _ = model(x_1, x_2, x_3)
39 |
40 | y_pred = output.data.cpu().numpy().squeeze()
41 | print(y, y_pred)
42 |
43 | show_corre(y[36:54], y_pred[36:54])
44 |
45 | print("Pearson: {}".format(pearsonr(y, y_pred)))
46 | print("Kendal: {}".format(kendalltau(y, y_pred)))
47 | print("Spearman: {}".format(spearmanr(y, y_pred)))
48 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 29/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py
7 |
8 |
9 |
--------------------------------------------------------------------------------
/demo.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 28/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : demo.py
--------------------------------------------------------------------------------
/log/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cap-ntu/DeepQoE/cb84870f57cd037a6396ff6a22b8d4eb486fe363/log/__init__.py
--------------------------------------------------------------------------------
/models/C3D/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 29/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py
--------------------------------------------------------------------------------
/models/GloVe/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 29/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py
--------------------------------------------------------------------------------
/models/LIVE_NFLX/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 8/3/18 10:50 PM
4 | # @Author : Huaizheng Zhang
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py.py
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 8/3/18 10:54 PM
4 | # @Author : Huaizheng Zhang
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py.py
--------------------------------------------------------------------------------
/models/protect/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cap-ntu/DeepQoE/cb84870f57cd037a6396ff6a22b8d4eb486fe363/models/protect/__init__.py
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 8/3/18 10:54 PM
4 | # @Author : Huaizheng Zhang
5 | # @Site : zhanghuaizheng.info
6 | # @File : __init__.py.py
--------------------------------------------------------------------------------
/scripts/download_glove_6B.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | FILE=glove.6B.zip
4 | URL=http://nlp.stanford.edu/data/glove.6B.zip
5 | CHECKSUM=056ea991adb4740ac6bf1b6d9b50408b
6 |
7 | DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../models/GloVe/" && pwd )"
8 | cd $DIR
9 |
10 | if [ -f $FILE ]; then
11 | echo "File already exists. Checking md5..."
12 | os=`uname -s`
13 | if [ "$os" = "Linux" ]; then
14 | checksum=`md5sum $FILE | awk '{ print $1 }'`
15 | elif [ "$os" = "Darwin" ]; then
16 | checksum=`cat $FILE | md5`
17 | fi
18 | if [ "$checksum" = "$CHECKSUM" ]; then
19 | echo "Checksum is correct. No need to download."
20 | exit 0
21 | else
22 | echo "Checksum is incorrect. Need to download again."
23 | fi
24 | fi
25 |
26 | echo "Downloading GloVe models..."
27 |
28 | wget $URL -O $FILE
29 |
30 | echo "Unzipping..."
31 |
32 | unzip $FILE
33 |
34 | echo "Done. Please run this command again to verify that checksum = $CHECKSUM."
35 |
--------------------------------------------------------------------------------
/scripts/generate_pretrained_models.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 8/3/18 10:42 PM
4 | # @Author : Huaizheng Zhang
5 | # @Site : zhanghuaizheng.info
6 | # @File : generate_pretrained_models.py
7 |
8 | # References:
9 | # 1) C. G. Bampis and A. C. Bovik, "Video ATLAS Software Release"
10 | # URL: http://live.ece.utexas.edu/research/Quality/VideoATLAS_release.zip, 2016
11 | # 2) C. G. Bampis and A. C. Bovik, "Learning to Predict Streaming Video QoE: Distortions, Rebuffering and Memory," under review
12 |
13 | from __future__ import print_function
14 |
15 | import os
16 | import re
17 | import numpy as np
18 | import scipy.io as sio
19 | import copy
20 | from sklearn import preprocessing
21 | from DeepQoE.config import cfg
22 | from scipy.stats import spearmanr as sr
23 | import matplotlib.pyplot as plt
24 |
25 |
26 | def generate_streaming_data():
27 |
28 | db_files = os.listdir(cfg.DB_PATH)
29 | db_files.sort(key=lambda var: [int(x) if x.isdigit() else x for x in re.findall(r'[^0-9]|[0-9]+', var)])
30 | Nvideos = len(db_files)
31 |
32 | pre_load_train_test_data_LIVE_Netflix = sio.loadmat('data/'+cfg.TRAIN_TEST_NFLX+'.mat')[cfg.TRAIN_TEST_NFLX]
33 |
34 | # randomly pick a trial out of the 1000
35 | nt_rand = np.random.choice(np.shape(pre_load_train_test_data_LIVE_Netflix)[1], 1)
36 | n_train = [ind for ind in range(0, Nvideos) if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 1]
37 | n_test = [ind for ind in range(0, Nvideos) if pre_load_train_test_data_LIVE_Netflix[ind, nt_rand] == 0]
38 |
39 | X = np.zeros((len(db_files), len(cfg.FEATURE_NAMES)))
40 | y = np.zeros((len(db_files), 1))
41 |
42 | feature_labels = list()
43 | for typ in cfg.FEATURE_NAMES:
44 | if typ == "VQA":
45 | feature_labels.append(cfg.QUALITY_MODEL + "_" + cfg.POOLING_TYPE)
46 | elif typ == "R$_1$":
47 | feature_labels.append("ds_norm")
48 | elif typ == "R$_2$":
49 | feature_labels.append("ns")
50 | elif typ == "M":
51 | feature_labels.append("tsl_norm")
52 | else:
53 | feature_labels.append("lt_norm")
54 |
55 | for i, f in enumerate(db_files):
56 | data = sio.loadmat(cfg.DB_PATH + f)
57 | for feat_cnt, feat in enumerate(feature_labels):
58 | X[i, feat_cnt] = data[feat]
59 | y[i] = data["final_subj_score"]
60 |
61 | X_train_before_scaling = X[n_train, :]
62 | X_test_before_scaling = X[n_test, :]
63 | y_train = y[n_train]
64 | y_test = y[n_test]
65 |
66 | if cfg.PREPROC:
67 | scaler = preprocessing.StandardScaler().fit(X_train_before_scaling)
68 | X_train = scaler.transform(X_train_before_scaling)
69 | X_test = scaler.transform(X_test_before_scaling)
70 | else:
71 | X_train = copy.deepcopy(X_train_before_scaling)
72 | X_test = copy.deepcopy(X_test_before_scaling)
73 |
74 | return X_train, y_train, X_test, y_test, feature_labels
75 |
76 | def show_results(X_test, X_test_before_scaling, y_test, regressor_name, feature_labels, answer):
77 |
78 | if cfg.QUALITY_MODEL + "_" + cfg.POOLING_TYPE in feature_labels:
79 | position_vqa = feature_labels.index(cfg.QUALITY_MODEL + "_" + cfg.POOLING_TYPE)
80 |
81 | plt.figure()
82 | ax1 = plt.subplot(1, 1, 1)
83 | plt.title("before: " + format(sr(y_test, X_test[:, position_vqa].reshape(-1, 1))[0], '.4f'))
84 | plt.scatter(y_test, X_test_before_scaling[:, position_vqa].reshape(-1, 1))
85 | plt.grid()
86 | x0, x1 = ax1.get_xlim()
87 | y0, y1 = ax1.get_ylim()
88 | ax1.set_aspect((x1 - x0) / (y1 - y0))
89 | plt.ylabel("predicted QoE")
90 | plt.xlabel("MOS")
91 | plt.show()
92 |
93 | plt.figure()
94 | ax1 = plt.subplot(1, 1, 1)
95 | plt.title("after: " + format(sr(y_test, answer.reshape(-1, 1))[0], '.4f'))
96 | plt.scatter(y_test, answer.reshape(-1, 1))
97 | plt.grid()
98 | x0, x1 = ax1.get_xlim()
99 | y0, y1 = ax1.get_ylim()
100 | ax1.set_aspect((x1 - x0) / (y1 - y0))
101 | plt.ylabel("predicted QoE")
102 | plt.xlabel("MOS")
103 | plt.show()
104 |
105 | print("SROCC before (" + str(cfg.QUALITY_MODEL) + "): " + str(sr(y_test, X_test[:, position_vqa].reshape(-1, 1))[0]))
106 | print("SROCC using DeepQoE (" + str(cfg.QUALITY_MODEL) + " + " + regressor_name + "): " + str(
107 | sr(y_test, answer.reshape(-1, 1))[0]))
108 |
109 |
110 | def show_corre(y_true, y_test):
111 | plt.figure()
112 | ax1 = plt.subplot(1, 1, 1)
113 | plt.title("Sport SROCC: " + format(sr(y_true, y_test.reshape(-1, 1))[0], '.4f'))
114 | plt.scatter(y_true, y_test.reshape(-1, 1))
115 | plt.grid()
116 | x0, x1 = ax1.get_xlim()
117 | y0, y1 = ax1.get_ylim()
118 | ax1.set_aspect((x1 - x0) / (y1 - y0))
119 | plt.ylabel("Predicted QoE")
120 | plt.xlabel("Real QoE")
121 | plt.show()
--------------------------------------------------------------------------------
/scripts/min_max.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 13/9/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : min_max.py
7 |
8 |
9 | from sklearn import preprocessing
10 | import pickle
11 |
12 | with open('data/embedding_data.pkl', 'rb') as f:
13 | data = pickle.load(f)
14 | x, y = data[0], data[1]
15 | print x[..., [53]]
16 |
17 | min_max_scaler = preprocessing.MinMaxScaler()
18 | X_train_minmax = min_max_scaler.fit_transform(x[..., [53]])
19 | x[..., [53]] = X_train_minmax
20 |
21 | with open('data/embedding_data_new.pkl', 'wb') as f:
22 | pickle.dump([x, y], f)
23 | print [x, y]
24 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 28/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : test.py
7 |
8 | from __future__ import print_function
9 |
10 | import torch
11 | import pickle
12 | import sys
13 | import numpy as np
14 | from sklearn import svm
15 | from sklearn.metrics import accuracy_score
16 | import torch.utils.data as data_utils
17 | from DeepQoE.config import cfg, parse_arguments
18 | from DeepQoE.nets import *
19 | from DeepQoE.data_loader import QoETextDataset
20 |
21 |
22 | def test_confusion_matrix(args):
23 | model = HybridNN()
24 | model.load_state_dict(torch.load(cfg.MODEL_SAVE_TEXT))
25 |
26 | with open(cfg.EMBEDDING_DATA, 'rb') as f:
27 | data = pickle.load(f)
28 | x, y = data[0], data[1]
29 |
30 | train_size = int(cfg.TRAIN_RATIO * len(x))
31 |
32 | x_train = x[:train_size]
33 | y_train = y[:train_size]
34 | x_test = x[train_size:]
35 | y_test = y[train_size:]
36 | print (y_test)
37 |
38 | if args.use_gpu and torch.cuda.is_available():
39 | torch.cuda.set_device(args.gpu_id)
40 | model.cuda()
41 |
42 | train_data = QoETextDataset(x_train, y_train)
43 | train_loader = data_utils.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
44 |
45 | test_data = QoETextDataset(x_test, y_test)
46 | test_loader = data_utils.DataLoader(test_data, batch_size=10, shuffle=False)
47 |
48 | model.eval()
49 |
50 | for sample_batched in test_loader:
51 | if args.use_gpu and torch.cuda.is_available():
52 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
53 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
54 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
55 | x_4 = torch.autograd.Variable(sample_batched['gender'].cuda())
56 | x_5 = torch.autograd.Variable(sample_batched['age'].cuda())
57 | target = torch.autograd.Variable(sample_batched['label'].cuda())
58 | else:
59 | x_1 = torch.autograd.Variable(sample_batched['glove'])
60 | x_2 = torch.autograd.Variable(sample_batched['res'])
61 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
62 | x_4 = torch.autograd.Variable(sample_batched['gender'])
63 | x_5 = torch.autograd.Variable(sample_batched['age'])
64 | target = torch.autograd.Variable(sample_batched['label'])
65 |
66 | output, _ = model(x_1, x_2, x_3, x_4, x_5)
67 | pred = output.data.max(1, keepdim=True)[1]
68 | print ("True: {}".format(target))
69 | print ("Predict: {}".format(pred))
70 |
71 | if __name__ == '__main__':
72 | test_confusion_matrix(parse_arguments(sys.argv[1:]))
--------------------------------------------------------------------------------
/train_test.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 28/8/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : train_test.py
7 |
8 | from __future__ import print_function
9 |
10 | import os
11 | import sys
12 | import numpy as np
13 | import pickle
14 | import logging
15 | import torch.optim as optim
16 | import torch.utils.data as data_utils
17 | from DeepQoE.nets import *
18 | from DeepQoE.config import cfg, parse_arguments
19 | from DeepQoE.data_loader import QoETextDataset
20 |
21 |
22 | def main(args):
23 | log_file = 'log/' + cfg.LOG + '_' + str(cfg.DIMENSION) + '.txt'
24 | print(log_file)
25 | logging.basicConfig(filename=log_file, level=logging.INFO, filemode='w')
26 |
27 | result = 0.0
28 | for i in range(10):
29 | if args.model_number == 0:
30 | model = HybridNN()
31 | elif args.model_number == 1:
32 | model = C3DHybridNN()
33 | else:
34 | model = HybridRR()
35 | if args.use_gpu and torch.cuda.is_available():
36 | torch.cuda.set_device(args.gpu_id)
37 | model.cuda()
38 | print(model)
39 |
40 | model.train()
41 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
42 |
43 | with open(cfg.EMBEDDING_DATA, 'rb') as f:
44 | data = pickle.load(f)
45 | x, y = data[0], data[1]
46 |
47 | train_size = int(cfg.TRAIN_RATIO * len(x))
48 |
49 | x_train = x[:train_size]
50 | y_train = y[:train_size]
51 |
52 | x_test = x[train_size:]
53 | y_test = y[train_size:]
54 |
55 |
56 |
57 |
58 | train_data = QoETextDataset(x_train, y_train)
59 | train_loader = data_utils.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
60 | # print (train_loader)
61 | # for batch_idx, sample_batched in enumerate(train_loader):
62 | # print (sample_batched)
63 | for epoch in range(args.epochs):
64 | for batch_idx, sample_batched in enumerate(train_loader):
65 | pid = os.getpid()
66 | if args.use_gpu and torch.cuda.is_available():
67 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
68 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
69 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
70 | x_4 = torch.autograd.Variable(sample_batched['gender'].cuda())
71 | x_5 = torch.autograd.Variable(sample_batched['age'].cuda())
72 | target = torch.autograd.Variable(sample_batched['label'].cuda())
73 | else:
74 | x_1 = torch.autograd.Variable(sample_batched['glove'])
75 | x_2 = torch.autograd.Variable(sample_batched['res'])
76 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
77 | x_4 = torch.autograd.Variable(sample_batched['gender'])
78 | x_5 = torch.autograd.Variable(sample_batched['age'])
79 | target = torch.autograd.Variable(sample_batched['label'])
80 | # print (target)
81 | optimizer.zero_grad()
82 | prediction, _ = model(x_1, x_2, x_3, x_4, x_5)
83 | # print(prediction)
84 | if args.model_number == 2:
85 | loss = F.mse_loss(prediction, target)
86 | else:
87 | loss = F.nll_loss(prediction, target)
88 |
89 | loss.backward()
90 | optimizer.step()
91 | # print(model.layer1_glove.weight)
92 |
93 | if batch_idx % 3 == 0:
94 | print('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
95 | pid, epoch, batch_idx * len(data), len(train_loader.dataset),
96 | 100. * batch_idx / len(train_loader), loss.data[0]))
97 | torch.save(model.state_dict(), cfg.MODEL_SAVE_TEXT)
98 |
99 | # test processing
100 | test_data = QoETextDataset(x_test, y_test)
101 | test_loader = data_utils.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
102 | model.eval()
103 | test_loss = 0
104 | correct = 0
105 | for sample_batched in test_loader:
106 | if args.use_gpu and torch.cuda.is_available():
107 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
108 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
109 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
110 | x_4 = torch.autograd.Variable(sample_batched['gender'].cuda())
111 | x_5 = torch.autograd.Variable(sample_batched['age'].cuda())
112 | target = torch.autograd.Variable(sample_batched['label'].cuda())
113 |
114 | else:
115 | x_1 = torch.autograd.Variable(sample_batched['glove'])
116 | x_2 = torch.autograd.Variable(sample_batched['res'])
117 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
118 | x_4 = torch.autograd.Variable(sample_batched['gender'])
119 | x_5 = torch.autograd.Variable(sample_batched['age'])
120 | target = torch.autograd.Variable(sample_batched['label'])
121 |
122 | output, _ = model(x_1, x_2, x_3, x_4, x_5)
123 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
124 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
125 | correct += pred.eq(target.data.view_as(pred)).cpu().sum()
126 |
127 | test_loss /= len(test_loader.dataset)
128 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
129 | test_loss, correct, len(test_loader.dataset),
130 | 100. * correct / len(test_loader.dataset)))
131 |
132 | result = result + 100. * correct / len(test_loader.dataset)
133 | logging.info('The {}th results: {}'.format(i+1, '\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
134 | test_loss, correct, len(test_loader.dataset),
135 | 100. * correct / len(test_loader.dataset))))
136 |
137 | logging.info('Average = {}'.format(result / 10.0))
138 | if __name__ == '__main__':
139 | main(parse_arguments(sys.argv[1:]))
140 |
141 |
142 |
143 | # model = Embedding()
144 |
--------------------------------------------------------------------------------
/train_test_MOS.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 14/12/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : train_test_MOS.py
7 |
8 | from __future__ import print_function
9 |
10 | import os
11 | import sys
12 | import numpy as np
13 | import pickle
14 | import logging
15 | import torch.optim as optim
16 | import torch.utils.data as data_utils
17 | from DeepQoE.nets import *
18 | from DeepQoE.config import cfg, parse_arguments
19 | from DeepQoE.data_loader import QoEMOSDataset
20 |
21 |
22 | def main(args):
23 | model = HybridRR()
24 | if args.use_gpu and torch.cuda.is_available():
25 | torch.cuda.set_device(args.gpu_id)
26 | model.cuda()
27 | print(model)
28 |
29 | model.train()
30 | # optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum)
31 | optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.95, eps=1e-08, weight_decay=1e-6)
32 | with open(cfg.MOS_DATA, 'rb') as f:
33 | data = pickle.load(f)
34 | x = data[:, 0:data.shape[1] - 1]
35 | y = np.array(data[:, -1], np.float)
36 |
37 | # x_train = np.concatenate((x[0:36], x[54:]), axis=0)
38 | # y_train = np.concatenate((y[0:36], y[54:]), axis=0)
39 | x_train = x[0:54]
40 | y_train = y[0:54]
41 |
42 | x_test = x[54:]
43 | y_test = y[54:]
44 |
45 | train_data = QoEMOSDataset(x_train, y_train)
46 | train_loader = data_utils.DataLoader(train_data, batch_size=54, shuffle=False)
47 | for epoch in range(args.epochs):
48 | for batch_idx, sample_batched in enumerate(train_loader):
49 | pid = os.getpid()
50 | if args.use_gpu and torch.cuda.is_available():
51 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
52 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
53 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
54 | target = torch.autograd.Variable(sample_batched['label'].cuda())
55 | else:
56 | x_1 = torch.autograd.Variable(sample_batched['glove'])
57 | x_2 = torch.autograd.Variable(sample_batched['res'])
58 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
59 | target = torch.autograd.Variable(sample_batched['label'])
60 | # print (target)
61 | optimizer.zero_grad()
62 | prediction, _ = model(x_1, x_2, x_3)
63 | loss = F.mse_loss(prediction, target.float())
64 |
65 | loss.backward()
66 | optimizer.step()
67 |
68 | if batch_idx % 3 == 0:
69 | print('{}\tTrain Epoch: {} \tLoss: {:.6f}'.format(
70 | pid, epoch, loss.data[0]))
71 | torch.save(model.state_dict(), cfg.MODEL_SAVE_MOS)
72 |
73 | # test processing
74 | test_data = QoEMOSDataset(x_test, y_test)
75 | test_loader = data_utils.DataLoader(test_data, batch_size=18, shuffle=False)
76 | model.eval()
77 | test_loss = 0
78 | for sample_batched in test_loader:
79 | if args.use_gpu and torch.cuda.is_available():
80 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
81 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
82 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
83 | target = torch.autograd.Variable(sample_batched['label'].cuda())
84 |
85 | else:
86 | x_1 = torch.autograd.Variable(sample_batched['glove'])
87 | x_2 = torch.autograd.Variable(sample_batched['res'])
88 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
89 | target = torch.autograd.Variable(sample_batched['label'])
90 |
91 | output, _ = model(x_1, x_2, x_3)
92 | test_loss += F.mse_loss(output, target.float(), size_average=False).data[0]
93 | # print (output)
94 | print (target.float())
95 |
96 | test_loss /= len(test_loader.dataset)
97 | print('\nTest set: Average loss: {:.4f}'.format(test_loss))
98 |
99 |
100 | if __name__ == '__main__':
101 | main(parse_arguments(sys.argv[1:]))
102 |
--------------------------------------------------------------------------------
/train_test_classifiers.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 12/9/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : train_test_classifiers.py
7 |
8 | from __future__ import print_function
9 |
10 | import torch
11 | import pickle
12 | import sys
13 | import numpy as np
14 | from sklearn import svm
15 | from sklearn.metrics import accuracy_score
16 | import torch.utils.data as data_utils
17 | from DeepQoE.config import cfg, parse_arguments
18 | from DeepQoE.nets import *
19 | from DeepQoE.data_loader import QoETextDataset
20 | import datetime
21 |
22 |
23 | def train_test_SVM(args):
24 | model = HybridNN()
25 | model.load_state_dict(torch.load(cfg.MODEL_SAVE_TEXT))
26 |
27 | with open(cfg.EMBEDDING_DATA, 'rb') as f:
28 | data = pickle.load(f)
29 | x, y = data[0], data[1]
30 |
31 | train_size = int(cfg.TRAIN_RATIO * len(x))
32 |
33 | x_train = x[:train_size]
34 | y_train = y[:train_size]
35 | x_test = x[train_size:]
36 | y_test = y[train_size:]
37 | print (y_test)
38 |
39 | if args.use_gpu and torch.cuda.is_available():
40 | torch.cuda.set_device(args.gpu_id)
41 | model.cuda()
42 |
43 | train_data = QoETextDataset(x_train, y_train)
44 | train_loader = data_utils.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
45 |
46 | test_data = QoETextDataset(x_test, y_test)
47 | test_loader = data_utils.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
48 |
49 | model.eval()
50 | features_train = []
51 |
52 | start_deep = datetime.datetime.now()
53 |
54 | for sample_batched in train_loader:
55 | if args.use_gpu and torch.cuda.is_available():
56 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
57 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
58 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
59 | x_4 = torch.autograd.Variable(sample_batched['gender'].cuda())
60 | x_5 = torch.autograd.Variable(sample_batched['age'].cuda())
61 |
62 | else:
63 | x_1 = torch.autograd.Variable(sample_batched['glove'])
64 | x_2 = torch.autograd.Variable(sample_batched['res'])
65 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
66 | x_4 = torch.autograd.Variable(sample_batched['gender'])
67 | x_5 = torch.autograd.Variable(sample_batched['age'])
68 | _, fc2_train = model(x_1, x_2, x_3, x_4, x_5)
69 | features_train.append(fc2_train.data.cpu().numpy())
70 | train_features = np.concatenate(features_train, 0)
71 |
72 | total_deep = float((datetime.datetime.now() - start_deep).total_seconds()) / float(len(train_data))
73 | print("DeepQoE total cost {}s".format(total_deep))
74 | print(len(train_data))
75 |
76 | clf = cfg.CLASSIFIER[args.classifier]
77 | clf.fit(train_features, y_train)
78 |
79 | features_test = []
80 |
81 | start_deep = datetime.datetime.now()
82 |
83 | for sample_batched in test_loader:
84 | if args.use_gpu and torch.cuda.is_available():
85 | x_1 = torch.autograd.Variable(sample_batched['glove'].cuda())
86 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
87 | x_3 = torch.autograd.Variable(sample_batched['bitrate'].cuda())
88 | x_4 = torch.autograd.Variable(sample_batched['gender'].cuda())
89 | x_5 = torch.autograd.Variable(sample_batched['age'].cuda())
90 |
91 | else:
92 | x_1 = torch.autograd.Variable(sample_batched['glove'])
93 | x_2 = torch.autograd.Variable(sample_batched['res'])
94 | x_3 = torch.autograd.Variable(sample_batched['bitrate'])
95 | x_4 = torch.autograd.Variable(sample_batched['gender'])
96 | x_5 = torch.autograd.Variable(sample_batched['age'])
97 | _, fc2_test = model(x_1, x_2, x_3, x_4, x_5)
98 | features_test.append(fc2_test.data.cpu().numpy())
99 | test_features = np.concatenate(features_test, 0)
100 |
101 | total_deep = float((datetime.datetime.now() - start_deep).total_seconds()) / float(len(test_data))
102 | print("DeepQoE total cost {}s".format(total_deep))
103 | print(len(test_data))
104 |
105 | prediction = clf.predict(test_features)
106 | acc = accuracy_score(prediction, y_test)
107 | print ("{} uses DeepQoE features can get {}%".format(cfg.CLASSIFIER_NAME[args.classifier], acc * 100.0))
108 |
109 | clf_ori = cfg.CLASSIFIER[args.classifier]
110 | clf_ori.fit(x_train.astype(float), y_train)
111 | prediction_ori = clf_ori.predict(x_test.astype(float))
112 | acc_ori = accuracy_score(prediction_ori, y_test)
113 | print("{} uses original features can get {}%".format(cfg.CLASSIFIER_NAME[args.classifier], acc_ori * 100.0))
114 |
115 | if __name__ == '__main__':
116 | train_test_SVM(parse_arguments(sys.argv[1:]))
117 |
--------------------------------------------------------------------------------
/train_test_streaming.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 9/3/18 3:40 PM
4 | # @Author : Huaizheng Zhang
5 | # @Site : zhanghuaizheng.info
6 | # @File : train_test_streaming.py
7 |
8 | from __future__ import print_function
9 |
10 | import os
11 | import sys
12 | import numpy as np
13 | import pickle
14 | import logging
15 | import torch.optim as optim
16 | import torch.utils.data as data_utils
17 | from DeepQoE.nets import *
18 | from DeepQoE.config import cfg, parse_arguments
19 | from scripts.generate_pretrained_models import generate_streaming_data, show_results
20 | from DeepQoE.data_loader import QoENFLXDataset
21 |
22 |
23 |
24 | def main(args):
25 | model = HybridStreaming()
26 | if args.use_gpu and torch.cuda.is_available():
27 | torch.cuda.set_device(args.gpu_id)
28 | model.cuda()
29 | print(model)
30 |
31 | optimizer = optim.Adadelta(model.parameters(), lr=1.0, rho=0.95, eps=1e-08, weight_decay=1e-6)
32 | x_train, y_train, x_test, y_test, feature_labels = generate_streaming_data()
33 | # print(x_train)
34 |
35 | train_data = QoENFLXDataset(x_train, y_train)
36 | train_loader = data_utils.DataLoader(train_data, batch_size=64, shuffle=False)
37 |
38 | test_data = QoENFLXDataset(x_test, y_test)
39 | test_loader = data_utils.DataLoader(test_data, batch_size=64, shuffle=False)
40 |
41 | for epoch in range(args.epochs):
42 | train(train_loader, model, optimizer, epoch)
43 |
44 | output = test(test_loader, model)
45 | predict = output.cpu().data.numpy()
46 | show_results(x_test, x_test, y_test, 'DeepQoE', feature_labels, predict)
47 |
48 | # torch.save(model.state_dict(), cfg.MODEL_SAVE_MOS)
49 |
50 |
51 | def train(train_loader, model, optimizer, epoch):
52 | model.train()
53 | for batch_idx, sample_batched in enumerate(train_loader):
54 | pid = os.getpid()
55 |
56 | x_1 = torch.autograd.Variable(sample_batched['VQA'].cuda())
57 | x_2 = torch.autograd.Variable(sample_batched['R1'].cuda())
58 | x_3 = torch.autograd.Variable(sample_batched['R2'].cuda())
59 | x_4 = torch.autograd.Variable(sample_batched['Mem'].cuda())
60 | x_5 = torch.autograd.Variable(sample_batched['Impair'].cuda())
61 | target = torch.autograd.Variable(sample_batched['label'].cuda())
62 |
63 | optimizer.zero_grad()
64 | prediction, _ = model(x_1, x_2, x_3, x_4, x_5)
65 | loss = F.mse_loss(prediction, target.float())
66 |
67 | loss.backward()
68 | optimizer.step()
69 |
70 | if batch_idx % 3 == 0:
71 | print('{}\tTrain Epoch: {} \tLoss: {:.6f}'.format(
72 | pid, epoch, loss.data[0]))
73 |
74 |
75 | def test(test_loader, model):
76 | model.eval()
77 | test_loss = 0
78 | for sample_batched in test_loader:
79 | x_1 = torch.autograd.Variable(sample_batched['VQA'].cuda())
80 | x_2 = torch.autograd.Variable(sample_batched['R1'].cuda())
81 | x_3 = torch.autograd.Variable(sample_batched['R2'].cuda())
82 | x_4 = torch.autograd.Variable(sample_batched['Mem'].cuda())
83 | x_5 = torch.autograd.Variable(sample_batched['Impair'].cuda())
84 | target = torch.autograd.Variable(sample_batched['label'].cuda())
85 |
86 | output, _ = model(x_1, x_2, x_3, x_4, x_5)
87 | test_loss += F.mse_loss(output, target.float(), size_average=False).data[0]
88 | # print (output)
89 | print(target.float())
90 |
91 | test_loss /= len(test_loader.dataset)
92 | print('\nTest set: Average loss: {:.4f}'.format(test_loss))
93 | return output
94 |
95 |
96 |
97 |
98 |
99 | if __name__ == '__main__':
100 | main(parse_arguments(sys.argv[1:]))
101 |
--------------------------------------------------------------------------------
/train_test_video.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # -*- coding: utf-8 -*-
3 | # @Time : 12/12/17
4 | # @Author : Huaizheng ZHANG
5 | # @Site : zhanghuaizheng.info
6 | # @File : train_test_video.py
7 |
8 | from __future__ import print_function
9 |
10 | import os
11 | import sys
12 | import numpy as np
13 | import pickle
14 | from sklearn import preprocessing
15 | import torch.optim as optim
16 | import torch.utils.data as data_utils
17 | from DeepQoE.nets import *
18 | from DeepQoE.config import cfg, parse_arguments
19 | from DeepQoE.data_loader import QoEVideoDataset
20 |
21 | def shuffle_data(x, y):
22 | sh = np.arange(x.shape[0])
23 | np.random.shuffle(sh)
24 | x = x[sh]
25 | y = y[sh]
26 | return x, y
27 |
28 | def main(args):
29 |
30 | if args.model_number == 0:
31 | model = HybridNN()
32 | elif args.model_number == 1:
33 | model = C3DHybridNN()
34 | else:
35 | model = HybridRR()
36 | if args.use_gpu and torch.cuda.is_available():
37 | torch.cuda.set_device(args.gpu_id)
38 | model.cuda()
39 | print(model)
40 |
41 | model.train()
42 | optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, nesterov=True)
43 |
44 | with open(cfg.JND_DATA, 'rb') as f:
45 | data = pickle.load(f)
46 |
47 | train_test = np.array(data)
48 | x_temp = train_test[:, 0:train_test.shape[1] - 1]
49 | y_temp = np.array(train_test[:, train_test.shape[1] - 1], np.int)
50 | le = preprocessing.LabelEncoder()
51 | le.fit(y_temp)
52 | encode_lable = le.transform(y_temp)
53 |
54 | print('Encoder {}'.format('lable') + '\n~~~~~~~~~~~~~~~~~~~~~~~')
55 | print(x_temp, y_temp)
56 |
57 | x, y = shuffle_data(x_temp, encode_lable)
58 |
59 | train_size = int(cfg.TRAIN_RATIO * len(x))
60 |
61 | x_train = x[:train_size]
62 | x_test = x[train_size:]
63 | y_train = y[:train_size]
64 | y_test = y[train_size:]
65 |
66 | train_data = QoEVideoDataset(x_train, y_train)
67 | train_loader = data_utils.DataLoader(train_data, batch_size=args.batch_size, shuffle=False)
68 | for epoch in range(args.epochs):
69 | for batch_idx, sample_batched in enumerate(train_loader):
70 | pid = os.getpid()
71 | if args.use_gpu and torch.cuda.is_available():
72 | x_1 = torch.autograd.Variable(sample_batched['video'].cuda())
73 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
74 | target = torch.autograd.Variable(sample_batched['label'].cuda())
75 | else:
76 | x_1 = torch.autograd.Variable(sample_batched['video'])
77 | x_2 = torch.autograd.Variable(sample_batched['res'])
78 | target = torch.autograd.Variable(sample_batched['label'])
79 | # print (target)
80 | optimizer.zero_grad()
81 | prediction, _ = model(x_1, x_2)
82 | # print(prediction)
83 | if args.model_number == 2:
84 | loss = F.mse_loss(prediction, target)
85 | else:
86 | loss = F.nll_loss(prediction, target)
87 |
88 | loss.backward()
89 | optimizer.step()
90 | # print(model.layer1_glove.weight)
91 |
92 | if batch_idx % 3 == 0:
93 | print('{}\tTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
94 | pid, epoch, batch_idx * len(data), len(train_loader.dataset),
95 | 100. * batch_idx / len(train_loader), loss.data[0]))
96 | torch.save(model.state_dict(), cfg.MODEL_SAVE_VIDEO)
97 |
98 | # test processing
99 | test_data = QoEVideoDataset(x_test, y_test)
100 | test_loader = data_utils.DataLoader(test_data, batch_size=args.batch_size, shuffle=False)
101 | model.eval()
102 | test_loss = 0
103 | correct = 0
104 | for sample_batched in test_loader:
105 | if args.use_gpu and torch.cuda.is_available():
106 | x_1 = torch.autograd.Variable(sample_batched['video'].cuda())
107 | x_2 = torch.autograd.Variable(sample_batched['res'].cuda())
108 | target = torch.autograd.Variable(sample_batched['label'].cuda())
109 | else:
110 | x_1 = torch.autograd.Variable(sample_batched['video'])
111 | x_2 = torch.autograd.Variable(sample_batched['res'])
112 | target = torch.autograd.Variable(sample_batched['label'])
113 | output, _ = model(x_1, x_2)
114 | test_loss += F.nll_loss(output, target, size_average=False).data[0] # sum up batch loss
115 | pred = output.data.max(1, keepdim=True)[1] # get the index of the max log-probability
116 | correct += pred.eq(target.data.view_as(pred)).cpu().sum()
117 |
118 | test_loss /= len(test_loader.dataset)
119 | print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
120 | test_loss, correct, len(test_loader.dataset),
121 | 100. * correct / len(test_loader.dataset)))
--------------------------------------------------------------------------------