├── .gitignore
├── Demo.gif
├── LICENSE.md
├── NxN.jpg
├── README.md
├── configs
├── .DS_Store
├── latest-fp16.yaml
└── latest.yaml
├── data.py
├── models
└── .gitkeep
├── networks.py
├── outputs
└── .gitkeep
├── prepare-market.py
├── random_erasing.py
├── reIDfolder.py
├── reIDmodel.py
├── reid_eval
├── README.md
├── evaluate_gpu.py
└── test_2label.py
├── train.py
├── trainer.py
├── utils.py
├── visual_data
├── .DS_Store
├── demo
│ ├── .DS_Store
│ └── train
│ │ ├── 0010_c6s4_002427_02.jpg
│ │ └── 0042_c3s3_064169_01.jpg
├── inputs_many_test
│ ├── 0
│ │ ├── 0000_c3s1_107067_00.jpg
│ │ ├── 0200_c3s3_056178_00.jpg
│ │ ├── 0223_c1s6_014296_00.jpg
│ │ ├── 0387_c3s1_091167_00.jpg
│ │ ├── 0400_c2s1_046426_00.jpg
│ │ ├── 0531_c2s1_149316_00.jpg
│ │ ├── 1196_c3s1_038576_00.jpg
│ │ ├── b0280_c3s3_065619_00.jpg
│ │ ├── b0387_c2s1_090996_00.jpg
│ │ ├── b0878_c2s2_106457_00.jpg
│ │ ├── b1324_c1s6_007741_00.jpg
│ │ ├── back0100_c3s2_115894_00.jpg
│ │ ├── l0155_c3s1_025176_00.jpg
│ │ ├── l0345_c1s2_006466_00.jpg
│ │ └── l0736_c6s1_022526_00.jpg
│ ├── 1
│ │ ├── 0051_c4s1_005526_00.jpg
│ │ ├── 0055_c3s3_076469_00.jpg
│ │ ├── 0089_c6s1_013501_00.jpg
│ │ ├── 0280_c5s1_078148_00.jpg
│ │ ├── 0715_c3s2_059153_00.jpg
│ │ ├── 0825_c6s2_097943_00.jpg
│ │ ├── 0829_c4s4_043560_00.jpg
│ │ ├── 1147_c2s2_158702_00.jpg
│ │ ├── 1226_c6s3_038167_00.jpg
│ │ ├── 1251_c3s3_020428_00.jpg
│ │ ├── 1255_c3s3_021678_00.jpg
│ │ ├── 1459_c2s3_052107_00.jpg
│ │ ├── bb.jpg
│ │ └── bbb.jpg
│ └── .DS_Store
├── inputs_many_test_duke
│ ├── 1
│ │ ├── 00000.jpg
│ │ ├── 0001_c7_f0106449.jpg
│ │ ├── 0100_c1_f0100590.jpg
│ │ ├── 0120_c7_f0112912.jpg
│ │ ├── 0180_c6_f0074316.jpg
│ │ ├── 0304_c4_f0069917.jpg
│ │ ├── 5842_c7_f0053326.jpg
│ │ └── data
│ │ │ ├── 0163_c5_f0084927.jpg
│ │ │ ├── 0279_c2_f0098340.jpg
│ │ │ ├── 0377_c1_f0110464.jpg
│ │ │ ├── 0459_c1_f0124423.jpg
│ │ │ ├── 0654_c1_f0165081.jpg
│ │ │ ├── 0749_c1_f0188863.jpg
│ │ │ └── 4141_c8_f0022224.jpg
│ └── 2
│ │ └── data
│ │ ├── 0201_c2_f0089204.jpg
│ │ ├── 0243_c1_f0095564.jpg
│ │ ├── 0371_c1_f0110190.jpg
│ │ ├── 0690_c4_f0142152.jpg
│ │ ├── 0712_c6_f0165928.jpg
│ │ ├── 1108_c6_f0176009.jpg
│ │ └── 1486_c4_f0069129.jpg
├── inputs_many_test_market
│ ├── 1
│ │ ├── .DS_Store
│ │ └── data
│ │ │ ├── 0051_c3s1_005551_00.jpg
│ │ │ ├── 0227_c2s1_046426_00.jpg
│ │ │ ├── 0304_c5s1_068523_00.jpg
│ │ │ ├── 0378_c2s1_089071_00.jpg
│ │ │ ├── 0447_c3s1_111408_00.jpg
│ │ │ ├── 1285_c1s5_053991_00.jpg
│ │ │ └── l0387_c6s2_006743_00.jpg
│ ├── 2
│ │ ├── .DS_Store
│ │ └── data
│ │ │ ├── 0146_c2s1_023126_00.jpg
│ │ │ ├── 0291_c3s3_079969_00.jpg
│ │ │ ├── 1154_c5s3_001368_00.jpg
│ │ │ ├── 1285_c2s3_021432_00.jpg
│ │ │ ├── 1394_c6s3_071892_00.jpg
│ │ │ ├── 1429_c2s3_048607_00.jpg
│ │ │ └── l1255_c3s3_021678_00.jpg
│ └── .DS_Store
├── inputs_many_test_msmt
│ ├── 1
│ │ └── data
│ │ │ ├── 0367_010_10_0303noon_0606_2.jpg
│ │ │ ├── 0434_020_07_0303noon_1785_5_ex.jpg
│ │ │ ├── 0494_029_01_0303afternoon_0435_2_ex.jpg
│ │ │ ├── 0514_012_05_0303afternoon_0585_1.jpg
│ │ │ ├── 0720_027_01_0303afternoon_0771_4.jpg
│ │ │ ├── 2744_001_12_0114noon_0960_1_ex.jpg
│ │ │ └── 2944_020_01_0114afternoon_1429_2_ex.jpg
│ └── 2
│ │ └── data
│ │ ├── 0168_018_14_0303morning_1118_1_ex.jpg
│ │ ├── 0318_001_01_0303noon_0111_4_ex.jpg
│ │ ├── 0771_019_01_0303afternoon_1000_9.jpg
│ │ ├── 0917_021_15_0113morning_0174_0.jpg
│ │ ├── 2744_037_06_0114noon_0915_0_ex.jpg
│ │ ├── 2790_015_05_0114afternoon_0244_1.jpg
│ │ ├── 2869_011_05_0114afternoon_0836_0_ex.jpg
│ │ └── 3032_043_07_0114afternoon_1088_3.jpg
├── inputs_many_train
│ ├── 1
│ │ ├── 0002_c2s1_000351_01.jpg
│ │ ├── 0028_c2s1_001751_02.jpg
│ │ ├── 0068_c6s1_012351_01.jpg
│ │ ├── 0070_c4s1_010576_01.jpg
│ │ ├── 0127_c6s1_021426_06.jpg
│ │ ├── 0142_c3s3_020203_02.jpg
│ │ ├── 0197_c6s1_045076_05.jpg
│ │ ├── 0259_c6s1_055126_01.jpg
│ │ ├── 0301_c6s1_076101_04.jpg
│ │ ├── 0348_c2s1_080671_01.jpg
│ │ ├── 0696_c6s2_054993_02.jpg
│ │ ├── 0711_c2s2_056057_03.jpg
│ │ ├── 0754_c2s2_068182_01.jpg
│ │ ├── 0820_c3s2_104428_05.jpg
│ │ ├── 0843_c2s2_104907_04.jpg
│ │ └── 0901_c6s2_119343_02.jpg
│ └── .DS_Store
├── inputs_two
│ ├── 1
│ │ ├── 00.jpg
│ │ ├── 000.jpg
│ │ ├── 0006_c2s3_069327_00.jpg
│ │ ├── 01.jpg
│ │ ├── 0289_c3s1_063567_00.jpg
│ │ └── back0100_c3s2_115894_00.jpg
│ ├── 2
│ │ ├── .DS_Store
│ │ └── data
│ │ │ ├── .DS_Store
│ │ │ ├── 0051_c3s1_005551_00.jpg
│ │ │ └── 0227_c2s1_046426_00.jpg
│ ├── 3
│ │ ├── .DS_Store
│ │ └── data
│ │ │ ├── .DS_Store
│ │ │ ├── 0000_c3s1_107067_00.jpg
│ │ │ └── 0531_c2s1_149316_00.jpg
│ ├── 4
│ │ ├── .DS_Store
│ │ └── data
│ │ │ ├── .DS_Store
│ │ │ ├── b1324_c1s6_007741_00.jpg
│ │ │ └── back0100_c3s2_115894_00.jpg
│ └── .DS_Store
└── train_sample
│ ├── 045_388_gan0734_c3s2_064678_06.jpg
│ ├── 104_188_gan0371_c5s1_088523_02.jpg
│ ├── 115_000_gan0002_c1s1_069056_02.jpg
│ ├── 258_750_gan1500_c5s3_063312_02.jpg
│ ├── 344_283_gan0549_c1s3_009921_05.jpg
│ ├── 400_344_gan0656_c6s2_045393_01.jpg
│ ├── 406_090_gan0179_c3s3_078044_01.jpg
│ ├── 496_722_gan1437_c1s6_007216_02.jpg
│ ├── 620_357_gan0673_c5s2_045155_01.jpg
│ ├── 643_269_gan0519_c6s2_087418_01.jpg
│ └── 643_375_gan0706_c6s2_062118_04.jpg
└── visual_tools
├── README.md
├── show1by1.py
├── show_rainbow.py
├── show_smooth.py
├── show_smooth_structure.py
├── show_swap.py
└── test_folder.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Don't track content of these folders
2 | outputs/
3 | models/
4 | logs/
5 | __pycache__/
6 | __MACOSX/
7 | configs/cifs9a31
8 | core
9 | temp/
10 | temp_output/
11 | center/
12 | rainbow/
13 |
14 | # Compiled source #
15 | ###################
16 | *.com
17 | *.class
18 | *.dll
19 | *.exe
20 | *.o
21 | *.so
22 | *.pyc
23 |
24 | # Packages #
25 | ############
26 | # it's better to unpack these files and commit the raw source
27 | # git has its own built in compression methods
28 | *.7z
29 | *.dmg
30 | *.gz
31 | *.iso
32 | *.jar
33 | *.rar
34 | *.tar
35 | *.zip
36 | *.mat
37 | *.npy
38 | *.pt
39 | *.pth
40 |
41 |
--------------------------------------------------------------------------------
/Demo.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/Demo.gif
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## creative commons
2 |
3 | # Attribution-NonCommercial-ShareAlike 4.0 International
4 |
5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
6 |
7 | ### Using Creative Commons Public Licenses
8 |
9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
10 |
11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
12 |
13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
14 |
15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
16 |
17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
18 |
19 | ### Section 1 – Definitions.
20 |
21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
22 |
23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
24 |
25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
26 |
27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
28 |
29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
30 |
31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
32 |
33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
34 |
35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
36 |
37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
38 |
39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
40 |
41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
42 |
43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
44 |
45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
46 |
47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
48 |
49 | ### Section 2 – Scope.
50 |
51 | a. ___License grant.___
52 |
53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
54 |
55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
56 |
57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
58 |
59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
60 |
61 | 3. __Term.__ The term of this Public License is specified in Section 6(a).
62 |
63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
64 |
65 | 5. __Downstream recipients.__
66 |
67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
68 |
69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
70 |
71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
72 |
73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
74 |
75 | b. ___Other rights.___
76 |
77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
78 |
79 | 2. Patent and trademark rights are not licensed under this Public License.
80 |
81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
82 |
83 | ### Section 3 – License Conditions.
84 |
85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions.
86 |
87 | a. ___Attribution.___
88 |
89 | 1. If You Share the Licensed Material (including in modified form), You must:
90 |
91 | A. retain the following if it is supplied by the Licensor with the Licensed Material:
92 |
93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
94 |
95 | ii. a copyright notice;
96 |
97 | iii. a notice that refers to this Public License;
98 |
99 | iv. a notice that refers to the disclaimer of warranties;
100 |
101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
102 |
103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
104 |
105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
106 |
107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
108 |
109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
110 |
111 | b. ___ShareAlike.___
112 |
113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
114 |
115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
116 |
117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
118 |
119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
120 |
121 | ### Section 4 – Sui Generis Database Rights.
122 |
123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
124 |
125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
126 |
127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
128 |
129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
130 |
131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
132 |
133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
134 |
135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
136 |
137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
138 |
139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
140 |
141 | ### Section 6 – Term and Termination.
142 |
143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
144 |
145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
146 |
147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
148 |
149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
150 |
151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
152 |
153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
154 |
155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
156 |
157 | ### Section 7 – Other Terms and Conditions.
158 |
159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
160 |
161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
162 |
163 | ### Section 8 – Interpretation.
164 |
165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
166 |
167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
168 |
169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
170 |
171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
172 |
173 | ```
174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
175 |
176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/).
177 | ```
178 |
179 |
--------------------------------------------------------------------------------
/NxN.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/NxN.jpg
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | [](https://raw.githubusercontent.com/nvlabs/SPADE/master/LICENSE.md)
2 | 
3 | [](https://lgtm.com/projects/g/NVlabs/DG-Net/context:python)
4 |
5 | ## Joint Discriminative and Generative Learning for Person Re-identification
6 | 
7 | 
8 |
9 | [[Project]](http://zdzheng.xyz/DG-Net/) [[Paper]](https://arxiv.org/abs/1904.07223) [[YouTube]](https://www.youtube.com/watch?v=ubCrEAIpQs4) [[Bilibili]](https://www.bilibili.com/video/av51439240) [[Poster]](http://zdzheng.xyz/files/DGNet_poster.pdf)
10 | [[Supp]](http://jankautz.com/publications/JointReID_CVPR19_supp.pdf)
11 |
12 | Joint Discriminative and Generative Learning for Person Re-identification, CVPR 2019 (Oral)
13 | [Zhedong Zheng](http://zdzheng.xyz/), [Xiaodong Yang](https://xiaodongyang.org/), [Zhiding Yu](https://chrisding.github.io/), [Liang Zheng](http://liangzheng.com.cn/), [Yi Yang](https://www.uts.edu.au/staff/yi.yang), [Jan Kautz](http://jankautz.com/)
14 |
15 | ## Table of contents
16 | * [News](#news)
17 | * [Features](#features)
18 | * [Prerequisites](#prerequisites)
19 | * [Getting Started](#getting-started)
20 | * [Installation](#installation)
21 | * [Dataset Preparation](#dataset-preparation)
22 | * [Testing](#testing)
23 | * [Training](#training)
24 | * [DG-Market](#dg-market)
25 | * [Tips](#tips)
26 | * [Citation](#citation)
27 | * [Related Work](#related-work)
28 | * [License](#license)
29 |
30 | ## News
31 | - 02/18/2021: We release [DG-Net++](https://github.com/NVlabs/DG-Net-PP): the extention of DG-Net for unsupervised cross-domain re-id.
32 | - 08/24/2019: We add the direct transfer learning results of DG-Net [here](https://github.com/NVlabs/DG-Net#person-re-id-evaluation).
33 | - 08/01/2019: We add the support of multi-GPU training: `python train.py --config configs/latest.yaml --gpu_ids 0,1`.
34 |
35 | ## Features
36 | We have supported:
37 | - Multi-GPU training (fp32)
38 | - [APEX](https://github.com/NVIDIA/apex) to save GPU memory (fp16/fp32)
39 | - Multi-query evaluation
40 | - Random erasing
41 | - Visualize training curves
42 | - Generate all figures in the paper
43 |
44 | ## Prerequisites
45 |
46 | - Python 3.6
47 | - GPU memory >= 15G (fp32)
48 | - GPU memory >= 10G (fp16/fp32)
49 | - NumPy
50 | - PyTorch 1.0+
51 | - [Optional] APEX (fp16/fp32)
52 |
53 | ## Getting Started
54 | ### Installation
55 | - Install [PyTorch](http://pytorch.org/)
56 | - Install torchvision from the source:
57 | ```
58 | git clone https://github.com/pytorch/vision
59 | cd vision
60 | python setup.py install
61 | ```
62 | - [Optional] You may skip it. Install APEX from the source:
63 | ```
64 | git clone https://github.com/NVIDIA/apex.git
65 | cd apex
66 | python setup.py install --cuda_ext --cpp_ext
67 | ```
68 | - Clone this repo:
69 | ```bash
70 | git clone https://github.com/NVlabs/DG-Net.git
71 | cd DG-Net/
72 | ```
73 |
74 | Our code is tested on PyTorch 1.0.0+ and torchvision 0.2.1+ .
75 |
76 | ### Dataset Preparation
77 | Download the dataset [Market-1501](http://www.liangzheng.com.cn/Project/project_reid.html) [[Google Drive]](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view) [[Baidu Disk]](https://pan.baidu.com/s/1ntIi2Op)
78 |
79 | Preparation: put the images with the same id in one folder. You may use
80 | ```bash
81 | python prepare-market.py # for Market-1501
82 | ```
83 | Note to modify the dataset path to your own path.
84 |
85 | ### Testing
86 |
87 | #### Download the trained model
88 | We provide our trained model. You may download it from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf). You may download and move it to the `outputs`.
89 | ```
90 | ├── outputs/
91 | │ ├── E0.5new_reid0.5_w30000
92 | ├── models
93 | │ ├── best/
94 | ```
95 | #### Person re-id evaluation
96 | - Supervised learning
97 |
98 | | | Market-1501 | DukeMTMC-reID | MSMT17 | CUHK03-NP |
99 | |---|--------------|----------------|----------|-----------|
100 | | Rank@1 | 94.8% | 86.6% | 77.2% | 65.6% |
101 | | mAP | 86.0% | 74.8% | 52.3% | 61.1% |
102 |
103 |
104 | - Direct transfer learning
105 | To verify the generalizability of DG-Net, we train the model on dataset A and directly test the model on dataset B (with no adaptation).
106 | We denote the direct transfer learning protocol as `A→B`.
107 |
108 | | |Market→Duke|Duke→Market|Market→MSMT|MSMT→Market|Duke→MSMT|MSMT→Duke|
109 | |---|----------------|----------------| -------------- |----------------| -------------- |----------------|
110 | | Rank@1 | 42.62% | 56.12% | 17.11% | 61.76% | 20.59% | 61.89% |
111 | | Rank@5 | 58.57% | 72.18% | 26.66% | 77.67% | 31.67% | 75.81% |
112 | | Rank@10 | 64.63% | 78.12% | 31.62% | 83.25% | 37.04% | 80.34% |
113 | | mAP | 24.25% | 26.83% | 5.41% | 33.62% | 6.35% | 40.69% |
114 |
115 |
116 | #### Image generation evaluation
117 |
118 | Please check the `README.md` in the `./visual_tools`.
119 |
120 | You may use the `./visual_tools/test_folder.py` to generate lots of images and then do the evaluation. The only thing you need to modify is the data path in [SSIM](https://github.com/layumi/PerceptualSimilarity) and [FID](https://github.com/layumi/TTUR).
121 |
122 | ### Training
123 |
124 | #### Train a teacher model
125 | You may directly download our trained teacher model from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf).
126 | If you want to have it trained by yourself, please check the [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch) repository to train a teacher model, then copy and put it in the `./models`.
127 | ```
128 | ├── models/
129 | │ ├── best/ /* teacher model for Market-1501
130 | │ ├── net_last.pth /* model file
131 | │ ├── ...
132 | ```
133 |
134 | #### Train DG-Net
135 | 1. Setup the yaml file. Check out `configs/latest.yaml`. Change the data_root field to the path of your prepared folder-based dataset, e.g. `../Market-1501/pytorch`.
136 |
137 |
138 | 2. Start training
139 | ```
140 | python train.py --config configs/latest.yaml
141 | ```
142 | Or train with low precision (fp16)
143 | ```
144 | python train.py --config configs/latest-fp16.yaml
145 | ```
146 | Intermediate image outputs and model binary files are saved in `outputs/latest`.
147 |
148 | 3. Check the loss log
149 | ```
150 | tensorboard --logdir logs/latest
151 | ```
152 |
153 | ## DG-Market
154 | 
155 |
156 | We provide our generated images and make a large-scale synthetic dataset called DG-Market. This dataset is generated by our DG-Net and consists of 128,307 images (613MB), about 10 times larger than the training set of original Market-1501 (even much more can be generated with DG-Net). It can be used as a source of unlabeled training dataset for semi-supervised learning. You may download the dataset from [Google Drive](https://drive.google.com/file/d/126Gn90Tzpk3zWp2c7OBYPKc-ZjhptKDo/view?usp=sharing) (or [Baidu Disk](https://pan.baidu.com/s/1n4M6s-qvE08J8SOOWtWfgw) password: qxyh).
157 |
158 | | | DG-Market | Market-1501 (training) |
159 | |---|--------------|-------------|
160 | | #identity| - | 751 |
161 | | #images| 128,307 | 12,936 |
162 |
163 | Quick Download via [gdrive](https://github.com/prasmussen/gdrive)
164 | ```bash
165 | wget https://github.com/prasmussen/gdrive/releases/download/2.1.1/gdrive_2.1.1_linux_386.tar.gz
166 | tar -xzvf gdrive_2.1.1_linux_386.tar.gz
167 | gdrive download 126Gn90Tzpk3zWp2c7OBYPKc-ZjhptKDo
168 | unzip DG-Market.zip
169 | ```
170 |
171 | ## Tips
172 | Note the format of camera id and number of cameras. For some datasets (e.g., MSMT17), there are more than 10 cameras. You need to modify the preparation and evaluation code to read the double-digit camera id. For some vehicle re-id datasets (e.g., VeRi) having different naming rules, you also need to modify the preparation and evaluation code.
173 |
174 | ## Citation
175 | Please cite this paper if it helps your research:
176 | ```bibtex
177 | @inproceedings{zheng2019joint,
178 | title={Joint discriminative and generative learning for person re-identification},
179 | author={Zheng, Zhedong and Yang, Xiaodong and Yu, Zhiding and Zheng, Liang and Yang, Yi and Kautz, Jan},
180 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
181 | year={2019}
182 | }
183 | ```
184 |
185 | ## Related Work
186 | Other GAN-based methods compared in the paper include [LSGAN](https://github.com/layumi/DCGAN-pytorch), [FDGAN](https://github.com/layumi/FD-GAN) and [PG2GAN](https://github.com/charliememory/Pose-Guided-Person-Image-Generation). We forked the code and made some changes for evaluatation, thank the authors for their great work. We would also like to thank to the great projects in [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch), [MUNIT](https://github.com/NVlabs/MUNIT) and [DRIT](https://github.com/HsinYingLee/DRIT).
187 |
188 | ## License
189 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**). The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com).
190 |
--------------------------------------------------------------------------------
/configs/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/configs/.DS_Store
--------------------------------------------------------------------------------
/configs/latest-fp16.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | apex: true # Set True to use float16.
5 | B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the paper.
6 | ID_class: 751 # The number of ID classes in the dataset. For example, 751 for Market, 702 for DukeMTMC
7 | ID_stride: 1 # Stride in Appearance encoder
8 | ID_style: AB # For time being, we only support AB. In the future, we will support PCB.
9 | batch_size: 8 # BatchSize
10 | beta1: 0 # Adam hyperparameter
11 | beta2: 0.999 # Adam hyperparameter
12 | crop_image_height: 256 # Input height
13 | crop_image_width: 128 # Input width
14 | data_root: ../Market/pytorch/ # Dataset Root
15 | dis:
16 | LAMBDA: 0.01 # the hyperparameter for the regularization term
17 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
18 | dim: 32 # number of filters in the bottommost layer
19 | gan_type: lsgan # GAN loss [lsgan/nsgan]
20 | n_layer: 2 # number of layers in D
21 | n_res: 4 # number of layers in D
22 | non_local: 0 # number of non_local layers
23 | norm: none # normalization layer [none/bn/in/ln]
24 | num_scales: 3 # number of scales
25 | pad_type: reflect # padding type [zero/reflect]
26 | display_size: 16 # How much display images
27 | erasing_p: 0.5 # Random erasing probability [0-1]
28 | gamma: 0.1 # Learning Rate Decay (except appearance encoder)
29 | gamma2: 0.1 # Learning Rate Decay (for appearance encoder)
30 | gan_w: 1 # the weight of gan loss
31 | gen:
32 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
33 | dec: basic # [basic/parallel/series]
34 | dim: 16 # number of filters in the bottommost layer
35 | dropout: 0 # use dropout in the generator
36 | id_dim: 2048 # length of appearance code
37 | mlp_dim: 512 # number of filters in MLP
38 | mlp_norm: none # norm in mlp [none/bn/in/ln]
39 | n_downsample: 2 # number of downsampling layers in content encoder
40 | n_res: 4 # number of residual blocks in content encoder/decoder
41 | non_local: 0 # number of non_local layer
42 | pad_type: reflect # padding type [zero/reflect]
43 | tanh: false # use tanh or not at the last layer
44 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
45 | id_w: 1.0 # the weight of ID loss
46 | image_display_iter: 5000 # How often do you want to display output images during training
47 | image_save_iter: 5000 # How often do you want to save output images during training
48 | input_dim_a: 1 # We use the gray-scale input, so the input dim is 1
49 | input_dim_b: 1 # We use the gray-scale input, so the input dim is 1
50 | log_iter: 1 # How often do you want to log the training stats
51 | lr2: 0.002 # initial appearance encoder learning rate
52 | lr_d: 0.0001 # initial discriminator learning rate
53 | lr_g: 0.0001 # initial generator (except appearance encoder) learning rate
54 | lr_policy: multistep # learning rate scheduler [multistep|constant|step]
55 | max_cyc_w: 2 # the maximum weight for cycle loss
56 | max_iter: 100000 # When you end the training
57 | max_teacher_w: 2 # the maximum weight for prime loss (teacher KL loss)
58 | max_w: 1 # the maximum weight for feature reconstruction losses
59 | new_size: 128 # the resized size
60 | norm_id: false # Do we normalize the appearance code
61 | num_workers: 8 # nworks to load the data
62 | pid_w: 1.0 # positive ID loss
63 | pool: max # pooling layer for the appearance encoder
64 | recon_s_w: 0 # the initial weight for structure code reconstruction
65 | recon_f_w: 0 # the initial weight for appearance code reconstruction
66 | recon_id_w: 0.5 # the initial weight for ID reconstruction
67 | recon_x_cyc_w: 0 # the initial weight for cycle reconstruction
68 | recon_x_w: 5 # the initial weight for self-reconstruction
69 | recon_xp_w: 5 # the initial weight for self-identity reconstruction
70 | single: gray # make input to gray-scale
71 | snapshot_save_iter: 10000 # How often to save the checkpoint
72 | sqrt: false # whether use square loss.
73 | step_size: 60000 # when to decay the learning rate
74 | teacher: best # teacher model name. For DukeMTMC, you may set `best-duke`
75 | teacher_w: 0 # the initial weight for prime loss (teacher KL loss)
76 | teacher_style: 0 # select teacher style.[0-4] # 0: Our smooth dynamic label# 1: Pseudo label, hard dynamic label# 2: Conditional label, hard static label # 3: LSRO, static smooth label# 4: Dynamic Soft Two-label
77 | train_bn: true # whether we train the bn for the generated image.
78 | use_decoder_again: true # whether we train the decoder on the generatd image.
79 | use_encoder_again: 0.5 # the probability we train the structure encoder on the generatd image.
80 | vgg_w: 0 # We do not use vgg as one kind of inception loss.
81 | warm_iter: 30000 # when to start warm up the losses (fine-grained/feature reconstruction losses).
82 | warm_scale: 0.0005 # how fast to warm up
83 | warm_teacher_iter: 30000 # when to start warm up the prime loss
84 | weight_decay: 0.0005 # weight decay
85 |
--------------------------------------------------------------------------------
/configs/latest.yaml:
--------------------------------------------------------------------------------
1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved.
2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
3 |
4 | apex: false # Set True to use float16.
5 | B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the paper.
6 | ID_class: 751 # The number of ID classes in the dataset. For example, 751 for Market, 702 for DukeMTMC
7 | ID_stride: 1 # Stride in Appearance encoder
8 | ID_style: AB # For time being, we only support AB. In the future, we will support PCB.
9 | batch_size: 8 # BatchSize
10 | beta1: 0 # Adam hyperparameter
11 | beta2: 0.999 # Adam hyperparameter
12 | crop_image_height: 256 # Input height
13 | crop_image_width: 128 # Input width
14 | data_root: ../Market/pytorch/ # Dataset Root
15 | dis:
16 | LAMBDA: 0.01 # the hyperparameter for the regularization term
17 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
18 | dim: 32 # number of filters in the bottommost layer
19 | gan_type: lsgan # GAN loss [lsgan/nsgan]
20 | n_layer: 2 # number of layers in D
21 | n_res: 4 # number of layers in D
22 | non_local: 0 # number of non_local layers
23 | norm: none # normalization layer [none/bn/in/ln]
24 | num_scales: 3 # number of scales
25 | pad_type: reflect # padding type [zero/reflect]
26 | display_size: 16 # How much display images
27 | erasing_p: 0.5 # Random erasing probability [0-1]
28 | gamma: 0.1 # Learning Rate Decay (except appearance encoder)
29 | gamma2: 0.1 # Learning Rate Decay (for appearance encoder)
30 | gan_w: 1 # the weight of gan loss
31 | gen:
32 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh]
33 | dec: basic # [basic/parallel/series]
34 | dim: 16 # number of filters in the bottommost layer
35 | dropout: 0 # use dropout in the generator
36 | id_dim: 2048 # length of appearance code
37 | mlp_dim: 512 # number of filters in MLP
38 | mlp_norm: none # norm in mlp [none/bn/in/ln]
39 | n_downsample: 2 # number of downsampling layers in content encoder
40 | n_res: 4 # number of residual blocks in content encoder/decoder
41 | non_local: 0 # number of non_local layer
42 | pad_type: reflect # padding type [zero/reflect]
43 | tanh: false # use tanh or not at the last layer
44 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal]
45 | id_w: 1.0 # the weight of ID loss
46 | image_display_iter: 5000 # How often do you want to display output images during training
47 | image_save_iter: 5000 # How often do you want to save output images during training
48 | input_dim_a: 1 # We use the gray-scale input, so the input dim is 1
49 | input_dim_b: 1 # We use the gray-scale input, so the input dim is 1
50 | log_iter: 1 # How often do you want to log the training stats
51 | lr2: 0.002 # initial appearance encoder learning rate
52 | lr_d: 0.0001 # initial discriminator learning rate
53 | lr_g: 0.0001 # initial generator (except appearance encoder) learning rate
54 | lr_policy: multistep # learning rate scheduler [multistep|constant|step]
55 | max_cyc_w: 2 # the maximum weight for cycle loss
56 | max_iter: 100000 # When you end the training
57 | max_teacher_w: 2 # the maximum weight for prime loss (teacher KL loss)
58 | max_w: 1 # the maximum weight for feature reconstruction losses
59 | new_size: 128 # the resized size
60 | norm_id: false # Do we normalize the appearance code
61 | num_workers: 8 # nworks to load the data
62 | pid_w: 1.0 # positive ID loss
63 | pool: max # pooling layer for the appearance encoder
64 | recon_s_w: 0 # the initial weight for structure code reconstruction
65 | recon_f_w: 0 # the initial weight for appearance code reconstruction
66 | recon_id_w: 0.5 # the initial weight for ID reconstruction
67 | recon_x_cyc_w: 0 # the initial weight for cycle reconstruction
68 | recon_x_w: 5 # the initial weight for self-reconstruction
69 | recon_xp_w: 5 # the initial weight for self-identity reconstruction
70 | single: gray # make input to gray-scale
71 | snapshot_save_iter: 10000 # How often to save the checkpoint
72 | sqrt: false # whether use square loss.
73 | step_size: 60000 # when to decay the learning rate
74 | teacher: best # teacher model name. For DukeMTMC, you may set `best-duke`
75 | teacher_w: 0 # the initial weight for prime loss (teacher KL loss)
76 | teacher_style: 0 # select teacher style.[0-4] # 0: Our smooth dynamic label# 1: Pseudo label, hard dynamic label# 2: Conditional label, hard static label # 3: LSRO, static smooth label# 4: Dynamic Soft Two-label
77 | train_bn: true # whether we train the bn for the generated image.
78 | use_decoder_again: true # whether we train the decoder on the generatd image.
79 | use_encoder_again: 0.5 # the probability we train the structure encoder on the generatd image.
80 | vgg_w: 0 # We do not use vgg as one kind of inception loss.
81 | warm_iter: 30000 # when to start warm up the losses (fine-grained/feature reconstruction losses).
82 | warm_scale: 0.0005 # how fast to warm up
83 | warm_teacher_iter: 30000 # when to start warm up the prime loss
84 | weight_decay: 0.0005 # weight decay
85 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | import torch.utils.data as data
6 | import os.path
7 |
8 | def default_loader(path):
9 | return Image.open(path).convert('RGB')
10 |
11 |
12 | def default_flist_reader(flist):
13 | """
14 | flist format: impath label\nimpath label\n ...(same to caffe's filelist)
15 | """
16 | imlist = []
17 | with open(flist, 'r') as rf:
18 | for line in rf.readlines():
19 | impath = line.strip()
20 | imlist.append(impath)
21 |
22 | return imlist
23 |
24 |
25 | class ImageFilelist(data.Dataset):
26 | def __init__(self, root, flist, transform=None,
27 | flist_reader=default_flist_reader, loader=default_loader):
28 | self.root = root
29 | self.imlist = flist_reader(flist)
30 | self.transform = transform
31 | self.loader = loader
32 |
33 | def __getitem__(self, index):
34 | impath = self.imlist[index]
35 | img = self.loader(os.path.join(self.root, impath))
36 | if self.transform is not None:
37 | img = self.transform(img)
38 |
39 | return img
40 |
41 | def __len__(self):
42 | return len(self.imlist)
43 |
44 |
45 | class ImageLabelFilelist(data.Dataset):
46 | def __init__(self, root, flist, transform=None,
47 | flist_reader=default_flist_reader, loader=default_loader):
48 | self.root = root
49 | self.imlist = flist_reader(os.path.join(self.root, flist))
50 | self.transform = transform
51 | self.loader = loader
52 | self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist])))
53 | self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))}
54 | self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist]
55 |
56 | def __getitem__(self, index):
57 | impath, label = self.imgs[index]
58 | img = self.loader(os.path.join(self.root, impath))
59 | if self.transform is not None:
60 | img = self.transform(img)
61 | return img, label
62 |
63 | def __len__(self):
64 | return len(self.imgs)
65 |
66 | ###############################################################################
67 | # Code from
68 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py
69 | # Modified the original code so that it also loads images from the current
70 | # directory as well as the subdirectories
71 | ###############################################################################
72 |
73 | import torch.utils.data as data
74 |
75 | from PIL import Image
76 | import os
77 | import os.path
78 |
79 | IMG_EXTENSIONS = [
80 | '.jpg', '.JPG', '.jpeg', '.JPEG',
81 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
82 | ]
83 |
84 |
85 | def is_image_file(filename):
86 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
87 |
88 |
89 | def make_dataset(dir):
90 | images = []
91 | assert os.path.isdir(dir), '%s is not a valid directory' % dir
92 |
93 | for root, _, fnames in sorted(os.walk(dir)):
94 | for fname in fnames:
95 | if is_image_file(fname):
96 | path = os.path.join(root, fname)
97 | images.append(path)
98 |
99 | return images
100 |
101 |
102 | class ImageFolder(data.Dataset):
103 |
104 | def __init__(self, root, transform=None, return_paths=False,
105 | loader=default_loader):
106 | imgs = sorted(make_dataset(root))
107 | if len(imgs) == 0:
108 | raise(RuntimeError("Found 0 images in: " + root + "\n"
109 | "Supported image extensions are: " +
110 | ",".join(IMG_EXTENSIONS)))
111 |
112 | self.root = root
113 | self.imgs = imgs
114 | self.transform = transform
115 | self.return_paths = return_paths
116 | self.loader = loader
117 |
118 | def __getitem__(self, index):
119 | path = self.imgs[index]
120 | img = self.loader(path)
121 | if self.transform is not None:
122 | img = self.transform(img)
123 | if self.return_paths:
124 | return img, path
125 | else:
126 | return img
127 |
128 | def __len__(self):
129 | return len(self.imgs)
130 |
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/outputs/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/prepare-market.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import os
7 | from shutil import copyfile
8 |
9 | # You only need to change this line to your dataset download path
10 | download_path = '../Market'
11 |
12 | if not os.path.isdir(download_path):
13 | print('please change the download_path')
14 |
15 | save_path = download_path + '/pytorch'
16 | if not os.path.isdir(save_path):
17 | os.mkdir(save_path)
18 | #-----------------------------------------
19 | #query
20 | query_path = download_path + '/query'
21 | query_save_path = download_path + '/pytorch/query'
22 | if not os.path.isdir(query_save_path):
23 | os.mkdir(query_save_path)
24 |
25 | for root, dirs, files in os.walk(query_path, topdown=True):
26 | for name in files:
27 | if not name[-3:]=='jpg':
28 | continue
29 | ID = name.split('_')
30 | src_path = query_path + '/' + name
31 | dst_path = query_save_path + '/' + ID[0]
32 | if not os.path.isdir(dst_path):
33 | os.mkdir(dst_path)
34 | copyfile(src_path, dst_path + '/' + name)
35 |
36 | #-----------------------------------------
37 | #multi-query
38 | query_path = download_path + '/gt_bbox'
39 | # for dukemtmc-reid, we do not need multi-query
40 | if os.path.isdir(query_path):
41 | query_save_path = download_path + '/pytorch/multi-query'
42 | if not os.path.isdir(query_save_path):
43 | os.mkdir(query_save_path)
44 |
45 | for root, dirs, files in os.walk(query_path, topdown=True):
46 | for name in files:
47 | if not name[-3:]=='jpg':
48 | continue
49 | ID = name.split('_')
50 | src_path = query_path + '/' + name
51 | dst_path = query_save_path + '/' + ID[0]
52 | if not os.path.isdir(dst_path):
53 | os.mkdir(dst_path)
54 | copyfile(src_path, dst_path + '/' + name)
55 |
56 | #-----------------------------------------
57 | #gallery
58 | gallery_path = download_path + '/bounding_box_test'
59 | gallery_save_path = download_path + '/pytorch/gallery'
60 | if not os.path.isdir(gallery_save_path):
61 | os.mkdir(gallery_save_path)
62 |
63 | for root, dirs, files in os.walk(gallery_path, topdown=True):
64 | for name in files:
65 | if not name[-3:]=='jpg':
66 | continue
67 | ID = name.split('_')
68 | src_path = gallery_path + '/' + name
69 | dst_path = gallery_save_path + '/' + ID[0]
70 | if not os.path.isdir(dst_path):
71 | os.mkdir(dst_path)
72 | copyfile(src_path, dst_path + '/' + name)
73 |
74 | #---------------------------------------
75 | #train_all
76 | train_path = download_path + '/bounding_box_train'
77 | train_save_path = download_path + '/pytorch/train_all'
78 | if not os.path.isdir(train_save_path):
79 | os.mkdir(train_save_path)
80 |
81 | for root, dirs, files in os.walk(train_path, topdown=True):
82 | for name in files:
83 | if not name[-3:]=='jpg':
84 | continue
85 | ID = name.split('_')
86 | src_path = train_path + '/' + name
87 | dst_path = train_save_path + '/' + ID[0]
88 | if not os.path.isdir(dst_path):
89 | os.mkdir(dst_path)
90 | copyfile(src_path, dst_path + '/' + name)
91 |
92 |
93 | #---------------------------------------
94 | #train_val
95 | train_path = download_path + '/bounding_box_train'
96 | train_save_path = download_path + '/pytorch/train'
97 | val_save_path = download_path + '/pytorch/val'
98 | if not os.path.isdir(train_save_path):
99 | os.mkdir(train_save_path)
100 | os.mkdir(val_save_path)
101 |
102 | for root, dirs, files in os.walk(train_path, topdown=True):
103 | for name in files:
104 | if not name[-3:]=='jpg':
105 | continue
106 | ID = name.split('_')
107 | src_path = train_path + '/' + name
108 | dst_path = train_save_path + '/' + ID[0]
109 | if not os.path.isdir(dst_path):
110 | os.mkdir(dst_path)
111 | dst_path = val_save_path + '/' + ID[0] #first image is used as val image
112 | os.mkdir(dst_path)
113 | copyfile(src_path, dst_path + '/' + name)
114 |
--------------------------------------------------------------------------------
/random_erasing.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from __future__ import absolute_import
7 |
8 | from torchvision.transforms import *
9 |
10 | import random
11 | import math
12 |
13 | class RandomErasing(object):
14 | """ Randomly selects a rectangle region in an image and erases its pixels.
15 | 'Random Erasing Data Augmentation' by Zhong et al.
16 | See https://arxiv.org/pdf/1708.04896.pdf
17 | Args:
18 | probability: The probability that the Random Erasing operation will be performed.
19 | sl: Minimum proportion of erased area against input image.
20 | sh: Maximum proportion of erased area against input image.
21 | r1: Minimum aspect ratio of erased area.
22 | mean: Erasing value.
23 | """
24 |
25 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]):
26 | self.probability = probability
27 | self.mean = mean
28 | self.sl = sl
29 | self.sh = sh
30 | self.r1 = r1
31 | random.seed(7)
32 |
33 | def __call__(self, img):
34 |
35 | if random.uniform(0, 1) > self.probability:
36 | return img
37 |
38 | for attempt in range(100):
39 | area = img.size()[1] * img.size()[2]
40 |
41 | target_area = random.uniform(self.sl, self.sh) * area
42 | aspect_ratio = random.uniform(self.r1, 1/self.r1)
43 |
44 | h = int(round(math.sqrt(target_area * aspect_ratio)))
45 | w = int(round(math.sqrt(target_area / aspect_ratio)))
46 |
47 | if w < img.size()[2] and h < img.size()[1]:
48 | x1 = random.randint(0, img.size()[1] - h)
49 | y1 = random.randint(0, img.size()[2] - w)
50 | if img.size()[0] == 3:
51 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
52 | img[1, x1:x1+h, y1:y1+w] = self.mean[1]
53 | img[2, x1:x1+h, y1:y1+w] = self.mean[2]
54 | else:
55 | img[0, x1:x1+h, y1:y1+w] = self.mean[0]
56 | return img.detach()
57 |
58 | return img.detach()
59 |
--------------------------------------------------------------------------------
/reIDfolder.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from torchvision import datasets
7 | import os
8 | import numpy as np
9 | import random
10 |
11 | class ReIDFolder(datasets.ImageFolder):
12 |
13 | def __init__(self, root, transform):
14 | super(ReIDFolder, self).__init__(root, transform)
15 | targets = np.asarray([s[1] for s in self.samples])
16 | self.targets = targets
17 | self.img_num = len(self.samples)
18 | print(self.img_num)
19 |
20 | def _get_cam_id(self, path):
21 | camera_id = []
22 | filename = os.path.basename(path)
23 | camera_id = filename.split('c')[1][0]
24 | return int(camera_id)-1
25 |
26 | def _get_pos_sample(self, target, index, path):
27 | pos_index = np.argwhere(self.targets == target)
28 | pos_index = pos_index.flatten()
29 | pos_index = np.setdiff1d(pos_index, index)
30 | if len(pos_index)==0: # in the query set, only one sample
31 | return path
32 | else:
33 | rand = random.randint(0,len(pos_index)-1)
34 | return self.samples[pos_index[rand]][0]
35 |
36 | def _get_neg_sample(self, target):
37 | neg_index = np.argwhere(self.targets != target)
38 | neg_index = neg_index.flatten()
39 | rand = random.randint(0,len(neg_index)-1)
40 | return self.samples[neg_index[rand]]
41 |
42 | def __getitem__(self, index):
43 | path, target = self.samples[index]
44 | sample = self.loader(path)
45 |
46 | pos_path = self._get_pos_sample(target, index, path)
47 | pos = self.loader(pos_path)
48 |
49 | if self.transform is not None:
50 | sample = self.transform(sample)
51 | pos = self.transform(pos)
52 |
53 | if self.target_transform is not None:
54 | target = self.target_transform(target)
55 |
56 | return sample, target, pos
57 |
58 |
--------------------------------------------------------------------------------
/reIDmodel.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | from torch.nn import init
9 | from torchvision import models
10 |
11 | ######################################################################
12 | def weights_init_kaiming(m):
13 | classname = m.__class__.__name__
14 | if classname.find('Conv') != -1:
15 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
16 | elif classname.find('Linear') != -1:
17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out')
18 | init.constant_(m.bias.data, 0.0)
19 | elif classname.find('InstanceNorm1d') != -1:
20 | init.normal_(m.weight.data, 1.0, 0.02)
21 | init.constant_(m.bias.data, 0.0)
22 |
23 | def weights_init_classifier(m):
24 | classname = m.__class__.__name__
25 | if classname.find('Linear') != -1:
26 | init.normal_(m.weight.data, std=0.001)
27 | init.constant_(m.bias.data, 0.0)
28 |
29 | def fix_bn(m):
30 | classname = m.__class__.__name__
31 | if classname.find('BatchNorm') != -1:
32 | m.eval()
33 |
34 | # Defines the new fc layer and classification layer
35 | # |--Linear--|--bn--|--relu--|--Linear--|
36 | class ClassBlock(nn.Module):
37 | def __init__(self, input_dim, class_num, droprate=0.5, relu=False, num_bottleneck=512):
38 | super(ClassBlock, self).__init__()
39 | add_block = []
40 | add_block += [nn.Linear(input_dim, num_bottleneck)]
41 | #num_bottleneck = input_dim # We remove the input_dim
42 | add_block += [nn.BatchNorm1d(num_bottleneck, affine=True)]
43 | if relu:
44 | add_block += [nn.LeakyReLU(0.1)]
45 | if droprate>0:
46 | add_block += [nn.Dropout(p=droprate)]
47 | add_block = nn.Sequential(*add_block)
48 | add_block.apply(weights_init_kaiming)
49 |
50 | classifier = []
51 | classifier += [nn.Linear(num_bottleneck, class_num)]
52 | classifier = nn.Sequential(*classifier)
53 | classifier.apply(weights_init_classifier)
54 |
55 | self.add_block = add_block
56 | self.classifier = classifier
57 | def forward(self, x):
58 | x = self.add_block(x)
59 | x = self.classifier(x)
60 | return x
61 |
62 | # Define the ResNet50-based Model
63 | class ft_net(nn.Module):
64 |
65 | def __init__(self, class_num, norm=False, pool='avg', stride=2):
66 | super(ft_net, self).__init__()
67 | if norm:
68 | self.norm = True
69 | else:
70 | self.norm = False
71 | model_ft = models.resnet50(pretrained=True)
72 | # avg pooling to global pooling
73 | self.part = 4
74 | if pool=='max':
75 | model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1))
76 | model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1))
77 | else:
78 | model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1))
79 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
80 | # remove the final downsample
81 | if stride == 1:
82 | model_ft.layer4[0].downsample[0].stride = (1,1)
83 | model_ft.layer4[0].conv2.stride = (1,1)
84 |
85 | self.model = model_ft
86 | self.classifier = ClassBlock(2048, class_num)
87 |
88 | def forward(self, x):
89 | x = self.model.conv1(x)
90 | x = self.model.bn1(x)
91 | x = self.model.relu(x)
92 | x = self.model.maxpool(x)
93 | x = self.model.layer1(x)
94 | x = self.model.layer2(x) # -> 512 32*16
95 | x = self.model.layer3(x)
96 | x = self.model.layer4(x)
97 | f = self.model.partpool(x) # 8 * 2048 4*1
98 | x = self.model.avgpool(x) # 8 * 2048 1*1
99 |
100 | x = x.view(x.size(0),x.size(1))
101 | f = f.view(f.size(0),f.size(1)*self.part)
102 | if self.norm:
103 | fnorm = torch.norm(f, p=2, dim=1, keepdim=True) + 1e-8
104 | f = f.div(fnorm.expand_as(f))
105 | x = self.classifier(x)
106 | return f, x
107 |
108 | # Define the AB Model
109 | class ft_netAB(nn.Module):
110 |
111 | def __init__(self, class_num, norm=False, stride=2, droprate=0.5, pool='avg'):
112 | super(ft_netAB, self).__init__()
113 | model_ft = models.resnet50(pretrained=True)
114 | self.part = 4
115 | if pool=='max':
116 | model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1))
117 | model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1))
118 | else:
119 | model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1))
120 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
121 |
122 | self.model = model_ft
123 |
124 | if stride == 1:
125 | self.model.layer4[0].downsample[0].stride = (1,1)
126 | self.model.layer4[0].conv2.stride = (1,1)
127 |
128 | self.classifier1 = ClassBlock(2048, class_num, 0.5)
129 | self.classifier2 = ClassBlock(2048, class_num, 0.75)
130 |
131 | def forward(self, x):
132 | x = self.model.conv1(x)
133 | x = self.model.bn1(x)
134 | x = self.model.relu(x)
135 | x = self.model.maxpool(x)
136 | x = self.model.layer1(x)
137 | x = self.model.layer2(x)
138 | x = self.model.layer3(x)
139 | x = self.model.layer4(x)
140 | f = self.model.partpool(x)
141 | f = f.view(f.size(0),f.size(1)*self.part)
142 | f = f.detach() # no gradient
143 | x = self.model.avgpool(x)
144 | x = x.view(x.size(0), x.size(1))
145 | x1 = self.classifier1(x)
146 | x2 = self.classifier2(x)
147 | x=[]
148 | x.append(x1)
149 | x.append(x2)
150 | return f, x
151 |
152 | # Define the DenseNet121-based Model
153 | class ft_net_dense(nn.Module):
154 |
155 | def __init__(self, class_num ):
156 | super().__init__()
157 | model_ft = models.densenet121(pretrained=True)
158 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1))
159 | model_ft.fc = nn.Sequential()
160 | self.model = model_ft
161 | # For DenseNet, the feature dim is 1024
162 | self.classifier = ClassBlock(1024, class_num)
163 |
164 | def forward(self, x):
165 | x = self.model.features(x)
166 | x = torch.squeeze(x)
167 | x = self.classifier(x)
168 | return x
169 |
170 | # Define the ResNet50-based Model (Middle-Concat)
171 | # In the spirit of "The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching." Yu, Qian, et al. arXiv:1711.08106 (2017).
172 | class ft_net_middle(nn.Module):
173 |
174 | def __init__(self, class_num ):
175 | super(ft_net_middle, self).__init__()
176 | model_ft = models.resnet50(pretrained=True)
177 | # avg pooling to global pooling
178 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1))
179 | self.model = model_ft
180 | self.classifier = ClassBlock(2048+1024, class_num)
181 |
182 | def forward(self, x):
183 | x = self.model.conv1(x)
184 | x = self.model.bn1(x)
185 | x = self.model.relu(x)
186 | x = self.model.maxpool(x)
187 | x = self.model.layer1(x)
188 | x = self.model.layer2(x)
189 | x = self.model.layer3(x)
190 | # x0 n*1024*1*1
191 | x0 = self.model.avgpool(x)
192 | x = self.model.layer4(x)
193 | # x1 n*2048*1*1
194 | x1 = self.model.avgpool(x)
195 | x = torch.cat((x0,x1),1)
196 | x = torch.squeeze(x)
197 | x = self.classifier(x)
198 | return x
199 |
200 | # Part Model proposed in Yifan Sun etal. (2018)
201 | class PCB(nn.Module):
202 | def __init__(self, class_num ):
203 | super(PCB, self).__init__()
204 |
205 | self.part = 4 # We cut the pool5 to 4 parts
206 | model_ft = models.resnet50(pretrained=True)
207 | self.model = model_ft
208 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1))
209 | self.dropout = nn.Dropout(p=0.5)
210 | # remove the final downsample
211 | self.model.layer4[0].downsample[0].stride = (1,1)
212 | self.model.layer4[0].conv2.stride = (1,1)
213 | self.softmax = nn.Softmax(dim=1)
214 | # define 4 classifiers
215 | for i in range(self.part):
216 | name = 'classifier'+str(i)
217 | setattr(self, name, ClassBlock(2048, class_num, True, False, 256))
218 |
219 | def forward(self, x):
220 | x = self.model.conv1(x)
221 | x = self.model.bn1(x)
222 | x = self.model.relu(x)
223 | x = self.model.maxpool(x)
224 |
225 | x = self.model.layer1(x)
226 | x = self.model.layer2(x)
227 | x = self.model.layer3(x)
228 | x = self.model.layer4(x)
229 | x = self.avgpool(x)
230 | f = x
231 | f = f.view(f.size(0),f.size(1)*self.part)
232 | x = self.dropout(x)
233 | part = {}
234 | predict = {}
235 | # get part feature batchsize*2048*4
236 | for i in range(self.part):
237 | part[i] = x[:,:,i].contiguous()
238 | part[i] = part[i].view(x.size(0), x.size(1))
239 | name = 'classifier'+str(i)
240 | c = getattr(self,name)
241 | predict[i] = c(part[i])
242 |
243 | y=[]
244 | for i in range(self.part):
245 | y.append(predict[i])
246 |
247 | return f, y
248 |
249 | class PCB_test(nn.Module):
250 | def __init__(self,model):
251 | super(PCB_test,self).__init__()
252 | self.part = 6
253 | self.model = model.model
254 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1))
255 | # remove the final downsample
256 | self.model.layer3[0].downsample[0].stride = (1,1)
257 | self.model.layer3[0].conv2.stride = (1,1)
258 |
259 | self.model.layer4[0].downsample[0].stride = (1,1)
260 | self.model.layer4[0].conv2.stride = (1,1)
261 |
262 | def forward(self, x):
263 | x = self.model.conv1(x)
264 | x = self.model.bn1(x)
265 | x = self.model.relu(x)
266 | x = self.model.maxpool(x)
267 |
268 | x = self.model.layer1(x)
269 | x = self.model.layer2(x)
270 | x = self.model.layer3(x)
271 | x = self.model.layer4(x)
272 | x = self.avgpool(x)
273 | y = x.view(x.size(0),x.size(1),x.size(2))
274 | return y
275 |
276 |
277 |
--------------------------------------------------------------------------------
/reid_eval/README.md:
--------------------------------------------------------------------------------
1 | ## Evaluation
2 |
3 | - For Market-1501
4 | ```bash
5 | python test_2label.py --name E0.5new_reid0.5_w30000 --which_epoch 100000 --multi
6 | ```
7 | The result is `Rank@1:0.9477 Rank@5:0.9798 Rank@10:0.9878 mAP:0.8609`.
8 | `--name` model name
9 |
10 | `--which_epoch` selects the i-th model
11 |
12 | `--multi` extracts and evaluates the multiply query. The result is `multi Rank@1:0.9608 Rank@5:0.9860 Rank@10:0.9923 mAP:0.9044`.
13 |
--------------------------------------------------------------------------------
/reid_eval/evaluate_gpu.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
4 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
5 | """
6 |
7 | import scipy.io
8 | import torch
9 | import numpy as np
10 | import os
11 | import matplotlib
12 | matplotlib.use('agg')
13 | import matplotlib.pyplot as plt
14 | #######################################################################
15 | # Evaluate
16 |
17 | def evaluate(qf,ql,qc,gf,gl,gc):
18 | query = qf.view(-1,1)
19 | score = torch.mm(gf,query)
20 | score = score.squeeze(1).cpu()
21 | score = score.numpy()
22 | # predict index
23 | index = np.argsort(score) #from small to large
24 | index = index[::-1]
25 | # good index
26 | query_index = np.argwhere(gl==ql)
27 | #same camera
28 | camera_index = np.argwhere(gc==qc)
29 |
30 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True)
31 | junk_index1 = np.argwhere(gl==-1)
32 | junk_index2 = np.intersect1d(query_index, camera_index)
33 | junk_index = np.append(junk_index2, junk_index1) #.flatten())
34 |
35 | CMC_tmp = compute_mAP(index, qc, good_index, junk_index)
36 | return CMC_tmp
37 |
38 |
39 | def compute_mAP(index, qc, good_index, junk_index):
40 | ap = 0
41 | cmc = torch.IntTensor(len(index)).zero_()
42 | if good_index.size==0: # if empty
43 | cmc[0] = -1
44 | return ap,cmc
45 |
46 | # remove junk_index
47 | ranked_camera = gallery_cam[index]
48 | mask = np.in1d(index, junk_index, invert=True)
49 | mask2 = np.in1d(index, np.append(good_index,junk_index), invert=True)
50 | index = index[mask]
51 | ranked_camera = ranked_camera[mask]
52 |
53 | # find good_index index
54 | ngood = len(good_index)
55 | mask = np.in1d(index, good_index)
56 | rows_good = np.argwhere(mask==True)
57 | rows_good = rows_good.flatten()
58 |
59 | cmc[rows_good[0]:] = 1
60 | for i in range(ngood):
61 | d_recall = 1.0/ngood
62 | precision = (i+1)*1.0/(rows_good[i]+1)
63 | if rows_good[i]!=0:
64 | old_precision = i*1.0/rows_good[i]
65 | else:
66 | old_precision=1.0
67 | ap = ap + d_recall*(old_precision + precision)/2
68 |
69 | return ap, cmc
70 |
71 | ######################################################################
72 | result = scipy.io.loadmat('pytorch_result.mat')
73 | query_feature = torch.FloatTensor(result['query_f'])
74 | query_cam = result['query_cam'][0]
75 | query_label = result['query_label'][0]
76 | gallery_feature = torch.FloatTensor(result['gallery_f'])
77 | gallery_cam = result['gallery_cam'][0]
78 | gallery_label = result['gallery_label'][0]
79 |
80 | multi = os.path.isfile('multi_query.mat')
81 |
82 | if multi:
83 | m_result = scipy.io.loadmat('multi_query.mat')
84 | mquery_feature = torch.FloatTensor(m_result['mquery_f'])
85 | mquery_cam = m_result['mquery_cam'][0]
86 | mquery_label = m_result['mquery_label'][0]
87 | mquery_feature = mquery_feature.cuda()
88 |
89 | query_feature = query_feature.cuda()
90 | gallery_feature = gallery_feature.cuda()
91 |
92 | print(query_feature.shape)
93 | alpha = [0, 0.5, -1]
94 | for j in range(len(alpha)):
95 | CMC = torch.IntTensor(len(gallery_label)).zero_()
96 | ap = 0.0
97 | for i in range(len(query_label)):
98 | qf = query_feature[i].clone()
99 | if alpha[j] == -1:
100 | qf[0:512] *= 0
101 | else:
102 | qf[512:1024] *= alpha[j]
103 | ap_tmp, CMC_tmp = evaluate(qf,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
104 | if CMC_tmp[0]==-1:
105 | continue
106 | CMC = CMC + CMC_tmp
107 | ap += ap_tmp
108 |
109 | CMC = CMC.float()
110 | CMC = CMC/len(query_label) #average CMC
111 | print('Alpha:%.2f Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f'%(alpha[j], CMC[0],CMC[4],CMC[9],ap/len(query_label)))
112 |
113 | # multiple-query
114 | CMC = torch.IntTensor(len(gallery_label)).zero_()
115 | ap = 0.0
116 | if multi:
117 | malpha = 0.5 ######
118 | for i in range(len(query_label)):
119 | mquery_index1 = np.argwhere(mquery_label==query_label[i])
120 | mquery_index2 = np.argwhere(mquery_cam==query_cam[i])
121 | mquery_index = np.intersect1d(mquery_index1, mquery_index2)
122 | mq = torch.mean(mquery_feature[mquery_index,:], dim=0)
123 | mq[512:1024] *= malpha
124 | ap_tmp, CMC_tmp = evaluate(mq,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam)
125 | if CMC_tmp[0]==-1:
126 | continue
127 | CMC = CMC + CMC_tmp
128 | ap += ap_tmp
129 | CMC = CMC.float()
130 | CMC = CMC/len(query_label) #average CMC
131 | print('multi Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label)))
132 |
--------------------------------------------------------------------------------
/reid_eval/test_2label.py:
--------------------------------------------------------------------------------
1 |
2 | """
3 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
4 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
5 | """
6 |
7 | from __future__ import print_function, division
8 |
9 | import sys
10 | sys.path.append('..')
11 | import argparse
12 | import torch
13 | import torch.nn as nn
14 | import torch.optim as optim
15 | from torch.optim import lr_scheduler
16 | from torch.autograd import Variable
17 | import numpy as np
18 | import torchvision
19 | from torchvision import datasets, models, transforms
20 | import time
21 | import os
22 | import scipy.io
23 | import yaml
24 | from reIDmodel import ft_net, ft_netAB, ft_net_dense, PCB, PCB_test
25 |
26 | ######################################################################
27 | # Options
28 | # --------
29 | parser = argparse.ArgumentParser(description='Training')
30 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
31 | parser.add_argument('--which_epoch',default=90000, type=int, help='80000')
32 | parser.add_argument('--test_dir',default='../../Market/pytorch',type=str, help='./test_data')
33 | parser.add_argument('--name', default='test', type=str, help='save model path')
34 | parser.add_argument('--batchsize', default=80, type=int, help='batchsize')
35 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' )
36 | parser.add_argument('--PCB', action='store_true', help='use PCB' )
37 | parser.add_argument('--multi', action='store_true', help='use multiple query' )
38 |
39 | opt = parser.parse_args()
40 |
41 | str_ids = opt.gpu_ids.split(',')
42 | which_epoch = opt.which_epoch
43 | name = opt.name
44 | test_dir = opt.test_dir
45 |
46 | gpu_ids = []
47 | for str_id in str_ids:
48 | id = int(str_id)
49 | if id >=0:
50 | gpu_ids.append(id)
51 |
52 | # set gpu ids
53 | if len(gpu_ids)>0:
54 | torch.cuda.set_device(gpu_ids[0])
55 |
56 | ######################################################################
57 | # Load Data
58 | # ---------
59 | #
60 | # We will use torchvision and torch.utils.data packages for loading the
61 | # data.
62 | #
63 | data_transforms = transforms.Compose([
64 | transforms.Resize((256,128), interpolation=3),
65 | transforms.ToTensor(),
66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
67 | ])
68 |
69 | if opt.PCB:
70 | data_transforms = transforms.Compose([
71 | transforms.Resize((384,192), interpolation=3),
72 | transforms.ToTensor(),
73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
74 | ])
75 |
76 |
77 | data_dir = test_dir
78 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']}
79 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize,
80 | shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']}
81 |
82 | class_names = image_datasets['query'].classes
83 | use_gpu = torch.cuda.is_available()
84 |
85 | ######################################################################
86 | # Load model
87 | #---------------------------
88 | def load_network(network):
89 | save_path = os.path.join('../outputs',name,'checkpoints/id_%08d.pt'%opt.which_epoch)
90 | state_dict = torch.load(save_path)
91 | network.load_state_dict(state_dict['a'], strict=False)
92 | return network
93 |
94 |
95 | ######################################################################
96 | # Extract feature
97 | # ----------------------
98 | #
99 | # Extract feature from a trained model.
100 | #
101 | def fliplr(img):
102 | '''flip horizontal'''
103 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W
104 | img_flip = img.index_select(3,inv_idx)
105 | return img_flip
106 |
107 | def norm(f):
108 | f = f.squeeze()
109 | fnorm = torch.norm(f, p=2, dim=1, keepdim=True)
110 | f = f.div(fnorm.expand_as(f))
111 | return f
112 |
113 | def extract_feature(model,dataloaders):
114 | features = torch.FloatTensor()
115 | count = 0
116 | for data in dataloaders:
117 | img, label = data
118 | n, c, h, w = img.size()
119 | count += n
120 | if opt.use_dense:
121 | ff = torch.FloatTensor(n,1024).zero_()
122 | else:
123 | ff = torch.FloatTensor(n,1024).zero_()
124 | if opt.PCB:
125 | ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts
126 | for i in range(2):
127 | if(i==1):
128 | img = fliplr(img)
129 | input_img = Variable(img.cuda())
130 | f, x = model(input_img)
131 | x[0] = norm(x[0])
132 | x[1] = norm(x[1])
133 | f = torch.cat((x[0],x[1]), dim=1) #use 512-dim feature
134 | f = f.data.cpu()
135 | ff = ff+f
136 |
137 | ff[:, 0:512] = norm(ff[:, 0:512])
138 | ff[:, 512:1024] = norm(ff[:, 512:1024])
139 |
140 | # norm feature
141 | if opt.PCB:
142 | # feature size (n,2048,6)
143 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature.
144 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6).
145 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6)
146 | ff = ff.div(fnorm.expand_as(ff))
147 | ff = ff.view(ff.size(0), -1)
148 |
149 | features = torch.cat((features,ff), 0)
150 | return features
151 |
152 | def get_id(img_path):
153 | camera_id = []
154 | labels = []
155 | for path, v in img_path:
156 | #filename = path.split('/')[-1]
157 | filename = os.path.basename(path)
158 | label = filename[0:4]
159 | camera = filename.split('c')[1]
160 | if label[0:2]=='-1':
161 | labels.append(-1)
162 | else:
163 | labels.append(int(label))
164 | camera_id.append(int(camera[0]))
165 | return camera_id, labels
166 |
167 | gallery_path = image_datasets['gallery'].imgs
168 | query_path = image_datasets['query'].imgs
169 | mquery_path = image_datasets['multi-query'].imgs
170 |
171 | gallery_cam,gallery_label = get_id(gallery_path)
172 | query_cam,query_label = get_id(query_path)
173 | mquery_cam,mquery_label = get_id(mquery_path)
174 |
175 | ######################################################################
176 | # Load Collected data Trained model
177 | print('-------test-----------')
178 |
179 | ###load config###
180 | config_path = os.path.join('../outputs',name,'config.yaml')
181 | with open(config_path, 'r') as stream:
182 | config = yaml.safe_load(stream)
183 |
184 | model_structure = ft_netAB(config['ID_class'], norm=config['norm_id'], stride=config['ID_stride'], pool=config['pool'])
185 |
186 | if opt.PCB:
187 | model_structure = PCB(config['ID_class'])
188 |
189 | model = load_network(model_structure)
190 |
191 | # Remove the final fc layer and classifier layer
192 | model.model.fc = nn.Sequential()
193 | model.classifier1.classifier = nn.Sequential()
194 | model.classifier2.classifier = nn.Sequential()
195 |
196 | # Change to test mode
197 | model = model.eval()
198 | if use_gpu:
199 | model = model.cuda()
200 |
201 | # Extract feature
202 | since = time.time()
203 | with torch.no_grad():
204 | gallery_feature = extract_feature(model,dataloaders['gallery'])
205 | query_feature = extract_feature(model,dataloaders['query'])
206 | time_elapsed = time.time() - since
207 | print('Extract features complete in {:.0f}m {:.0f}s'.format(
208 | time_elapsed // 60, time_elapsed % 60))
209 | if opt.multi:
210 | mquery_feature = extract_feature(model,dataloaders['multi-query'])
211 |
212 | # Save to Matlab for check
213 | result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam}
214 | scipy.io.savemat('pytorch_result.mat',result)
215 | if opt.multi:
216 | result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam}
217 | scipy.io.savemat('multi_query.mat',result)
218 |
219 | os.system('python evaluate_gpu.py')
220 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from utils import get_all_data_loaders, prepare_sub_folder, write_loss, get_config, write_2images, Timer
6 | import argparse
7 | from trainer import DGNet_Trainer
8 | import torch.backends.cudnn as cudnn
9 | import torch
10 | import numpy.random as random
11 | try:
12 | from itertools import izip as zip
13 | except ImportError: # will be 3.x series
14 | pass
15 | import os
16 | import sys
17 | import tensorboardX
18 | import shutil
19 |
20 | parser = argparse.ArgumentParser()
21 | parser.add_argument('--config', type=str, default='configs/latest.yaml', help='Path to the config file.')
22 | parser.add_argument('--output_path', type=str, default='.', help="outputs path")
23 | parser.add_argument('--name', type=str, default='latest_ablation', help="outputs path")
24 | parser.add_argument("--resume", action="store_true")
25 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet")
26 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2')
27 | opts = parser.parse_args()
28 |
29 | str_ids = opts.gpu_ids.split(',')
30 | gpu_ids = []
31 | for str_id in str_ids:
32 | gpu_ids.append(int(str_id))
33 | num_gpu = len(gpu_ids)
34 |
35 | cudnn.benchmark = True
36 |
37 | # Load experiment setting
38 | if opts.resume:
39 | config = get_config('./outputs/'+opts.name+'/config.yaml')
40 | else:
41 | config = get_config(opts.config)
42 | max_iter = config['max_iter']
43 | display_size = config['display_size']
44 | config['vgg_model_path'] = opts.output_path
45 |
46 | # Setup model and data loader
47 | if opts.trainer == 'DGNet':
48 | trainer = DGNet_Trainer(config, gpu_ids)
49 | trainer.cuda()
50 |
51 | random.seed(7) #fix random result
52 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config)
53 | train_a_rand = random.permutation(train_loader_a.dataset.img_num)[0:display_size]
54 | train_b_rand = random.permutation(train_loader_b.dataset.img_num)[0:display_size]
55 | test_a_rand = random.permutation(test_loader_a.dataset.img_num)[0:display_size]
56 | test_b_rand = random.permutation(test_loader_b.dataset.img_num)[0:display_size]
57 |
58 | train_display_images_a = torch.stack([train_loader_a.dataset[i][0] for i in train_a_rand]).cuda()
59 | train_display_images_ap = torch.stack([train_loader_a.dataset[i][2] for i in train_a_rand]).cuda()
60 | train_display_images_b = torch.stack([train_loader_b.dataset[i][0] for i in train_b_rand]).cuda()
61 | train_display_images_bp = torch.stack([train_loader_b.dataset[i][2] for i in train_b_rand]).cuda()
62 | test_display_images_a = torch.stack([test_loader_a.dataset[i][0] for i in test_a_rand]).cuda()
63 | test_display_images_ap = torch.stack([test_loader_a.dataset[i][2] for i in test_a_rand]).cuda()
64 | test_display_images_b = torch.stack([test_loader_b.dataset[i][0] for i in test_b_rand]).cuda()
65 | test_display_images_bp = torch.stack([test_loader_b.dataset[i][2] for i in test_b_rand]).cuda()
66 |
67 | # Setup logger and output folders
68 | if not opts.resume:
69 | model_name = os.path.splitext(os.path.basename(opts.config))[0]
70 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name))
71 | output_directory = os.path.join(opts.output_path + "/outputs", model_name)
72 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
73 | shutil.copyfile(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder
74 | shutil.copyfile('trainer.py', os.path.join(output_directory, 'trainer.py')) # copy file to output folder
75 | shutil.copyfile('reIDmodel.py', os.path.join(output_directory, 'reIDmodel.py')) # copy file to output folder
76 | shutil.copyfile('networks.py', os.path.join(output_directory, 'networks.py')) # copy file to output folder
77 | else:
78 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", opts.name))
79 | output_directory = os.path.join(opts.output_path + "/outputs", opts.name)
80 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory)
81 | # Start training
82 | iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0
83 | config['epoch_iteration'] = round( train_loader_a.dataset.img_num / config['batch_size'] )
84 | print('Every epoch need %d iterations'%config['epoch_iteration'])
85 | nepoch = 0
86 |
87 | print('Note that dataloader may hang with too much nworkers.')
88 |
89 | if num_gpu>1:
90 | print('Now you are using %d gpus.'%num_gpu)
91 | trainer.dis_a = torch.nn.DataParallel(trainer.dis_a, gpu_ids)
92 | trainer.dis_b = trainer.dis_a
93 | trainer = torch.nn.DataParallel(trainer, gpu_ids)
94 |
95 | while True:
96 | for it, ((images_a,labels_a, pos_a), (images_b, labels_b, pos_b)) in enumerate(zip(train_loader_a, train_loader_b)):
97 | if num_gpu>1:
98 | trainer.module.update_learning_rate()
99 | else:
100 | trainer.update_learning_rate()
101 | images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach()
102 | pos_a, pos_b = pos_a.cuda().detach(), pos_b.cuda().detach()
103 | labels_a, labels_b = labels_a.cuda().detach(), labels_b.cuda().detach()
104 |
105 | with Timer("Elapsed time in update: %f"):
106 | # Main training code
107 | x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p = \
108 | trainer.forward(images_a, images_b, pos_a, pos_b)
109 | if num_gpu>1:
110 | trainer.module.dis_update(x_ab.clone(), x_ba.clone(), images_a, images_b, config, num_gpu)
111 | trainer.module.gen_update(x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, images_a, images_b, pos_a, pos_b, labels_a, labels_b, config, iterations, num_gpu)
112 | else:
113 | trainer.dis_update(x_ab.clone(), x_ba.clone(), images_a, images_b, config, num_gpu=1)
114 | trainer.gen_update(x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, images_a, images_b, pos_a, pos_b, labels_a, labels_b, config, iterations, num_gpu=1)
115 |
116 | torch.cuda.synchronize()
117 |
118 | # Dump training stats in log file
119 | if (iterations + 1) % config['log_iter'] == 0:
120 | print("\033[1m Epoch: %02d Iteration: %08d/%08d \033[0m" % (nepoch, iterations + 1, max_iter), end=" ")
121 | if num_gpu==1:
122 | write_loss(iterations, trainer, train_writer)
123 | else:
124 | write_loss(iterations, trainer.module, train_writer)
125 |
126 | # Write images
127 | if (iterations + 1) % config['image_save_iter'] == 0:
128 | with torch.no_grad():
129 | if num_gpu>1:
130 | test_image_outputs = trainer.module.sample(test_display_images_a, test_display_images_b)
131 | else:
132 | test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b)
133 | write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1))
134 | del test_image_outputs
135 |
136 | if (iterations + 1) % config['image_display_iter'] == 0:
137 | with torch.no_grad():
138 | if num_gpu>1:
139 | image_outputs = trainer.module.sample(train_display_images_a, train_display_images_b)
140 | else:
141 | image_outputs = trainer.sample(train_display_images_a, train_display_images_b)
142 | write_2images(image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1))
143 | del image_outputs
144 | # Save network weights
145 | if (iterations + 1) % config['snapshot_save_iter'] == 0:
146 | if num_gpu>1:
147 | trainer.module.save(checkpoint_directory, iterations)
148 | else:
149 | trainer.save(checkpoint_directory, iterations)
150 |
151 | iterations += 1
152 | if iterations >= max_iter:
153 | sys.exit('Finish training')
154 |
155 | # Save network weights by epoch number
156 | nepoch = nepoch+1
157 | if(nepoch + 1) % 10 == 0:
158 | if num_gpu>1:
159 | trainer.module.save(checkpoint_directory, iterations)
160 | else:
161 | trainer.save(checkpoint_directory, iterations)
162 |
163 |
--------------------------------------------------------------------------------
/trainer.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 | from networks import AdaINGen, MsImageDis
6 | from reIDmodel import ft_net, ft_netAB, PCB
7 | from utils import get_model_list, vgg_preprocess, load_vgg16, get_scheduler
8 | from torch.autograd import Variable
9 | import torch
10 | import torch.nn as nn
11 | import copy
12 | import os
13 | import cv2
14 | import numpy as np
15 | from random_erasing import RandomErasing
16 | import random
17 | import yaml
18 |
19 | #fp16
20 | try:
21 | from apex import amp
22 | from apex.fp16_utils import *
23 | except ImportError:
24 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0')
25 |
26 |
27 | def to_gray(half=False): #simple
28 | def forward(x):
29 | x = torch.mean(x, dim=1, keepdim=True)
30 | if half:
31 | x = x.half()
32 | return x
33 | return forward
34 |
35 | def to_edge(x):
36 | x = x.data.cpu()
37 | out = torch.FloatTensor(x.size(0), x.size(2), x.size(3))
38 | for i in range(x.size(0)):
39 | xx = recover(x[i,:,:,:]) # 3 channel, 256x128x3
40 | xx = cv2.cvtColor(xx, cv2.COLOR_RGB2GRAY) # 256x128x1
41 | xx = cv2.Canny(xx, 10, 200) #256x128
42 | xx = xx/255.0 - 0.5 # {-0.5,0.5}
43 | xx += np.random.randn(xx.shape[0],xx.shape[1])*0.1 #add random noise
44 | xx = torch.from_numpy(xx.astype(np.float32))
45 | out[i,:,:] = xx
46 | out = out.unsqueeze(1)
47 | return out.cuda()
48 |
49 | def scale2(x):
50 | if x.size(2) > 128: # do not need to scale the input
51 | return x
52 | x = torch.nn.functional.upsample(x, scale_factor=2, mode='nearest') #bicubic is not available for the time being.
53 | return x
54 |
55 | def recover(inp):
56 | inp = inp.numpy().transpose((1, 2, 0))
57 | mean = np.array([0.485, 0.456, 0.406])
58 | std = np.array([0.229, 0.224, 0.225])
59 | inp = std * inp + mean
60 | inp = inp * 255.0
61 | inp = np.clip(inp, 0, 255)
62 | inp = inp.astype(np.uint8)
63 | return inp
64 |
65 | def train_bn(m):
66 | classname = m.__class__.__name__
67 | if classname.find('BatchNorm') != -1:
68 | m.train()
69 |
70 | def fliplr(img):
71 | '''flip horizontal'''
72 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W
73 | img_flip = img.index_select(3,inv_idx)
74 | return img_flip
75 |
76 | def update_teacher(model_s, model_t, alpha=0.999):
77 | for param_s, param_t in zip(model_s.parameters(), model_t.parameters()):
78 | param_t.data.mul_(alpha).add_(1 - alpha, param_s.data)
79 |
80 | def predict_label(teacher_models, inputs, num_class, alabel, slabel, teacher_style=0):
81 | # teacher_style:
82 | # 0: Our smooth dynamic label
83 | # 1: Pseudo label, hard dynamic label
84 | # 2: Conditional label, hard static label
85 | # 3: LSRO, static smooth label
86 | # 4: Dynamic Soft Two-label
87 | # alabel is appearance label
88 | if teacher_style == 0:
89 | count = 0
90 | sm = nn.Softmax(dim=1)
91 | for teacher_model in teacher_models:
92 | _, outputs_t1 = teacher_model(inputs)
93 | outputs_t1 = sm(outputs_t1.detach())
94 | _, outputs_t2 = teacher_model(fliplr(inputs))
95 | outputs_t2 = sm(outputs_t2.detach())
96 | if count==0:
97 | outputs_t = outputs_t1 + outputs_t2
98 | else:
99 | outputs_t = outputs_t * opt.alpha # old model decay
100 | outputs_t += outputs_t1 + outputs_t2
101 | count +=2
102 | elif teacher_style == 1: # dynamic one-hot label
103 | count = 0
104 | sm = nn.Softmax(dim=1)
105 | for teacher_model in teacher_models:
106 | _, outputs_t1 = teacher_model(inputs)
107 | outputs_t1 = sm(outputs_t1.detach()) # change softmax to max
108 | _, outputs_t2 = teacher_model(fliplr(inputs))
109 | outputs_t2 = sm(outputs_t2.detach())
110 | if count==0:
111 | outputs_t = outputs_t1 + outputs_t2
112 | else:
113 | outputs_t = outputs_t * opt.alpha # old model decay
114 | outputs_t += outputs_t1 + outputs_t2
115 | count +=2
116 | _, dlabel = torch.max(outputs_t.data, 1)
117 | outputs_t = torch.zeros(inputs.size(0), num_class).cuda()
118 | for i in range(inputs.size(0)):
119 | outputs_t[i, dlabel[i]] = 1
120 | elif teacher_style == 2: # appearance label
121 | outputs_t = torch.zeros(inputs.size(0), num_class).cuda()
122 | for i in range(inputs.size(0)):
123 | outputs_t[i, alabel[i]] = 1
124 | elif teacher_style == 3: # LSRO
125 | outputs_t = torch.ones(inputs.size(0), num_class).cuda()
126 | elif teacher_style == 4: #Two-label
127 | count = 0
128 | sm = nn.Softmax(dim=1)
129 | for teacher_model in teacher_models:
130 | _, outputs_t1 = teacher_model(inputs)
131 | outputs_t1 = sm(outputs_t1.detach())
132 | _, outputs_t2 = teacher_model(fliplr(inputs))
133 | outputs_t2 = sm(outputs_t2.detach())
134 | if count==0:
135 | outputs_t = outputs_t1 + outputs_t2
136 | else:
137 | outputs_t = outputs_t * opt.alpha # old model decay
138 | outputs_t += outputs_t1 + outputs_t2
139 | count +=2
140 | mask = torch.zeros(outputs_t.shape)
141 | mask = mask.cuda()
142 | for i in range(inputs.size(0)):
143 | mask[i, alabel[i]] = 1
144 | mask[i, slabel[i]] = 1
145 | outputs_t = outputs_t*mask
146 | else:
147 | print('not valid style. teacher-style is in [0-3].')
148 |
149 | s = torch.sum(outputs_t, dim=1, keepdim=True)
150 | s = s.expand_as(outputs_t)
151 | outputs_t = outputs_t/s
152 | return outputs_t
153 |
154 | ######################################################################
155 | # Load model
156 | #---------------------------
157 | def load_network(network, name):
158 | save_path = os.path.join('./models',name,'net_last.pth')
159 | network.load_state_dict(torch.load(save_path))
160 | return network
161 |
162 | def load_config(name):
163 | config_path = os.path.join('./models',name,'opts.yaml')
164 | with open(config_path, 'r') as stream:
165 | config = yaml.safe_load(stream)
166 | return config
167 |
168 |
169 | class DGNet_Trainer(nn.Module):
170 | def __init__(self, hyperparameters, gpu_ids=[0]):
171 | super(DGNet_Trainer, self).__init__()
172 | lr_g = hyperparameters['lr_g']
173 | lr_d = hyperparameters['lr_d']
174 | ID_class = hyperparameters['ID_class']
175 | if not 'apex' in hyperparameters.keys():
176 | hyperparameters['apex'] = False
177 | self.fp16 = hyperparameters['apex']
178 | # Initiate the networks
179 | # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False.
180 | self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False) # auto-encoder for domain a
181 | self.gen_b = self.gen_a # auto-encoder for domain b
182 |
183 | if not 'ID_stride' in hyperparameters.keys():
184 | hyperparameters['ID_stride'] = 2
185 |
186 | if hyperparameters['ID_style']=='PCB':
187 | self.id_a = PCB(ID_class)
188 | elif hyperparameters['ID_style']=='AB':
189 | self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool'])
190 | else:
191 | self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now
192 |
193 | self.id_b = self.id_a
194 | self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False) # discriminator for domain a
195 | self.dis_b = self.dis_a # discriminator for domain b
196 |
197 | # load teachers
198 | if hyperparameters['teacher'] != "":
199 | teacher_name = hyperparameters['teacher']
200 | print(teacher_name)
201 | teacher_names = teacher_name.split(',')
202 | teacher_model = nn.ModuleList()
203 | teacher_count = 0
204 | for teacher_name in teacher_names:
205 | config_tmp = load_config(teacher_name)
206 | if 'stride' in config_tmp:
207 | stride = config_tmp['stride']
208 | else:
209 | stride = 2
210 | model_tmp = ft_net(ID_class, stride = stride)
211 | teacher_model_tmp = load_network(model_tmp, teacher_name)
212 | teacher_model_tmp.model.fc = nn.Sequential() # remove the original fc layer in ImageNet
213 | teacher_model_tmp = teacher_model_tmp.cuda()
214 | if self.fp16:
215 | teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1")
216 | teacher_model.append(teacher_model_tmp.cuda().eval())
217 | teacher_count +=1
218 | self.teacher_model = teacher_model
219 | if hyperparameters['train_bn']:
220 | self.teacher_model = self.teacher_model.apply(train_bn)
221 |
222 | self.instancenorm = nn.InstanceNorm2d(512, affine=False)
223 |
224 | # RGB to one channel
225 | if hyperparameters['single']=='edge':
226 | self.single = to_edge
227 | else:
228 | self.single = to_gray(False)
229 |
230 | # Random Erasing when training
231 | if not 'erasing_p' in hyperparameters.keys():
232 | self.erasing_p = 0
233 | else:
234 | self.erasing_p = hyperparameters['erasing_p']
235 | self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0])
236 |
237 | if not 'T_w' in hyperparameters.keys():
238 | hyperparameters['T_w'] = 1
239 | # Setup the optimizers
240 | beta1 = hyperparameters['beta1']
241 | beta2 = hyperparameters['beta2']
242 | dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters())
243 | gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters())
244 |
245 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad],
246 | lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
247 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad],
248 | lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay'])
249 | # id params
250 | if hyperparameters['ID_style']=='PCB':
251 | ignored_params = (list(map(id, self.id_a.classifier0.parameters() ))
252 | +list(map(id, self.id_a.classifier1.parameters() ))
253 | +list(map(id, self.id_a.classifier2.parameters() ))
254 | +list(map(id, self.id_a.classifier3.parameters() ))
255 | )
256 | base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
257 | lr2 = hyperparameters['lr2']
258 | self.id_opt = torch.optim.SGD([
259 | {'params': base_params, 'lr': lr2},
260 | {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10},
261 | {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
262 | {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10},
263 | {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10}
264 | ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
265 | elif hyperparameters['ID_style']=='AB':
266 | ignored_params = (list(map(id, self.id_a.classifier1.parameters()))
267 | + list(map(id, self.id_a.classifier2.parameters())))
268 | base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
269 | lr2 = hyperparameters['lr2']
270 | self.id_opt = torch.optim.SGD([
271 | {'params': base_params, 'lr': lr2},
272 | {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10},
273 | {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}
274 | ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
275 | else:
276 | ignored_params = list(map(id, self.id_a.classifier.parameters() ))
277 | base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters())
278 | lr2 = hyperparameters['lr2']
279 | self.id_opt = torch.optim.SGD([
280 | {'params': base_params, 'lr': lr2},
281 | {'params': self.id_a.classifier.parameters(), 'lr': lr2*10}
282 | ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True)
283 |
284 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters)
285 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters)
286 | self.id_scheduler = get_scheduler(self.id_opt, hyperparameters)
287 | self.id_scheduler.gamma = hyperparameters['gamma2']
288 |
289 | #ID Loss
290 | self.id_criterion = nn.CrossEntropyLoss()
291 | self.criterion_teacher = nn.KLDivLoss(size_average=False)
292 | # Load VGG model if needed
293 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0:
294 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models')
295 | self.vgg.eval()
296 | for param in self.vgg.parameters():
297 | param.requires_grad = False
298 |
299 | # save memory
300 | if self.fp16:
301 | # Name the FP16_Optimizer instance to replace the existing optimizer
302 | assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled."
303 | self.gen_a = self.gen_a.cuda()
304 | self.dis_a = self.dis_a.cuda()
305 | self.id_a = self.id_a.cuda()
306 |
307 | self.gen_b = self.gen_a
308 | self.dis_b = self.dis_a
309 | self.id_b = self.id_a
310 |
311 | self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1")
312 | self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1")
313 | self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1")
314 |
315 | def to_re(self, x):
316 | out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3))
317 | out = out.cuda()
318 | for i in range(x.size(0)):
319 | out[i,:,:,:] = self.single_re(x[i,:,:,:])
320 | return out
321 |
322 | def recon_criterion(self, input, target):
323 | diff = input - target.detach()
324 | return torch.mean(torch.abs(diff[:]))
325 |
326 | def recon_criterion_sqrt(self, input, target):
327 | diff = input - target
328 | return torch.mean(torch.sqrt(torch.abs(diff[:])+1e-8))
329 |
330 | def recon_criterion2(self, input, target):
331 | diff = input - target
332 | return torch.mean(diff[:]**2)
333 |
334 | def recon_cos(self, input, target):
335 | cos = torch.nn.CosineSimilarity()
336 | cos_dis = 1 - cos(input, target)
337 | return torch.mean(cos_dis[:])
338 |
339 | def forward(self, x_a, x_b, xp_a, xp_b):
340 | s_a = self.gen_a.encode(self.single(x_a))
341 | s_b = self.gen_b.encode(self.single(x_b))
342 | f_a, p_a = self.id_a(scale2(x_a))
343 | f_b, p_b = self.id_b(scale2(x_b))
344 | x_ba = self.gen_a.decode(s_b, f_a)
345 | x_ab = self.gen_b.decode(s_a, f_b)
346 | x_a_recon = self.gen_a.decode(s_a, f_a)
347 | x_b_recon = self.gen_b.decode(s_b, f_b)
348 | fp_a, pp_a = self.id_a(scale2(xp_a))
349 | fp_b, pp_b = self.id_b(scale2(xp_b))
350 | # decode the same person
351 | x_a_recon_p = self.gen_a.decode(s_a, fp_a)
352 | x_b_recon_p = self.gen_b.decode(s_b, fp_b)
353 |
354 | # Random Erasing only effect the ID and PID loss.
355 | if self.erasing_p > 0:
356 | x_a_re = self.to_re(scale2(x_a.clone()))
357 | x_b_re = self.to_re(scale2(x_b.clone()))
358 | xp_a_re = self.to_re(scale2(xp_a.clone()))
359 | xp_b_re = self.to_re(scale2(xp_b.clone()))
360 | _, p_a = self.id_a(x_a_re)
361 | _, p_b = self.id_b(x_b_re)
362 | # encode the same ID different photo
363 | _, pp_a = self.id_a(xp_a_re)
364 | _, pp_b = self.id_b(xp_b_re)
365 |
366 | return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p
367 |
368 | def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu):
369 | # ppa, ppb is the same person
370 | self.gen_opt.zero_grad()
371 | self.id_opt.zero_grad()
372 |
373 | # no gradient
374 | x_ba_copy = Variable(x_ba.data, requires_grad=False)
375 | x_ab_copy = Variable(x_ab.data, requires_grad=False)
376 |
377 | rand_num = random.uniform(0,1)
378 | #################################
379 | # encode structure
380 | if hyperparameters['use_encoder_again']>=rand_num:
381 | # encode again (encoder is tuned, input is fixed)
382 | s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy))
383 | s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy))
384 | else:
385 | # copy the encoder
386 | self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content)
387 | self.enc_content_copy = self.enc_content_copy.eval()
388 | # encode again (encoder is fixed, input is tuned)
389 | s_a_recon = self.enc_content_copy(self.single(x_ab))
390 | s_b_recon = self.enc_content_copy(self.single(x_ba))
391 |
392 | #################################
393 | # encode appearance
394 | self.id_a_copy = copy.deepcopy(self.id_a)
395 | self.id_a_copy = self.id_a_copy.eval()
396 | if hyperparameters['train_bn']:
397 | self.id_a_copy = self.id_a_copy.apply(train_bn)
398 | self.id_b_copy = self.id_a_copy
399 | # encode again (encoder is fixed, input is tuned)
400 | f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba))
401 | f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab))
402 |
403 | # teacher Loss
404 | # Tune the ID model
405 | log_sm = nn.LogSoftmax(dim=1)
406 | if hyperparameters['teacher_w'] >0 and hyperparameters['teacher'] != "":
407 | if hyperparameters['ID_style'] == 'normal':
408 | _, p_a_student = self.id_a(scale2(x_ba_copy))
409 | p_a_student = log_sm(p_a_student)
410 | p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
411 | self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)
412 |
413 | _, p_b_student = self.id_b(scale2(x_ab_copy))
414 | p_b_student = log_sm(p_b_student)
415 | p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
416 | self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)
417 | elif hyperparameters['ID_style'] == 'AB':
418 | # normal teacher-student loss
419 | # BA -> LabelA(smooth) + LabelB(batchB)
420 | _, p_ba_student = self.id_a(scale2(x_ba_copy))# f_a, s_b
421 | p_a_student = log_sm(p_ba_student[0])
422 | with torch.no_grad():
423 | p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style'])
424 | self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0)
425 |
426 | _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a
427 | p_b_student = log_sm(p_ab_student[0])
428 | with torch.no_grad():
429 | p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style'])
430 | self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0)
431 |
432 | # branch b loss
433 | # here we give different label
434 | loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a)
435 | self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B
436 | else:
437 | self.loss_teacher = 0.0
438 |
439 | # auto-encoder image reconstruction
440 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a)
441 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b)
442 | self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a)
443 | self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b)
444 |
445 | # feature reconstruction
446 | self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0
447 | self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0
448 | self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0
449 | self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0
450 |
451 | x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
452 | x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None
453 |
454 | # ID loss AND Tune the Generated image
455 | if hyperparameters['ID_style']=='PCB':
456 | self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b)
457 | self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b)
458 | self.loss_gen_recon_id = self.PCB_loss(p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b)
459 | elif hyperparameters['ID_style']=='AB':
460 | weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w']
461 | self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \
462 | + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) )
463 | self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], l_b) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) )
464 | self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b)
465 | else:
466 | self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(p_b, l_b)
467 | self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(pp_b, l_b)
468 | self.loss_gen_recon_id = self.id_criterion(p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b)
469 |
470 | #print(f_a_recon, f_a)
471 | self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0
472 | self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0
473 | # GAN loss
474 | if num_gpu>1:
475 | self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(self.dis_a, x_ba)
476 | self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(self.dis_b, x_ab)
477 | else:
478 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba)
479 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab)
480 | # domain-invariant perceptual loss
481 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0
482 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0
483 |
484 | if iteration > hyperparameters['warm_iter']:
485 | hyperparameters['recon_f_w'] += hyperparameters['warm_scale']
486 | hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w'])
487 | hyperparameters['recon_s_w'] += hyperparameters['warm_scale']
488 | hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w'])
489 | hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale']
490 | hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w'])
491 |
492 | if iteration > hyperparameters['warm_teacher_iter']:
493 | hyperparameters['teacher_w'] += hyperparameters['warm_scale']
494 | hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w'])
495 | # total loss
496 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \
497 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \
498 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \
499 | hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \
500 | hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \
501 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \
502 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \
503 | hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \
504 | hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \
505 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \
506 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \
507 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \
508 | hyperparameters['id_w'] * self.loss_id + \
509 | hyperparameters['pid_w'] * self.loss_pid + \
510 | hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \
511 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \
512 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \
513 | hyperparameters['teacher_w'] * self.loss_teacher
514 | if self.fp16:
515 | with amp.scale_loss(self.loss_gen_total, [self.gen_opt,self.id_opt]) as scaled_loss:
516 | scaled_loss.backward()
517 | self.gen_opt.step()
518 | self.id_opt.step()
519 | else:
520 | self.loss_gen_total.backward()
521 | self.gen_opt.step()
522 | self.id_opt.step()
523 | print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \
524 | hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \
525 | hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \
526 | hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \
527 | hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \
528 | hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \
529 | hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \
530 | hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \
531 | hyperparameters['id_w'] * self.loss_id,\
532 | hyperparameters['pid_w'] * self.loss_pid,\
533 | hyperparameters['teacher_w'] * self.loss_teacher ) )
534 |
535 | def compute_vgg_loss(self, vgg, img, target):
536 | img_vgg = vgg_preprocess(img)
537 | target_vgg = vgg_preprocess(target)
538 | img_fea = vgg(img_vgg)
539 | target_fea = vgg(target_vgg)
540 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2)
541 |
542 | def PCB_loss(self, inputs, labels):
543 | loss = 0.0
544 | for part in inputs:
545 | loss += self.id_criterion(part, labels)
546 | return loss/len(inputs)
547 |
548 | def sample(self, x_a, x_b):
549 | self.eval()
550 | x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], []
551 | for i in range(x_a.size(0)):
552 | s_a = self.gen_a.encode( self.single(x_a[i].unsqueeze(0)) )
553 | s_b = self.gen_b.encode( self.single(x_b[i].unsqueeze(0)) )
554 | f_a, _ = self.id_a( scale2(x_a[i].unsqueeze(0)))
555 | f_b, _ = self.id_b( scale2(x_b[i].unsqueeze(0)))
556 | x_a_recon.append(self.gen_a.decode(s_a, f_a))
557 | x_b_recon.append(self.gen_b.decode(s_b, f_b))
558 | x_ba = self.gen_a.decode(s_b, f_a)
559 | x_ab = self.gen_b.decode(s_a, f_b)
560 | x_ba1.append(x_ba)
561 | x_ab1.append(x_ab)
562 | #cycle
563 | s_b_recon = self.gen_a.enc_content(self.single(x_ba))
564 | s_a_recon = self.gen_b.enc_content(self.single(x_ab))
565 | f_a_recon, _ = self.id_a(scale2(x_ba))
566 | f_b_recon, _ = self.id_b(scale2(x_ab))
567 | x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon))
568 | x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon))
569 |
570 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon)
571 | x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab)
572 | x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1)
573 | self.train()
574 |
575 | return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1
576 |
577 | def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu):
578 | self.dis_opt.zero_grad()
579 | # D loss
580 | if num_gpu>1:
581 | self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(self.dis_a, x_ba.detach(), x_a)
582 | self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(self.dis_b, x_ab.detach(), x_b)
583 | else:
584 | self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(self.dis_a, x_ba.detach(), x_a)
585 | self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(self.dis_b, x_ab.detach(), x_b)
586 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b
587 | print("DLoss: %.4f"%self.loss_dis_total, "Reg: %.4f"%(reg_a+reg_b) )
588 | if self.fp16:
589 | with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss:
590 | scaled_loss.backward()
591 | else:
592 | self.loss_dis_total.backward()
593 | self.dis_opt.step()
594 |
595 | def update_learning_rate(self):
596 | if self.dis_scheduler is not None:
597 | self.dis_scheduler.step()
598 | if self.gen_scheduler is not None:
599 | self.gen_scheduler.step()
600 | if self.id_scheduler is not None:
601 | self.id_scheduler.step()
602 |
603 | def resume(self, checkpoint_dir, hyperparameters):
604 | # Load generators
605 | last_model_name = get_model_list(checkpoint_dir, "gen")
606 | state_dict = torch.load(last_model_name)
607 | self.gen_a.load_state_dict(state_dict['a'])
608 | self.gen_b = self.gen_a
609 | iterations = int(last_model_name[-11:-3])
610 | # Load discriminators
611 | last_model_name = get_model_list(checkpoint_dir, "dis")
612 | state_dict = torch.load(last_model_name)
613 | self.dis_a.load_state_dict(state_dict['a'])
614 | self.dis_b = self.dis_a
615 | # Load ID dis
616 | last_model_name = get_model_list(checkpoint_dir, "id")
617 | state_dict = torch.load(last_model_name)
618 | self.id_a.load_state_dict(state_dict['a'])
619 | self.id_b = self.id_a
620 | # Load optimizers
621 | try:
622 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt'))
623 | self.dis_opt.load_state_dict(state_dict['dis'])
624 | self.gen_opt.load_state_dict(state_dict['gen'])
625 | self.id_opt.load_state_dict(state_dict['id'])
626 | except:
627 | pass
628 | # Reinitilize schedulers
629 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations)
630 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations)
631 | print('Resume from iteration %d' % iterations)
632 | return iterations
633 |
634 | def save(self, snapshot_dir, iterations, num_gpu=1):
635 | # Save generators, discriminators, and optimizers
636 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1))
637 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1))
638 | id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1))
639 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt')
640 | torch.save({'a': self.gen_a.state_dict()}, gen_name)
641 | if num_gpu>1:
642 | torch.save({'a': self.dis_a.module.state_dict()}, dis_name)
643 | else:
644 | torch.save({'a': self.dis_a.state_dict()}, dis_name)
645 | torch.save({'a': self.id_a.state_dict()}, id_name)
646 | torch.save({'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name)
647 |
648 |
649 |
650 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
4 | """
5 |
6 | from torch.utils.data import DataLoader
7 | from torch.autograd import Variable
8 | from torch.optim import lr_scheduler
9 | from torchvision import transforms
10 | from data import ImageFilelist
11 | from reIDfolder import ReIDFolder
12 | import torch
13 | import os
14 | import math
15 | import torchvision.utils as vutils
16 | import yaml
17 | import numpy as np
18 | import torch.nn.init as init
19 | import time
20 | # Methods
21 | # get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB)
22 | # get_data_loader_list : list-based data loader
23 | # get_data_loader_folder : folder-based data loader
24 | # get_config : load yaml file
25 | # eformat :
26 | # write_2images : save output image
27 | # prepare_sub_folder : create checkpoints and images folders for saving outputs
28 | # write_one_row_html : write one row of the html file for output images
29 | # write_html : create the html file.
30 | # write_loss
31 | # slerp
32 | # get_slerp_interp
33 | # get_model_list
34 | # load_vgg16
35 | # vgg_preprocess
36 | # get_scheduler
37 | # weights_init
38 |
39 | def get_all_data_loaders(conf):
40 | batch_size = conf['batch_size']
41 | num_workers = conf['num_workers']
42 | if 'new_size' in conf:
43 | new_size_a= conf['new_size']
44 | new_size_b = conf['new_size']
45 | else:
46 | new_size_a = conf['new_size_a']
47 | new_size_b = conf['new_size_b']
48 | height = conf['crop_image_height']
49 | width = conf['crop_image_width']
50 |
51 | if 'data_root' in conf:
52 | train_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'train_all'), batch_size, True,
53 | new_size_a, height, width, num_workers, True)
54 | test_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'query'), batch_size, False,
55 | new_size_a, height, width, num_workers, False)
56 | train_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'train_all'), batch_size, True,
57 | new_size_b, height, width, num_workers, True)
58 | test_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'query'), batch_size, False,
59 | new_size_b, height, width, num_workers, False)
60 | else:
61 | train_loader_a = get_data_loader_list(conf['data_folder_train_a'], conf['data_list_train_a'], batch_size, True,
62 | new_size_a, height, width, num_workers, True)
63 | test_loader_a = get_data_loader_list(conf['data_folder_test_a'], conf['data_list_test_a'], batch_size, False,
64 | new_size_a, height, width, num_workers, False)
65 | train_loader_b = get_data_loader_list(conf['data_folder_train_b'], conf['data_list_train_b'], batch_size, True,
66 | new_size_b, height, width, num_workers, True)
67 | test_loader_b = get_data_loader_list(conf['data_folder_test_b'], conf['data_list_test_b'], batch_size, False,
68 | new_size_b, height, width, num_workers, False)
69 | return train_loader_a, train_loader_b, test_loader_a, test_loader_b
70 |
71 |
72 | def get_data_loader_list(root, file_list, batch_size, train, new_size=None,
73 | height=256, width=128, num_workers=4, crop=True):
74 | transform_list = [transforms.ToTensor(),
75 | transforms.Normalize((0.485, 0.456, 0.406),
76 | (0.229, 0.224, 0.225))]
77 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
78 | transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list
79 | transform_list = [transforms.Resize((height, width), interpolation=3)] + transform_list if new_size is not None else transform_list
80 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
81 | transform = transforms.Compose(transform_list)
82 | dataset = ImageFilelist(root, file_list, transform=transform)
83 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
84 | return loader
85 |
86 | def get_data_loader_folder(input_folder, batch_size, train, new_size=None,
87 | height=256, width=128, num_workers=4, crop=True):
88 | transform_list = [transforms.ToTensor(),
89 | transforms.Normalize((0.485, 0.456, 0.406),
90 | (0.229, 0.224, 0.225))]
91 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list
92 | transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list
93 | transform_list = [transforms.Resize((height,width), interpolation=3)] + transform_list if new_size is not None else transform_list
94 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list
95 | transform = transforms.Compose(transform_list)
96 | dataset = ReIDFolder(input_folder, transform=transform)
97 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers)
98 | return loader
99 |
100 |
101 | def get_config(config):
102 | with open(config, 'r') as stream:
103 | return yaml.safe_load(stream)
104 |
105 |
106 | def eformat(f, prec):
107 | s = "%.*e"%(prec, f)
108 | mantissa, exp = s.split('e')
109 | # add 1 to digits as 1 is taken by sign +/-
110 | return "%se%d"%(mantissa, int(exp))
111 |
112 |
113 | def __write_images(image_outputs, display_image_num, file_name):
114 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels
115 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0)
116 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True, scale_each=True)
117 | vutils.save_image(image_grid, file_name, nrow=1)
118 |
119 |
120 | def write_2images(image_outputs, display_image_num, image_directory, postfix):
121 | n = len(image_outputs)
122 | __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix))
123 | __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix))
124 |
125 |
126 | def prepare_sub_folder(output_directory):
127 | image_directory = os.path.join(output_directory, 'images')
128 | if not os.path.exists(image_directory):
129 | print("Creating directory: {}".format(image_directory))
130 | os.makedirs(image_directory)
131 | checkpoint_directory = os.path.join(output_directory, 'checkpoints')
132 | if not os.path.exists(checkpoint_directory):
133 | print("Creating directory: {}".format(checkpoint_directory))
134 | os.makedirs(checkpoint_directory)
135 | return checkpoint_directory, image_directory
136 |
137 |
138 | def write_one_row_html(html_file, iterations, img_filename, all_size):
139 | html_file.write("
142 |
143 |
144 |
145 | """ % (img_filename, img_filename, all_size)) 146 | return 147 | 148 | 149 | def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): 150 | html_file = open(filename, "w") 151 | html_file.write(''' 152 | 153 | 154 |
155 |