:dolphin:DOLPHIN
2 |
3 |
17 |
18 | ##
:ocean:Introduction
19 |
20 | DOLPHIN is an online writer retrieval model, designed to retrieve all online handwriting samples of a specific writer. It synergizes temporal and frequency learning to extract discriminative feature representations for online handwriting.
21 |
22 | 
23 |
24 |
The model architecture of DOLPIHN
25 |
26 | ##
:earth_asia:Environment
27 |
28 | ```bash
29 | git clone https://github.com/SCUT-DLVCLab/DOLPHIN.git
30 | cd DOLPHIN
31 | conda create -n dolphin python=3.8.16
32 | conda activate dolphin
33 | pip install -r requirements.txt
34 | ```
35 |
36 | ##
:hammer_and_pick:Data Preparation
37 |
38 | Download the three subsets: CASIA-OLHWDB2, DCOH-E, and SCUT-COUCH2009 using the following links:
39 |
40 | - [Baidu Cloud](https://pan.baidu.com/s/1Op917v5IM7OushQ_xPNLSg?pwd=oler)
41 | - [Google Drive](https://drive.google.com/drive/folders/1W-R78wLSJXDhK998c_zIAEFxtPE10AX4?usp=sharing)
42 |
43 | Unzip the .zip archives using the following commands:
44 |
45 | ```bash
46 | unzip OLHWDB2.zip -d .
47 | unzip DCOH-E.zip -d .
48 | unzip COUCH09.zip -d .
49 | ```
50 |
51 | The directory should look like this:
52 |
53 | ```
54 | data-raw
55 | ├── COUCH09
56 | │ ├── 001
57 | │ └── ...
58 | ├── DCOH-E
59 | │ ├── dcoh-e313
60 | │ └── ...
61 | └── OLHWDB2
62 | ├── 001
63 | └── ...
64 | ```
65 |
66 | Then run `preprocess.py` for data preprocessing:
67 |
68 | ```bash
69 | python preprocess.py --dataset olhwdb2
70 | python preprocess.py --dataset dcohe
71 | python preprocess.py --dataset couch
72 | ```
73 |
74 | The preprocessed data will be saved at the `data` folder.
75 |
76 | Then run the `divide.py` to merge the three subsets into the **OLIWER** dataset and divide the data into `training` and `testing` parts.
77 |
78 | ```bash
79 | python divide.py --divide
80 | python divide.py --extract
81 | ```
82 |
83 | Now the data should be all preprocessed. The final data directory should look like:
84 |
85 | ```bash
86 | data
87 | ├── COUCH09
88 | │ └── COUCH09.pkl
89 | ├── DCOH-E
90 | │ └── DCOH-E.pkl
91 | ├── OLHWDB2
92 | │ └── OLHWDB2.pkl
93 | └── OLIWER
94 | ├── split.json
95 | ├── test.pkl
96 | ├── test-tf.pkl
97 | ├── train.pkl
98 | └── train-tf.pkl
99 | ```
100 |
101 | ##
:rocket:Test
102 |
103 | ```
104 | python test.py --weights weights/model.pth
105 | ```
106 |
107 | ##
:bookmark_tabs:Citation
108 |
109 | ```bibtex
110 | @ARTICLE{dolphin2024zhang,
111 | author={Zhang, Peirong and Jin, Lianwen},
112 | journal={IEEE Transactions on Information Forensics and Security (TIFS)},
113 | title={{Online Writer Retrieval With Chinese Handwritten Phrases: A Synergistic Temporal-Frequency Representation Learning Approach}},
114 | year={2024},
115 | volume={19},
116 | number={},
117 | pages={10387-10399}
118 | }
119 | ```
120 |
121 | ##
:phone:Cotact
122 |
123 | Peirong Zhang: eeprzhang@mail.scut.edu.cn
124 |
125 | ##
:palm_tree:Copyright
126 |
127 | Copyright 2024-2025, Deep Learning and Vision Computing (DLVC) Lab, South China China University of Technology. [http://www.dlvc-lab.net](http://www.dlvc-lab.net/).
128 |
129 |
--------------------------------------------------------------------------------
/asset/arch.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCUT-DLVCLab/DOLPHIN/fc90f848ef11cbdcc2f89fabf3a9ed68343641d0/asset/arch.png
--------------------------------------------------------------------------------
/data/OLIWER/split.json:
--------------------------------------------------------------------------------
1 | {
2 | "train_writers": [
3 | "dcoh-e416",
4 | "dcoh154",
5 | "1006",
6 | "120",
7 | "dcoh-e439",
8 | "106",
9 | "couch87",
10 | "740",
11 | "couch33",
12 | "639",
13 | "821",
14 | "804",
15 | "717",
16 | "687",
17 | "995",
18 | "dcoh61",
19 | "dcoh-e373",
20 | "912",
21 | "142",
22 | "790",
23 | "dcoh-e413",
24 | "419",
25 | "couch143",
26 | "12",
27 | "565",
28 | "dcoh13",
29 | "couch51",
30 | "567",
31 | "925",
32 | "dcoh264",
33 | "dcoh122",
34 | "754",
35 | "dcoh162",
36 | "dcoh223",
37 | "381",
38 | "dcoh-e392",
39 | "225",
40 | "753",
41 | "383",
42 | "dcoh-e313",
43 | "492",
44 | "dcoh62",
45 | "dcoh180",
46 | "dcoh276",
47 | "dcoh-e490",
48 | "dcoh-e421",
49 | "dcoh-e314",
50 | "749",
51 | "dcoh75",
52 | "dcoh-e508",
53 | "386",
54 | "683",
55 | "dcoh267",
56 | "382",
57 | "dcoh15",
58 | "712",
59 | "129",
60 | "dcoh151",
61 | "468",
62 | "775",
63 | "161",
64 | "931",
65 | "53",
66 | "dcoh-e411",
67 | "312",
68 | "dcoh-e519",
69 | "dcoh41",
70 | "448",
71 | "dcoh261",
72 | "17",
73 | "250",
74 | "893",
75 | "487",
76 | "242",
77 | "couch99",
78 | "961",
79 | "264",
80 | "dcoh92",
81 | "couch130",
82 | "455",
83 | "969",
84 | "dcoh-e456",
85 | "dcoh4",
86 | "43",
87 | "761",
88 | "883",
89 | "couch109",
90 | "dcoh-e484",
91 | "dcoh-e426",
92 | "733",
93 | "739",
94 | "dcoh-e533",
95 | "322",
96 | "dcoh-e328",
97 | "147",
98 | "dcoh150",
99 | "dcoh144",
100 | "168",
101 | "633",
102 | "dcoh-e345",
103 | "dcoh212",
104 | "dcoh269",
105 | "dcoh89",
106 | "couch20",
107 | "dcoh157",
108 | "335",
109 | "328",
110 | "807",
111 | "8",
112 | "826",
113 | "couch25",
114 | "dcoh298",
115 | "couch42",
116 | "773",
117 | "couch6",
118 | "65",
119 | "182",
120 | "couch21",
121 | "dcoh54",
122 | "295",
123 | "547",
124 | "30",
125 | "817",
126 | "688",
127 | "couch84",
128 | "dcoh-e385",
129 | "938",
130 | "606",
131 | "618",
132 | "373",
133 | "couch137",
134 | "dcoh-e381",
135 | "263",
136 | "305",
137 | "851",
138 | "138",
139 | "934",
140 | "811",
141 | "425",
142 | "36",
143 | "dcoh26",
144 | "dcoh48",
145 | "662",
146 | "233",
147 | "843",
148 | "877",
149 | "couch119",
150 | "80",
151 | "couch128",
152 | "329",
153 | "dcoh128",
154 | "357",
155 | "154",
156 | "764",
157 | "18",
158 | "dcoh-e403",
159 | "couch121",
160 | "couch27",
161 | "dcoh-e344",
162 | "dcoh-e468",
163 | "454",
164 | "806",
165 | "couch44",
166 | "906",
167 | "dcoh124",
168 | "589",
169 | "137",
170 | "dcoh178",
171 | "dcoh-e509",
172 | "963",
173 | "dcoh224",
174 | "dcoh-e459",
175 | "510",
176 | "dcoh-e316",
177 | "521",
178 | "443",
179 | "317",
180 | "473",
181 | "dcoh192",
182 | "413",
183 | "277",
184 | "729",
185 | "702",
186 | "813",
187 | "dcoh-e476",
188 | "824",
189 | "dcoh198",
190 | "715",
191 | "dcoh-e470",
192 | "672",
193 | "848",
194 | "couch60",
195 | "398",
196 | "759",
197 | "701",
198 | "967",
199 | "771",
200 | "370",
201 | "812",
202 | "dcoh-e474",
203 | "dcoh103",
204 | "943",
205 | "dcoh241",
206 | "1004",
207 | "couch58",
208 | "964",
209 | "209",
210 | "780",
211 | "347",
212 | "309",
213 | "132",
214 | "623",
215 | "805",
216 | "837",
217 | "881",
218 | "dcoh170",
219 | "dcoh60",
220 | "dcoh116",
221 | "713",
222 | "853",
223 | "735",
224 | "dcoh280",
225 | "dcoh219",
226 | "757",
227 | "153",
228 | "dcoh77",
229 | "353",
230 | "469",
231 | "808",
232 | "couch115",
233 | "1007",
234 | "285",
235 | "409",
236 | "362",
237 | "dcoh215",
238 | "dcoh-e354",
239 | "couch77",
240 | "689",
241 | "612",
242 | "81",
243 | "dcoh-e551",
244 | "252",
245 | "236",
246 | "614",
247 | "couch72",
248 | "600",
249 | "648",
250 | "872",
251 | "couch10",
252 | "dcoh272",
253 | "304",
254 | "dcoh-e451",
255 | "dcoh-e529",
256 | "527",
257 | "dcoh-e431",
258 | "dcoh-e471",
259 | "dcoh210",
260 | "41",
261 | "112",
262 | "675",
263 | "dcoh93",
264 | "485",
265 | "835",
266 | "581",
267 | "couch11",
268 | "dcoh-e318",
269 | "couch117",
270 | "dcoh200",
271 | "dcoh-e493",
272 | "282",
273 | "dcoh195",
274 | "202",
275 | "913",
276 | "couch7",
277 | "dcoh-e361",
278 | "993",
279 | "dcoh-e383",
280 | "96",
281 | "666",
282 | "couch118",
283 | "519",
284 | "722",
285 | "dcoh-e418",
286 | "dcoh-e343",
287 | "dcoh-e379",
288 | "dcoh46",
289 | "dcoh-e356",
290 | "725",
291 | "56",
292 | "165",
293 | "dcoh168",
294 | "64",
295 | "dcoh31",
296 | "324",
297 | "dcoh134",
298 | "694",
299 | "196",
300 | "couch136",
301 | "163",
302 | "763",
303 | "couch22",
304 | "798",
305 | "dcoh-e353",
306 | "49",
307 | "dcoh-e472",
308 | "dcoh250",
309 | "dcoh-e547",
310 | "16",
311 | "958",
312 | "dcoh-e401",
313 | "387",
314 | "dcoh-e412",
315 | "800",
316 | "940",
317 | "dcoh-e505",
318 | "92",
319 | "125",
320 | "dcoh57",
321 | "1012",
322 | "1008",
323 | "936",
324 | "593",
325 | "219",
326 | "844",
327 | "dcoh-e410",
328 | "888",
329 | "dcoh-e522",
330 | "dcoh-e521",
331 | "514",
332 | "653",
333 | "74",
334 | "228",
335 | "211",
336 | "355",
337 | "998",
338 | "569",
339 | "224",
340 | "dcoh-e402",
341 | "334",
342 | "248",
343 | "dcoh174",
344 | "dcoh-e534",
345 | "571",
346 | "352",
347 | "116",
348 | "150",
349 | "couch135",
350 | "449",
351 | "680",
352 | "95",
353 | "776",
354 | "928",
355 | "767",
356 | "868",
357 | "699",
358 | "542",
359 | "dcoh160",
360 | "54",
361 | "947",
362 | "307",
363 | "369",
364 | "couch75",
365 | "dcoh209",
366 | "1002",
367 | "992",
368 | "dcoh-e420",
369 | "707",
370 | "399",
371 | "dcoh-e495",
372 | "656",
373 | "dcoh294",
374 | "102",
375 | "230",
376 | "dcoh105",
377 | "751",
378 | "718",
379 | "dcoh101",
380 | "340",
381 | "774",
382 | "119",
383 | "260",
384 | "439",
385 | "dcoh211",
386 | "38",
387 | "396",
388 | "dcoh-e452",
389 | "dcoh0",
390 | "couch122",
391 | "458",
392 | "dcoh113",
393 | "dcoh2",
394 | "dcoh-e545",
395 | "dcoh-e473",
396 | "dcoh177",
397 | "578",
398 | "dcoh-e475",
399 | "736",
400 | "661",
401 | "dcoh-e329",
402 | "dcoh-e520",
403 | "dcoh71",
404 | "156",
405 | "395",
406 | "dcoh19",
407 | "802",
408 | "407",
409 | "507",
410 | "dcoh-e461",
411 | "870",
412 | "177",
413 | "dcoh238",
414 | "204",
415 | "dcoh63",
416 | "954",
417 | "dcoh32",
418 | "270",
419 | "919",
420 | "dcoh-e478",
421 | "dcoh-e400",
422 | "couch35",
423 | "dcoh106",
424 | "dcoh8",
425 | "684",
426 | "dcoh88",
427 | "555",
428 | "971",
429 | "784",
430 | "dcoh-e555",
431 | "568",
432 | "247",
433 | "358",
434 | "dcoh129",
435 | "397",
436 | "889",
437 | "105",
438 | "619",
439 | "dcoh259",
440 | "986",
441 | "couch62",
442 | "dcoh-e393",
443 | "148",
444 | "830",
445 | "538",
446 | "166",
447 | "dcoh74",
448 | "39",
449 | "couch73",
450 | "842",
451 | "143",
452 | "dcoh-e325",
453 | "couch19",
454 | "218",
455 | "1018",
456 | "627",
457 | "311",
458 | "652",
459 | "768",
460 | "dcoh-e527",
461 | "447",
462 | "620",
463 | "187",
464 | "831",
465 | "676",
466 | "428",
467 | "dcoh9",
468 | "couch68",
469 | "977",
470 | "28",
471 | "dcoh234",
472 | "dcoh68",
473 | "354",
474 | "85",
475 | "86",
476 | "dcoh-e457",
477 | "dcoh312",
478 | "dcoh278",
479 | "985",
480 | "dcoh-e434",
481 | "275",
482 | "880",
483 | "306",
484 | "dcoh45",
485 | "dcoh-e557",
486 | "couch64",
487 | "908",
488 | "113",
489 | "378",
490 | "77",
491 | "302",
492 | "dcoh-e357",
493 | "320",
494 | "803",
495 | "dcoh254",
496 | "dcoh85",
497 | "535",
498 | "couch52",
499 | "dcoh-e364",
500 | "dcoh187",
501 | "642",
502 | "dcoh208",
503 | "dcoh226",
504 | "640",
505 | "511",
506 | "dcoh-e358",
507 | "390",
508 | "886",
509 | "dcoh-e396",
510 | "dcoh266",
511 | "dcoh-e566",
512 | "159",
513 | "0",
514 | "dcoh303",
515 | "dcoh-e567",
516 | "232",
517 | "dcoh-e513",
518 | "341",
519 | "601",
520 | "852",
521 | "dcoh-e377",
522 | "414",
523 | "dcoh127",
524 | "dcoh-e319",
525 | "dcoh-e481",
526 | "185",
527 | "175",
528 | "292",
529 | "269",
530 | "786",
531 | "622",
532 | "229",
533 | "207",
534 | "332",
535 | "253",
536 | "109",
537 | "dcoh283",
538 | "636",
539 | "couch34",
540 | "549",
541 | "dcoh-e515",
542 | "200",
543 | "dcoh-e324",
544 | "couch45",
545 | "dcoh-e347",
546 | "110",
547 | "dcoh-e532",
548 | "337",
549 | "789",
550 | "69",
551 | "dcoh175",
552 | "dcoh-e376",
553 | "dcoh-e510",
554 | "dcoh167",
555 | "635",
556 | "488",
557 | "647",
558 | "dcoh-e453",
559 | "dcoh43",
560 | "900",
561 | "couch90",
562 | "58",
563 | "22",
564 | "dcoh-e525",
565 | "dcoh-e367",
566 | "dcoh-e550",
567 | "dcoh227",
568 | "couch96",
569 | "84",
570 | "couch120",
571 | "couch100",
572 | "674",
573 | "dcoh-e553",
574 | "42",
575 | "dcoh-e561",
576 | "610",
577 | "288",
578 | "dcoh40",
579 | "dcoh-e443",
580 | "158",
581 | "dcoh-e503",
582 | "604",
583 | "7",
584 | "520",
585 | "62",
586 | "dcoh-e560",
587 | "dcoh-e339",
588 | "dcoh251",
589 | "408",
590 | "748",
591 | "dcoh-e494",
592 | "573",
593 | "465",
594 | "489",
595 | "195",
596 | "400",
597 | "646",
598 | "dcoh110",
599 | "617",
600 | "couch85",
601 | "345",
602 | "couch53",
603 | "559",
604 | "dcoh-e386",
605 | "couch106",
606 | "couch67",
607 | "11",
608 | "83",
609 | "couch39",
610 | "dcoh291",
611 | "dcoh18",
612 | "dcoh76",
613 | "dcoh118",
614 | "couch138",
615 | "664",
616 | "75",
617 | "4",
618 | "391",
619 | "dcoh17",
620 | "176",
621 | "609",
622 | "536",
623 | "279",
624 | "dcoh-e415",
625 | "227",
626 | "dcoh142",
627 | "719",
628 | "dcoh-e432",
629 | "dcoh-e430",
630 | "dcoh309",
631 | "dcoh-e466",
632 | "dcoh70",
633 | "88",
634 | "dcoh-e563",
635 | "278",
636 | "dcoh-e408",
637 | "dcoh-e445",
638 | "486",
639 | "dcoh277",
640 | "dcoh-e317",
641 | "couch112",
642 | "660",
643 | "894",
644 | "dcoh130",
645 | "215",
646 | "290",
647 | "44",
648 | "dcoh-e487",
649 | "603",
650 | "dcoh237",
651 | "404",
652 | "dcoh28",
653 | "371",
654 | "539",
655 | "453",
656 | "couch108",
657 | "360",
658 | "186",
659 | "990",
660 | "dcoh-e336",
661 | "162",
662 | "494",
663 | "dcoh-e539",
664 | "460",
665 | "720",
666 | "711",
667 | "331",
668 | "673",
669 | "dcoh27",
670 | "500",
671 | "452",
672 | "dcoh279",
673 | "866",
674 | "1",
675 | "823",
676 | "dcoh107",
677 | "couch82",
678 | "173",
679 | "dcoh58",
680 | "24",
681 | "dcoh-e365",
682 | "338",
683 | "628",
684 | "497",
685 | "645",
686 | "couch61",
687 | "737",
688 | "671",
689 | "778",
690 | "474",
691 | "dcoh136",
692 | "dcoh-e341",
693 | "dcoh146",
694 | "210",
695 | "dcoh-e501",
696 | "350",
697 | "467",
698 | "dcoh23",
699 | "dcoh240",
700 | "couch16",
701 | "70",
702 | "1016",
703 | "420",
704 | "dcoh100",
705 | "dcoh-e372",
706 | "dcoh184",
707 | "957",
708 | "579",
709 | "dcoh-e424",
710 | "734",
711 | "898",
712 | "dcoh64",
713 | "393",
714 | "dcoh53",
715 | "199",
716 | "651",
717 | "300",
718 | "dcoh176",
719 | "couch14",
720 | "dcoh189",
721 | "dcoh-e334",
722 | "dcoh-e435",
723 | "863",
724 | "dcoh-e331",
725 | "68",
726 | "133",
727 | "dcoh164",
728 | "107",
729 | "couch57",
730 | "273",
731 | "dcoh232",
732 | "couch101",
733 | "459",
734 | "dcoh-e399",
735 | "942",
736 | "502",
737 | "dcoh-e359",
738 | "283",
739 | "384",
740 | "97",
741 | "655",
742 | "19",
743 | "819",
744 | "827",
745 | "418",
746 | "dcoh143",
747 | "205",
748 | "dcoh161",
749 | "1014",
750 | "52",
751 | "543",
752 | "dcoh256",
753 | "424",
754 | "dcoh-e467",
755 | "203",
756 | "214",
757 | "dcoh72",
758 | "dcoh290",
759 | "289",
760 | "dcoh222",
761 | "191",
762 | "couch134",
763 | "couch29",
764 | "239",
765 | "dcoh-e500",
766 | "dcoh307",
767 | "dcoh-e371",
768 | "478",
769 | "couch94",
770 | "10",
771 | "dcoh36",
772 | "745",
773 | "910",
774 | "832",
775 | "dcoh293",
776 | "couch133",
777 | "dcoh235",
778 | "dcoh218",
779 | "dcoh-e465",
780 | "825",
781 | "dcoh6",
782 | "343",
783 | "585",
784 | "dcoh42",
785 | "93",
786 | "dcoh172",
787 | "couch17",
788 | "518",
789 | "911",
790 | "226",
791 | "920",
792 | "692",
793 | "557",
794 | "602",
795 | "607",
796 | "845",
797 | "693",
798 | "788",
799 | "644",
800 | "643",
801 | "586",
802 | "couch56",
803 | "584",
804 | "141",
805 | "couch98",
806 | "374",
807 | "659",
808 | "couch69",
809 | "801",
810 | "136",
811 | "dcoh-e544",
812 | "611",
813 | "57",
814 | "dcoh-e323",
815 | "904",
816 | "479",
817 | "dcoh-e389",
818 | "941",
819 | "933",
820 | "dcoh29",
821 | "854",
822 | "1003",
823 | "190",
824 | "705",
825 | "dcoh102",
826 | "couch1",
827 | "dcoh49",
828 | "dcoh-e441",
829 | "280",
830 | "598",
831 | "244",
832 | "dcoh-e338",
833 | "412",
834 | "55",
835 | "dcoh51",
836 | "172",
837 | "couch50",
838 | "dcoh112",
839 | "dcoh-e541",
840 | "365",
841 | "dcoh119",
842 | "couch31",
843 | "327",
844 | "dcoh-e488",
845 | "94",
846 | "dcoh271",
847 | "dcoh59",
848 | "dcoh-e395",
849 | "dcoh-e460",
850 | "dcoh56",
851 | "dcoh-e498",
852 | "couch43",
853 | "dcoh5",
854 | "342",
855 | "dcoh-e531",
856 | "dcoh260",
857 | "291",
858 | "545",
859 | "dcoh73",
860 | "couch5",
861 | "couch113",
862 | "dcoh201",
863 | "dcoh228",
864 | "192",
865 | "dcoh138",
866 | "388",
867 | "dcoh22",
868 | "563",
869 | "dcoh247",
870 | "895",
871 | "833",
872 | "476",
873 | "756",
874 | "456",
875 | "dcoh-e427",
876 | "921",
877 | "dcoh38",
878 | "445",
879 | "dcoh1",
880 | "29",
881 | "548",
882 | "975",
883 | "1017",
884 | "dcoh308",
885 | "40",
886 | "208",
887 | "couch144",
888 | "978",
889 | "dcoh-e448",
890 | "742",
891 | "795",
892 | "180",
893 | "dcoh263",
894 | "902",
895 | "268",
896 | "71",
897 | "164",
898 | "dcoh243",
899 | "couch8",
900 | "266",
901 | "670",
902 | "698",
903 | "couch9",
904 | "336",
905 | "828",
906 | "35",
907 | "686",
908 | "989",
909 | "446",
910 | "178",
911 | "149",
912 | "dcoh193",
913 | "743",
914 | "dcoh-e446",
915 | "dcoh194",
916 | "882",
917 | "530",
918 | "430",
919 | "dcoh39",
920 | "727",
921 | "couch81",
922 | "dcoh83",
923 | "980",
924 | "389",
925 | "couch66",
926 | "988",
927 | "dcoh246",
928 | "544",
929 | "dcoh248",
930 | "654",
931 | "220",
932 | "dcoh-e320",
933 | "dcoh86",
934 | "968",
935 | "couch114",
936 | "368",
937 | "344",
938 | "dcoh-e538",
939 | "dcoh253",
940 | "364",
941 | "1005",
942 | "dcoh206",
943 | "738",
944 | "892",
945 | "dcoh-e499",
946 | "907",
947 | "dcoh-e437",
948 | "dcoh242",
949 | "426",
950 | "couch15",
951 | "323",
952 | "758",
953 | "couch86",
954 | "3",
955 | "867",
956 | "dcoh-e485",
957 | "552",
958 | "695",
959 | "dcoh-e438",
960 | "615",
961 | "929",
962 | "433",
963 | "972",
964 | "dcoh299",
965 | "dcoh109",
966 | "376",
967 | "532",
968 | "937",
969 | "dcoh-e483",
970 | "126",
971 | "dcoh126",
972 | "dcoh179",
973 | "dcoh-e407",
974 | "dcoh11",
975 | "dcoh166",
976 | "213",
977 | "528",
978 | "dcoh-e506",
979 | "dcoh155",
980 | "couch129",
981 | "dcoh310",
982 | "435",
983 | "783",
984 | "483",
985 | "375",
986 | "678",
987 | "682",
988 | "657",
989 | "dcoh-e504",
990 | "dcoh-e558",
991 | "484",
992 | "couch70",
993 | "dcoh295",
994 | "couch104",
995 | "436",
996 | "couch105",
997 | "dcoh-e440",
998 | "265",
999 | "couch132",
1000 | "couch0",
1001 | "170",
1002 | "dcoh-e480",
1003 | "529",
1004 | "dcoh-e479",
1005 | "couch59",
1006 | "710",
1007 | "709",
1008 | "dcoh-e349",
1009 | "dcoh95",
1010 | "dcoh-e335",
1011 | "dcoh-e414",
1012 | "198",
1013 | "308",
1014 | "157",
1015 | "131",
1016 | "700",
1017 | "couch83",
1018 | "98",
1019 | "dcoh233",
1020 | "377",
1021 | "27",
1022 | "951",
1023 | "924",
1024 | "235",
1025 | "791",
1026 | "dcoh245",
1027 | "dcoh-e326",
1028 | "dcoh52",
1029 | "dcoh287",
1030 | "974",
1031 | "859",
1032 | "dcoh-e464",
1033 | "dcoh205",
1034 | "708",
1035 | "505",
1036 | "dcoh-e375",
1037 | "490",
1038 | "379",
1039 | "dcoh33",
1040 | "108",
1041 | "184",
1042 | "dcoh90",
1043 | "189",
1044 | "dcoh7",
1045 | "dcoh-e535",
1046 | "506",
1047 | "366",
1048 | "222",
1049 | "582",
1050 | "810",
1051 | "566",
1052 | "553",
1053 | "349",
1054 | "couch142",
1055 | "822",
1056 | "267",
1057 | "471",
1058 | "dcoh220",
1059 | "286",
1060 | "792",
1061 | "82",
1062 | "72",
1063 | "860",
1064 | "couch78",
1065 | "781",
1066 | "271",
1067 | "dcoh173",
1068 | "79",
1069 | "356",
1070 | "dcoh156",
1071 | "dcoh203",
1072 | "481",
1073 | "dcoh-e536",
1074 | "34",
1075 | "dcoh-e450",
1076 | "525",
1077 | "dcoh-e442",
1078 | "dcoh152",
1079 | "dcoh-e423",
1080 | "dcoh111",
1081 | "dcoh133",
1082 | "256",
1083 | "495",
1084 | "561",
1085 | "427",
1086 | "dcoh-e542",
1087 | "442",
1088 | "930",
1089 | "984",
1090 | "76",
1091 | "67",
1092 | "457",
1093 | "223",
1094 | "1013",
1095 | "965",
1096 | "99",
1097 | "130",
1098 | "915",
1099 | "234",
1100 | "918",
1101 | "785",
1102 | "23",
1103 | "couch55",
1104 | "dcoh183",
1105 | "127",
1106 | "dcoh-e537",
1107 | "dcoh-e562",
1108 | "73",
1109 | "319",
1110 | "dcoh-e404",
1111 | "516",
1112 | "429",
1113 | "976",
1114 | "973",
1115 | "dcoh216",
1116 | "909",
1117 | "770",
1118 | "592",
1119 | "dcoh-e454",
1120 | "dcoh296",
1121 | "dcoh66",
1122 | "dcoh196",
1123 | "691",
1124 | "590",
1125 | "272",
1126 | "193",
1127 | "dcoh-e444",
1128 | "953",
1129 | "dcoh-e337",
1130 | "923",
1131 | "884",
1132 | "dcoh-e512",
1133 | "477",
1134 | "dcoh-e321",
1135 | "46",
1136 | "534",
1137 | "dcoh286",
1138 | "dcoh-e489",
1139 | "dcoh84",
1140 | "577",
1141 | "333",
1142 | "772",
1143 | "922",
1144 | "dcoh207",
1145 | "914",
1146 | "944",
1147 | "351",
1148 | "732",
1149 | "313",
1150 | "couch13",
1151 | "dcoh82",
1152 | "dcoh-e502",
1153 | "724",
1154 | "dcoh-e390",
1155 | "couch37",
1156 | "1000",
1157 | "297",
1158 | "779",
1159 | "905",
1160 | "171",
1161 | "dcoh-e482",
1162 | "385",
1163 | "450",
1164 | "dcoh-e394",
1165 | "183",
1166 | "dcoh10",
1167 | "51",
1168 | "501",
1169 | "dcoh-e564",
1170 | "dcoh21",
1171 | "583",
1172 | "596",
1173 | "dcoh-e322",
1174 | "dcoh-e548",
1175 | "361",
1176 | "750",
1177 | "87",
1178 | "405",
1179 | "515",
1180 | "31",
1181 | "849",
1182 | "dcoh96",
1183 | "dcoh273",
1184 | "dcoh121",
1185 | "310",
1186 | "dcoh139",
1187 | "901",
1188 | "dcoh-e405",
1189 | "217",
1190 | "dcoh-e497",
1191 | "917",
1192 | "dcoh301",
1193 | "996",
1194 | "15",
1195 | "726",
1196 | "dcoh188",
1197 | "441",
1198 | "194",
1199 | "1010",
1200 | "48",
1201 | "728",
1202 | "dcoh-e425",
1203 | "dcoh-e524",
1204 | "14",
1205 | "637",
1206 | "523",
1207 | "246",
1208 | "dcoh-e492",
1209 | "13",
1210 | "318",
1211 | "983",
1212 | "dcoh104",
1213 | "152",
1214 | "249",
1215 | "410",
1216 | "299",
1217 | "461",
1218 | "576",
1219 | "6",
1220 | "couch92",
1221 | "259",
1222 | "952",
1223 | "101",
1224 | "625",
1225 | "560",
1226 | "394",
1227 | "855",
1228 | "dcoh-e429",
1229 | "787",
1230 | "dcoh114",
1231 | "994",
1232 | "dcoh229",
1233 | "dcoh305",
1234 | "dcoh149",
1235 | "couch139",
1236 | "294",
1237 | "dcoh311",
1238 | "685",
1239 | "dcoh80",
1240 | "dcoh97",
1241 | "437",
1242 | "dcoh87",
1243 | "dcoh257",
1244 | "couch97",
1245 | "couch24",
1246 | "537",
1247 | "846",
1248 | "dcoh165",
1249 | "dcoh249",
1250 | "dcoh44",
1251 | "681",
1252 | "955",
1253 | "956",
1254 | "couch26",
1255 | "111",
1256 | "123",
1257 | "dcoh-e528",
1258 | "508",
1259 | "dcoh284",
1260 | "dcoh-e428",
1261 | "181",
1262 | "899",
1263 | "690",
1264 | "380",
1265 | "dcoh-e458",
1266 | "dcoh-e496",
1267 | "838",
1268 | "dcoh-e511",
1269 | "151",
1270 | "dcoh-e327",
1271 | "221",
1272 | "556",
1273 | "926",
1274 | "couch2",
1275 | "dcoh147",
1276 | "517",
1277 | "dcoh132",
1278 | "dcoh231",
1279 | "dcoh37",
1280 | "533",
1281 | "dcoh-e355",
1282 | "703",
1283 | "dcoh-e374",
1284 | "503",
1285 | "122",
1286 | "782",
1287 | "couch49",
1288 | "couch141",
1289 | "946",
1290 | "890",
1291 | "couch48",
1292 | "839",
1293 | "564",
1294 | "864",
1295 | "755",
1296 | "470",
1297 | "546",
1298 | "dcoh50",
1299 | "145",
1300 | "262",
1301 | "dcoh-e491",
1302 | "188",
1303 | "dcoh-e540",
1304 | "856",
1305 | "couch63",
1306 | "dcoh-e340",
1307 | "66",
1308 | "60",
1309 | "631",
1310 | "301",
1311 | "777",
1312 | "dcoh304",
1313 | "167",
1314 | "dcoh159",
1315 | "couch107",
1316 | "dcoh236",
1317 | "dcoh78",
1318 | "dcoh135",
1319 | "624",
1320 | "dcoh282",
1321 | "650",
1322 | "887",
1323 | "dcoh275",
1324 | "dcoh-e348",
1325 | "couch116",
1326 | "dcoh181",
1327 | "862",
1328 | "couch124",
1329 | "509",
1330 | "dcoh117",
1331 | "dcoh25",
1332 | "959",
1333 | "couch93",
1334 | "634",
1335 | "dcoh285",
1336 | "couch88",
1337 | "339",
1338 | "dcoh131",
1339 | "179",
1340 | "595",
1341 | "861",
1342 | "dcoh281",
1343 | "697",
1344 | "816",
1345 | "103",
1346 | "201",
1347 | "dcoh-e397",
1348 | "dcoh-e369",
1349 | "367",
1350 | "dcoh65",
1351 | "608",
1352 | "couch47",
1353 | "dcoh-e409",
1354 | "677",
1355 | "dcoh67",
1356 | "499",
1357 | "dcoh306",
1358 | "117",
1359 | "dcoh-e333",
1360 | "dcoh141",
1361 | "dcoh274",
1362 | "dcoh-e342",
1363 | "997",
1364 | "dcoh169",
1365 | "316",
1366 | "couch140",
1367 | "couch4",
1368 | "1009",
1369 | "dcoh-e315",
1370 | "dcoh47",
1371 | "dcoh-e417",
1372 | "couch65",
1373 | "981",
1374 | "61",
1375 | "124",
1376 | "421",
1377 | "dcoh288",
1378 | "1001",
1379 | "dcoh148",
1380 | "760",
1381 | "621",
1382 | "927",
1383 | "281",
1384 | "dcoh91",
1385 | "858",
1386 | "dcoh-e391"
1387 | ],
1388 | "test_writers": [
1389 | "325",
1390 | "258",
1391 | "couch41",
1392 | "104",
1393 | "dcoh30",
1394 | "couch54",
1395 | "296",
1396 | "couch125",
1397 | "935",
1398 | "169",
1399 | "160",
1400 | "840",
1401 | "730",
1402 | "dcoh292",
1403 | "dcoh163",
1404 | "504",
1405 | "dcoh-e330",
1406 | "dcoh-e477",
1407 | "498",
1408 | "721",
1409 | "dcoh-e455",
1410 | "dcoh-e549",
1411 | "couch38",
1412 | "663",
1413 | "597",
1414 | "363",
1415 | "793",
1416 | "dcoh265",
1417 | "638",
1418 | "dcoh-e433",
1419 | "287",
1420 | "522",
1421 | "dcoh14",
1422 | "723",
1423 | "dcoh-e523",
1424 | "dcoh-e517",
1425 | "dcoh-e526",
1426 | "20",
1427 | "dcoh-e546",
1428 | "dcoh244",
1429 | "879",
1430 | "326",
1431 | "128",
1432 | "dcoh-e514",
1433 | "89",
1434 | "873",
1435 | "9",
1436 | "couch23",
1437 | "869",
1438 | "799",
1439 | "669",
1440 | "875",
1441 | "423",
1442 | "114",
1443 | "216",
1444 | "541",
1445 | "231",
1446 | "580",
1447 | "979",
1448 | "118",
1449 | "274",
1450 | "480",
1451 | "couch3",
1452 | "865",
1453 | "dcoh-e422",
1454 | "797",
1455 | "438",
1456 | "667",
1457 | "dcoh153",
1458 | "276",
1459 | "251",
1460 | "769",
1461 | "couch46",
1462 | "570",
1463 | "dcoh-e530",
1464 | "dcoh-e449",
1465 | "850",
1466 | "496",
1467 | "dcoh270",
1468 | "321",
1469 | "794",
1470 | "26",
1471 | "couch89",
1472 | "dcoh20",
1473 | "970",
1474 | "dcoh-e370",
1475 | "5",
1476 | "couch76",
1477 | "747",
1478 | "540",
1479 | "878",
1480 | "814",
1481 | "couch102",
1482 | "440",
1483 | "641",
1484 | "752",
1485 | "dcoh-e363",
1486 | "432",
1487 | "704",
1488 | "197",
1489 | "491",
1490 | "962",
1491 | "dcoh-e398",
1492 | "939",
1493 | "dcoh-e559",
1494 | "dcoh-e362",
1495 | "dcoh-e332",
1496 | "146",
1497 | "731",
1498 | "616",
1499 | "dcoh262",
1500 | "444",
1501 | "dcoh-e406",
1502 | "100",
1503 | "dcoh197",
1504 | "59",
1505 | "dcoh-e554",
1506 | "613",
1507 | "couch103",
1508 | "679",
1509 | "couch131",
1510 | "63",
1511 | "dcoh182",
1512 | "257",
1513 | "706",
1514 | "dcoh12",
1515 | "415",
1516 | "372",
1517 | "78",
1518 | "809",
1519 | "431",
1520 | "982",
1521 | "dcoh-e366",
1522 | "966",
1523 | "couch36",
1524 | "462",
1525 | "dcoh35",
1526 | "dcoh-e552",
1527 | "couch32",
1528 | "couch79",
1529 | "couch91",
1530 | "588",
1531 | "434",
1532 | "dcoh99",
1533 | "245",
1534 | "411",
1535 | "couch40",
1536 | "632",
1537 | "couch74",
1538 | "284",
1539 | "475",
1540 | "couch30",
1541 | "dcoh158",
1542 | "dcoh-e462",
1543 | "dcoh16",
1544 | "744",
1545 | "couch12",
1546 | "dcoh221",
1547 | "696",
1548 | "37",
1549 | "dcoh255",
1550 | "1015",
1551 | "885",
1552 | "212",
1553 | "847",
1554 | "dcoh185",
1555 | "dcoh69",
1556 | "626",
1557 | "dcoh81",
1558 | "couch18",
1559 | "dcoh98",
1560 | "21",
1561 | "dcoh258",
1562 | "dcoh268",
1563 | "422",
1564 | "dcoh302",
1565 | "dcoh252",
1566 | "144",
1567 | "714",
1568 | "298",
1569 | "665",
1570 | "820",
1571 | "couch80",
1572 | "575",
1573 | "315",
1574 | "dcoh-e382",
1575 | "466",
1576 | "dcoh145",
1577 | "359",
1578 | "558",
1579 | "916",
1580 | "134",
1581 | "dcoh-e518",
1582 | "554",
1583 | "couch28",
1584 | "562",
1585 | "241",
1586 | "dcoh94",
1587 | "330",
1588 | "746",
1589 | "dcoh204",
1590 | "dcoh125",
1591 | "950",
1592 | "couch126",
1593 | "174",
1594 | "dcoh217",
1595 | "572",
1596 | "dcoh-e507",
1597 | "402",
1598 | "348",
1599 | "591",
1600 | "303",
1601 | "551",
1602 | "829",
1603 | "dcoh-e368",
1604 | "dcoh34",
1605 | "524",
1606 | "91",
1607 | "dcoh230",
1608 | "dcoh171",
1609 | "716",
1610 | "dcoh289",
1611 | "897",
1612 | "couch71",
1613 | "dcoh-e436",
1614 | "416",
1615 | "dcoh-e378",
1616 | "574",
1617 | "dcoh225",
1618 | "dcoh191",
1619 | "493",
1620 | "629",
1621 | "dcoh-e380",
1622 | "couch123",
1623 | "dcoh140",
1624 | "741",
1625 | "dcoh-e516",
1626 | "33",
1627 | "135",
1628 | "949",
1629 | "658",
1630 | "932",
1631 | "630",
1632 | "987",
1633 | "dcoh-e447",
1634 | "dcoh-e556",
1635 | "dcoh190",
1636 | "531",
1637 | "605",
1638 | "dcoh213",
1639 | "90",
1640 | "818",
1641 | "47",
1642 | "dcoh120",
1643 | "dcoh55",
1644 | "couch95",
1645 | "243",
1646 | "406",
1647 | "dcoh115",
1648 | "dcoh-e351",
1649 | "796",
1650 | "dcoh-e360",
1651 | "dcoh-e384",
1652 | "dcoh108",
1653 | "891",
1654 | "762",
1655 | "115",
1656 | "dcoh202",
1657 | "874",
1658 | "834",
1659 | "451",
1660 | "261",
1661 | "2",
1662 | "513",
1663 | "871",
1664 | "140",
1665 | "32",
1666 | "991",
1667 | "206",
1668 | "dcoh199",
1669 | "472",
1670 | "255",
1671 | "401",
1672 | "dcoh123",
1673 | "dcoh297",
1674 | "841",
1675 | "254",
1676 | "240",
1677 | "948",
1678 | "50",
1679 | "45",
1680 | "594",
1681 | "945",
1682 | "587",
1683 | "464",
1684 | "550",
1685 | "766",
1686 | "dcoh-e419",
1687 | "dcoh137",
1688 | "couch110",
1689 | "121",
1690 | "960",
1691 | "dcoh-e387",
1692 | "dcoh300",
1693 | "482",
1694 | "dcoh239",
1695 | "dcoh-e486",
1696 | "25",
1697 | "dcoh214",
1698 | "765",
1699 | "903",
1700 | "512",
1701 | "649",
1702 | "dcoh-e346",
1703 | "417",
1704 | "896",
1705 | "403",
1706 | "dcoh-e543",
1707 | "dcoh-e469",
1708 | "857",
1709 | "526",
1710 | "dcoh-e463",
1711 | "155",
1712 | "463",
1713 | "dcoh24",
1714 | "815",
1715 | "293",
1716 | "314",
1717 | "couch111",
1718 | "999",
1719 | "couch127",
1720 | "392",
1721 | "836",
1722 | "dcoh-e352",
1723 | "dcoh-e565",
1724 | "dcoh-e350",
1725 | "668",
1726 | "139",
1727 | "238",
1728 | "1011",
1729 | "dcoh3",
1730 | "346",
1731 | "dcoh-e388",
1732 | "599",
1733 | "876",
1734 | "dcoh79",
1735 | "237"
1736 | ]
1737 | }
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import numpy as np
4 | from torch.utils.data import Dataset
5 |
6 | class Writing(Dataset):
7 | def __init__(self,handwriting_info:dict,transform=None,train=True):
8 | super().__init__()
9 | self.users = handwriting_info.keys()
10 | self.users_cnt = len(self.users)
11 | self.train = train
12 | self.features = []
13 | self.user_labels = []
14 | for i,k in enumerate(self.users):
15 | # extract_features(handwriting_info[k],self.features)
16 | self.features.extend(handwriting_info[k])
17 | self.user_labels.extend([i] * len(handwriting_info[k]))
18 | assert len(self.user_labels) == len(self.features)
19 | self.features_cnt = len(self.features)
20 | self.feature_dims = np.shape(self.features[0])[1] # 就是时间函数的数量,这里是12个
21 | self.transform = transform
22 |
23 | def __len__(self):
24 | return self.features_cnt
25 |
26 | def __getitem__(self,idx):
27 | if self.train:
28 | feature = self.features[idx]
29 | if self.transform is not None:
30 | feature = self.transform(feature)
31 | else:
32 | feature = self.features[idx]
33 | return feature,len(feature),self.user_labels[idx]
34 |
35 | def collate_fn(batch:list):
36 | batch_size = len(batch)
37 | handwriting = [i[0] for i in batch]
38 | hw_len = np.array([i[1] for i in batch],dtype=np.float32)
39 | user_labels = np.array([i[2] for i in batch])
40 | max_len = int(np.max(hw_len))
41 | time_function_cnts = np.shape(handwriting[0])[1]
42 | handwriting_padded = np.zeros((batch_size,max_len,time_function_cnts),dtype=np.float32)
43 | for i,hw in enumerate(handwriting):
44 | handwriting_padded[i,:hw.shape[0]] = hw
45 | return handwriting_padded,hw_len,user_labels
46 |
47 | if __name__ == '__main__':
48 | ...
--------------------------------------------------------------------------------
/divide.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import os,pickle,json,argparse
4 | import numpy as np
5 | from utils import time_functions,clock
6 |
7 | def divide_data(src_root='./data',tgt_root='./data/OLIWER'):
8 | if not os.path.exists(f'{tgt_root}/OLIWER.pkl'):
9 | print('Merging data.')
10 | with open(f'{src_root}/OLHWDB2/OLHWDB2.pkl','rb') as f:
11 | olhwdb2 = pickle.load(f,encoding='iso-8859-1')
12 | with open(f'{src_root}/DCOH-E/DCOH-E.pkl','rb') as f:
13 | dcohe = pickle.load(f,encoding='iso-8859-1')
14 | with open(f'{src_root}/COUCH09/COUCH09.pkl','rb') as f:
15 | couch = pickle.load(f,encoding='iso-8859-1')
16 |
17 | olhwdb_cnt,dcohe_cnt,couch_cnt = 0,0,0
18 | olhwdb_writer_cnt,dcohe_writer_cnt,couch_writer_cnt = 0,0,0
19 | data = {}
20 | for k in olhwdb2:
21 | if len(olhwdb2[k]) <= 20:
22 | continue
23 | data[str(k)] = olhwdb2[k]
24 | olhwdb_cnt += len(olhwdb2[k])
25 | olhwdb_writer_cnt += 1
26 | for k in couch:
27 | if len(couch[k]) <= 20:
28 | continue
29 | newk = f'couch{k}'
30 | data[newk] = couch[k]
31 | couch_cnt += len(couch[k])
32 |
33 | couch_writer_cnt += 1
34 | for k in dcohe:
35 | if len(dcohe[k]) <= 20:
36 | continue
37 | data[k] = dcohe[k]
38 | dcohe_cnt += len(dcohe[k])
39 | dcohe_writer_cnt += 1
40 | cnt = 0
41 | for k in data:
42 | cnt += len(data[k])
43 | print('user:',len(data),'sample:',cnt)
44 | print('dcohe samples:',dcohe_cnt,dcohe_writer_cnt)
45 | print('olhwdb2 samples:',olhwdb_cnt,olhwdb_writer_cnt)
46 | print('couch samples:',couch_cnt,couch_writer_cnt)
47 | os.makedirs(tgt_root,exist_ok=True)
48 | with open(f'{tgt_root}/OLIWER.pkl','wb') as f:
49 | pickle.dump(data,f)
50 | else:
51 | print('Loading existing data.')
52 | with open(f'{tgt_root}/OLIWER.pkl','rb') as f:
53 | data = pickle.load(f,encoding='iso-8859-1')
54 | print('user:',len(data),'sample:',np.sum([len(data[k]) for k in data.keys()]))
55 |
56 | if os.path.exists(f'./{tgt_root}/split.json'):
57 | with open(f'{tgt_root}/split.json') as f:
58 | split = json.load(f)
59 | train_writers = split['train_writers']
60 | test_writers = split['test_writers']
61 | print('Loading existing splits.')
62 | else:
63 | train_num = int(0.8 * len(data.keys()))
64 | train_writers = np.random.choice(list(data.keys()),size=train_num,replace=False)
65 | test_writers = list(set(list(data.keys())) - set(train_writers))
66 | split = {}
67 | split['train_writers'] = list(train_writers)
68 | split['test_writers'] = list(test_writers)
69 | with open(f'{tgt_root}/split.json','w',encoding='utf-8') as f:
70 | f.write(json.dumps(split,indent=4,ensure_ascii=False))
71 | print('Generating new splits.')
72 |
73 | train,test = {},{}
74 | for k in train_writers:
75 | train[k] = data[k]
76 | for k in test_writers:
77 | test[k] = data[k]
78 |
79 | with open(f'{tgt_root}/train.pkl','wb') as f:
80 | pickle.dump(train,f)
81 | with open(f'{tgt_root}/test.pkl','wb') as f:
82 | pickle.dump(test,f)
83 |
84 | @clock
85 | def extract_and_store(src_root='./data/OLIWER/train.pkl',tgt_root='./data/OLIWER/train-tf.pkl'):
86 | with open(src_root,'rb') as f:
87 | handwriting_info = pickle.load(f,encoding='iso-8859-1')
88 | writing = {}
89 | for i,k in enumerate(handwriting_info.keys()):
90 | writing[k] = []
91 | for each in handwriting_info[k]:
92 | writing[k].append(time_functions(each))
93 | with open(tgt_root,'wb') as f:
94 | pickle.dump(writing,f)
95 |
96 | if __name__ == '__main__':
97 | parser = argparse.ArgumentParser()
98 | parser.add_argument('--divide',action='store_true')
99 | parser.add_argument('--extract',action='store_true')
100 | opt = parser.parse_args()
101 | if opt.divide:
102 | divide_data('./data','./data/OLIWER')
103 | if opt.extract:
104 | extract_and_store('./data/OLIWER/train.pkl','./data/OLIWER/train-tf.pkl')
105 | extract_and_store('./data/OLIWER/test.pkl','./data/OLIWER/test-tf.pkl')
--------------------------------------------------------------------------------
/evaluate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch,os,time
4 | from torch import nn
5 | from scipy import io
6 | import numpy as np
7 | from model_utils import db_augmentation,average_query_expansion
8 |
9 | def evaluate(qf,ql,gf,gl): # q是query,g是gallery
10 | # q是一个,gallery是全部19732个
11 | # print(qf.shape,ql.shape,qc.shape,gf.shape,gl.shape,gc.shape)
12 | query = qf.view(-1,1) # (512,1)
13 | score = torch.mm(gf,query) # (19732,1)
14 | # score = nn.PairwiseDistance(p=2)(gf,qf.view(1,-1))
15 | score = score.squeeze().cpu().numpy()
16 |
17 | idx = np.argsort(score)[::-1]
18 | # print(ql,gl[idx][:10],gl[idx][10:20],gl[idx][20:30])
19 | query_idx = np.argwhere(gl == ql) # gallery里总有当前query的类别的,而且不止一张图是这个类
20 | positive_idx = query_idx
21 | metrics = compute_mAP(idx,positive_idx)
22 | return metrics # (ap,CMC)
23 |
24 | def compute_mAP(idx,positive_idx):
25 | ap = 0
26 | cmc = torch.zeros((len(idx)))
27 | if positive_idx.size == 0:
28 | cmc[0] = -1
29 | return ap,cmc
30 | len_pos = len(positive_idx)
31 | mask = np.in1d(idx,positive_idx) # 找到positive的index,也就是不同摄像头的同一个人
32 | rows_pos = np.argwhere(mask).flatten()
33 | # print(rows_pos,len_pos)
34 | cmc[rows_pos[0]:] = 1 # 赋值1的每个位置都不同,每个位置累加起来再除以总数就是Acc@1这些了
35 | # 注意这里有个:,是从第一个往后都赋值1,所以Rank@10会比Rank@1大
36 | for i in range(len_pos): # len_pos是不知道的,所以无所谓有多少个gallery样本
37 | precision = (i + 1) * 1. / (rows_pos[i] + 1) # 这就是每一个格子的precision
38 | if rows_pos[i] != 0:
39 | old_precision = i * 1.0 / rows_pos[i]
40 | else:
41 | old_precision = 1.0
42 | ap = ap + (old_precision + precision) / 2 # 不太理解为什么要old_precision然后除以2
43 | # ap = ap + precision
44 | ap = ap / len_pos
45 |
46 | return ap,cmc
47 |
48 | def compute_metrics(res,logger,dba,device,verbose=True):
49 | query_feature = res['query_feature']
50 | query_label = res['query_label']
51 | gallery_feature = res['gallery_feature']
52 | gallery_label = res['gallery_label']
53 |
54 | if dba:
55 | time_start = time.time()
56 | query_feature,gallery_feature = db_augmentation(query_feature,gallery_feature,10)
57 | query_feature,gallery_feature = average_query_expansion(query_feature,gallery_feature,5)
58 | query_feature = query_feature / np.linalg.norm(query_feature,axis=1,keepdims=True)
59 | gallery_feature = gallery_feature / np.linalg.norm(gallery_feature,axis=1,keepdims=True)
60 | logger.info(f'DBA & AQE time consuming: {time.time() - time_start:.4f}s')
61 |
62 | query_feature = torch.FloatTensor(query_feature).to(device)
63 | gallery_feature = torch.FloatTensor(gallery_feature).to(device)
64 |
65 | CMC = torch.zeros((len(gallery_label)))
66 | # aps = []
67 | ap = 0.
68 | time_sum = 0.
69 | for i in range(len(query_label)):
70 | time_start = time.time()
71 | cur_ap,cur_CMC = evaluate(query_feature[i],query_label[i],gallery_feature,gallery_label)
72 | time_sum += (time.time() - time_start)
73 | if cur_CMC[0] == -1: continue
74 | CMC += cur_CMC
75 | ap += cur_ap
76 | # aps.append(cur_ap)
77 | # logger.info(f'evaluate time consuming: {time_sum:.4f}s')
78 | time_avg = time_sum / len(query_label)
79 |
80 | CMC /= len(query_label)
81 | ap /= len(query_label)
82 | if verbose:
83 | logger.info(f'[single query] Rank@1: {CMC[0] * 100.:.4f}% Rank@5: {CMC[4] * 100.:.4f}% Rank@10: {CMC[9] * 100.:.4f}%')
84 | logger.info(f'[single query] mAP: {ap * 100.:.4f}%')
85 | return time_avg,ap,CMC[0] * 100.,CMC[4] * 100.,CMC[9] * 100.
86 |
87 | if __name__ == '__main__':
88 | from dataset import Writing
89 | import pickle
90 | import matplotlib.pyplot as plt
91 | gallery_root = f'./data/OLER/gallery-tf-optxy2.pkl'
92 | with open(gallery_root,'rb') as f:
93 | handwriting_info = pickle.load(f,encoding='iso-8859-1')
94 | gallery_dataset = Writing(handwriting_info,train=False)
95 | _,aps = compute_metrics(None,False,'cpu')
96 | l = gallery_dataset.user_labels
97 | print(aps)
98 | k = [len(np.where(l == i)[0]) for i in np.sort(list(set(l)))]
99 | print(k)
100 | d = {i:(len(np.where(l == i)[0]),aps[i]) for i in np.sort(list(set(l)))}
101 |
102 | d1 = [(len(np.where(l == i)[0]),aps[i]) for i in np.sort(list(set(l)))]
103 | d1 = sorted(d1,key=lambda x:x[0])
104 | # plt.hist(aps,len(aps))
105 | # x_array = list(set(d.values()))
106 | # plt.bar(range(x_array),)
107 | aps = [each[1] for each in d1]
108 | idx = [each[0] for each in d1]
109 | idx1 = [idx[0],idx[len(idx) // 4],idx[len(idx) // 2],idx[len(idx) // 4 * 3],idx[-1]]
110 | idx2 = [0,len(idx) // 4,len(idx) // 2,len(idx) // 4 * 3,len(idx)]
111 | idx3 = [''] * len(idx)
112 | idx3[0] = idx[0]
113 | idx3[len(idx) // 4] = idx[len(idx) // 4]
114 | idx3[len(idx) // 2] = idx[len(idx) // 2]
115 | idx3[len(idx) // 4 * 3] = idx[len(idx) // 4 * 3]
116 | idx3[-1] = idx[-1]
117 | print(d1)
118 | plt.bar(range(len(aps)),aps)
119 | plt.xticks(range(len(idx)),idx3)
120 | plt.savefig('./kkk.png',dpi=500)
121 | # print(aps)
122 | # a = np.array([1,2,3,4,5,1,2,1,1])
123 | # b = np.array([3,2])
124 | # c = np.in1d(a,b)
125 | # print(c)
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | from model_utils import DepthwiseSeparableConv,CAIR,HFGA,ContextBlock,SelectivePool1d,get_len_mask
7 |
8 | class Head(nn.Module):
9 | def __init__(self,d_in,d_hidden,num_classes,bias=True):
10 | super().__init__()
11 | self.fc1 = nn.Linear(d_in,d_hidden,bias=bias)
12 | self.dropout = nn.Dropout(0.1)
13 | self.head = nn.Linear(d_hidden,num_classes,bias=bias)
14 |
15 | def forward(self,x):
16 | x = self.dropout(self.fc1(x))
17 | y = self.head(x)
18 | return x,y
19 |
20 | class DOLPHIN(nn.Module):
21 | def __init__(self,d_in,num_classes):
22 | super().__init__()
23 | self.conv = DepthwiseSeparableConv(d_in,64,7,skip=False,se_rate=0.,drop_path_rate=0.)
24 | self.block0 = nn.Sequential(
25 | CAIR(64,96,1,kernel_size=5,stride=2,skip=False,se_rate=0.25,drop_path_rate=0.1),
26 | CAIR(96,128,3,kernel_size=3,stride=1,skip=False,se_rate=0.25,drop_path_rate=0.1)
27 | )
28 | self.block1 = nn.Sequential(
29 | CAIR(128,160,1,kernel_size=5,stride=2,skip=False,se_rate=0.25,drop_path_rate=0.1),
30 | CAIR(160,192,3,kernel_size=3,stride=1,skip=False,se_rate=0.25,drop_path_rate=0.1)
31 | )
32 | self.block2 = nn.Sequential(
33 | CAIR(192,224,1,kernel_size=5,stride=2,skip=False,se_rate=0.25,drop_path_rate=0.1),
34 | CAIR(224,256,3,kernel_size=3,stride=1,skip=False,se_rate=0.25,drop_path_rate=0.1),
35 | )
36 | self.freq1 = HFGA(128)
37 | self.freq_proj1 = DepthwiseSeparableConv(128,192,stride=2,skip=False,se_rate=0.,drop_path_rate=0.)
38 | self.freq2 = HFGA(192)
39 | self.freq_proj2 = DepthwiseSeparableConv(192,256,stride=2,skip=False,se_rate=0.,drop_path_rate=0.)
40 | self.freq3 = HFGA(256)
41 | self.up = nn.Upsample(scale_factor=2)
42 | self.conv_proj1 = nn.Conv1d(192,256,1,padding=0)
43 | self.conv_proj2 = nn.Conv1d(128,256,1,padding=0)
44 | self.context1 = ContextBlock(256,256 // 8)
45 | self.context2 = ContextBlock(256,256 // 8)
46 | self.context3 = ContextBlock(256,256 // 8)
47 | self.context4 = ContextBlock(256,256 // 8)
48 | self.head = Head(384 * 2,384,num_classes,bias=False)
49 | self.sel_pool1 = SelectivePool1d(256,d_head=24,num_heads=16)
50 | self.sel_pool2 = SelectivePool1d(256,d_head=24,num_heads=16)
51 | self.sel_pool3 = SelectivePool1d(256,d_head=24,num_heads=16)
52 | self.freq_head = nn.Sequential(
53 | nn.Linear(384,384,bias=False),
54 | nn.Dropout(0.1),
55 | )
56 | self.weights_init()
57 |
58 | def forward(self,x,feature_lens):
59 | x = x.transpose(1,2)
60 | x = self.conv(x)
61 | y0 = self.block0(x) # (n,128,l/2)
62 | freq1 = self.freq1(y0)
63 |
64 | y1 = y0 + freq1
65 | y1 = self.block1(y1) # (n,192,l/4)
66 | freq1 = self.freq_proj1(freq1)
67 | freq1 = freq1 + y1
68 | freq2 = self.freq2(freq1)
69 |
70 | y2 = y1 + freq2
71 | y2 = self.block2(y2) # (n,256,l/8)
72 | freq2 = self.freq_proj2(freq2)
73 | freq2 = freq2 + y2
74 | freq3 = self.freq3(freq2)
75 |
76 | y3 = self.context1(self.up(y2)[:,:,:y1.shape[2]]) + self.context2(self.conv_proj1(y1)) # (n,320,l/8)
77 | y4 = self.context3(self.up(y3)[:,:,:y0.shape[2]]) + self.context4(self.conv_proj2(y0)) # (n,320,l/4)
78 | y3 = F.selu(y3,inplace=True)
79 | y4 = F.selu(y4,inplace=True)
80 |
81 | feature_lens = torch.div(feature_lens + 1,2,rounding_mode='trunc')
82 | mask = get_len_mask(feature_lens)
83 | f1 = self.sel_pool1(y4,mask)
84 | mask = F.max_pool1d(mask,2,ceil_mode=True)
85 | f2 = self.sel_pool2(y3,mask)
86 | f3 = self.sel_pool3(freq3,None)
87 | y_vector = torch.cat([f1,f2],dim=1)
88 | y_vector,y_prob = self.head(y_vector)
89 | f3 = self.freq_head(f3)
90 | return y_vector,y_prob,f3
91 |
92 | def weights_init(self):
93 | for m in self.modules():
94 | if isinstance(m,nn.Conv1d):
95 | nn.init.kaiming_normal_(m.weight,mode='fan_out',a=1)
96 | if m.bias is not None:
97 | nn.init.zeros_(m.bias)
98 | elif isinstance(m,nn.Linear):
99 | nn.init.kaiming_normal_(m.weight,a=1)
100 | if m.bias is not None:
101 | nn.init.zeros_(m.bias)
--------------------------------------------------------------------------------
/model_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 | import numpy as np
7 | from timm.models.layers import trunc_normal_
8 | from pytorch_wavelets import DWT1D,IDWT1D
9 |
10 | class SelectivePool1d(nn.Module):
11 | def __init__(self,in_features,d_head,num_heads):
12 | super().__init__()
13 | self.keys = nn.Parameter(torch.Tensor(num_heads,d_head),requires_grad=True)
14 | self.W_q = nn.Conv1d(in_features,d_head * num_heads,kernel_size=1)
15 | self.norm = 1 / np.sqrt(d_head)
16 | self.d_head = d_head
17 | self.num_heads = num_heads
18 | self.weights_init()
19 |
20 | def weights_init(self):
21 | nn.init.orthogonal_(self.keys,gain=1)
22 | nn.init.kaiming_normal_(self.W_q.weight,a=1)
23 | nn.init.zeros_(self.W_q.bias)
24 |
25 | def orthogonal_norm(self):
26 | keys = F.normalize(self.keys,dim=1)
27 | corr = torch.mm(keys,keys.transpose(0,1))
28 | return torch.sum(torch.triu(corr,1).abs_())
29 |
30 | def forward(self,x,mask):
31 | N,_,L = x.shape # (N,C,L)
32 | q = v = self.W_q(x).transpose(1,2).view(N,L,self.num_heads,self.d_head)
33 | if mask is not None:
34 | mask = mask.to(x.device)
35 | attn = F.softmax(torch.sum(q * self.keys,dim=-1) * self.norm - (1. - mask).unsqueeze(2) * 1000,dim=1)
36 | # (N,L,num_heads)
37 | else:
38 | attn = F.softmax(torch.sum(q * self.keys,dim=-1) * self.norm,dim=1)
39 | y = torch.sum(v * attn.unsqueeze(3),dim=1).view(N,-1) # (N,d_head * num_heads)
40 | return y
41 |
42 | def get_len_mask(features_lens): # mask需要更新,因为长度不一样
43 | features_lens = features_lens
44 | batch_size = len(features_lens)
45 | max_len = torch.max(features_lens)
46 | mask = torch.zeros((batch_size,max_len),dtype=torch.float32)
47 | for i in range(batch_size):
48 | mask[i,0:features_lens[i]] = 1.0
49 | return mask
50 |
51 | class Swish(nn.Module):
52 | def forward(self,x):
53 | return x * torch.sigmoid(x)
54 |
55 | class SwishImpl(torch.autograd.Function):
56 | @staticmethod
57 | def forward(ctx,i):
58 | res = i * torch.sigmoid(i)
59 | ctx.save_for_backward(i)
60 | return res
61 |
62 | @staticmethod
63 | def backward(ctx,y_grad):
64 | i = ctx.saved_tensors[0]
65 | x_sigmoid = torch.sigmoid(i)
66 | return y_grad * (x_sigmoid * (1 + i * (1 - x_sigmoid)))
67 |
68 | class MemoryEfficientSwish(nn.Module):
69 | def forward(self,x):
70 | return SwishImpl.apply(x)
71 |
72 | class SEBlock2(nn.Module): # 通道不一样
73 | def __init__(self,d_in,d_hidden,act_layer=Swish): # Swish或SiLU
74 | super().__init__()
75 | self.fc = nn.Sequential(
76 | nn.AdaptiveAvgPool1d(1),
77 | nn.Conv1d(d_in,d_hidden,kernel_size=1,padding=0,stride=1),
78 | act_layer(),
79 | nn.Conv1d(d_hidden,d_in,kernel_size=1,padding=0,stride=1),
80 | nn.Sigmoid())
81 |
82 | def forward(self,x): # x: (n,c,l)
83 | y = self.fc(x)
84 | return x * y.expand_as(x)
85 |
86 | def compute_similarity(query,gallery):
87 | query = query / np.linalg.norm(query,axis=1,keepdims=True)
88 | gallery = gallery / np.linalg.norm(gallery,axis=1,keepdims=True)
89 | return np.matmul(query,gallery.T)
90 |
91 | def db_augmentation(query,gallery,topk=10):
92 | # DBA: Database-side feature augmentation https://link.springer.com/article/10.1007/s11263-017-1016-8
93 | weights = np.logspace(0,-2.,topk + 1)
94 |
95 | # query augmentation
96 | similarity = compute_similarity(query,gallery)
97 | indices = np.argsort(-similarity,axis=1)
98 | topk_gallery = gallery[indices[:,:topk],:]
99 | query = np.tensordot(weights,np.concatenate([query[:,None],topk_gallery],axis=1),axes=(0,1))
100 |
101 | # gallery augmentation
102 | similarity = compute_similarity(gallery,gallery)
103 | indices = np.argsort(-similarity,axis=1)
104 | topk_gallery = gallery[indices[:,:topk + 1],:]
105 | gallery = np.tensordot(weights,topk_gallery,axes=(0,1))
106 | return query,gallery
107 |
108 | def average_query_expansion(query,gallery,topk=5):
109 | similarity = compute_similarity(query,gallery)
110 | indices = np.argsort(-similarity,axis=1)
111 | topk_gallery = np.mean(gallery[indices[:,:topk],:],axis=1)
112 | query = np.concatenate([query,topk_gallery],axis=1)
113 |
114 | similarity = compute_similarity(gallery,gallery)
115 | indices = np.argsort(-similarity,axis=1)
116 | topk_gallery = np.mean(gallery[indices[:,1:topk + 1],:],axis=1)
117 | gallery = np.concatenate([gallery,topk_gallery],axis=1)
118 | return query,gallery
119 | def __init__(self,d_feat,l):
120 | super().__init__()
121 | self.conv_depth = nn.Conv1d(d_feat,d_feat,kernel_size=3,padding=1,bias=False,groups=d_feat // 2)
122 | self.complex_weight = nn.Parameter(torch.randn(d_feat,l,2,dtype=torch.float32) * 0.02)
123 | trunc_normal_(self.complex_weight,std=.02)
124 | self.head = nn.Linear(d_feat,d_feat,bias=True)
125 |
126 | def forward(self,x):
127 | x1 = x[:,:,0::2]
128 | x2 = x[:,:,1::2]
129 | x1 = self.conv_depth(x1)
130 | _,_,l = x2.shape
131 | x2 = torch.fft.rfft(x2,dim=2,norm='ortho')
132 | weight = self.complex_weight
133 | if not weight.shape[1:2] == x2.shape[2:3]:
134 | weight = F.interpolate(weight.permute(2,0,1).unsqueeze(2),size=(1,x2.shape[2]),mode='bilinear',align_corners=True).squeeze().permute(1,2,0)
135 | weight = torch.view_as_complex(weight.contiguous())
136 | x2 *= weight
137 | x2 = torch.fft.irfft(x2,n=l,dim=2,norm='ortho')
138 | y = x1 + x2
139 | y = self.head(y.transpose(1,2)).transpose(1,2)
140 | return y
141 |
142 | def channel_shuffle(x,groups):
143 | n,c,l = x.shape
144 | d_hidden = c // groups
145 | x = x.view(n,groups,d_hidden,l)
146 | x = x.transpose(1,2).contiguous()
147 | x = x.view(n,-1,l)
148 | return x
149 |
150 | class ShuffleBlock(nn.Module): # 这就是ShufflNetV2的简化实现,本来就这样的,里面没有groups的
151 | def __init__(self,d_in,kernel_size=3):
152 | super().__init__()
153 | self.conv = nn.Sequential(
154 | nn.Conv1d(d_in // 2,d_in // 2,kernel_size=kernel_size,padding=kernel_size // 2,stride=1),
155 | nn.BatchNorm1d(d_in // 2),
156 | nn.Conv1d(d_in // 2,d_in // 2,kernel_size=1,stride=1,padding=0),
157 | nn.SELU(True)
158 | )
159 |
160 | def forward(self,x):
161 | x1,x2 = x.chunk(2,dim=1)
162 | y = torch.cat((x1,self.conv(x2)),dim=1)
163 | return channel_shuffle(y,2)
164 |
165 | class CBA(nn.Module):
166 | def __init__(self,d_in,d_out,kernel_size,stride=1,groups=1,bias=True,skip=False,act_layer=nn.ReLU):
167 | super().__init__()
168 | padding = kernel_size // 2
169 | self.conv = nn.Conv1d(d_in,d_out,kernel_size=kernel_size,stride=stride,padding=padding,groups=groups,bias=bias)
170 | self.bn = nn.BatchNorm1d(d_out)
171 | # self.bn = GhostBatchNorm1d(d_out)
172 | self.relu = act_layer(True)
173 | self.skip = skip and (stride == 1) and (d_in == d_out)
174 |
175 | def forward(self,x):
176 | identity = x
177 | y = self.relu(self.bn(self.conv(x)))
178 | if self.skip:
179 | y = y + identity
180 | return y
181 |
182 | class DepthwiseSeparableConv(nn.Module):
183 | def __init__(self,d_in,d_out,dw_kernel_size=3,stride=1,skip=True,se_rate=0.2,drop_path_rate=0.,group_size=1,):
184 | super().__init__()
185 | groups = d_in // group_size
186 | padding = dw_kernel_size // 2
187 | self.has_skip = (stride == 1 and d_in == d_out) and skip
188 | self.dw_conv = nn.Conv1d(d_in,d_in,dw_kernel_size,stride=stride,padding=padding,groups=groups)
189 | self.bn1 = nn.BatchNorm1d(d_in)
190 | self.relu = nn.ReLU(inplace=True)
191 | self.se = SEBlock2(d_in,int(d_in * se_rate),act_layer=nn.SELU) if se_rate else nn.Identity()
192 | self.pw_conv = nn.Conv1d(d_in,d_out,1,padding=0)
193 | self.bn2 = nn.BatchNorm1d(d_out)
194 | self.drop_path = nn.Dropout(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
195 |
196 | def forward(self,x):
197 | identity = x
198 | x = self.relu(self.bn1(self.dw_conv(x)))
199 | x = self.se(x)
200 | x = self.relu(self.bn2(self.pw_conv(x)))
201 | if self.has_skip:
202 | x = self.drop_path(x) + identity
203 | return x
204 |
205 | class CAIR(nn.Module):
206 | def __init__(self,d_in,d_out,expand,kernel_size,stride,skip,se_rate,drop_path_rate):
207 | super().__init__()
208 | d_mid = d_in // 2 * expand
209 | self.expand_conv = CBA(d_in // 2,d_mid,kernel_size=1,bias=False) if expand != 1 else nn.Identity()
210 | self.dw_conv = CBA(d_mid,d_mid,kernel_size=kernel_size,stride=stride,groups=d_mid,bias=False)
211 | self.project_conv = nn.Sequential(
212 | nn.Conv1d(d_mid,d_out // 2,kernel_size=1,stride=1,bias=False),
213 | nn.SELU(True)
214 | )
215 | self.identity_conv = CBA(d_in // 2,d_out // 2,3,stride=2,groups=1,bias=False) if stride == 2 else \
216 | (nn.Conv1d(d_in // 2,d_out // 2,1) if d_in != d_out else nn.Identity())
217 | self.se = SEBlock2(d_mid,int(d_mid * se_rate),act_layer=nn.SELU) if se_rate > 0. else nn.Identity()
218 | self.post_conv = CBA(d_out,d_out,3,1,1,act_layer=nn.SELU)
219 | self.skip = (stride == 1 and d_in == d_out) and skip
220 | self.drop_path = nn.Dropout(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
221 |
222 | def forward(self,x):
223 | identity = x.clone()
224 | x1,x2 = x.chunk(2,dim=1)
225 | expand = self.expand_conv(x1)
226 | y1 = self.dw_conv(expand)
227 | y1 = self.se(y1)
228 | y1 = self.project_conv(y1)
229 | y2 = self.identity_conv(x2)
230 | y = torch.cat((y1,y2),dim=1)
231 | y = channel_shuffle(y,2)
232 | y = self.post_conv(y)
233 | y = self.drop_path(y)
234 | return y
235 |
236 | class HFGA(nn.Module):
237 | def __init__(self,d_feat):
238 | super().__init__()
239 | self.dwt = DWT1D(J=1,wave='bior1.1',mode='symmetric')
240 | self.attn_gate = nn.Parameter(torch.Tensor([0.0]))
241 | self.to_q = nn.Conv1d(d_feat,d_feat,1)
242 | self.to_k = nn.Conv1d(d_feat,d_feat,1)
243 | self.to_v = nn.Conv1d(d_feat,d_feat,1)
244 | # self.to_out = nn.Conv1d(d_feat,d_feat,1)
245 |
246 | def compute_attn_matmul(self,q,k,v): # k和v是一样的
247 | # q:(n,c,l1), k & v: (n,c,l2), l1比l2长
248 | attn = k.transpose(1,2) @ q / np.sqrt(q.shape[1]) # (n,l2,l1)
249 | attn = attn - attn.amax(dim=1,keepdim=True).detach()
250 | attn = F.softmax(attn,dim=1)
251 | y = v @ attn # (n,c,l1)
252 | return y
253 |
254 | def forward(self,x):
255 | xl,xh = self.dwt(x)
256 | xh = xh[0]
257 | q = self.to_q(x)
258 | k = self.to_k(xh)
259 | v = self.to_v(xh)
260 | yh = self.compute_attn_matmul(q,k,v)
261 | y = yh * self.attn_gate.tanh() + x
262 | return y
263 |
264 | class LayerNorm(nn.Module):
265 | r""" LayerNorm that supports two data formats: channels_last (default) or channels_first.
266 | The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
267 | shape (batch_size, height, width, channels) while channels_first corresponds to inputs
268 | with shape (batch_size, channels, height, width).
269 | """
270 | def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
271 | super().__init__()
272 | self.weight = nn.Parameter(torch.ones(normalized_shape))
273 | self.bias = nn.Parameter(torch.zeros(normalized_shape))
274 | self.eps = eps
275 | self.data_format = data_format
276 | if self.data_format not in ["channels_last", "channels_first"]:
277 | raise NotImplementedError
278 | self.normalized_shape = (normalized_shape,)
279 |
280 | def forward(self, x):
281 | if self.data_format == "channels_last":
282 | return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
283 | elif self.data_format == "channels_first":
284 | u = x.mean(1, keepdim=True)
285 | s = (x - u).pow(2).mean(1, keepdim=True)
286 | x = (x - u) / torch.sqrt(s + self.eps)
287 | x = self.weight[:, None] * x + self.bias[:, None]
288 | return x
289 |
290 | class ContextBlock(nn.Module):
291 | def __init__(self,d_in,d_hidden,pooling='attn',fusions=['channel_add']):
292 | super().__init__()
293 | self.pooling = pooling
294 | self.conv_mask = nn.Conv1d(d_in,1,kernel_size=1) if pooling == 'attn' else nn.AdaptiveAvgPool1d(1)
295 | if 'channel_add' in fusions:
296 | self.channel_add_conv = nn.Sequential(
297 | nn.Conv1d(d_in,d_hidden,1),
298 | nn.LayerNorm([d_hidden,1]),
299 | nn.ReLU(True),
300 | nn.Conv1d(d_hidden,d_in,1)
301 | )
302 | else:
303 | self.channel_add_conv = None
304 | if 'channel_mul' in fusions:
305 | self.channel_mul_conv = nn.Sequential(
306 | nn.Conv1d(d_in,d_hidden,1),
307 | nn.LayerNorm([d_hidden,1]),
308 | nn.ReLU(True),
309 | nn.Conv1d(d_hidden,d_in,1)
310 | )
311 | else:
312 | self.channel_mul_conv = None
313 | self.weights_init()
314 |
315 | def weights_init(self):
316 | if self.pooling == 'attn':
317 | nn.init.kaiming_normal_(self.conv_mask.weight,a=0,mode='fan_in',nonlinearity='relu')
318 | if hasattr(self.conv_mask, 'bias') and self.conv_mask.bias is not None:
319 | nn.init.zeros_(self.conv_mask.bias)
320 | self.conv_mask.inited = True
321 | if self.channel_add_conv is not None:
322 | self.last_zero_init(self.channel_add_conv)
323 | if self.channel_mul_conv is not None:
324 | self.last_zero_init(self.channel_mul_conv)
325 |
326 | def last_zero_init(self,m):
327 | if isinstance(m,nn.Sequential):
328 | nn.init.zeros_(m[-1].weight)
329 | if hasattr(m[-1],'bias') and m[-1].bias is not None:
330 | nn.init.zeros_(m[-1].bias)
331 | else:
332 | nn.init.zeros_(m.weight)
333 | if hasattr(m,'bias') and m.bias is not None:
334 | nn.init.zeros_(m.bias)
335 |
336 | def spatial_pool(self,x):
337 | if self.pooling == 'attn':
338 | context_mask = self.conv_mask(x) # (n,1,l)
339 | context_mask = F.softmax(context_mask,dim=2) # 对l维softmax
340 | context_mask = context_mask.squeeze().unsqueeze(-1)
341 | context = torch.matmul(x,context_mask) # (n,c,l) * (n,l,1) = (n,c,1)
342 | else:
343 | context = self.conv_mask(x)
344 | return context
345 |
346 | def forward(self,x):
347 | context = self.spatial_pool(x) # (n,c,1)
348 | if self.channel_add_conv is not None:
349 | channel_add = self.channel_add_conv(context)
350 | x = x + channel_add
351 | if self.channel_mul_conv is not None:
352 | weights = torch.sigmoid(self.channel_mul_conv(context))
353 | x = x * weights
354 | return x
355 |
356 | def main():
357 | ...
358 |
359 | if __name__ == '__main__':
360 | main()
--------------------------------------------------------------------------------
/preprocess.py:
--------------------------------------------------------------------------------
1 |
2 | # -*- coding: utf-8 -*-
3 |
4 | import numpy as np
5 | import os,pickle,argparse
6 | from utils import centernorm_size,interpolate_torch
7 |
8 | def preprocess_OLHWDB2(src_root='./data-raw/OLHWDB2',interp=4):
9 | writing = {}
10 | writers = os.listdir(src_root)
11 | for i,w in enumerate(writers):
12 | writing[i] = []
13 | for file in os.listdir(f'{src_root}/{w}'):
14 | info = []
15 | with open(f'{src_root}/{w}/{file}','r',encoding='utf-8') as f:
16 | lines = f.readlines()
17 | lines = lines[1:]
18 | lines = [l.strip() for l in lines]
19 | info = [list(map(lambda x:float(x),l.split()[:3])) for l in lines]
20 | info = np.array(info,np.float32)
21 | info = centernorm_size(info)
22 | if interp != None:
23 | info = interpolate_torch(info,interp_ratio=interp)
24 | writing[i].append(info)
25 | tgt_root = src_root.replace('data-raw','data')
26 | os.makedirs(tgt_root,exist_ok=True)
27 | with open(f'{tgt_root}/OLHWDB2.pkl','wb') as f:
28 | pickle.dump(writing,f)
29 |
30 | def preprocess_DCOHE(src_root='./data-raw/DCOH-E'):
31 | writing = {}
32 | writers = os.listdir(src_root)
33 | for i,w in enumerate(writers):
34 | writing[w] = []
35 | for file in os.listdir(f'{src_root}/{w}'):
36 | info = []
37 | with open(f'{src_root}/{w}/{file}','r',encoding='utf-8') as f:
38 | lines = f.readlines()
39 | lines = lines[1:]
40 | lines = [l.strip() for l in lines]
41 | info = [list(map(lambda x:float(x),l.split()[:3])) for l in lines]
42 | info = np.array(info,np.float32)
43 | info = centernorm_size(info)
44 | if 'dcoh-e' in file:
45 | info = interpolate_torch(info,interp_ratio=2)
46 | writing[w].append(info)
47 | tgt_root = src_root.replace('data-raw','data')
48 | os.makedirs(tgt_root,exist_ok=True)
49 | with open(f'{tgt_root}/DCOH-E.pkl','wb') as f:
50 | pickle.dump(writing,f)
51 |
52 | def preprocess_COUCH(src_root='./data-raw/COUCH09',interp=4):
53 | writing = {}
54 | writers = os.listdir(src_root)
55 | for i,w in enumerate(writers):
56 | writing[i] = []
57 | for file in os.listdir(f'{src_root}/{w}'):
58 | with open(f'{src_root}/{w}/{file}','r',encoding='utf-8') as f:
59 | lines = f.readlines()
60 | lines = lines[1:]
61 | lines = [l.strip() for l in lines]
62 | info = [list(map(lambda x:float(x),l.split()[:3])) for l in lines]
63 | info = np.array(info,np.float32)
64 | info = centernorm_size(info)
65 | if interp != None:
66 | info = interpolate_torch(info,interp_ratio=interp)
67 | writing[i].append(info)
68 | tgt_root = src_root.replace('data-raw','data')
69 | os.makedirs(tgt_root,exist_ok=True)
70 | with open(f'{tgt_root}/COUCH09.pkl','wb') as f:
71 | pickle.dump(writing,f)
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser()
75 | parser.add_argument('--dataset',type=str,default='olhwdb2',help='processed dataset names: [olhwdb2,dcohe,couch]')
76 | opt = parser.parse_args()
77 | func = globals()[f'preprocess_{opt.dataset.upper()}']
78 | print(f'start preprocessing {opt.dataset.upper()} dataset.')
79 | func()
80 | print(f'end preprocessing {opt.dataset.upper()} dataset.')
81 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==2.1.0
2 | aiohttp==3.9.5
3 | aiosignal==1.3.1
4 | albumentations==1.4.3
5 | antlr4-python3-runtime==4.9.3
6 | asttokens @ file:///home/conda/feedstock_root/build_artifacts/asttokens_1698341106958/work
7 | async-timeout==4.0.3
8 | attrs==23.2.0
9 | backcall @ file:///home/conda/feedstock_root/build_artifacts/backcall_1592338393461/work
10 | beartype==0.17.2
11 | beautifulsoup4==4.12.3
12 | cachetools==5.3.1
13 | causal-conv1d==1.3.0.post1
14 | certifi==2023.7.22
15 | charset-normalizer==3.2.0
16 | colorama==0.4.6
17 | comm @ file:///home/conda/feedstock_root/build_artifacts/comm_1710320294760/work
18 | contourpy==1.1.0
19 | coolgpus==0.23
20 | cycler==0.11.0
21 | debugpy @ file:///croot/debugpy_1690905042057/work
22 | decorator @ file:///home/conda/feedstock_root/build_artifacts/decorator_1641555617451/work
23 | dominate==2.9.1
24 | easydict==1.11
25 | edit-distance==1.0.6
26 | einops==0.7.0
27 | ema-pytorch==0.4.3
28 | entrypoints @ file:///home/conda/feedstock_root/build_artifacts/entrypoints_1643888246732/work
29 | executing @ file:///home/conda/feedstock_root/build_artifacts/executing_1698579936712/work
30 | fastdtw==0.3.4
31 | filelock==3.12.2
32 | fonttools==4.42.0
33 | frozenlist==1.4.1
34 | fsspec==2023.6.0
35 | google-auth==2.29.0
36 | google-auth-oauthlib==1.0.0
37 | grpcio==1.62.1
38 | huggingface-hub==0.23.4
39 | idna==3.4
40 | iisignature @ file:///home/WD16T/zpr/iisignature-0.24-cp38-cp38-linux_x86_64.whl#sha256=24b2e81340e26b1a8501d721b87892c706fdcb98c4ee72d24721b76602ae80ec
41 | imagecorruptions==1.1.2
42 | imageio==2.34.0
43 | imgaug==0.4.0
44 | importlib-metadata==6.8.0
45 | importlib-resources==6.0.0
46 | ipykernel @ file:///home/conda/feedstock_root/build_artifacts/ipykernel_1708996548741/work
47 | ipython @ file:///home/conda/feedstock_root/build_artifacts/ipython_1680185408135/work
48 | jedi @ file:///home/conda/feedstock_root/build_artifacts/jedi_1696326070614/work
49 | joblib==1.3.1
50 | jsonpatch==1.33
51 | jsonpointer==2.4
52 | jupyter-client @ file:///home/conda/feedstock_root/build_artifacts/jupyter_client_1654730843242/work
53 | jupyter_core @ file:///home/conda/feedstock_root/build_artifacts/jupyter_core_1710257397447/work
54 | kiwisolver==1.4.4
55 | kornia==0.7.2
56 | kornia_rs==0.1.3
57 | lazy_loader==0.4
58 | lightning-utilities==0.11.2
59 | llvmlite==0.40.1
60 | lmdb==1.4.1
61 | mamba-ssm==2.1.0
62 | Markdown==3.6
63 | MarkupSafe==2.1.5
64 | matplotlib==3.7.2
65 | matplotlib-inline @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-inline_1660814786464/work
66 | multidict==6.0.5
67 | natsort==8.4.0
68 | nest_asyncio @ file:///home/conda/feedstock_root/build_artifacts/nest-asyncio_1705850609492/work
69 | networkx==3.1
70 | ninja==1.11.1.1
71 | numba==0.57.1
72 | numpy==1.24.4
73 | nvidia-ml-py==12.535.77
74 | nvitop==1.2.0
75 | oauthlib==3.2.2
76 | omegaconf==2.3.0
77 | opencv-contrib-python==4.9.0.80
78 | opencv-python==4.9.0.80
79 | opencv-python-headless==4.9.0.80
80 | packaging==23.1
81 | pandas==2.0.3
82 | parso @ file:///home/conda/feedstock_root/build_artifacts/parso_1638334955874/work
83 | pexpect @ file:///home/conda/feedstock_root/build_artifacts/pexpect_1706113125309/work
84 | pickleshare @ file:///home/conda/feedstock_root/build_artifacts/pickleshare_1602536217715/work
85 | Pillow==10.0.0
86 | platformdirs @ file:///home/conda/feedstock_root/build_artifacts/platformdirs_1706713388748/work
87 | pretty-errors==1.2.25
88 | prompt-toolkit @ file:///home/conda/feedstock_root/build_artifacts/prompt-toolkit_1702399386289/work
89 | protobuf==5.26.0
90 | psutil @ file:///opt/conda/conda-bld/psutil_1656431268089/work
91 | ptflops==0.7
92 | ptyprocess @ file:///home/conda/feedstock_root/build_artifacts/ptyprocess_1609419310487/work/dist/ptyprocess-0.7.0-py2.py3-none-any.whl
93 | pure-eval @ file:///home/conda/feedstock_root/build_artifacts/pure_eval_1642875951954/work
94 | pyasn1==0.5.1
95 | pyasn1-modules==0.3.0
96 | Pygments @ file:///home/conda/feedstock_root/build_artifacts/pygments_1700607939962/work
97 | pyparsing==3.0.9
98 | python-dateutil==2.8.2
99 | pytorch-lightning==2.0.0
100 | pytorch-metric-learning==2.3.0
101 | pytorch-wavelets==1.3.0
102 | pytz==2023.3
103 | PyWavelets==1.4.1
104 | PyYAML==6.0.1
105 | pyzmq @ file:///croot/pyzmq_1705605076900/work
106 | regex==2024.5.15
107 | requests==2.31.0
108 | requests-oauthlib==2.0.0
109 | rsa==4.9
110 | safetensors==0.4.3
111 | scikit-image==0.21.0
112 | scikit-learn==1.3.2
113 | scipy==1.10.1
114 | seaborn==0.13.2
115 | shapely==2.0.3
116 | six @ file:///home/conda/feedstock_root/build_artifacts/six_1620240208055/work
117 | soupsieve==2.6
118 | stack-data @ file:///home/conda/feedstock_root/build_artifacts/stack_data_1669632077133/work
119 | tensorboard==2.14.0
120 | tensorboard-data-server==0.7.2
121 | termcolor==2.3.0
122 | thop==0.1.1.post2209072238
123 | threadpoolctl==3.2.0
124 | tifffile==2023.7.10
125 | timm==0.3.2
126 | tokenizers==0.19.1
127 | torch==1.12.1+cu113
128 | torchaudio==0.12.1+cu113
129 | torchmetrics==1.4.0
130 | torchnet==0.0.4
131 | torchstat==0.0.7
132 | torchvision==0.13.1+cu113
133 | tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1648827257044/work
134 | tqdm==4.65.0
135 | traitlets @ file:///home/conda/feedstock_root/build_artifacts/traitlets_1710254411456/work
136 | transformers==4.41.2
137 | triton==2.3.1
138 | tslearn==0.6.1
139 | typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1708904622550/work
140 | tzdata==2023.3
141 | urllib3==2.0.4
142 | visdom==0.2.4
143 | wcwidth @ file:///home/conda/feedstock_root/build_artifacts/wcwidth_1704731205417/work
144 | websocket-client==1.7.0
145 | Werkzeug==3.0.1
146 | yarl==1.9.4
147 | zipp==3.16.2
148 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | import torch
4 | from torch.utils.data import DataLoader
5 | from torch.backends import cudnn
6 | import argparse,os,time,json
7 | from model import DOLPHIN
8 | from dataset import Writing,collate_fn
9 | from utils import create_logger,load_ckpt,l2_norm,fuse_all_conv_bn
10 | import numpy as np
11 | from evaluate import compute_metrics
12 | import pickle
13 | from natsort import natsorted
14 | import matplotlib.pyplot as plt
15 | from ptflops import get_model_complexity_info
16 | from thop import profile
17 | from torchstat import stat
18 |
19 | torch._C._jit_set_profiling_mode(False)
20 | torch._C._jit_set_profiling_executor(False)
21 | torch.cuda.empty_cache()
22 |
23 | parser = argparse.ArgumentParser()
24 | parser.add_argument('--batch_size',type=int,default=8)
25 | parser.add_argument('--num_classes',type=int,default=1731)
26 | parser.add_argument('--epoch',type=int,default=80)
27 | parser.add_argument('--seed',type=int,default=123)
28 | parser.add_argument('--cuda',type=bool,default=True)
29 | parser.add_argument('--folder',type=str,default='./data/OLIWER')
30 | parser.add_argument('--ngpu',type=int,default=1)
31 | parser.add_argument('--gpu',type=str,default='0')
32 | parser.add_argument('--weights',type=str,default='./weights')
33 | parser.add_argument('--output_root',type=str,default='./output')
34 | parser.add_argument('--log_root',type=str,default='./logs')
35 | parser.add_argument('--dba',action='store_true')
36 | parser.add_argument('--rerank',action='store_true')
37 | parser.add_argument('--name',type=str,default='DOLPHIN')
38 | opt = parser.parse_args()
39 |
40 | # with open(f'{opt.weights}/settings.json','r',encoding='utf-8') as f:
41 | # settings = json.loads(f.read())
42 | # opt.seed = settings['seed']
43 | # opt.name = settings['name']
44 | # opt.notes = settings['notes']
45 | # opt.log_root = settings['log_root']
46 | # # opt.folder = settings['folder']
47 | # # opt.gpu = settings['gpu']
48 |
49 | np.random.seed(opt.seed)
50 | torch.manual_seed(opt.seed)
51 | torch.cuda.manual_seed_all(opt.seed)
52 |
53 | logger = create_logger(opt.log_root,name=opt.name,test=True)
54 |
55 | # query_root = f'{opt.folder}/query-tf.pkl'
56 | # with open(query_root,'rb') as f:
57 | # query_data = pickle.load(f,encoding='iso-8859-1')
58 | gallery_root = f'{opt.folder}/test-tf.pkl'
59 | with open(gallery_root,'rb') as f:
60 | gallery_data = pickle.load(f,encoding='iso-8859-1')
61 | # handwriting_info = {}
62 | # min_sample = 10000
63 | # for k in query_data:
64 | # handwriting_info[k] = query_data[k] + gallery_data[k]
65 | # min_sample = min(len(handwriting_info[k]),min_sample)
66 | # print(min_sample,len(query_data),len(gallery_data))
67 | gallery_dataset = Writing(gallery_data,train=False)
68 | d_in = gallery_dataset.feature_dims
69 |
70 | gallery_loader = DataLoader(gallery_dataset,batch_size=opt.batch_size,shuffle=False,collate_fn=collate_fn)
71 |
72 | model = DOLPHIN(d_in,opt.num_classes)
73 |
74 | if opt.cuda and torch.cuda.is_available():
75 | torch.cuda.set_device(int(opt.gpu))
76 | device = torch.device(f'cuda:{opt.gpu}')
77 | else:
78 | device = torch.device('cpu')
79 | model = model.to(device)
80 |
81 | logger.info(f'\ngallery root: {gallery_root}\n'
82 | f'gallery loader length: {len(gallery_loader)} gallery features length: {len(gallery_dataset)}\n'
83 | f'model: {model.__class__.__name__}\nDBA & AQE: {opt.dba}\nRerank: {opt.rerank}')
84 |
85 | def extract_features(model,data_loader,time_model):
86 | for i,(x,features_lens,user_labels) in enumerate(data_loader):
87 | x = torch.from_numpy(x).to(device)
88 | features_lens = torch.tensor(features_lens).long().to(device)
89 | user_labels = torch.from_numpy(user_labels).long()
90 | s = time.time()
91 | y_vector = model(x,features_lens)[0]
92 | # y_vector = model(x)[0]
93 | e = time.time()
94 | time_model += (e - s)
95 | y_vector = l2_norm(y_vector)
96 | if i == 0:
97 | features = torch.zeros(len(data_loader.dataset),y_vector.shape[1])
98 | start = i * opt.batch_size
99 | end = min((i + 1) * opt.batch_size,len(data_loader.dataset))
100 | features[start:end,:] = y_vector
101 | if i == 0:
102 | labels = user_labels
103 | else:
104 | labels = torch.cat([labels,user_labels],0)
105 | return features.cpu().numpy(),labels.cpu().numpy(),time_model
106 |
107 | def transform_user2feat(features,labels):
108 | label_indices = natsorted(np.unique(labels))
109 | user2feat = {}
110 | for i in label_indices:
111 | pos = np.where(labels == i)[0]
112 | user2feat[i] = features[pos]
113 | return user2feat
114 |
115 | @torch.no_grad()
116 | def test_impl(model):
117 | model = model.eval()
118 | model = model.to(device)
119 | # model = fuse_all_conv_bn(model)
120 |
121 | time_elapsed_start = time.time()
122 | all_features,all_labels,time_model = extract_features(model,gallery_loader,0)
123 | user2feat = transform_user2feat(all_features,all_labels)
124 | repeat_times = 1
125 | logger.info(f'repeat times: {repeat_times}')
126 | gallery_labels,query_labels = [],[]
127 | for i in natsorted(np.unique(all_labels)):
128 | gallery_labels.extend([i] * (len(user2feat[i]) - 1))
129 | query_labels.append(i)
130 | gallery_labels = np.array(gallery_labels)
131 | query_labels = np.array(query_labels)
132 | aps,top1s,top5s,top10s = [],[],[],[]
133 | for _ in range(repeat_times):
134 | gallery_features,query_features = [],[]
135 | label_indices = natsorted(np.unique(all_labels))
136 | for i in label_indices:
137 | idx = np.random.choice(len(user2feat[i]),size=1)[0]
138 | gallery_features.append(user2feat[i][:idx])
139 | gallery_features.append(user2feat[i][idx + 1:])
140 | query_features.append(user2feat[i][idx])
141 | gallery_features = np.concatenate(gallery_features)
142 | query_features = np.array(query_features)
143 | res = {
144 | 'gallery_feature':gallery_features,'gallery_label':gallery_labels,
145 | 'query_feature':query_features,'query_label':query_labels,
146 | }
147 | _,ap,top1,top5,top10 = compute_metrics(res,logger,opt.dba,device,verbose=False)
148 | aps.append(ap)
149 | top1s.append(top1)
150 | top5s.append(top5)
151 | top10s.append(top10)
152 | ap_mean,ap_std = np.mean(aps),np.std(aps)
153 | top1_mean,top1_std = np.mean(top1s),np.std(top1s)
154 | top5_mean,top5_std = np.mean(top5s),np.std(top5s)
155 | top10_mean,top10_std = np.mean(top10s),np.std(top10s)
156 | logger.info(f'[final] Rank@1: {top1_mean:.4f}% ({top1_std:.4f}%) Rank@5: {top5_mean:.4f}% ({top5_std:.4f}%) '
157 | f'Rank@10: {top10_mean:.4f}% ({top10_std:.4f}%)')
158 | logger.info(f'[final] mAP: {ap_mean * 100.:.4f}% ({ap_std * 100:.4f}%)')
159 | logger.info(f'time elapsed: {time.time() - time_elapsed_start:.5f}s\n')
160 |
161 | def test():
162 | load_ckpt(model,opt.weights,device,logger,mode='test')
163 | test_impl(model)
164 |
165 | def main():
166 | test()
167 |
168 | if __name__ == '__main__':
169 | main()
170 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding:utf-8 -*-
2 |
3 | from scipy import signal
4 | from torch.nn.utils import fuse_conv_bn_eval
5 | import numpy as np
6 | import cv2,logging,sys,time,os
7 | from functools import wraps,lru_cache,reduce
8 | from termcolor import colored
9 | import torch
10 | import torch.nn as nn
11 | import torch.nn.functional as F
12 | import matplotlib.pyplot as plt
13 | from PIL import Image
14 |
15 | class ButterWorthLPF:
16 | # 巴特沃斯低通滤波器
17 | def __init__(self,order=3,half_pnt=15.0,fnyquist=100.0):
18 | '''
19 | scipy.signal.butter(N, Wn, btype='low', analog=False, output='ba', fs=None)
20 | 这个是滤波器,高通低通带通
21 | N是滤波器阶数,Wn是3dB带宽点,dB就是radians/frequency(弧度/频率)
22 | 3dB就是功率下降到二分之一的点,所以这里叫做half_pnt(point)
23 | byte决定是什么通,analog=False表示模拟滤波器,True表示数字滤波器
24 | 'ba'表示输出分子和分母的系数;'zpk'表示输出零极点;'sos'表示输出second-order sections.
25 | 默认是'ba',意思应该是系统响应函数分子上的系数bk和分母上的系数ak
26 | https://wenku.baidu.com/view/adec241352d380eb62946d82.html
27 | order: 滤波器阶数
28 | half_pnt: 3dB点,功率降到一半的点
29 | fnyquist: sampling frequency,奈奎斯特采样率,采样频率必须要大于原始频率的两倍,书上叫fs
30 | '''
31 | fM = 0.5 * fnyquist # fM是原始频率,命名参考信号与系统第2版第七章
32 | half_pnt /= fM
33 | b,a = signal.butter(order,half_pnt,'low')
34 | self.b = b # 分子
35 | self.a = a # 分母
36 |
37 | def __call__(self,x): # x就是输入进来要滤波的数据
38 | return signal.filtfilt(self.b,self.a,x)
39 | # 将data通过零相位滤波器,零相位的意思就是输入和输出信号的相位完全相同,相移为0
40 | # 至于为什么是零相位滤波,暂时只能说经验值
41 |
42 | lpf = ButterWorthLPF()
43 |
44 | def difference(x): # 差分,两跨点之间相减
45 | '''
46 | numpy.convolve(a, v, mode='full'),a的长度是N,v是M
47 | mode可以取'full','same','valid',full的意思是长度为N+M-1,same的意思是长度为max(M,N),
48 | valid的意思是长度为max(M,N) - min(M,N) + 1,只有完全重叠的点有效,
49 | 边缘点无效(边缘点就是有一个序列突出去的那些)
50 | mode='same'的时候,比如[5,8,6,9,1,2]*[0.5,0,-0.5],full的长度是8,然后两边减一个,
51 | 刚刚好就是6了,而且[0.5,0,-0.5]是会反转的,[4., 0.5, 0.5, -2.5, -3.5, -0.5]
52 | delta_x[0] = delta_x[1]
53 | delta_x[-1] = delta_x[-2]
54 | 这两句的作用就是将4代替为0.5,-0.5代替为-3.5,是因为那两个点没人减
55 | '''
56 | delta_x = np.convolve(x,[0.5,0,-0.5],mode='same')
57 | delta_x[0] = delta_x[1]
58 | delta_x[-1] = delta_x[-2]
59 | return delta_x
60 |
61 | def difference_theta(x): # 输入的x是角度
62 | delta_x = np.zeros_like(x)
63 | delta_x[1:-1] = x[2:] - x[:-2]
64 | delta_x[-1] = delta_x[-2]
65 | delta_x[0] = delta_x[1]
66 | t = np.where(np.abs(delta_x) > np.pi)
67 | delta_x[t] = np.sign(delta_x[t]) * 2 * np.pi
68 | delta_x *= 0.5
69 | return delta_x
70 |
71 | def extract_features(handwritings,features,gpnoise=None,num=2,transform=False):
72 | '''
73 | paths: 路径列表,第一维应该是点的个数
74 | features: 这个就是所有特征的列表,是单个feature append进去的
75 | num: 使用的信息个数,比如x,y,pressure...,2就是只用012前三个
76 | gpnoise: 不知道
77 | transform: 不知道
78 | use_finger: 是否是用手指写的
79 | '''
80 | for handwriting in handwritings:
81 | pressure = handwriting[:,num]
82 | handwriting = handwriting[:,0:num] # (x,y,pressure)
83 | handwriting[:,0] = lpf(handwriting[:,0])
84 | handwriting[:,1] = lpf(handwriting[:,1])
85 | delta_x = difference(handwriting[:,0])
86 | delta_y = difference(handwriting[:,1])
87 | v = np.sqrt(delta_x ** 2 + delta_y ** 2) # 速度
88 | theta = np.arctan2(delta_y,delta_x)
89 | cos_theta = np.cos(theta)
90 | sin_theta = np.sin(theta)
91 | delta_v = difference(v)
92 | delta_theta = np.abs(difference_theta(theta))
93 | log_curve_radius = np.log((v + 0.05) / (delta_theta + 0.05)) # log的曲线弧度
94 | delta_v2 = np.abs(v * delta_theta)
95 | acceleration = np.sqrt(delta_v ** 2 + delta_v2 ** 2)
96 |
97 | # None在这里的作用是升维,比如说[2,2]会变成[2,1,2],concat起来就是[2,x,2]
98 | single_feature = np.concatenate((delta_x[:,None],delta_y[:,None],v[:,None],
99 | cos_theta[:,None],sin_theta[:,None],theta[:,None],log_curve_radius[:,None],
100 | acceleration[:,None],delta_v[:,None],delta_v2[:,None],delta_theta[:,None],
101 | pressure[:,None]),axis=1).astype(np.float32)
102 | single_feature[:,:-1] = (single_feature[:,:-1] - np.mean(single_feature[:,:-1],axis=0)) / \
103 | np.std(single_feature[:,:-1],axis=0)
104 | # single_feature[:,:-1] = regression_based_norm(single_feature[:,:-1])
105 | features.append(single_feature)
106 |
107 | def time_functions(handwriting,num=2):
108 | # handwriting = deepcopy(handwriting_org)
109 | pressure = handwriting[:,num]
110 | # pressure = np.ones_like(pressure)
111 | handwriting = handwriting[:,0:num] # (x,y,pressure)
112 | handwriting[:,0] = lpf(handwriting[:,0])
113 | handwriting[:,1] = lpf(handwriting[:,1])
114 | delta_x = difference(handwriting[:,0])
115 | delta_y = difference(handwriting[:,1])
116 | v = np.sqrt(delta_x ** 2 + delta_y ** 2) # 速度
117 | theta = np.arctan2(delta_y,delta_x)
118 | cos_theta = np.cos(theta)
119 | sin_theta = np.sin(theta)
120 | delta_v = difference(v)
121 | delta_theta = np.abs(difference_theta(theta))
122 | log_curve_radius = np.log((v + 0.05) / (delta_theta + 0.05)) # log的曲线弧度
123 | delta_v2 = np.abs(v * delta_theta)
124 | acceleration = np.sqrt(delta_v ** 2 + delta_v2 ** 2)
125 | delta_x2 = difference(delta_x)
126 | delta_y2 = difference(delta_y)
127 | # None在这里的作用是升维,比如说[2,2]会变成[2,1,2],concat起来就是[2,x,2]
128 | single_feature = np.concatenate((delta_x[:,None],delta_y[:,None],delta_x2[:,None],delta_y2[:,None],v[:,None],
129 | cos_theta[:,None],sin_theta[:,None],theta[:,None],log_curve_radius[:,None],
130 | acceleration[:,None],delta_v[:,None],delta_v2[:,None],delta_theta[:,None],
131 | pressure[:,None]),axis=1).astype(np.float32)
132 |
133 | single_feature[:,:-1] = (single_feature[:,:-1] - np.mean(single_feature[:,:-1],axis=0)) / np.std(single_feature[:,:-1],axis=0)
134 | return single_feature
135 |
136 | def letterbox_image(img,target_h,target_w):
137 | img_h,img_w = img.shape
138 | scale = min(target_h / img_h,target_w / img_w)
139 | # 长宽比目标size小的,可以变大,不过变大不一定有必要
140 | new_w,new_h = int(img_w * scale),int(img_h * scale) # 这样做就依然保持了长宽比
141 | img = cv2.resize(img,(new_w,new_h),interpolation=cv2.INTER_AREA)
142 | new_img = np.ones((target_h,target_w),dtype=np.uint8) * 255
143 | up = (target_h - new_h) // 2
144 | left = (target_w - new_w) // 2
145 | new_img[up:up + new_h,left:left + new_w] = img
146 | return new_img
147 |
148 | def interpolate_torch(org_info,interp_ratio):
149 | l = len(org_info)
150 | org_info = torch.tensor(org_info).view(1,1,l,-1)
151 | new_info = F.interpolate(org_info,size=(l * interp_ratio,3),mode='bicubic').squeeze().numpy()
152 | return new_info
153 |
154 | def load_ckpt(model,pretrained_root,device,logger,optimizer=None,scheduler=None,mode='train',resume=False):
155 | # pretrained=True是否基于其他任务的预训练
156 | state_dict = torch.load(pretrained_root,map_location=device)
157 | if mode == 'train':
158 | if resume:
159 | optimizer.load_state_dict(state_dict['optimizer'])
160 | scheduler.load_state_dict(state_dict['lr_scheduler'])
161 | print(model.load_state_dict(state_dict['model']))
162 | start_epoch = state_dict['epoch'] + 1
163 | logger.info(f'mode: "{mode} + resume" {pretrained_root} successfully loaded.')
164 | else:
165 | state_dict = state_dict['model'] if 'model' in state_dict else state_dict
166 | state_dict = {k:v for k,v in state_dict.items() if k in model.state_dict().keys() and v.numel() == model.state_dict()[k].numel()}
167 | print(model.load_state_dict(state_dict,strict=False))
168 | logger.info(f'mode: "{mode} + pretrained" {pretrained_root} successfully loaded.')
169 | start_epoch = 0
170 | return start_epoch
171 | else:
172 | state_dict = state_dict['model'] if 'model' in state_dict else state_dict
173 | state_dict = {k:v for k,v in state_dict.items() if k in model.state_dict().keys() and v.numel() == model.state_dict()[k].numel()}
174 | print(model.load_state_dict(state_dict,strict=False))
175 | # model.load_state_dict(state_dict['model'])
176 | logger.info(f'mode: "{mode}" {pretrained_root} successfully loaded.')
177 |
178 | @lru_cache()
179 | def create_logger(log_root,name='',test=False):
180 | os.makedirs(f'{log_root}',exist_ok=True)
181 | logger = logging.getLogger(name)
182 | logger.setLevel(logging.INFO)
183 | logger.propagate = False
184 |
185 | color_fmt = colored('[%(asctime)s %(name)s]','green') + \
186 | colored('(%(filename)s %(lineno)d)','yellow') + ': %(levelname)s %(message)s'
187 | console_handler = logging.StreamHandler(sys.stderr)
188 | console_handler.setLevel(logging.INFO) # 分布式的等级
189 | console_handler.setFormatter(logging.Formatter(fmt=color_fmt,datefmt='%Y-%m-%d %H:%M:%S'))
190 | logger.addHandler(console_handler)
191 |
192 | fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s'
193 | date = time.strftime('%Y-%m-%d') if not test else time.strftime('%Y-%m-%d') + '-test'
194 | file_handler = logging.FileHandler(f'{log_root}/log-{date}.txt',mode='a')
195 | file_handler.setLevel(logging.INFO)
196 | file_handler.setFormatter(logging.Formatter(fmt=fmt,datefmt='%Y-%m-%d %H:%M:%S'))
197 | logger.addHandler(file_handler)
198 | return logger
199 |
200 | def l2_norm(x): # x:(batch_size,seq_len)
201 | org_size = x.size()
202 | x_pow = torch.pow(x,2)
203 | x_pow = torch.sum(x_pow,1).add_(1e-6)
204 | x_pow = torch.sqrt(x_pow)
205 | y = torch.div(x,x_pow.view(-1,1).expand_as(x)).view(org_size)
206 | return y
207 |
208 | def centernorm_size(handwriting,coord_idx=[0,1]):
209 | # coord_idx其实是下标,就是说在handwriting这个二维数组里面是下标0和1分别是x和y
210 | assert len(coord_idx) == 2
211 | pos = handwriting[:,coord_idx]
212 | minx = np.min(pos,axis=0)
213 | maxn = np.max(pos,axis=0)
214 | pos = (pos - (maxn + minx) / 2.) / np.max(maxn - minx) # 不知道为什么这样除,经验值
215 | handwriting[:,coord_idx] = pos
216 | return handwriting
217 |
218 | def norm_pressure(handwriting,pressure_idx=2): # 单纯变0到1,但是其实可以不用
219 | pressure = handwriting[:,pressure_idx]
220 | maxn = np.max(pressure)
221 | pressure /= maxn
222 | handwriting[:,pressure_idx] = pressure
223 | return handwriting
224 |
225 | def fuse_all_conv_bn(model):
226 | stack = []
227 | for name,module in model.named_children():
228 | if list(module.named_children()):
229 | fuse_all_conv_bn(module)
230 | if isinstance(module,nn.BatchNorm1d):
231 | if not stack: # 空的
232 | continue
233 | if isinstance(stack[-1][1],nn.Conv1d):
234 | setattr(model,stack[-1][0],fuse_conv_bn_eval(stack[-1][1],module))
235 | setattr(model,name,nn.Identity())
236 | else:
237 | stack.append((name,module))
238 | return model
239 |
240 | def extract_and_store(src_root='./data/OCNOLHW-granular/train.pkl'):
241 | import pickle
242 | with open(src_root,'rb') as f:
243 | handwriting_info = pickle.load(f,encoding='iso-8859-1')
244 | writing = {}
245 | print(len(handwriting_info))
246 | for i,k in enumerate(handwriting_info.keys()):
247 | writing[k] = []
248 | a = time.time()
249 | for each in handwriting_info[k]:
250 | writing[k].append(time_functions(each))
251 | print(time.time() - a)
252 | break
253 |
254 | def clock(func):
255 | @wraps(func)
256 | def impl(*args,**kwargs):
257 | start = time.perf_counter()
258 | res = func(*args,**kwargs)
259 | end = time.perf_counter()
260 | args_list = []
261 | if args:
262 | args_list.extend([repr(arg) for arg in args])
263 | if kwargs:
264 | args_list.extend([f'{key}={value}' for key,value in kwargs.items()])
265 | args_str = ','.join(i for i in args_list)
266 | print(f'[executed in {(end - start):.5f}s, '
267 | f'{func.__name__}({args_str}) -> {res}]')
268 | return res
269 | return impl
270 |
271 | def main():
272 | ...
273 |
274 | if __name__ == '__main__':
275 | main()
--------------------------------------------------------------------------------
/weights/model.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/SCUT-DLVCLab/DOLPHIN/fc90f848ef11cbdcc2f89fabf3a9ed68343641d0/weights/model.pth
--------------------------------------------------------------------------------