├── .gitignore ├── Demo.gif ├── LICENSE.md ├── NxN.jpg ├── README.md ├── configs ├── .DS_Store ├── latest-fp16.yaml └── latest.yaml ├── data.py ├── models └── .gitkeep ├── networks.py ├── outputs └── .gitkeep ├── prepare-market.py ├── random_erasing.py ├── reIDfolder.py ├── reIDmodel.py ├── reid_eval ├── README.md ├── evaluate_gpu.py └── test_2label.py ├── train.py ├── trainer.py ├── utils.py ├── visual_data ├── .DS_Store ├── demo │ ├── .DS_Store │ └── train │ │ ├── 0010_c6s4_002427_02.jpg │ │ └── 0042_c3s3_064169_01.jpg ├── inputs_many_test │ ├── 0 │ │ ├── 0000_c3s1_107067_00.jpg │ │ ├── 0200_c3s3_056178_00.jpg │ │ ├── 0223_c1s6_014296_00.jpg │ │ ├── 0387_c3s1_091167_00.jpg │ │ ├── 0400_c2s1_046426_00.jpg │ │ ├── 0531_c2s1_149316_00.jpg │ │ ├── 1196_c3s1_038576_00.jpg │ │ ├── b0280_c3s3_065619_00.jpg │ │ ├── b0387_c2s1_090996_00.jpg │ │ ├── b0878_c2s2_106457_00.jpg │ │ ├── b1324_c1s6_007741_00.jpg │ │ ├── back0100_c3s2_115894_00.jpg │ │ ├── l0155_c3s1_025176_00.jpg │ │ ├── l0345_c1s2_006466_00.jpg │ │ └── l0736_c6s1_022526_00.jpg │ ├── 1 │ │ ├── 0051_c4s1_005526_00.jpg │ │ ├── 0055_c3s3_076469_00.jpg │ │ ├── 0089_c6s1_013501_00.jpg │ │ ├── 0280_c5s1_078148_00.jpg │ │ ├── 0715_c3s2_059153_00.jpg │ │ ├── 0825_c6s2_097943_00.jpg │ │ ├── 0829_c4s4_043560_00.jpg │ │ ├── 1147_c2s2_158702_00.jpg │ │ ├── 1226_c6s3_038167_00.jpg │ │ ├── 1251_c3s3_020428_00.jpg │ │ ├── 1255_c3s3_021678_00.jpg │ │ ├── 1459_c2s3_052107_00.jpg │ │ ├── bb.jpg │ │ └── bbb.jpg │ └── .DS_Store ├── inputs_many_test_duke │ ├── 1 │ │ ├── 00000.jpg │ │ ├── 0001_c7_f0106449.jpg │ │ ├── 0100_c1_f0100590.jpg │ │ ├── 0120_c7_f0112912.jpg │ │ ├── 0180_c6_f0074316.jpg │ │ ├── 0304_c4_f0069917.jpg │ │ ├── 5842_c7_f0053326.jpg │ │ └── data │ │ │ ├── 0163_c5_f0084927.jpg │ │ │ ├── 0279_c2_f0098340.jpg │ │ │ ├── 0377_c1_f0110464.jpg │ │ │ ├── 0459_c1_f0124423.jpg │ │ │ ├── 0654_c1_f0165081.jpg │ │ │ ├── 0749_c1_f0188863.jpg │ │ │ └── 4141_c8_f0022224.jpg │ └── 2 │ │ └── data │ │ ├── 0201_c2_f0089204.jpg │ │ ├── 0243_c1_f0095564.jpg │ │ ├── 0371_c1_f0110190.jpg │ │ ├── 0690_c4_f0142152.jpg │ │ ├── 0712_c6_f0165928.jpg │ │ ├── 1108_c6_f0176009.jpg │ │ └── 1486_c4_f0069129.jpg ├── inputs_many_test_market │ ├── 1 │ │ ├── .DS_Store │ │ └── data │ │ │ ├── 0051_c3s1_005551_00.jpg │ │ │ ├── 0227_c2s1_046426_00.jpg │ │ │ ├── 0304_c5s1_068523_00.jpg │ │ │ ├── 0378_c2s1_089071_00.jpg │ │ │ ├── 0447_c3s1_111408_00.jpg │ │ │ ├── 1285_c1s5_053991_00.jpg │ │ │ └── l0387_c6s2_006743_00.jpg │ ├── 2 │ │ ├── .DS_Store │ │ └── data │ │ │ ├── 0146_c2s1_023126_00.jpg │ │ │ ├── 0291_c3s3_079969_00.jpg │ │ │ ├── 1154_c5s3_001368_00.jpg │ │ │ ├── 1285_c2s3_021432_00.jpg │ │ │ ├── 1394_c6s3_071892_00.jpg │ │ │ ├── 1429_c2s3_048607_00.jpg │ │ │ └── l1255_c3s3_021678_00.jpg │ └── .DS_Store ├── inputs_many_test_msmt │ ├── 1 │ │ └── data │ │ │ ├── 0367_010_10_0303noon_0606_2.jpg │ │ │ ├── 0434_020_07_0303noon_1785_5_ex.jpg │ │ │ ├── 0494_029_01_0303afternoon_0435_2_ex.jpg │ │ │ ├── 0514_012_05_0303afternoon_0585_1.jpg │ │ │ ├── 0720_027_01_0303afternoon_0771_4.jpg │ │ │ ├── 2744_001_12_0114noon_0960_1_ex.jpg │ │ │ └── 2944_020_01_0114afternoon_1429_2_ex.jpg │ └── 2 │ │ └── data │ │ ├── 0168_018_14_0303morning_1118_1_ex.jpg │ │ ├── 0318_001_01_0303noon_0111_4_ex.jpg │ │ ├── 0771_019_01_0303afternoon_1000_9.jpg │ │ ├── 0917_021_15_0113morning_0174_0.jpg │ │ ├── 2744_037_06_0114noon_0915_0_ex.jpg │ │ ├── 2790_015_05_0114afternoon_0244_1.jpg │ │ ├── 2869_011_05_0114afternoon_0836_0_ex.jpg │ │ └── 3032_043_07_0114afternoon_1088_3.jpg ├── inputs_many_train │ ├── 1 │ │ ├── 0002_c2s1_000351_01.jpg │ │ ├── 0028_c2s1_001751_02.jpg │ │ ├── 0068_c6s1_012351_01.jpg │ │ ├── 0070_c4s1_010576_01.jpg │ │ ├── 0127_c6s1_021426_06.jpg │ │ ├── 0142_c3s3_020203_02.jpg │ │ ├── 0197_c6s1_045076_05.jpg │ │ ├── 0259_c6s1_055126_01.jpg │ │ ├── 0301_c6s1_076101_04.jpg │ │ ├── 0348_c2s1_080671_01.jpg │ │ ├── 0696_c6s2_054993_02.jpg │ │ ├── 0711_c2s2_056057_03.jpg │ │ ├── 0754_c2s2_068182_01.jpg │ │ ├── 0820_c3s2_104428_05.jpg │ │ ├── 0843_c2s2_104907_04.jpg │ │ └── 0901_c6s2_119343_02.jpg │ └── .DS_Store ├── inputs_two │ ├── 1 │ │ ├── 00.jpg │ │ ├── 000.jpg │ │ ├── 0006_c2s3_069327_00.jpg │ │ ├── 01.jpg │ │ ├── 0289_c3s1_063567_00.jpg │ │ └── back0100_c3s2_115894_00.jpg │ ├── 2 │ │ ├── .DS_Store │ │ └── data │ │ │ ├── .DS_Store │ │ │ ├── 0051_c3s1_005551_00.jpg │ │ │ └── 0227_c2s1_046426_00.jpg │ ├── 3 │ │ ├── .DS_Store │ │ └── data │ │ │ ├── .DS_Store │ │ │ ├── 0000_c3s1_107067_00.jpg │ │ │ └── 0531_c2s1_149316_00.jpg │ ├── 4 │ │ ├── .DS_Store │ │ └── data │ │ │ ├── .DS_Store │ │ │ ├── b1324_c1s6_007741_00.jpg │ │ │ └── back0100_c3s2_115894_00.jpg │ └── .DS_Store └── train_sample │ ├── 045_388_gan0734_c3s2_064678_06.jpg │ ├── 104_188_gan0371_c5s1_088523_02.jpg │ ├── 115_000_gan0002_c1s1_069056_02.jpg │ ├── 258_750_gan1500_c5s3_063312_02.jpg │ ├── 344_283_gan0549_c1s3_009921_05.jpg │ ├── 400_344_gan0656_c6s2_045393_01.jpg │ ├── 406_090_gan0179_c3s3_078044_01.jpg │ ├── 496_722_gan1437_c1s6_007216_02.jpg │ ├── 620_357_gan0673_c5s2_045155_01.jpg │ ├── 643_269_gan0519_c6s2_087418_01.jpg │ └── 643_375_gan0706_c6s2_062118_04.jpg └── visual_tools ├── README.md ├── show1by1.py ├── show_rainbow.py ├── show_smooth.py ├── show_smooth_structure.py ├── show_swap.py └── test_folder.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Don't track content of these folders 2 | outputs/ 3 | models/ 4 | logs/ 5 | __pycache__/ 6 | __MACOSX/ 7 | configs/cifs9a31 8 | core 9 | temp/ 10 | temp_output/ 11 | center/ 12 | rainbow/ 13 | 14 | # Compiled source # 15 | ################### 16 | *.com 17 | *.class 18 | *.dll 19 | *.exe 20 | *.o 21 | *.so 22 | *.pyc 23 | 24 | # Packages # 25 | ############ 26 | # it's better to unpack these files and commit the raw source 27 | # git has its own built in compression methods 28 | *.7z 29 | *.dmg 30 | *.gz 31 | *.iso 32 | *.jar 33 | *.rar 34 | *.tar 35 | *.zip 36 | *.mat 37 | *.npy 38 | *.pt 39 | *.pth 40 | 41 | -------------------------------------------------------------------------------- /Demo.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/Demo.gif -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## creative commons 2 | 3 | # Attribution-NonCommercial-ShareAlike 4.0 International 4 | 5 | Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible. 6 | 7 | ### Using Creative Commons Public Licenses 8 | 9 | Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses. 10 | 11 | * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors). 12 | 13 | * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees). 14 | 15 | ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License 16 | 17 | By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions. 18 | 19 | ### Section 1 – Definitions. 20 | 21 | a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image. 22 | 23 | b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License. 24 | 25 | c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License. 26 | 27 | d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights. 28 | 29 | e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements. 30 | 31 | f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material. 32 | 33 | g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike. 34 | 35 | h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License. 36 | 37 | i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license. 38 | 39 | h. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License. 40 | 41 | i. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange. 42 | 43 | j. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them. 44 | 45 | k. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world. 46 | 47 | l. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning. 48 | 49 | ### Section 2 – Scope. 50 | 51 | a. ___License grant.___ 52 | 53 | 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to: 54 | 55 | A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and 56 | 57 | B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only. 58 | 59 | 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions. 60 | 61 | 3. __Term.__ The term of this Public License is specified in Section 6(a). 62 | 63 | 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material. 64 | 65 | 5. __Downstream recipients.__ 66 | 67 | A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License. 68 | 69 | B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply. 70 | 71 | C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material. 72 | 73 | 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i). 74 | 75 | b. ___Other rights.___ 76 | 77 | 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise. 78 | 79 | 2. Patent and trademark rights are not licensed under this Public License. 80 | 81 | 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes. 82 | 83 | ### Section 3 – License Conditions. 84 | 85 | Your exercise of the Licensed Rights is expressly made subject to the following conditions. 86 | 87 | a. ___Attribution.___ 88 | 89 | 1. If You Share the Licensed Material (including in modified form), You must: 90 | 91 | A. retain the following if it is supplied by the Licensor with the Licensed Material: 92 | 93 | i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated); 94 | 95 | ii. a copyright notice; 96 | 97 | iii. a notice that refers to this Public License; 98 | 99 | iv. a notice that refers to the disclaimer of warranties; 100 | 101 | v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable; 102 | 103 | B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and 104 | 105 | C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License. 106 | 107 | 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information. 108 | 109 | 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable. 110 | 111 | b. ___ShareAlike.___ 112 | 113 | In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply. 114 | 115 | 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License. 116 | 117 | 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material. 118 | 119 | 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply. 120 | 121 | ### Section 4 – Sui Generis Database Rights. 122 | 123 | Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material: 124 | 125 | a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only; 126 | 127 | b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and 128 | 129 | c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database. 130 | 131 | For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights. 132 | 133 | ### Section 5 – Disclaimer of Warranties and Limitation of Liability. 134 | 135 | a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__ 136 | 137 | b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__ 138 | 139 | c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability. 140 | 141 | ### Section 6 – Term and Termination. 142 | 143 | a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically. 144 | 145 | b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates: 146 | 147 | 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 148 | 149 | 2. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or 150 | 151 | For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License. 152 | 153 | c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License. 154 | 155 | d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License. 156 | 157 | ### Section 7 – Other Terms and Conditions. 158 | 159 | a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed. 160 | 161 | b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License. 162 | 163 | ### Section 8 – Interpretation. 164 | 165 | a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License. 166 | 167 | b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions. 168 | 169 | c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor. 170 | 171 | d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority. 172 | 173 | ``` 174 | Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at [creativecommons.org/policies](http://creativecommons.org/policies), Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses. 175 | 176 | Creative Commons may be contacted at [creativecommons.org](http://creativecommons.org/). 177 | ``` 178 | 179 | -------------------------------------------------------------------------------- /NxN.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/NxN.jpg -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | [![License CC BY-NC-SA 4.0](https://img.shields.io/badge/license-CC4.0-blue.svg)](https://raw.githubusercontent.com/nvlabs/SPADE/master/LICENSE.md) 2 | ![Python 3.6](https://img.shields.io/badge/python-3.6-green.svg) 3 | [![Language grade: Python](https://img.shields.io/lgtm/grade/python/g/NVlabs/DG-Net.svg?logo=lgtm&logoWidth=18)](https://lgtm.com/projects/g/NVlabs/DG-Net/context:python) 4 | 5 | ## Joint Discriminative and Generative Learning for Person Re-identification 6 | ![](NxN.jpg) 7 | ![](Demo.gif) 8 | 9 | [[Project]](http://zdzheng.xyz/DG-Net/) [[Paper]](https://arxiv.org/abs/1904.07223) [[YouTube]](https://www.youtube.com/watch?v=ubCrEAIpQs4) [[Bilibili]](https://www.bilibili.com/video/av51439240) [[Poster]](http://zdzheng.xyz/files/DGNet_poster.pdf) 10 | [[Supp]](http://jankautz.com/publications/JointReID_CVPR19_supp.pdf) 11 | 12 | Joint Discriminative and Generative Learning for Person Re-identification, CVPR 2019 (Oral)
13 | [Zhedong Zheng](http://zdzheng.xyz/), [Xiaodong Yang](https://xiaodongyang.org/), [Zhiding Yu](https://chrisding.github.io/), [Liang Zheng](http://liangzheng.com.cn/), [Yi Yang](https://www.uts.edu.au/staff/yi.yang), [Jan Kautz](http://jankautz.com/)
14 | 15 | ## Table of contents 16 | * [News](#news) 17 | * [Features](#features) 18 | * [Prerequisites](#prerequisites) 19 | * [Getting Started](#getting-started) 20 | * [Installation](#installation) 21 | * [Dataset Preparation](#dataset-preparation) 22 | * [Testing](#testing) 23 | * [Training](#training) 24 | * [DG-Market](#dg-market) 25 | * [Tips](#tips) 26 | * [Citation](#citation) 27 | * [Related Work](#related-work) 28 | * [License](#license) 29 | 30 | ## News 31 | - 02/18/2021: We release [DG-Net++](https://github.com/NVlabs/DG-Net-PP): the extention of DG-Net for unsupervised cross-domain re-id. 32 | - 08/24/2019: We add the direct transfer learning results of DG-Net [here](https://github.com/NVlabs/DG-Net#person-re-id-evaluation). 33 | - 08/01/2019: We add the support of multi-GPU training: `python train.py --config configs/latest.yaml --gpu_ids 0,1`. 34 | 35 | ## Features 36 | We have supported: 37 | - Multi-GPU training (fp32) 38 | - [APEX](https://github.com/NVIDIA/apex) to save GPU memory (fp16/fp32) 39 | - Multi-query evaluation 40 | - Random erasing 41 | - Visualize training curves 42 | - Generate all figures in the paper 43 | 44 | ## Prerequisites 45 | 46 | - Python 3.6 47 | - GPU memory >= 15G (fp32) 48 | - GPU memory >= 10G (fp16/fp32) 49 | - NumPy 50 | - PyTorch 1.0+ 51 | - [Optional] APEX (fp16/fp32) 52 | 53 | ## Getting Started 54 | ### Installation 55 | - Install [PyTorch](http://pytorch.org/) 56 | - Install torchvision from the source: 57 | ``` 58 | git clone https://github.com/pytorch/vision 59 | cd vision 60 | python setup.py install 61 | ``` 62 | - [Optional] You may skip it. Install APEX from the source: 63 | ``` 64 | git clone https://github.com/NVIDIA/apex.git 65 | cd apex 66 | python setup.py install --cuda_ext --cpp_ext 67 | ``` 68 | - Clone this repo: 69 | ```bash 70 | git clone https://github.com/NVlabs/DG-Net.git 71 | cd DG-Net/ 72 | ``` 73 | 74 | Our code is tested on PyTorch 1.0.0+ and torchvision 0.2.1+ . 75 | 76 | ### Dataset Preparation 77 | Download the dataset [Market-1501](http://www.liangzheng.com.cn/Project/project_reid.html) [[Google Drive]](https://drive.google.com/file/d/0B8-rUzbwVRk0c054eEozWG9COHM/view) [[Baidu Disk]](https://pan.baidu.com/s/1ntIi2Op) 78 | 79 | Preparation: put the images with the same id in one folder. You may use 80 | ```bash 81 | python prepare-market.py # for Market-1501 82 | ``` 83 | Note to modify the dataset path to your own path. 84 | 85 | ### Testing 86 | 87 | #### Download the trained model 88 | We provide our trained model. You may download it from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf). You may download and move it to the `outputs`. 89 | ``` 90 | ├── outputs/ 91 | │ ├── E0.5new_reid0.5_w30000 92 | ├── models 93 | │ ├── best/ 94 | ``` 95 | #### Person re-id evaluation 96 | - Supervised learning 97 | 98 | | | Market-1501 | DukeMTMC-reID | MSMT17 | CUHK03-NP | 99 | |---|--------------|----------------|----------|-----------| 100 | | Rank@1 | 94.8% | 86.6% | 77.2% | 65.6% | 101 | | mAP | 86.0% | 74.8% | 52.3% | 61.1% | 102 | 103 | 104 | - Direct transfer learning 105 | To verify the generalizability of DG-Net, we train the model on dataset A and directly test the model on dataset B (with no adaptation). 106 | We denote the direct transfer learning protocol as `A→B`. 107 | 108 | | |Market→Duke|Duke→Market|Market→MSMT|MSMT→Market|Duke→MSMT|MSMT→Duke| 109 | |---|----------------|----------------| -------------- |----------------| -------------- |----------------| 110 | | Rank@1 | 42.62% | 56.12% | 17.11% | 61.76% | 20.59% | 61.89% | 111 | | Rank@5 | 58.57% | 72.18% | 26.66% | 77.67% | 31.67% | 75.81% | 112 | | Rank@10 | 64.63% | 78.12% | 31.62% | 83.25% | 37.04% | 80.34% | 113 | | mAP | 24.25% | 26.83% | 5.41% | 33.62% | 6.35% | 40.69% | 114 | 115 | 116 | #### Image generation evaluation 117 | 118 | Please check the `README.md` in the `./visual_tools`. 119 | 120 | You may use the `./visual_tools/test_folder.py` to generate lots of images and then do the evaluation. The only thing you need to modify is the data path in [SSIM](https://github.com/layumi/PerceptualSimilarity) and [FID](https://github.com/layumi/TTUR). 121 | 122 | ### Training 123 | 124 | #### Train a teacher model 125 | You may directly download our trained teacher model from [Google Drive](https://drive.google.com/open?id=1lL18FZX1uZMWKzaZOuPe3IuAdfUYyJKH) (or [Baidu Disk](https://pan.baidu.com/s/1503831XfW0y4g3PHir91yw) password: rqvf). 126 | If you want to have it trained by yourself, please check the [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch) repository to train a teacher model, then copy and put it in the `./models`. 127 | ``` 128 | ├── models/ 129 | │ ├── best/ /* teacher model for Market-1501 130 | │ ├── net_last.pth /* model file 131 | │ ├── ... 132 | ``` 133 | 134 | #### Train DG-Net 135 | 1. Setup the yaml file. Check out `configs/latest.yaml`. Change the data_root field to the path of your prepared folder-based dataset, e.g. `../Market-1501/pytorch`. 136 | 137 | 138 | 2. Start training 139 | ``` 140 | python train.py --config configs/latest.yaml 141 | ``` 142 | Or train with low precision (fp16) 143 | ``` 144 | python train.py --config configs/latest-fp16.yaml 145 | ``` 146 | Intermediate image outputs and model binary files are saved in `outputs/latest`. 147 | 148 | 3. Check the loss log 149 | ``` 150 | tensorboard --logdir logs/latest 151 | ``` 152 | 153 | ## DG-Market 154 | ![](https://github.com/layumi/DG-Net/blob/gh-pages/index_files/DGMarket-logo.png) 155 | 156 | We provide our generated images and make a large-scale synthetic dataset called DG-Market. This dataset is generated by our DG-Net and consists of 128,307 images (613MB), about 10 times larger than the training set of original Market-1501 (even much more can be generated with DG-Net). It can be used as a source of unlabeled training dataset for semi-supervised learning. You may download the dataset from [Google Drive](https://drive.google.com/file/d/126Gn90Tzpk3zWp2c7OBYPKc-ZjhptKDo/view?usp=sharing) (or [Baidu Disk](https://pan.baidu.com/s/1n4M6s-qvE08J8SOOWtWfgw) password: qxyh). 157 | 158 | | | DG-Market | Market-1501 (training) | 159 | |---|--------------|-------------| 160 | | #identity| - | 751 | 161 | | #images| 128,307 | 12,936 | 162 | 163 | Quick Download via [gdrive](https://github.com/prasmussen/gdrive) 164 | ```bash 165 | wget https://github.com/prasmussen/gdrive/releases/download/2.1.1/gdrive_2.1.1_linux_386.tar.gz 166 | tar -xzvf gdrive_2.1.1_linux_386.tar.gz 167 | gdrive download 126Gn90Tzpk3zWp2c7OBYPKc-ZjhptKDo 168 | unzip DG-Market.zip 169 | ``` 170 | 171 | ## Tips 172 | Note the format of camera id and number of cameras. For some datasets (e.g., MSMT17), there are more than 10 cameras. You need to modify the preparation and evaluation code to read the double-digit camera id. For some vehicle re-id datasets (e.g., VeRi) having different naming rules, you also need to modify the preparation and evaluation code. 173 | 174 | ## Citation 175 | Please cite this paper if it helps your research: 176 | ```bibtex 177 | @inproceedings{zheng2019joint, 178 | title={Joint discriminative and generative learning for person re-identification}, 179 | author={Zheng, Zhedong and Yang, Xiaodong and Yu, Zhiding and Zheng, Liang and Yang, Yi and Kautz, Jan}, 180 | booktitle={IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 181 | year={2019} 182 | } 183 | ``` 184 | 185 | ## Related Work 186 | Other GAN-based methods compared in the paper include [LSGAN](https://github.com/layumi/DCGAN-pytorch), [FDGAN](https://github.com/layumi/FD-GAN) and [PG2GAN](https://github.com/charliememory/Pose-Guided-Person-Image-Generation). We forked the code and made some changes for evaluatation, thank the authors for their great work. We would also like to thank to the great projects in [person re-id baseline](https://github.com/layumi/Person_reID_baseline_pytorch), [MUNIT](https://github.com/NVlabs/MUNIT) and [DRIT](https://github.com/HsinYingLee/DRIT). 187 | 188 | ## License 189 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. Licensed under the [CC BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode) (**Attribution-NonCommercial-ShareAlike 4.0 International**). The code is released for academic research use only. For commercial use, please contact [researchinquiries@nvidia.com](researchinquiries@nvidia.com). 190 | -------------------------------------------------------------------------------- /configs/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/configs/.DS_Store -------------------------------------------------------------------------------- /configs/latest-fp16.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | apex: true # Set True to use float16. 5 | B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the paper. 6 | ID_class: 751 # The number of ID classes in the dataset. For example, 751 for Market, 702 for DukeMTMC 7 | ID_stride: 1 # Stride in Appearance encoder 8 | ID_style: AB # For time being, we only support AB. In the future, we will support PCB. 9 | batch_size: 8 # BatchSize 10 | beta1: 0 # Adam hyperparameter 11 | beta2: 0.999 # Adam hyperparameter 12 | crop_image_height: 256 # Input height 13 | crop_image_width: 128 # Input width 14 | data_root: ../Market/pytorch/ # Dataset Root 15 | dis: 16 | LAMBDA: 0.01 # the hyperparameter for the regularization term 17 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] 18 | dim: 32 # number of filters in the bottommost layer 19 | gan_type: lsgan # GAN loss [lsgan/nsgan] 20 | n_layer: 2 # number of layers in D 21 | n_res: 4 # number of layers in D 22 | non_local: 0 # number of non_local layers 23 | norm: none # normalization layer [none/bn/in/ln] 24 | num_scales: 3 # number of scales 25 | pad_type: reflect # padding type [zero/reflect] 26 | display_size: 16 # How much display images 27 | erasing_p: 0.5 # Random erasing probability [0-1] 28 | gamma: 0.1 # Learning Rate Decay (except appearance encoder) 29 | gamma2: 0.1 # Learning Rate Decay (for appearance encoder) 30 | gan_w: 1 # the weight of gan loss 31 | gen: 32 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] 33 | dec: basic # [basic/parallel/series] 34 | dim: 16 # number of filters in the bottommost layer 35 | dropout: 0 # use dropout in the generator 36 | id_dim: 2048 # length of appearance code 37 | mlp_dim: 512 # number of filters in MLP 38 | mlp_norm: none # norm in mlp [none/bn/in/ln] 39 | n_downsample: 2 # number of downsampling layers in content encoder 40 | n_res: 4 # number of residual blocks in content encoder/decoder 41 | non_local: 0 # number of non_local layer 42 | pad_type: reflect # padding type [zero/reflect] 43 | tanh: false # use tanh or not at the last layer 44 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 45 | id_w: 1.0 # the weight of ID loss 46 | image_display_iter: 5000 # How often do you want to display output images during training 47 | image_save_iter: 5000 # How often do you want to save output images during training 48 | input_dim_a: 1 # We use the gray-scale input, so the input dim is 1 49 | input_dim_b: 1 # We use the gray-scale input, so the input dim is 1 50 | log_iter: 1 # How often do you want to log the training stats 51 | lr2: 0.002 # initial appearance encoder learning rate 52 | lr_d: 0.0001 # initial discriminator learning rate 53 | lr_g: 0.0001 # initial generator (except appearance encoder) learning rate 54 | lr_policy: multistep # learning rate scheduler [multistep|constant|step] 55 | max_cyc_w: 2 # the maximum weight for cycle loss 56 | max_iter: 100000 # When you end the training 57 | max_teacher_w: 2 # the maximum weight for prime loss (teacher KL loss) 58 | max_w: 1 # the maximum weight for feature reconstruction losses 59 | new_size: 128 # the resized size 60 | norm_id: false # Do we normalize the appearance code 61 | num_workers: 8 # nworks to load the data 62 | pid_w: 1.0 # positive ID loss 63 | pool: max # pooling layer for the appearance encoder 64 | recon_s_w: 0 # the initial weight for structure code reconstruction 65 | recon_f_w: 0 # the initial weight for appearance code reconstruction 66 | recon_id_w: 0.5 # the initial weight for ID reconstruction 67 | recon_x_cyc_w: 0 # the initial weight for cycle reconstruction 68 | recon_x_w: 5 # the initial weight for self-reconstruction 69 | recon_xp_w: 5 # the initial weight for self-identity reconstruction 70 | single: gray # make input to gray-scale 71 | snapshot_save_iter: 10000 # How often to save the checkpoint 72 | sqrt: false # whether use square loss. 73 | step_size: 60000 # when to decay the learning rate 74 | teacher: best # teacher model name. For DukeMTMC, you may set `best-duke` 75 | teacher_w: 0 # the initial weight for prime loss (teacher KL loss) 76 | teacher_style: 0 # select teacher style.[0-4] # 0: Our smooth dynamic label# 1: Pseudo label, hard dynamic label# 2: Conditional label, hard static label # 3: LSRO, static smooth label# 4: Dynamic Soft Two-label 77 | train_bn: true # whether we train the bn for the generated image. 78 | use_decoder_again: true # whether we train the decoder on the generatd image. 79 | use_encoder_again: 0.5 # the probability we train the structure encoder on the generatd image. 80 | vgg_w: 0 # We do not use vgg as one kind of inception loss. 81 | warm_iter: 30000 # when to start warm up the losses (fine-grained/feature reconstruction losses). 82 | warm_scale: 0.0005 # how fast to warm up 83 | warm_teacher_iter: 30000 # when to start warm up the prime loss 84 | weight_decay: 0.0005 # weight decay 85 | -------------------------------------------------------------------------------- /configs/latest.yaml: -------------------------------------------------------------------------------- 1 | # Copyright (C) 2017 NVIDIA Corporation. All rights reserved. 2 | # Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 3 | 4 | apex: false # Set True to use float16. 5 | B_w: 0.2 # The loss weight of fine-grained loss, which is named as `alpha` in the paper. 6 | ID_class: 751 # The number of ID classes in the dataset. For example, 751 for Market, 702 for DukeMTMC 7 | ID_stride: 1 # Stride in Appearance encoder 8 | ID_style: AB # For time being, we only support AB. In the future, we will support PCB. 9 | batch_size: 8 # BatchSize 10 | beta1: 0 # Adam hyperparameter 11 | beta2: 0.999 # Adam hyperparameter 12 | crop_image_height: 256 # Input height 13 | crop_image_width: 128 # Input width 14 | data_root: ../Market/pytorch/ # Dataset Root 15 | dis: 16 | LAMBDA: 0.01 # the hyperparameter for the regularization term 17 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] 18 | dim: 32 # number of filters in the bottommost layer 19 | gan_type: lsgan # GAN loss [lsgan/nsgan] 20 | n_layer: 2 # number of layers in D 21 | n_res: 4 # number of layers in D 22 | non_local: 0 # number of non_local layers 23 | norm: none # normalization layer [none/bn/in/ln] 24 | num_scales: 3 # number of scales 25 | pad_type: reflect # padding type [zero/reflect] 26 | display_size: 16 # How much display images 27 | erasing_p: 0.5 # Random erasing probability [0-1] 28 | gamma: 0.1 # Learning Rate Decay (except appearance encoder) 29 | gamma2: 0.1 # Learning Rate Decay (for appearance encoder) 30 | gan_w: 1 # the weight of gan loss 31 | gen: 32 | activ: lrelu # activation function style [relu/lrelu/prelu/selu/tanh] 33 | dec: basic # [basic/parallel/series] 34 | dim: 16 # number of filters in the bottommost layer 35 | dropout: 0 # use dropout in the generator 36 | id_dim: 2048 # length of appearance code 37 | mlp_dim: 512 # number of filters in MLP 38 | mlp_norm: none # norm in mlp [none/bn/in/ln] 39 | n_downsample: 2 # number of downsampling layers in content encoder 40 | n_res: 4 # number of residual blocks in content encoder/decoder 41 | non_local: 0 # number of non_local layer 42 | pad_type: reflect # padding type [zero/reflect] 43 | tanh: false # use tanh or not at the last layer 44 | init: kaiming # initialization [gaussian/kaiming/xavier/orthogonal] 45 | id_w: 1.0 # the weight of ID loss 46 | image_display_iter: 5000 # How often do you want to display output images during training 47 | image_save_iter: 5000 # How often do you want to save output images during training 48 | input_dim_a: 1 # We use the gray-scale input, so the input dim is 1 49 | input_dim_b: 1 # We use the gray-scale input, so the input dim is 1 50 | log_iter: 1 # How often do you want to log the training stats 51 | lr2: 0.002 # initial appearance encoder learning rate 52 | lr_d: 0.0001 # initial discriminator learning rate 53 | lr_g: 0.0001 # initial generator (except appearance encoder) learning rate 54 | lr_policy: multistep # learning rate scheduler [multistep|constant|step] 55 | max_cyc_w: 2 # the maximum weight for cycle loss 56 | max_iter: 100000 # When you end the training 57 | max_teacher_w: 2 # the maximum weight for prime loss (teacher KL loss) 58 | max_w: 1 # the maximum weight for feature reconstruction losses 59 | new_size: 128 # the resized size 60 | norm_id: false # Do we normalize the appearance code 61 | num_workers: 8 # nworks to load the data 62 | pid_w: 1.0 # positive ID loss 63 | pool: max # pooling layer for the appearance encoder 64 | recon_s_w: 0 # the initial weight for structure code reconstruction 65 | recon_f_w: 0 # the initial weight for appearance code reconstruction 66 | recon_id_w: 0.5 # the initial weight for ID reconstruction 67 | recon_x_cyc_w: 0 # the initial weight for cycle reconstruction 68 | recon_x_w: 5 # the initial weight for self-reconstruction 69 | recon_xp_w: 5 # the initial weight for self-identity reconstruction 70 | single: gray # make input to gray-scale 71 | snapshot_save_iter: 10000 # How often to save the checkpoint 72 | sqrt: false # whether use square loss. 73 | step_size: 60000 # when to decay the learning rate 74 | teacher: best # teacher model name. For DukeMTMC, you may set `best-duke` 75 | teacher_w: 0 # the initial weight for prime loss (teacher KL loss) 76 | teacher_style: 0 # select teacher style.[0-4] # 0: Our smooth dynamic label# 1: Pseudo label, hard dynamic label# 2: Conditional label, hard static label # 3: LSRO, static smooth label# 4: Dynamic Soft Two-label 77 | train_bn: true # whether we train the bn for the generated image. 78 | use_decoder_again: true # whether we train the decoder on the generatd image. 79 | use_encoder_again: 0.5 # the probability we train the structure encoder on the generatd image. 80 | vgg_w: 0 # We do not use vgg as one kind of inception loss. 81 | warm_iter: 30000 # when to start warm up the losses (fine-grained/feature reconstruction losses). 82 | warm_scale: 0.0005 # how fast to warm up 83 | warm_teacher_iter: 30000 # when to start warm up the prime loss 84 | weight_decay: 0.0005 # weight decay 85 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | import torch.utils.data as data 6 | import os.path 7 | 8 | def default_loader(path): 9 | return Image.open(path).convert('RGB') 10 | 11 | 12 | def default_flist_reader(flist): 13 | """ 14 | flist format: impath label\nimpath label\n ...(same to caffe's filelist) 15 | """ 16 | imlist = [] 17 | with open(flist, 'r') as rf: 18 | for line in rf.readlines(): 19 | impath = line.strip() 20 | imlist.append(impath) 21 | 22 | return imlist 23 | 24 | 25 | class ImageFilelist(data.Dataset): 26 | def __init__(self, root, flist, transform=None, 27 | flist_reader=default_flist_reader, loader=default_loader): 28 | self.root = root 29 | self.imlist = flist_reader(flist) 30 | self.transform = transform 31 | self.loader = loader 32 | 33 | def __getitem__(self, index): 34 | impath = self.imlist[index] 35 | img = self.loader(os.path.join(self.root, impath)) 36 | if self.transform is not None: 37 | img = self.transform(img) 38 | 39 | return img 40 | 41 | def __len__(self): 42 | return len(self.imlist) 43 | 44 | 45 | class ImageLabelFilelist(data.Dataset): 46 | def __init__(self, root, flist, transform=None, 47 | flist_reader=default_flist_reader, loader=default_loader): 48 | self.root = root 49 | self.imlist = flist_reader(os.path.join(self.root, flist)) 50 | self.transform = transform 51 | self.loader = loader 52 | self.classes = sorted(list(set([path.split('/')[0] for path in self.imlist]))) 53 | self.class_to_idx = {self.classes[i]: i for i in range(len(self.classes))} 54 | self.imgs = [(impath, self.class_to_idx[impath.split('/')[0]]) for impath in self.imlist] 55 | 56 | def __getitem__(self, index): 57 | impath, label = self.imgs[index] 58 | img = self.loader(os.path.join(self.root, impath)) 59 | if self.transform is not None: 60 | img = self.transform(img) 61 | return img, label 62 | 63 | def __len__(self): 64 | return len(self.imgs) 65 | 66 | ############################################################################### 67 | # Code from 68 | # https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py 69 | # Modified the original code so that it also loads images from the current 70 | # directory as well as the subdirectories 71 | ############################################################################### 72 | 73 | import torch.utils.data as data 74 | 75 | from PIL import Image 76 | import os 77 | import os.path 78 | 79 | IMG_EXTENSIONS = [ 80 | '.jpg', '.JPG', '.jpeg', '.JPEG', 81 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 82 | ] 83 | 84 | 85 | def is_image_file(filename): 86 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 87 | 88 | 89 | def make_dataset(dir): 90 | images = [] 91 | assert os.path.isdir(dir), '%s is not a valid directory' % dir 92 | 93 | for root, _, fnames in sorted(os.walk(dir)): 94 | for fname in fnames: 95 | if is_image_file(fname): 96 | path = os.path.join(root, fname) 97 | images.append(path) 98 | 99 | return images 100 | 101 | 102 | class ImageFolder(data.Dataset): 103 | 104 | def __init__(self, root, transform=None, return_paths=False, 105 | loader=default_loader): 106 | imgs = sorted(make_dataset(root)) 107 | if len(imgs) == 0: 108 | raise(RuntimeError("Found 0 images in: " + root + "\n" 109 | "Supported image extensions are: " + 110 | ",".join(IMG_EXTENSIONS))) 111 | 112 | self.root = root 113 | self.imgs = imgs 114 | self.transform = transform 115 | self.return_paths = return_paths 116 | self.loader = loader 117 | 118 | def __getitem__(self, index): 119 | path = self.imgs[index] 120 | img = self.loader(path) 121 | if self.transform is not None: 122 | img = self.transform(img) 123 | if self.return_paths: 124 | return img, path 125 | else: 126 | return img 127 | 128 | def __len__(self): 129 | return len(self.imgs) 130 | -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /outputs/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /prepare-market.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import os 7 | from shutil import copyfile 8 | 9 | # You only need to change this line to your dataset download path 10 | download_path = '../Market' 11 | 12 | if not os.path.isdir(download_path): 13 | print('please change the download_path') 14 | 15 | save_path = download_path + '/pytorch' 16 | if not os.path.isdir(save_path): 17 | os.mkdir(save_path) 18 | #----------------------------------------- 19 | #query 20 | query_path = download_path + '/query' 21 | query_save_path = download_path + '/pytorch/query' 22 | if not os.path.isdir(query_save_path): 23 | os.mkdir(query_save_path) 24 | 25 | for root, dirs, files in os.walk(query_path, topdown=True): 26 | for name in files: 27 | if not name[-3:]=='jpg': 28 | continue 29 | ID = name.split('_') 30 | src_path = query_path + '/' + name 31 | dst_path = query_save_path + '/' + ID[0] 32 | if not os.path.isdir(dst_path): 33 | os.mkdir(dst_path) 34 | copyfile(src_path, dst_path + '/' + name) 35 | 36 | #----------------------------------------- 37 | #multi-query 38 | query_path = download_path + '/gt_bbox' 39 | # for dukemtmc-reid, we do not need multi-query 40 | if os.path.isdir(query_path): 41 | query_save_path = download_path + '/pytorch/multi-query' 42 | if not os.path.isdir(query_save_path): 43 | os.mkdir(query_save_path) 44 | 45 | for root, dirs, files in os.walk(query_path, topdown=True): 46 | for name in files: 47 | if not name[-3:]=='jpg': 48 | continue 49 | ID = name.split('_') 50 | src_path = query_path + '/' + name 51 | dst_path = query_save_path + '/' + ID[0] 52 | if not os.path.isdir(dst_path): 53 | os.mkdir(dst_path) 54 | copyfile(src_path, dst_path + '/' + name) 55 | 56 | #----------------------------------------- 57 | #gallery 58 | gallery_path = download_path + '/bounding_box_test' 59 | gallery_save_path = download_path + '/pytorch/gallery' 60 | if not os.path.isdir(gallery_save_path): 61 | os.mkdir(gallery_save_path) 62 | 63 | for root, dirs, files in os.walk(gallery_path, topdown=True): 64 | for name in files: 65 | if not name[-3:]=='jpg': 66 | continue 67 | ID = name.split('_') 68 | src_path = gallery_path + '/' + name 69 | dst_path = gallery_save_path + '/' + ID[0] 70 | if not os.path.isdir(dst_path): 71 | os.mkdir(dst_path) 72 | copyfile(src_path, dst_path + '/' + name) 73 | 74 | #--------------------------------------- 75 | #train_all 76 | train_path = download_path + '/bounding_box_train' 77 | train_save_path = download_path + '/pytorch/train_all' 78 | if not os.path.isdir(train_save_path): 79 | os.mkdir(train_save_path) 80 | 81 | for root, dirs, files in os.walk(train_path, topdown=True): 82 | for name in files: 83 | if not name[-3:]=='jpg': 84 | continue 85 | ID = name.split('_') 86 | src_path = train_path + '/' + name 87 | dst_path = train_save_path + '/' + ID[0] 88 | if not os.path.isdir(dst_path): 89 | os.mkdir(dst_path) 90 | copyfile(src_path, dst_path + '/' + name) 91 | 92 | 93 | #--------------------------------------- 94 | #train_val 95 | train_path = download_path + '/bounding_box_train' 96 | train_save_path = download_path + '/pytorch/train' 97 | val_save_path = download_path + '/pytorch/val' 98 | if not os.path.isdir(train_save_path): 99 | os.mkdir(train_save_path) 100 | os.mkdir(val_save_path) 101 | 102 | for root, dirs, files in os.walk(train_path, topdown=True): 103 | for name in files: 104 | if not name[-3:]=='jpg': 105 | continue 106 | ID = name.split('_') 107 | src_path = train_path + '/' + name 108 | dst_path = train_save_path + '/' + ID[0] 109 | if not os.path.isdir(dst_path): 110 | os.mkdir(dst_path) 111 | dst_path = val_save_path + '/' + ID[0] #first image is used as val image 112 | os.mkdir(dst_path) 113 | copyfile(src_path, dst_path + '/' + name) 114 | -------------------------------------------------------------------------------- /random_erasing.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from __future__ import absolute_import 7 | 8 | from torchvision.transforms import * 9 | 10 | import random 11 | import math 12 | 13 | class RandomErasing(object): 14 | """ Randomly selects a rectangle region in an image and erases its pixels. 15 | 'Random Erasing Data Augmentation' by Zhong et al. 16 | See https://arxiv.org/pdf/1708.04896.pdf 17 | Args: 18 | probability: The probability that the Random Erasing operation will be performed. 19 | sl: Minimum proportion of erased area against input image. 20 | sh: Maximum proportion of erased area against input image. 21 | r1: Minimum aspect ratio of erased area. 22 | mean: Erasing value. 23 | """ 24 | 25 | def __init__(self, probability = 0.5, sl = 0.02, sh = 0.4, r1 = 0.3, mean=[0.4914, 0.4822, 0.4465]): 26 | self.probability = probability 27 | self.mean = mean 28 | self.sl = sl 29 | self.sh = sh 30 | self.r1 = r1 31 | random.seed(7) 32 | 33 | def __call__(self, img): 34 | 35 | if random.uniform(0, 1) > self.probability: 36 | return img 37 | 38 | for attempt in range(100): 39 | area = img.size()[1] * img.size()[2] 40 | 41 | target_area = random.uniform(self.sl, self.sh) * area 42 | aspect_ratio = random.uniform(self.r1, 1/self.r1) 43 | 44 | h = int(round(math.sqrt(target_area * aspect_ratio))) 45 | w = int(round(math.sqrt(target_area / aspect_ratio))) 46 | 47 | if w < img.size()[2] and h < img.size()[1]: 48 | x1 = random.randint(0, img.size()[1] - h) 49 | y1 = random.randint(0, img.size()[2] - w) 50 | if img.size()[0] == 3: 51 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 52 | img[1, x1:x1+h, y1:y1+w] = self.mean[1] 53 | img[2, x1:x1+h, y1:y1+w] = self.mean[2] 54 | else: 55 | img[0, x1:x1+h, y1:y1+w] = self.mean[0] 56 | return img.detach() 57 | 58 | return img.detach() 59 | -------------------------------------------------------------------------------- /reIDfolder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from torchvision import datasets 7 | import os 8 | import numpy as np 9 | import random 10 | 11 | class ReIDFolder(datasets.ImageFolder): 12 | 13 | def __init__(self, root, transform): 14 | super(ReIDFolder, self).__init__(root, transform) 15 | targets = np.asarray([s[1] for s in self.samples]) 16 | self.targets = targets 17 | self.img_num = len(self.samples) 18 | print(self.img_num) 19 | 20 | def _get_cam_id(self, path): 21 | camera_id = [] 22 | filename = os.path.basename(path) 23 | camera_id = filename.split('c')[1][0] 24 | return int(camera_id)-1 25 | 26 | def _get_pos_sample(self, target, index, path): 27 | pos_index = np.argwhere(self.targets == target) 28 | pos_index = pos_index.flatten() 29 | pos_index = np.setdiff1d(pos_index, index) 30 | if len(pos_index)==0: # in the query set, only one sample 31 | return path 32 | else: 33 | rand = random.randint(0,len(pos_index)-1) 34 | return self.samples[pos_index[rand]][0] 35 | 36 | def _get_neg_sample(self, target): 37 | neg_index = np.argwhere(self.targets != target) 38 | neg_index = neg_index.flatten() 39 | rand = random.randint(0,len(neg_index)-1) 40 | return self.samples[neg_index[rand]] 41 | 42 | def __getitem__(self, index): 43 | path, target = self.samples[index] 44 | sample = self.loader(path) 45 | 46 | pos_path = self._get_pos_sample(target, index, path) 47 | pos = self.loader(pos_path) 48 | 49 | if self.transform is not None: 50 | sample = self.transform(sample) 51 | pos = self.transform(pos) 52 | 53 | if self.target_transform is not None: 54 | target = self.target_transform(target) 55 | 56 | return sample, target, pos 57 | 58 | -------------------------------------------------------------------------------- /reIDmodel.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | from torch.nn import init 9 | from torchvision import models 10 | 11 | ###################################################################### 12 | def weights_init_kaiming(m): 13 | classname = m.__class__.__name__ 14 | if classname.find('Conv') != -1: 15 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 16 | elif classname.find('Linear') != -1: 17 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_out') 18 | init.constant_(m.bias.data, 0.0) 19 | elif classname.find('InstanceNorm1d') != -1: 20 | init.normal_(m.weight.data, 1.0, 0.02) 21 | init.constant_(m.bias.data, 0.0) 22 | 23 | def weights_init_classifier(m): 24 | classname = m.__class__.__name__ 25 | if classname.find('Linear') != -1: 26 | init.normal_(m.weight.data, std=0.001) 27 | init.constant_(m.bias.data, 0.0) 28 | 29 | def fix_bn(m): 30 | classname = m.__class__.__name__ 31 | if classname.find('BatchNorm') != -1: 32 | m.eval() 33 | 34 | # Defines the new fc layer and classification layer 35 | # |--Linear--|--bn--|--relu--|--Linear--| 36 | class ClassBlock(nn.Module): 37 | def __init__(self, input_dim, class_num, droprate=0.5, relu=False, num_bottleneck=512): 38 | super(ClassBlock, self).__init__() 39 | add_block = [] 40 | add_block += [nn.Linear(input_dim, num_bottleneck)] 41 | #num_bottleneck = input_dim # We remove the input_dim 42 | add_block += [nn.BatchNorm1d(num_bottleneck, affine=True)] 43 | if relu: 44 | add_block += [nn.LeakyReLU(0.1)] 45 | if droprate>0: 46 | add_block += [nn.Dropout(p=droprate)] 47 | add_block = nn.Sequential(*add_block) 48 | add_block.apply(weights_init_kaiming) 49 | 50 | classifier = [] 51 | classifier += [nn.Linear(num_bottleneck, class_num)] 52 | classifier = nn.Sequential(*classifier) 53 | classifier.apply(weights_init_classifier) 54 | 55 | self.add_block = add_block 56 | self.classifier = classifier 57 | def forward(self, x): 58 | x = self.add_block(x) 59 | x = self.classifier(x) 60 | return x 61 | 62 | # Define the ResNet50-based Model 63 | class ft_net(nn.Module): 64 | 65 | def __init__(self, class_num, norm=False, pool='avg', stride=2): 66 | super(ft_net, self).__init__() 67 | if norm: 68 | self.norm = True 69 | else: 70 | self.norm = False 71 | model_ft = models.resnet50(pretrained=True) 72 | # avg pooling to global pooling 73 | self.part = 4 74 | if pool=='max': 75 | model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1)) 76 | model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1)) 77 | else: 78 | model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1)) 79 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 80 | # remove the final downsample 81 | if stride == 1: 82 | model_ft.layer4[0].downsample[0].stride = (1,1) 83 | model_ft.layer4[0].conv2.stride = (1,1) 84 | 85 | self.model = model_ft 86 | self.classifier = ClassBlock(2048, class_num) 87 | 88 | def forward(self, x): 89 | x = self.model.conv1(x) 90 | x = self.model.bn1(x) 91 | x = self.model.relu(x) 92 | x = self.model.maxpool(x) 93 | x = self.model.layer1(x) 94 | x = self.model.layer2(x) # -> 512 32*16 95 | x = self.model.layer3(x) 96 | x = self.model.layer4(x) 97 | f = self.model.partpool(x) # 8 * 2048 4*1 98 | x = self.model.avgpool(x) # 8 * 2048 1*1 99 | 100 | x = x.view(x.size(0),x.size(1)) 101 | f = f.view(f.size(0),f.size(1)*self.part) 102 | if self.norm: 103 | fnorm = torch.norm(f, p=2, dim=1, keepdim=True) + 1e-8 104 | f = f.div(fnorm.expand_as(f)) 105 | x = self.classifier(x) 106 | return f, x 107 | 108 | # Define the AB Model 109 | class ft_netAB(nn.Module): 110 | 111 | def __init__(self, class_num, norm=False, stride=2, droprate=0.5, pool='avg'): 112 | super(ft_netAB, self).__init__() 113 | model_ft = models.resnet50(pretrained=True) 114 | self.part = 4 115 | if pool=='max': 116 | model_ft.partpool = nn.AdaptiveMaxPool2d((self.part,1)) 117 | model_ft.avgpool = nn.AdaptiveMaxPool2d((1,1)) 118 | else: 119 | model_ft.partpool = nn.AdaptiveAvgPool2d((self.part,1)) 120 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 121 | 122 | self.model = model_ft 123 | 124 | if stride == 1: 125 | self.model.layer4[0].downsample[0].stride = (1,1) 126 | self.model.layer4[0].conv2.stride = (1,1) 127 | 128 | self.classifier1 = ClassBlock(2048, class_num, 0.5) 129 | self.classifier2 = ClassBlock(2048, class_num, 0.75) 130 | 131 | def forward(self, x): 132 | x = self.model.conv1(x) 133 | x = self.model.bn1(x) 134 | x = self.model.relu(x) 135 | x = self.model.maxpool(x) 136 | x = self.model.layer1(x) 137 | x = self.model.layer2(x) 138 | x = self.model.layer3(x) 139 | x = self.model.layer4(x) 140 | f = self.model.partpool(x) 141 | f = f.view(f.size(0),f.size(1)*self.part) 142 | f = f.detach() # no gradient 143 | x = self.model.avgpool(x) 144 | x = x.view(x.size(0), x.size(1)) 145 | x1 = self.classifier1(x) 146 | x2 = self.classifier2(x) 147 | x=[] 148 | x.append(x1) 149 | x.append(x2) 150 | return f, x 151 | 152 | # Define the DenseNet121-based Model 153 | class ft_net_dense(nn.Module): 154 | 155 | def __init__(self, class_num ): 156 | super().__init__() 157 | model_ft = models.densenet121(pretrained=True) 158 | model_ft.features.avgpool = nn.AdaptiveAvgPool2d((1,1)) 159 | model_ft.fc = nn.Sequential() 160 | self.model = model_ft 161 | # For DenseNet, the feature dim is 1024 162 | self.classifier = ClassBlock(1024, class_num) 163 | 164 | def forward(self, x): 165 | x = self.model.features(x) 166 | x = torch.squeeze(x) 167 | x = self.classifier(x) 168 | return x 169 | 170 | # Define the ResNet50-based Model (Middle-Concat) 171 | # In the spirit of "The Devil is in the Middle: Exploiting Mid-level Representations for Cross-Domain Instance Matching." Yu, Qian, et al. arXiv:1711.08106 (2017). 172 | class ft_net_middle(nn.Module): 173 | 174 | def __init__(self, class_num ): 175 | super(ft_net_middle, self).__init__() 176 | model_ft = models.resnet50(pretrained=True) 177 | # avg pooling to global pooling 178 | model_ft.avgpool = nn.AdaptiveAvgPool2d((1,1)) 179 | self.model = model_ft 180 | self.classifier = ClassBlock(2048+1024, class_num) 181 | 182 | def forward(self, x): 183 | x = self.model.conv1(x) 184 | x = self.model.bn1(x) 185 | x = self.model.relu(x) 186 | x = self.model.maxpool(x) 187 | x = self.model.layer1(x) 188 | x = self.model.layer2(x) 189 | x = self.model.layer3(x) 190 | # x0 n*1024*1*1 191 | x0 = self.model.avgpool(x) 192 | x = self.model.layer4(x) 193 | # x1 n*2048*1*1 194 | x1 = self.model.avgpool(x) 195 | x = torch.cat((x0,x1),1) 196 | x = torch.squeeze(x) 197 | x = self.classifier(x) 198 | return x 199 | 200 | # Part Model proposed in Yifan Sun etal. (2018) 201 | class PCB(nn.Module): 202 | def __init__(self, class_num ): 203 | super(PCB, self).__init__() 204 | 205 | self.part = 4 # We cut the pool5 to 4 parts 206 | model_ft = models.resnet50(pretrained=True) 207 | self.model = model_ft 208 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 209 | self.dropout = nn.Dropout(p=0.5) 210 | # remove the final downsample 211 | self.model.layer4[0].downsample[0].stride = (1,1) 212 | self.model.layer4[0].conv2.stride = (1,1) 213 | self.softmax = nn.Softmax(dim=1) 214 | # define 4 classifiers 215 | for i in range(self.part): 216 | name = 'classifier'+str(i) 217 | setattr(self, name, ClassBlock(2048, class_num, True, False, 256)) 218 | 219 | def forward(self, x): 220 | x = self.model.conv1(x) 221 | x = self.model.bn1(x) 222 | x = self.model.relu(x) 223 | x = self.model.maxpool(x) 224 | 225 | x = self.model.layer1(x) 226 | x = self.model.layer2(x) 227 | x = self.model.layer3(x) 228 | x = self.model.layer4(x) 229 | x = self.avgpool(x) 230 | f = x 231 | f = f.view(f.size(0),f.size(1)*self.part) 232 | x = self.dropout(x) 233 | part = {} 234 | predict = {} 235 | # get part feature batchsize*2048*4 236 | for i in range(self.part): 237 | part[i] = x[:,:,i].contiguous() 238 | part[i] = part[i].view(x.size(0), x.size(1)) 239 | name = 'classifier'+str(i) 240 | c = getattr(self,name) 241 | predict[i] = c(part[i]) 242 | 243 | y=[] 244 | for i in range(self.part): 245 | y.append(predict[i]) 246 | 247 | return f, y 248 | 249 | class PCB_test(nn.Module): 250 | def __init__(self,model): 251 | super(PCB_test,self).__init__() 252 | self.part = 6 253 | self.model = model.model 254 | self.avgpool = nn.AdaptiveAvgPool2d((self.part,1)) 255 | # remove the final downsample 256 | self.model.layer3[0].downsample[0].stride = (1,1) 257 | self.model.layer3[0].conv2.stride = (1,1) 258 | 259 | self.model.layer4[0].downsample[0].stride = (1,1) 260 | self.model.layer4[0].conv2.stride = (1,1) 261 | 262 | def forward(self, x): 263 | x = self.model.conv1(x) 264 | x = self.model.bn1(x) 265 | x = self.model.relu(x) 266 | x = self.model.maxpool(x) 267 | 268 | x = self.model.layer1(x) 269 | x = self.model.layer2(x) 270 | x = self.model.layer3(x) 271 | x = self.model.layer4(x) 272 | x = self.avgpool(x) 273 | y = x.view(x.size(0),x.size(1),x.size(2)) 274 | return y 275 | 276 | 277 | -------------------------------------------------------------------------------- /reid_eval/README.md: -------------------------------------------------------------------------------- 1 | ## Evaluation 2 | 3 | - For Market-1501 4 | ```bash 5 | python test_2label.py --name E0.5new_reid0.5_w30000 --which_epoch 100000 --multi 6 | ``` 7 | The result is `Rank@1:0.9477 Rank@5:0.9798 Rank@10:0.9878 mAP:0.8609`. 8 | `--name` model name 9 | 10 | `--which_epoch` selects the i-th model 11 | 12 | `--multi` extracts and evaluates the multiply query. The result is `multi Rank@1:0.9608 Rank@5:0.9860 Rank@10:0.9923 mAP:0.9044`. 13 | -------------------------------------------------------------------------------- /reid_eval/evaluate_gpu.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 4 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 5 | """ 6 | 7 | import scipy.io 8 | import torch 9 | import numpy as np 10 | import os 11 | import matplotlib 12 | matplotlib.use('agg') 13 | import matplotlib.pyplot as plt 14 | ####################################################################### 15 | # Evaluate 16 | 17 | def evaluate(qf,ql,qc,gf,gl,gc): 18 | query = qf.view(-1,1) 19 | score = torch.mm(gf,query) 20 | score = score.squeeze(1).cpu() 21 | score = score.numpy() 22 | # predict index 23 | index = np.argsort(score) #from small to large 24 | index = index[::-1] 25 | # good index 26 | query_index = np.argwhere(gl==ql) 27 | #same camera 28 | camera_index = np.argwhere(gc==qc) 29 | 30 | good_index = np.setdiff1d(query_index, camera_index, assume_unique=True) 31 | junk_index1 = np.argwhere(gl==-1) 32 | junk_index2 = np.intersect1d(query_index, camera_index) 33 | junk_index = np.append(junk_index2, junk_index1) #.flatten()) 34 | 35 | CMC_tmp = compute_mAP(index, qc, good_index, junk_index) 36 | return CMC_tmp 37 | 38 | 39 | def compute_mAP(index, qc, good_index, junk_index): 40 | ap = 0 41 | cmc = torch.IntTensor(len(index)).zero_() 42 | if good_index.size==0: # if empty 43 | cmc[0] = -1 44 | return ap,cmc 45 | 46 | # remove junk_index 47 | ranked_camera = gallery_cam[index] 48 | mask = np.in1d(index, junk_index, invert=True) 49 | mask2 = np.in1d(index, np.append(good_index,junk_index), invert=True) 50 | index = index[mask] 51 | ranked_camera = ranked_camera[mask] 52 | 53 | # find good_index index 54 | ngood = len(good_index) 55 | mask = np.in1d(index, good_index) 56 | rows_good = np.argwhere(mask==True) 57 | rows_good = rows_good.flatten() 58 | 59 | cmc[rows_good[0]:] = 1 60 | for i in range(ngood): 61 | d_recall = 1.0/ngood 62 | precision = (i+1)*1.0/(rows_good[i]+1) 63 | if rows_good[i]!=0: 64 | old_precision = i*1.0/rows_good[i] 65 | else: 66 | old_precision=1.0 67 | ap = ap + d_recall*(old_precision + precision)/2 68 | 69 | return ap, cmc 70 | 71 | ###################################################################### 72 | result = scipy.io.loadmat('pytorch_result.mat') 73 | query_feature = torch.FloatTensor(result['query_f']) 74 | query_cam = result['query_cam'][0] 75 | query_label = result['query_label'][0] 76 | gallery_feature = torch.FloatTensor(result['gallery_f']) 77 | gallery_cam = result['gallery_cam'][0] 78 | gallery_label = result['gallery_label'][0] 79 | 80 | multi = os.path.isfile('multi_query.mat') 81 | 82 | if multi: 83 | m_result = scipy.io.loadmat('multi_query.mat') 84 | mquery_feature = torch.FloatTensor(m_result['mquery_f']) 85 | mquery_cam = m_result['mquery_cam'][0] 86 | mquery_label = m_result['mquery_label'][0] 87 | mquery_feature = mquery_feature.cuda() 88 | 89 | query_feature = query_feature.cuda() 90 | gallery_feature = gallery_feature.cuda() 91 | 92 | print(query_feature.shape) 93 | alpha = [0, 0.5, -1] 94 | for j in range(len(alpha)): 95 | CMC = torch.IntTensor(len(gallery_label)).zero_() 96 | ap = 0.0 97 | for i in range(len(query_label)): 98 | qf = query_feature[i].clone() 99 | if alpha[j] == -1: 100 | qf[0:512] *= 0 101 | else: 102 | qf[512:1024] *= alpha[j] 103 | ap_tmp, CMC_tmp = evaluate(qf,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) 104 | if CMC_tmp[0]==-1: 105 | continue 106 | CMC = CMC + CMC_tmp 107 | ap += ap_tmp 108 | 109 | CMC = CMC.float() 110 | CMC = CMC/len(query_label) #average CMC 111 | print('Alpha:%.2f Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f'%(alpha[j], CMC[0],CMC[4],CMC[9],ap/len(query_label))) 112 | 113 | # multiple-query 114 | CMC = torch.IntTensor(len(gallery_label)).zero_() 115 | ap = 0.0 116 | if multi: 117 | malpha = 0.5 ###### 118 | for i in range(len(query_label)): 119 | mquery_index1 = np.argwhere(mquery_label==query_label[i]) 120 | mquery_index2 = np.argwhere(mquery_cam==query_cam[i]) 121 | mquery_index = np.intersect1d(mquery_index1, mquery_index2) 122 | mq = torch.mean(mquery_feature[mquery_index,:], dim=0) 123 | mq[512:1024] *= malpha 124 | ap_tmp, CMC_tmp = evaluate(mq,query_label[i],query_cam[i],gallery_feature,gallery_label,gallery_cam) 125 | if CMC_tmp[0]==-1: 126 | continue 127 | CMC = CMC + CMC_tmp 128 | ap += ap_tmp 129 | CMC = CMC.float() 130 | CMC = CMC/len(query_label) #average CMC 131 | print('multi Rank@1:%.4f Rank@5:%.4f Rank@10:%.4f mAP:%.4f'%(CMC[0],CMC[4],CMC[9],ap/len(query_label))) 132 | -------------------------------------------------------------------------------- /reid_eval/test_2label.py: -------------------------------------------------------------------------------- 1 | 2 | """ 3 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 4 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 5 | """ 6 | 7 | from __future__ import print_function, division 8 | 9 | import sys 10 | sys.path.append('..') 11 | import argparse 12 | import torch 13 | import torch.nn as nn 14 | import torch.optim as optim 15 | from torch.optim import lr_scheduler 16 | from torch.autograd import Variable 17 | import numpy as np 18 | import torchvision 19 | from torchvision import datasets, models, transforms 20 | import time 21 | import os 22 | import scipy.io 23 | import yaml 24 | from reIDmodel import ft_net, ft_netAB, ft_net_dense, PCB, PCB_test 25 | 26 | ###################################################################### 27 | # Options 28 | # -------- 29 | parser = argparse.ArgumentParser(description='Training') 30 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 31 | parser.add_argument('--which_epoch',default=90000, type=int, help='80000') 32 | parser.add_argument('--test_dir',default='../../Market/pytorch',type=str, help='./test_data') 33 | parser.add_argument('--name', default='test', type=str, help='save model path') 34 | parser.add_argument('--batchsize', default=80, type=int, help='batchsize') 35 | parser.add_argument('--use_dense', action='store_true', help='use densenet121' ) 36 | parser.add_argument('--PCB', action='store_true', help='use PCB' ) 37 | parser.add_argument('--multi', action='store_true', help='use multiple query' ) 38 | 39 | opt = parser.parse_args() 40 | 41 | str_ids = opt.gpu_ids.split(',') 42 | which_epoch = opt.which_epoch 43 | name = opt.name 44 | test_dir = opt.test_dir 45 | 46 | gpu_ids = [] 47 | for str_id in str_ids: 48 | id = int(str_id) 49 | if id >=0: 50 | gpu_ids.append(id) 51 | 52 | # set gpu ids 53 | if len(gpu_ids)>0: 54 | torch.cuda.set_device(gpu_ids[0]) 55 | 56 | ###################################################################### 57 | # Load Data 58 | # --------- 59 | # 60 | # We will use torchvision and torch.utils.data packages for loading the 61 | # data. 62 | # 63 | data_transforms = transforms.Compose([ 64 | transforms.Resize((256,128), interpolation=3), 65 | transforms.ToTensor(), 66 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 67 | ]) 68 | 69 | if opt.PCB: 70 | data_transforms = transforms.Compose([ 71 | transforms.Resize((384,192), interpolation=3), 72 | transforms.ToTensor(), 73 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 74 | ]) 75 | 76 | 77 | data_dir = test_dir 78 | image_datasets = {x: datasets.ImageFolder( os.path.join(data_dir,x) ,data_transforms) for x in ['gallery','query','multi-query']} 79 | dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=opt.batchsize, 80 | shuffle=False, num_workers=16) for x in ['gallery','query','multi-query']} 81 | 82 | class_names = image_datasets['query'].classes 83 | use_gpu = torch.cuda.is_available() 84 | 85 | ###################################################################### 86 | # Load model 87 | #--------------------------- 88 | def load_network(network): 89 | save_path = os.path.join('../outputs',name,'checkpoints/id_%08d.pt'%opt.which_epoch) 90 | state_dict = torch.load(save_path) 91 | network.load_state_dict(state_dict['a'], strict=False) 92 | return network 93 | 94 | 95 | ###################################################################### 96 | # Extract feature 97 | # ---------------------- 98 | # 99 | # Extract feature from a trained model. 100 | # 101 | def fliplr(img): 102 | '''flip horizontal''' 103 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 104 | img_flip = img.index_select(3,inv_idx) 105 | return img_flip 106 | 107 | def norm(f): 108 | f = f.squeeze() 109 | fnorm = torch.norm(f, p=2, dim=1, keepdim=True) 110 | f = f.div(fnorm.expand_as(f)) 111 | return f 112 | 113 | def extract_feature(model,dataloaders): 114 | features = torch.FloatTensor() 115 | count = 0 116 | for data in dataloaders: 117 | img, label = data 118 | n, c, h, w = img.size() 119 | count += n 120 | if opt.use_dense: 121 | ff = torch.FloatTensor(n,1024).zero_() 122 | else: 123 | ff = torch.FloatTensor(n,1024).zero_() 124 | if opt.PCB: 125 | ff = torch.FloatTensor(n,2048,6).zero_() # we have six parts 126 | for i in range(2): 127 | if(i==1): 128 | img = fliplr(img) 129 | input_img = Variable(img.cuda()) 130 | f, x = model(input_img) 131 | x[0] = norm(x[0]) 132 | x[1] = norm(x[1]) 133 | f = torch.cat((x[0],x[1]), dim=1) #use 512-dim feature 134 | f = f.data.cpu() 135 | ff = ff+f 136 | 137 | ff[:, 0:512] = norm(ff[:, 0:512]) 138 | ff[:, 512:1024] = norm(ff[:, 512:1024]) 139 | 140 | # norm feature 141 | if opt.PCB: 142 | # feature size (n,2048,6) 143 | # 1. To treat every part equally, I calculate the norm for every 2048-dim part feature. 144 | # 2. To keep the cosine score==1, sqrt(6) is added to norm the whole feature (2048*6). 145 | fnorm = torch.norm(ff, p=2, dim=1, keepdim=True) * np.sqrt(6) 146 | ff = ff.div(fnorm.expand_as(ff)) 147 | ff = ff.view(ff.size(0), -1) 148 | 149 | features = torch.cat((features,ff), 0) 150 | return features 151 | 152 | def get_id(img_path): 153 | camera_id = [] 154 | labels = [] 155 | for path, v in img_path: 156 | #filename = path.split('/')[-1] 157 | filename = os.path.basename(path) 158 | label = filename[0:4] 159 | camera = filename.split('c')[1] 160 | if label[0:2]=='-1': 161 | labels.append(-1) 162 | else: 163 | labels.append(int(label)) 164 | camera_id.append(int(camera[0])) 165 | return camera_id, labels 166 | 167 | gallery_path = image_datasets['gallery'].imgs 168 | query_path = image_datasets['query'].imgs 169 | mquery_path = image_datasets['multi-query'].imgs 170 | 171 | gallery_cam,gallery_label = get_id(gallery_path) 172 | query_cam,query_label = get_id(query_path) 173 | mquery_cam,mquery_label = get_id(mquery_path) 174 | 175 | ###################################################################### 176 | # Load Collected data Trained model 177 | print('-------test-----------') 178 | 179 | ###load config### 180 | config_path = os.path.join('../outputs',name,'config.yaml') 181 | with open(config_path, 'r') as stream: 182 | config = yaml.safe_load(stream) 183 | 184 | model_structure = ft_netAB(config['ID_class'], norm=config['norm_id'], stride=config['ID_stride'], pool=config['pool']) 185 | 186 | if opt.PCB: 187 | model_structure = PCB(config['ID_class']) 188 | 189 | model = load_network(model_structure) 190 | 191 | # Remove the final fc layer and classifier layer 192 | model.model.fc = nn.Sequential() 193 | model.classifier1.classifier = nn.Sequential() 194 | model.classifier2.classifier = nn.Sequential() 195 | 196 | # Change to test mode 197 | model = model.eval() 198 | if use_gpu: 199 | model = model.cuda() 200 | 201 | # Extract feature 202 | since = time.time() 203 | with torch.no_grad(): 204 | gallery_feature = extract_feature(model,dataloaders['gallery']) 205 | query_feature = extract_feature(model,dataloaders['query']) 206 | time_elapsed = time.time() - since 207 | print('Extract features complete in {:.0f}m {:.0f}s'.format( 208 | time_elapsed // 60, time_elapsed % 60)) 209 | if opt.multi: 210 | mquery_feature = extract_feature(model,dataloaders['multi-query']) 211 | 212 | # Save to Matlab for check 213 | result = {'gallery_f':gallery_feature.numpy(),'gallery_label':gallery_label,'gallery_cam':gallery_cam,'query_f':query_feature.numpy(),'query_label':query_label,'query_cam':query_cam} 214 | scipy.io.savemat('pytorch_result.mat',result) 215 | if opt.multi: 216 | result = {'mquery_f':mquery_feature.numpy(),'mquery_label':mquery_label,'mquery_cam':mquery_cam} 217 | scipy.io.savemat('multi_query.mat',result) 218 | 219 | os.system('python evaluate_gpu.py') 220 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from utils import get_all_data_loaders, prepare_sub_folder, write_loss, get_config, write_2images, Timer 6 | import argparse 7 | from trainer import DGNet_Trainer 8 | import torch.backends.cudnn as cudnn 9 | import torch 10 | import numpy.random as random 11 | try: 12 | from itertools import izip as zip 13 | except ImportError: # will be 3.x series 14 | pass 15 | import os 16 | import sys 17 | import tensorboardX 18 | import shutil 19 | 20 | parser = argparse.ArgumentParser() 21 | parser.add_argument('--config', type=str, default='configs/latest.yaml', help='Path to the config file.') 22 | parser.add_argument('--output_path', type=str, default='.', help="outputs path") 23 | parser.add_argument('--name', type=str, default='latest_ablation', help="outputs path") 24 | parser.add_argument("--resume", action="store_true") 25 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 26 | parser.add_argument('--gpu_ids',default='0', type=str,help='gpu_ids: e.g. 0 0,1,2 0,2') 27 | opts = parser.parse_args() 28 | 29 | str_ids = opts.gpu_ids.split(',') 30 | gpu_ids = [] 31 | for str_id in str_ids: 32 | gpu_ids.append(int(str_id)) 33 | num_gpu = len(gpu_ids) 34 | 35 | cudnn.benchmark = True 36 | 37 | # Load experiment setting 38 | if opts.resume: 39 | config = get_config('./outputs/'+opts.name+'/config.yaml') 40 | else: 41 | config = get_config(opts.config) 42 | max_iter = config['max_iter'] 43 | display_size = config['display_size'] 44 | config['vgg_model_path'] = opts.output_path 45 | 46 | # Setup model and data loader 47 | if opts.trainer == 'DGNet': 48 | trainer = DGNet_Trainer(config, gpu_ids) 49 | trainer.cuda() 50 | 51 | random.seed(7) #fix random result 52 | train_loader_a, train_loader_b, test_loader_a, test_loader_b = get_all_data_loaders(config) 53 | train_a_rand = random.permutation(train_loader_a.dataset.img_num)[0:display_size] 54 | train_b_rand = random.permutation(train_loader_b.dataset.img_num)[0:display_size] 55 | test_a_rand = random.permutation(test_loader_a.dataset.img_num)[0:display_size] 56 | test_b_rand = random.permutation(test_loader_b.dataset.img_num)[0:display_size] 57 | 58 | train_display_images_a = torch.stack([train_loader_a.dataset[i][0] for i in train_a_rand]).cuda() 59 | train_display_images_ap = torch.stack([train_loader_a.dataset[i][2] for i in train_a_rand]).cuda() 60 | train_display_images_b = torch.stack([train_loader_b.dataset[i][0] for i in train_b_rand]).cuda() 61 | train_display_images_bp = torch.stack([train_loader_b.dataset[i][2] for i in train_b_rand]).cuda() 62 | test_display_images_a = torch.stack([test_loader_a.dataset[i][0] for i in test_a_rand]).cuda() 63 | test_display_images_ap = torch.stack([test_loader_a.dataset[i][2] for i in test_a_rand]).cuda() 64 | test_display_images_b = torch.stack([test_loader_b.dataset[i][0] for i in test_b_rand]).cuda() 65 | test_display_images_bp = torch.stack([test_loader_b.dataset[i][2] for i in test_b_rand]).cuda() 66 | 67 | # Setup logger and output folders 68 | if not opts.resume: 69 | model_name = os.path.splitext(os.path.basename(opts.config))[0] 70 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", model_name)) 71 | output_directory = os.path.join(opts.output_path + "/outputs", model_name) 72 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 73 | shutil.copyfile(opts.config, os.path.join(output_directory, 'config.yaml')) # copy config file to output folder 74 | shutil.copyfile('trainer.py', os.path.join(output_directory, 'trainer.py')) # copy file to output folder 75 | shutil.copyfile('reIDmodel.py', os.path.join(output_directory, 'reIDmodel.py')) # copy file to output folder 76 | shutil.copyfile('networks.py', os.path.join(output_directory, 'networks.py')) # copy file to output folder 77 | else: 78 | train_writer = tensorboardX.SummaryWriter(os.path.join(opts.output_path + "/logs", opts.name)) 79 | output_directory = os.path.join(opts.output_path + "/outputs", opts.name) 80 | checkpoint_directory, image_directory = prepare_sub_folder(output_directory) 81 | # Start training 82 | iterations = trainer.resume(checkpoint_directory, hyperparameters=config) if opts.resume else 0 83 | config['epoch_iteration'] = round( train_loader_a.dataset.img_num / config['batch_size'] ) 84 | print('Every epoch need %d iterations'%config['epoch_iteration']) 85 | nepoch = 0 86 | 87 | print('Note that dataloader may hang with too much nworkers.') 88 | 89 | if num_gpu>1: 90 | print('Now you are using %d gpus.'%num_gpu) 91 | trainer.dis_a = torch.nn.DataParallel(trainer.dis_a, gpu_ids) 92 | trainer.dis_b = trainer.dis_a 93 | trainer = torch.nn.DataParallel(trainer, gpu_ids) 94 | 95 | while True: 96 | for it, ((images_a,labels_a, pos_a), (images_b, labels_b, pos_b)) in enumerate(zip(train_loader_a, train_loader_b)): 97 | if num_gpu>1: 98 | trainer.module.update_learning_rate() 99 | else: 100 | trainer.update_learning_rate() 101 | images_a, images_b = images_a.cuda().detach(), images_b.cuda().detach() 102 | pos_a, pos_b = pos_a.cuda().detach(), pos_b.cuda().detach() 103 | labels_a, labels_b = labels_a.cuda().detach(), labels_b.cuda().detach() 104 | 105 | with Timer("Elapsed time in update: %f"): 106 | # Main training code 107 | x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p = \ 108 | trainer.forward(images_a, images_b, pos_a, pos_b) 109 | if num_gpu>1: 110 | trainer.module.dis_update(x_ab.clone(), x_ba.clone(), images_a, images_b, config, num_gpu) 111 | trainer.module.gen_update(x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, images_a, images_b, pos_a, pos_b, labels_a, labels_b, config, iterations, num_gpu) 112 | else: 113 | trainer.dis_update(x_ab.clone(), x_ba.clone(), images_a, images_b, config, num_gpu=1) 114 | trainer.gen_update(x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, images_a, images_b, pos_a, pos_b, labels_a, labels_b, config, iterations, num_gpu=1) 115 | 116 | torch.cuda.synchronize() 117 | 118 | # Dump training stats in log file 119 | if (iterations + 1) % config['log_iter'] == 0: 120 | print("\033[1m Epoch: %02d Iteration: %08d/%08d \033[0m" % (nepoch, iterations + 1, max_iter), end=" ") 121 | if num_gpu==1: 122 | write_loss(iterations, trainer, train_writer) 123 | else: 124 | write_loss(iterations, trainer.module, train_writer) 125 | 126 | # Write images 127 | if (iterations + 1) % config['image_save_iter'] == 0: 128 | with torch.no_grad(): 129 | if num_gpu>1: 130 | test_image_outputs = trainer.module.sample(test_display_images_a, test_display_images_b) 131 | else: 132 | test_image_outputs = trainer.sample(test_display_images_a, test_display_images_b) 133 | write_2images(test_image_outputs, display_size, image_directory, 'test_%08d' % (iterations + 1)) 134 | del test_image_outputs 135 | 136 | if (iterations + 1) % config['image_display_iter'] == 0: 137 | with torch.no_grad(): 138 | if num_gpu>1: 139 | image_outputs = trainer.module.sample(train_display_images_a, train_display_images_b) 140 | else: 141 | image_outputs = trainer.sample(train_display_images_a, train_display_images_b) 142 | write_2images(image_outputs, display_size, image_directory, 'train_%08d' % (iterations + 1)) 143 | del image_outputs 144 | # Save network weights 145 | if (iterations + 1) % config['snapshot_save_iter'] == 0: 146 | if num_gpu>1: 147 | trainer.module.save(checkpoint_directory, iterations) 148 | else: 149 | trainer.save(checkpoint_directory, iterations) 150 | 151 | iterations += 1 152 | if iterations >= max_iter: 153 | sys.exit('Finish training') 154 | 155 | # Save network weights by epoch number 156 | nepoch = nepoch+1 157 | if(nepoch + 1) % 10 == 0: 158 | if num_gpu>1: 159 | trainer.module.save(checkpoint_directory, iterations) 160 | else: 161 | trainer.save(checkpoint_directory, iterations) 162 | 163 | -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from networks import AdaINGen, MsImageDis 6 | from reIDmodel import ft_net, ft_netAB, PCB 7 | from utils import get_model_list, vgg_preprocess, load_vgg16, get_scheduler 8 | from torch.autograd import Variable 9 | import torch 10 | import torch.nn as nn 11 | import copy 12 | import os 13 | import cv2 14 | import numpy as np 15 | from random_erasing import RandomErasing 16 | import random 17 | import yaml 18 | 19 | #fp16 20 | try: 21 | from apex import amp 22 | from apex.fp16_utils import * 23 | except ImportError: 24 | print('This is not an error. If you want to use low precision, i.e., fp16, please install the apex with cuda support (https://github.com/NVIDIA/apex) and update pytorch to 1.0') 25 | 26 | 27 | def to_gray(half=False): #simple 28 | def forward(x): 29 | x = torch.mean(x, dim=1, keepdim=True) 30 | if half: 31 | x = x.half() 32 | return x 33 | return forward 34 | 35 | def to_edge(x): 36 | x = x.data.cpu() 37 | out = torch.FloatTensor(x.size(0), x.size(2), x.size(3)) 38 | for i in range(x.size(0)): 39 | xx = recover(x[i,:,:,:]) # 3 channel, 256x128x3 40 | xx = cv2.cvtColor(xx, cv2.COLOR_RGB2GRAY) # 256x128x1 41 | xx = cv2.Canny(xx, 10, 200) #256x128 42 | xx = xx/255.0 - 0.5 # {-0.5,0.5} 43 | xx += np.random.randn(xx.shape[0],xx.shape[1])*0.1 #add random noise 44 | xx = torch.from_numpy(xx.astype(np.float32)) 45 | out[i,:,:] = xx 46 | out = out.unsqueeze(1) 47 | return out.cuda() 48 | 49 | def scale2(x): 50 | if x.size(2) > 128: # do not need to scale the input 51 | return x 52 | x = torch.nn.functional.upsample(x, scale_factor=2, mode='nearest') #bicubic is not available for the time being. 53 | return x 54 | 55 | def recover(inp): 56 | inp = inp.numpy().transpose((1, 2, 0)) 57 | mean = np.array([0.485, 0.456, 0.406]) 58 | std = np.array([0.229, 0.224, 0.225]) 59 | inp = std * inp + mean 60 | inp = inp * 255.0 61 | inp = np.clip(inp, 0, 255) 62 | inp = inp.astype(np.uint8) 63 | return inp 64 | 65 | def train_bn(m): 66 | classname = m.__class__.__name__ 67 | if classname.find('BatchNorm') != -1: 68 | m.train() 69 | 70 | def fliplr(img): 71 | '''flip horizontal''' 72 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long().cuda() # N x C x H x W 73 | img_flip = img.index_select(3,inv_idx) 74 | return img_flip 75 | 76 | def update_teacher(model_s, model_t, alpha=0.999): 77 | for param_s, param_t in zip(model_s.parameters(), model_t.parameters()): 78 | param_t.data.mul_(alpha).add_(1 - alpha, param_s.data) 79 | 80 | def predict_label(teacher_models, inputs, num_class, alabel, slabel, teacher_style=0): 81 | # teacher_style: 82 | # 0: Our smooth dynamic label 83 | # 1: Pseudo label, hard dynamic label 84 | # 2: Conditional label, hard static label 85 | # 3: LSRO, static smooth label 86 | # 4: Dynamic Soft Two-label 87 | # alabel is appearance label 88 | if teacher_style == 0: 89 | count = 0 90 | sm = nn.Softmax(dim=1) 91 | for teacher_model in teacher_models: 92 | _, outputs_t1 = teacher_model(inputs) 93 | outputs_t1 = sm(outputs_t1.detach()) 94 | _, outputs_t2 = teacher_model(fliplr(inputs)) 95 | outputs_t2 = sm(outputs_t2.detach()) 96 | if count==0: 97 | outputs_t = outputs_t1 + outputs_t2 98 | else: 99 | outputs_t = outputs_t * opt.alpha # old model decay 100 | outputs_t += outputs_t1 + outputs_t2 101 | count +=2 102 | elif teacher_style == 1: # dynamic one-hot label 103 | count = 0 104 | sm = nn.Softmax(dim=1) 105 | for teacher_model in teacher_models: 106 | _, outputs_t1 = teacher_model(inputs) 107 | outputs_t1 = sm(outputs_t1.detach()) # change softmax to max 108 | _, outputs_t2 = teacher_model(fliplr(inputs)) 109 | outputs_t2 = sm(outputs_t2.detach()) 110 | if count==0: 111 | outputs_t = outputs_t1 + outputs_t2 112 | else: 113 | outputs_t = outputs_t * opt.alpha # old model decay 114 | outputs_t += outputs_t1 + outputs_t2 115 | count +=2 116 | _, dlabel = torch.max(outputs_t.data, 1) 117 | outputs_t = torch.zeros(inputs.size(0), num_class).cuda() 118 | for i in range(inputs.size(0)): 119 | outputs_t[i, dlabel[i]] = 1 120 | elif teacher_style == 2: # appearance label 121 | outputs_t = torch.zeros(inputs.size(0), num_class).cuda() 122 | for i in range(inputs.size(0)): 123 | outputs_t[i, alabel[i]] = 1 124 | elif teacher_style == 3: # LSRO 125 | outputs_t = torch.ones(inputs.size(0), num_class).cuda() 126 | elif teacher_style == 4: #Two-label 127 | count = 0 128 | sm = nn.Softmax(dim=1) 129 | for teacher_model in teacher_models: 130 | _, outputs_t1 = teacher_model(inputs) 131 | outputs_t1 = sm(outputs_t1.detach()) 132 | _, outputs_t2 = teacher_model(fliplr(inputs)) 133 | outputs_t2 = sm(outputs_t2.detach()) 134 | if count==0: 135 | outputs_t = outputs_t1 + outputs_t2 136 | else: 137 | outputs_t = outputs_t * opt.alpha # old model decay 138 | outputs_t += outputs_t1 + outputs_t2 139 | count +=2 140 | mask = torch.zeros(outputs_t.shape) 141 | mask = mask.cuda() 142 | for i in range(inputs.size(0)): 143 | mask[i, alabel[i]] = 1 144 | mask[i, slabel[i]] = 1 145 | outputs_t = outputs_t*mask 146 | else: 147 | print('not valid style. teacher-style is in [0-3].') 148 | 149 | s = torch.sum(outputs_t, dim=1, keepdim=True) 150 | s = s.expand_as(outputs_t) 151 | outputs_t = outputs_t/s 152 | return outputs_t 153 | 154 | ###################################################################### 155 | # Load model 156 | #--------------------------- 157 | def load_network(network, name): 158 | save_path = os.path.join('./models',name,'net_last.pth') 159 | network.load_state_dict(torch.load(save_path)) 160 | return network 161 | 162 | def load_config(name): 163 | config_path = os.path.join('./models',name,'opts.yaml') 164 | with open(config_path, 'r') as stream: 165 | config = yaml.safe_load(stream) 166 | return config 167 | 168 | 169 | class DGNet_Trainer(nn.Module): 170 | def __init__(self, hyperparameters, gpu_ids=[0]): 171 | super(DGNet_Trainer, self).__init__() 172 | lr_g = hyperparameters['lr_g'] 173 | lr_d = hyperparameters['lr_d'] 174 | ID_class = hyperparameters['ID_class'] 175 | if not 'apex' in hyperparameters.keys(): 176 | hyperparameters['apex'] = False 177 | self.fp16 = hyperparameters['apex'] 178 | # Initiate the networks 179 | # We do not need to manually set fp16 in the network for the new apex. So here I set fp16=False. 180 | self.gen_a = AdaINGen(hyperparameters['input_dim_a'], hyperparameters['gen'], fp16 = False) # auto-encoder for domain a 181 | self.gen_b = self.gen_a # auto-encoder for domain b 182 | 183 | if not 'ID_stride' in hyperparameters.keys(): 184 | hyperparameters['ID_stride'] = 2 185 | 186 | if hyperparameters['ID_style']=='PCB': 187 | self.id_a = PCB(ID_class) 188 | elif hyperparameters['ID_style']=='AB': 189 | self.id_a = ft_netAB(ID_class, stride = hyperparameters['ID_stride'], norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) 190 | else: 191 | self.id_a = ft_net(ID_class, norm=hyperparameters['norm_id'], pool=hyperparameters['pool']) # return 2048 now 192 | 193 | self.id_b = self.id_a 194 | self.dis_a = MsImageDis(3, hyperparameters['dis'], fp16 = False) # discriminator for domain a 195 | self.dis_b = self.dis_a # discriminator for domain b 196 | 197 | # load teachers 198 | if hyperparameters['teacher'] != "": 199 | teacher_name = hyperparameters['teacher'] 200 | print(teacher_name) 201 | teacher_names = teacher_name.split(',') 202 | teacher_model = nn.ModuleList() 203 | teacher_count = 0 204 | for teacher_name in teacher_names: 205 | config_tmp = load_config(teacher_name) 206 | if 'stride' in config_tmp: 207 | stride = config_tmp['stride'] 208 | else: 209 | stride = 2 210 | model_tmp = ft_net(ID_class, stride = stride) 211 | teacher_model_tmp = load_network(model_tmp, teacher_name) 212 | teacher_model_tmp.model.fc = nn.Sequential() # remove the original fc layer in ImageNet 213 | teacher_model_tmp = teacher_model_tmp.cuda() 214 | if self.fp16: 215 | teacher_model_tmp = amp.initialize(teacher_model_tmp, opt_level="O1") 216 | teacher_model.append(teacher_model_tmp.cuda().eval()) 217 | teacher_count +=1 218 | self.teacher_model = teacher_model 219 | if hyperparameters['train_bn']: 220 | self.teacher_model = self.teacher_model.apply(train_bn) 221 | 222 | self.instancenorm = nn.InstanceNorm2d(512, affine=False) 223 | 224 | # RGB to one channel 225 | if hyperparameters['single']=='edge': 226 | self.single = to_edge 227 | else: 228 | self.single = to_gray(False) 229 | 230 | # Random Erasing when training 231 | if not 'erasing_p' in hyperparameters.keys(): 232 | self.erasing_p = 0 233 | else: 234 | self.erasing_p = hyperparameters['erasing_p'] 235 | self.single_re = RandomErasing(probability = self.erasing_p, mean=[0.0, 0.0, 0.0]) 236 | 237 | if not 'T_w' in hyperparameters.keys(): 238 | hyperparameters['T_w'] = 1 239 | # Setup the optimizers 240 | beta1 = hyperparameters['beta1'] 241 | beta2 = hyperparameters['beta2'] 242 | dis_params = list(self.dis_a.parameters()) #+ list(self.dis_b.parameters()) 243 | gen_params = list(self.gen_a.parameters()) #+ list(self.gen_b.parameters()) 244 | 245 | self.dis_opt = torch.optim.Adam([p for p in dis_params if p.requires_grad], 246 | lr=lr_d, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 247 | self.gen_opt = torch.optim.Adam([p for p in gen_params if p.requires_grad], 248 | lr=lr_g, betas=(beta1, beta2), weight_decay=hyperparameters['weight_decay']) 249 | # id params 250 | if hyperparameters['ID_style']=='PCB': 251 | ignored_params = (list(map(id, self.id_a.classifier0.parameters() )) 252 | +list(map(id, self.id_a.classifier1.parameters() )) 253 | +list(map(id, self.id_a.classifier2.parameters() )) 254 | +list(map(id, self.id_a.classifier3.parameters() )) 255 | ) 256 | base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) 257 | lr2 = hyperparameters['lr2'] 258 | self.id_opt = torch.optim.SGD([ 259 | {'params': base_params, 'lr': lr2}, 260 | {'params': self.id_a.classifier0.parameters(), 'lr': lr2*10}, 261 | {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10}, 262 | {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10}, 263 | {'params': self.id_a.classifier3.parameters(), 'lr': lr2*10} 264 | ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) 265 | elif hyperparameters['ID_style']=='AB': 266 | ignored_params = (list(map(id, self.id_a.classifier1.parameters())) 267 | + list(map(id, self.id_a.classifier2.parameters()))) 268 | base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) 269 | lr2 = hyperparameters['lr2'] 270 | self.id_opt = torch.optim.SGD([ 271 | {'params': base_params, 'lr': lr2}, 272 | {'params': self.id_a.classifier1.parameters(), 'lr': lr2*10}, 273 | {'params': self.id_a.classifier2.parameters(), 'lr': lr2*10} 274 | ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) 275 | else: 276 | ignored_params = list(map(id, self.id_a.classifier.parameters() )) 277 | base_params = filter(lambda p: id(p) not in ignored_params, self.id_a.parameters()) 278 | lr2 = hyperparameters['lr2'] 279 | self.id_opt = torch.optim.SGD([ 280 | {'params': base_params, 'lr': lr2}, 281 | {'params': self.id_a.classifier.parameters(), 'lr': lr2*10} 282 | ], weight_decay=hyperparameters['weight_decay'], momentum=0.9, nesterov=True) 283 | 284 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters) 285 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters) 286 | self.id_scheduler = get_scheduler(self.id_opt, hyperparameters) 287 | self.id_scheduler.gamma = hyperparameters['gamma2'] 288 | 289 | #ID Loss 290 | self.id_criterion = nn.CrossEntropyLoss() 291 | self.criterion_teacher = nn.KLDivLoss(size_average=False) 292 | # Load VGG model if needed 293 | if 'vgg_w' in hyperparameters.keys() and hyperparameters['vgg_w'] > 0: 294 | self.vgg = load_vgg16(hyperparameters['vgg_model_path'] + '/models') 295 | self.vgg.eval() 296 | for param in self.vgg.parameters(): 297 | param.requires_grad = False 298 | 299 | # save memory 300 | if self.fp16: 301 | # Name the FP16_Optimizer instance to replace the existing optimizer 302 | assert torch.backends.cudnn.enabled, "fp16 mode requires cudnn backend to be enabled." 303 | self.gen_a = self.gen_a.cuda() 304 | self.dis_a = self.dis_a.cuda() 305 | self.id_a = self.id_a.cuda() 306 | 307 | self.gen_b = self.gen_a 308 | self.dis_b = self.dis_a 309 | self.id_b = self.id_a 310 | 311 | self.gen_a, self.gen_opt = amp.initialize(self.gen_a, self.gen_opt, opt_level="O1") 312 | self.dis_a, self.dis_opt = amp.initialize(self.dis_a, self.dis_opt, opt_level="O1") 313 | self.id_a, self.id_opt = amp.initialize(self.id_a, self.id_opt, opt_level="O1") 314 | 315 | def to_re(self, x): 316 | out = torch.FloatTensor(x.size(0), x.size(1), x.size(2), x.size(3)) 317 | out = out.cuda() 318 | for i in range(x.size(0)): 319 | out[i,:,:,:] = self.single_re(x[i,:,:,:]) 320 | return out 321 | 322 | def recon_criterion(self, input, target): 323 | diff = input - target.detach() 324 | return torch.mean(torch.abs(diff[:])) 325 | 326 | def recon_criterion_sqrt(self, input, target): 327 | diff = input - target 328 | return torch.mean(torch.sqrt(torch.abs(diff[:])+1e-8)) 329 | 330 | def recon_criterion2(self, input, target): 331 | diff = input - target 332 | return torch.mean(diff[:]**2) 333 | 334 | def recon_cos(self, input, target): 335 | cos = torch.nn.CosineSimilarity() 336 | cos_dis = 1 - cos(input, target) 337 | return torch.mean(cos_dis[:]) 338 | 339 | def forward(self, x_a, x_b, xp_a, xp_b): 340 | s_a = self.gen_a.encode(self.single(x_a)) 341 | s_b = self.gen_b.encode(self.single(x_b)) 342 | f_a, p_a = self.id_a(scale2(x_a)) 343 | f_b, p_b = self.id_b(scale2(x_b)) 344 | x_ba = self.gen_a.decode(s_b, f_a) 345 | x_ab = self.gen_b.decode(s_a, f_b) 346 | x_a_recon = self.gen_a.decode(s_a, f_a) 347 | x_b_recon = self.gen_b.decode(s_b, f_b) 348 | fp_a, pp_a = self.id_a(scale2(xp_a)) 349 | fp_b, pp_b = self.id_b(scale2(xp_b)) 350 | # decode the same person 351 | x_a_recon_p = self.gen_a.decode(s_a, fp_a) 352 | x_b_recon_p = self.gen_b.decode(s_b, fp_b) 353 | 354 | # Random Erasing only effect the ID and PID loss. 355 | if self.erasing_p > 0: 356 | x_a_re = self.to_re(scale2(x_a.clone())) 357 | x_b_re = self.to_re(scale2(x_b.clone())) 358 | xp_a_re = self.to_re(scale2(xp_a.clone())) 359 | xp_b_re = self.to_re(scale2(xp_b.clone())) 360 | _, p_a = self.id_a(x_a_re) 361 | _, p_b = self.id_b(x_b_re) 362 | # encode the same ID different photo 363 | _, pp_a = self.id_a(xp_a_re) 364 | _, pp_b = self.id_b(xp_b_re) 365 | 366 | return x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p 367 | 368 | def gen_update(self, x_ab, x_ba, s_a, s_b, f_a, f_b, p_a, p_b, pp_a, pp_b, x_a_recon, x_b_recon, x_a_recon_p, x_b_recon_p, x_a, x_b, xp_a, xp_b, l_a, l_b, hyperparameters, iteration, num_gpu): 369 | # ppa, ppb is the same person 370 | self.gen_opt.zero_grad() 371 | self.id_opt.zero_grad() 372 | 373 | # no gradient 374 | x_ba_copy = Variable(x_ba.data, requires_grad=False) 375 | x_ab_copy = Variable(x_ab.data, requires_grad=False) 376 | 377 | rand_num = random.uniform(0,1) 378 | ################################# 379 | # encode structure 380 | if hyperparameters['use_encoder_again']>=rand_num: 381 | # encode again (encoder is tuned, input is fixed) 382 | s_a_recon = self.gen_b.enc_content(self.single(x_ab_copy)) 383 | s_b_recon = self.gen_a.enc_content(self.single(x_ba_copy)) 384 | else: 385 | # copy the encoder 386 | self.enc_content_copy = copy.deepcopy(self.gen_a.enc_content) 387 | self.enc_content_copy = self.enc_content_copy.eval() 388 | # encode again (encoder is fixed, input is tuned) 389 | s_a_recon = self.enc_content_copy(self.single(x_ab)) 390 | s_b_recon = self.enc_content_copy(self.single(x_ba)) 391 | 392 | ################################# 393 | # encode appearance 394 | self.id_a_copy = copy.deepcopy(self.id_a) 395 | self.id_a_copy = self.id_a_copy.eval() 396 | if hyperparameters['train_bn']: 397 | self.id_a_copy = self.id_a_copy.apply(train_bn) 398 | self.id_b_copy = self.id_a_copy 399 | # encode again (encoder is fixed, input is tuned) 400 | f_a_recon, p_a_recon = self.id_a_copy(scale2(x_ba)) 401 | f_b_recon, p_b_recon = self.id_b_copy(scale2(x_ab)) 402 | 403 | # teacher Loss 404 | # Tune the ID model 405 | log_sm = nn.LogSoftmax(dim=1) 406 | if hyperparameters['teacher_w'] >0 and hyperparameters['teacher'] != "": 407 | if hyperparameters['ID_style'] == 'normal': 408 | _, p_a_student = self.id_a(scale2(x_ba_copy)) 409 | p_a_student = log_sm(p_a_student) 410 | p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style']) 411 | self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) 412 | 413 | _, p_b_student = self.id_b(scale2(x_ab_copy)) 414 | p_b_student = log_sm(p_b_student) 415 | p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style']) 416 | self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) 417 | elif hyperparameters['ID_style'] == 'AB': 418 | # normal teacher-student loss 419 | # BA -> LabelA(smooth) + LabelB(batchB) 420 | _, p_ba_student = self.id_a(scale2(x_ba_copy))# f_a, s_b 421 | p_a_student = log_sm(p_ba_student[0]) 422 | with torch.no_grad(): 423 | p_a_teacher = predict_label(self.teacher_model, scale2(x_ba_copy), num_class = hyperparameters['ID_class'], alabel = l_a, slabel = l_b, teacher_style = hyperparameters['teacher_style']) 424 | self.loss_teacher = self.criterion_teacher(p_a_student, p_a_teacher) / p_a_student.size(0) 425 | 426 | _, p_ab_student = self.id_b(scale2(x_ab_copy)) # f_b, s_a 427 | p_b_student = log_sm(p_ab_student[0]) 428 | with torch.no_grad(): 429 | p_b_teacher = predict_label(self.teacher_model, scale2(x_ab_copy), num_class = hyperparameters['ID_class'], alabel = l_b, slabel = l_a, teacher_style = hyperparameters['teacher_style']) 430 | self.loss_teacher += self.criterion_teacher(p_b_student, p_b_teacher) / p_b_student.size(0) 431 | 432 | # branch b loss 433 | # here we give different label 434 | loss_B = self.id_criterion(p_ba_student[1], l_b) + self.id_criterion(p_ab_student[1], l_a) 435 | self.loss_teacher = hyperparameters['T_w'] * self.loss_teacher + hyperparameters['B_w'] * loss_B 436 | else: 437 | self.loss_teacher = 0.0 438 | 439 | # auto-encoder image reconstruction 440 | self.loss_gen_recon_x_a = self.recon_criterion(x_a_recon, x_a) 441 | self.loss_gen_recon_x_b = self.recon_criterion(x_b_recon, x_b) 442 | self.loss_gen_recon_xp_a = self.recon_criterion(x_a_recon_p, x_a) 443 | self.loss_gen_recon_xp_b = self.recon_criterion(x_b_recon_p, x_b) 444 | 445 | # feature reconstruction 446 | self.loss_gen_recon_s_a = self.recon_criterion(s_a_recon, s_a) if hyperparameters['recon_s_w'] > 0 else 0 447 | self.loss_gen_recon_s_b = self.recon_criterion(s_b_recon, s_b) if hyperparameters['recon_s_w'] > 0 else 0 448 | self.loss_gen_recon_f_a = self.recon_criterion(f_a_recon, f_a) if hyperparameters['recon_f_w'] > 0 else 0 449 | self.loss_gen_recon_f_b = self.recon_criterion(f_b_recon, f_b) if hyperparameters['recon_f_w'] > 0 else 0 450 | 451 | x_aba = self.gen_a.decode(s_a_recon, f_a_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None 452 | x_bab = self.gen_b.decode(s_b_recon, f_b_recon) if hyperparameters['recon_x_cyc_w'] > 0 else None 453 | 454 | # ID loss AND Tune the Generated image 455 | if hyperparameters['ID_style']=='PCB': 456 | self.loss_id = self.PCB_loss(p_a, l_a) + self.PCB_loss(p_b, l_b) 457 | self.loss_pid = self.PCB_loss(pp_a, l_a) + self.PCB_loss(pp_b, l_b) 458 | self.loss_gen_recon_id = self.PCB_loss(p_a_recon, l_a) + self.PCB_loss(p_b_recon, l_b) 459 | elif hyperparameters['ID_style']=='AB': 460 | weight_B = hyperparameters['teacher_w'] * hyperparameters['B_w'] 461 | self.loss_id = self.id_criterion(p_a[0], l_a) + self.id_criterion(p_b[0], l_b) \ 462 | + weight_B * ( self.id_criterion(p_a[1], l_a) + self.id_criterion(p_b[1], l_b) ) 463 | self.loss_pid = self.id_criterion(pp_a[0], l_a) + self.id_criterion(pp_b[0], l_b) #+ weight_B * ( self.id_criterion(pp_a[1], l_a) + self.id_criterion(pp_b[1], l_b) ) 464 | self.loss_gen_recon_id = self.id_criterion(p_a_recon[0], l_a) + self.id_criterion(p_b_recon[0], l_b) 465 | else: 466 | self.loss_id = self.id_criterion(p_a, l_a) + self.id_criterion(p_b, l_b) 467 | self.loss_pid = self.id_criterion(pp_a, l_a) + self.id_criterion(pp_b, l_b) 468 | self.loss_gen_recon_id = self.id_criterion(p_a_recon, l_a) + self.id_criterion(p_b_recon, l_b) 469 | 470 | #print(f_a_recon, f_a) 471 | self.loss_gen_cycrecon_x_a = self.recon_criterion(x_aba, x_a) if hyperparameters['recon_x_cyc_w'] > 0 else 0 472 | self.loss_gen_cycrecon_x_b = self.recon_criterion(x_bab, x_b) if hyperparameters['recon_x_cyc_w'] > 0 else 0 473 | # GAN loss 474 | if num_gpu>1: 475 | self.loss_gen_adv_a = self.dis_a.module.calc_gen_loss(self.dis_a, x_ba) 476 | self.loss_gen_adv_b = self.dis_b.module.calc_gen_loss(self.dis_b, x_ab) 477 | else: 478 | self.loss_gen_adv_a = self.dis_a.calc_gen_loss(self.dis_a, x_ba) 479 | self.loss_gen_adv_b = self.dis_b.calc_gen_loss(self.dis_b, x_ab) 480 | # domain-invariant perceptual loss 481 | self.loss_gen_vgg_a = self.compute_vgg_loss(self.vgg, x_ba, x_b) if hyperparameters['vgg_w'] > 0 else 0 482 | self.loss_gen_vgg_b = self.compute_vgg_loss(self.vgg, x_ab, x_a) if hyperparameters['vgg_w'] > 0 else 0 483 | 484 | if iteration > hyperparameters['warm_iter']: 485 | hyperparameters['recon_f_w'] += hyperparameters['warm_scale'] 486 | hyperparameters['recon_f_w'] = min(hyperparameters['recon_f_w'], hyperparameters['max_w']) 487 | hyperparameters['recon_s_w'] += hyperparameters['warm_scale'] 488 | hyperparameters['recon_s_w'] = min(hyperparameters['recon_s_w'], hyperparameters['max_w']) 489 | hyperparameters['recon_x_cyc_w'] += hyperparameters['warm_scale'] 490 | hyperparameters['recon_x_cyc_w'] = min(hyperparameters['recon_x_cyc_w'], hyperparameters['max_cyc_w']) 491 | 492 | if iteration > hyperparameters['warm_teacher_iter']: 493 | hyperparameters['teacher_w'] += hyperparameters['warm_scale'] 494 | hyperparameters['teacher_w'] = min(hyperparameters['teacher_w'], hyperparameters['max_teacher_w']) 495 | # total loss 496 | self.loss_gen_total = hyperparameters['gan_w'] * self.loss_gen_adv_a + \ 497 | hyperparameters['gan_w'] * self.loss_gen_adv_b + \ 498 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_a + \ 499 | hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_a + \ 500 | hyperparameters['recon_f_w'] * self.loss_gen_recon_f_a + \ 501 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_a + \ 502 | hyperparameters['recon_x_w'] * self.loss_gen_recon_x_b + \ 503 | hyperparameters['recon_xp_w'] * self.loss_gen_recon_xp_b + \ 504 | hyperparameters['recon_f_w'] * self.loss_gen_recon_f_b + \ 505 | hyperparameters['recon_s_w'] * self.loss_gen_recon_s_b + \ 506 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_a + \ 507 | hyperparameters['recon_x_cyc_w'] * self.loss_gen_cycrecon_x_b + \ 508 | hyperparameters['id_w'] * self.loss_id + \ 509 | hyperparameters['pid_w'] * self.loss_pid + \ 510 | hyperparameters['recon_id_w'] * self.loss_gen_recon_id + \ 511 | hyperparameters['vgg_w'] * self.loss_gen_vgg_a + \ 512 | hyperparameters['vgg_w'] * self.loss_gen_vgg_b + \ 513 | hyperparameters['teacher_w'] * self.loss_teacher 514 | if self.fp16: 515 | with amp.scale_loss(self.loss_gen_total, [self.gen_opt,self.id_opt]) as scaled_loss: 516 | scaled_loss.backward() 517 | self.gen_opt.step() 518 | self.id_opt.step() 519 | else: 520 | self.loss_gen_total.backward() 521 | self.gen_opt.step() 522 | self.id_opt.step() 523 | print("L_total: %.4f, L_gan: %.4f, Lx: %.4f, Lxp: %.4f, Lrecycle:%.4f, Lf: %.4f, Ls: %.4f, Recon-id: %.4f, id: %.4f, pid:%.4f, teacher: %.4f"%( self.loss_gen_total, \ 524 | hyperparameters['gan_w'] * (self.loss_gen_adv_a + self.loss_gen_adv_b), \ 525 | hyperparameters['recon_x_w'] * (self.loss_gen_recon_x_a + self.loss_gen_recon_x_b), \ 526 | hyperparameters['recon_xp_w'] * (self.loss_gen_recon_xp_a + self.loss_gen_recon_xp_b), \ 527 | hyperparameters['recon_x_cyc_w'] * (self.loss_gen_cycrecon_x_a + self.loss_gen_cycrecon_x_b), \ 528 | hyperparameters['recon_f_w'] * (self.loss_gen_recon_f_a + self.loss_gen_recon_f_b), \ 529 | hyperparameters['recon_s_w'] * (self.loss_gen_recon_s_a + self.loss_gen_recon_s_b), \ 530 | hyperparameters['recon_id_w'] * self.loss_gen_recon_id, \ 531 | hyperparameters['id_w'] * self.loss_id,\ 532 | hyperparameters['pid_w'] * self.loss_pid,\ 533 | hyperparameters['teacher_w'] * self.loss_teacher ) ) 534 | 535 | def compute_vgg_loss(self, vgg, img, target): 536 | img_vgg = vgg_preprocess(img) 537 | target_vgg = vgg_preprocess(target) 538 | img_fea = vgg(img_vgg) 539 | target_fea = vgg(target_vgg) 540 | return torch.mean((self.instancenorm(img_fea) - self.instancenorm(target_fea)) ** 2) 541 | 542 | def PCB_loss(self, inputs, labels): 543 | loss = 0.0 544 | for part in inputs: 545 | loss += self.id_criterion(part, labels) 546 | return loss/len(inputs) 547 | 548 | def sample(self, x_a, x_b): 549 | self.eval() 550 | x_a_recon, x_b_recon, x_ba1, x_ab1, x_aba, x_bab = [], [], [], [], [], [] 551 | for i in range(x_a.size(0)): 552 | s_a = self.gen_a.encode( self.single(x_a[i].unsqueeze(0)) ) 553 | s_b = self.gen_b.encode( self.single(x_b[i].unsqueeze(0)) ) 554 | f_a, _ = self.id_a( scale2(x_a[i].unsqueeze(0))) 555 | f_b, _ = self.id_b( scale2(x_b[i].unsqueeze(0))) 556 | x_a_recon.append(self.gen_a.decode(s_a, f_a)) 557 | x_b_recon.append(self.gen_b.decode(s_b, f_b)) 558 | x_ba = self.gen_a.decode(s_b, f_a) 559 | x_ab = self.gen_b.decode(s_a, f_b) 560 | x_ba1.append(x_ba) 561 | x_ab1.append(x_ab) 562 | #cycle 563 | s_b_recon = self.gen_a.enc_content(self.single(x_ba)) 564 | s_a_recon = self.gen_b.enc_content(self.single(x_ab)) 565 | f_a_recon, _ = self.id_a(scale2(x_ba)) 566 | f_b_recon, _ = self.id_b(scale2(x_ab)) 567 | x_aba.append(self.gen_a.decode(s_a_recon, f_a_recon)) 568 | x_bab.append(self.gen_b.decode(s_b_recon, f_b_recon)) 569 | 570 | x_a_recon, x_b_recon = torch.cat(x_a_recon), torch.cat(x_b_recon) 571 | x_aba, x_bab = torch.cat(x_aba), torch.cat(x_bab) 572 | x_ba1, x_ab1 = torch.cat(x_ba1), torch.cat(x_ab1) 573 | self.train() 574 | 575 | return x_a, x_a_recon, x_aba, x_ab1, x_b, x_b_recon, x_bab, x_ba1 576 | 577 | def dis_update(self, x_ab, x_ba, x_a, x_b, hyperparameters, num_gpu): 578 | self.dis_opt.zero_grad() 579 | # D loss 580 | if num_gpu>1: 581 | self.loss_dis_a, reg_a = self.dis_a.module.calc_dis_loss(self.dis_a, x_ba.detach(), x_a) 582 | self.loss_dis_b, reg_b = self.dis_b.module.calc_dis_loss(self.dis_b, x_ab.detach(), x_b) 583 | else: 584 | self.loss_dis_a, reg_a = self.dis_a.calc_dis_loss(self.dis_a, x_ba.detach(), x_a) 585 | self.loss_dis_b, reg_b = self.dis_b.calc_dis_loss(self.dis_b, x_ab.detach(), x_b) 586 | self.loss_dis_total = hyperparameters['gan_w'] * self.loss_dis_a + hyperparameters['gan_w'] * self.loss_dis_b 587 | print("DLoss: %.4f"%self.loss_dis_total, "Reg: %.4f"%(reg_a+reg_b) ) 588 | if self.fp16: 589 | with amp.scale_loss(self.loss_dis_total, self.dis_opt) as scaled_loss: 590 | scaled_loss.backward() 591 | else: 592 | self.loss_dis_total.backward() 593 | self.dis_opt.step() 594 | 595 | def update_learning_rate(self): 596 | if self.dis_scheduler is not None: 597 | self.dis_scheduler.step() 598 | if self.gen_scheduler is not None: 599 | self.gen_scheduler.step() 600 | if self.id_scheduler is not None: 601 | self.id_scheduler.step() 602 | 603 | def resume(self, checkpoint_dir, hyperparameters): 604 | # Load generators 605 | last_model_name = get_model_list(checkpoint_dir, "gen") 606 | state_dict = torch.load(last_model_name) 607 | self.gen_a.load_state_dict(state_dict['a']) 608 | self.gen_b = self.gen_a 609 | iterations = int(last_model_name[-11:-3]) 610 | # Load discriminators 611 | last_model_name = get_model_list(checkpoint_dir, "dis") 612 | state_dict = torch.load(last_model_name) 613 | self.dis_a.load_state_dict(state_dict['a']) 614 | self.dis_b = self.dis_a 615 | # Load ID dis 616 | last_model_name = get_model_list(checkpoint_dir, "id") 617 | state_dict = torch.load(last_model_name) 618 | self.id_a.load_state_dict(state_dict['a']) 619 | self.id_b = self.id_a 620 | # Load optimizers 621 | try: 622 | state_dict = torch.load(os.path.join(checkpoint_dir, 'optimizer.pt')) 623 | self.dis_opt.load_state_dict(state_dict['dis']) 624 | self.gen_opt.load_state_dict(state_dict['gen']) 625 | self.id_opt.load_state_dict(state_dict['id']) 626 | except: 627 | pass 628 | # Reinitilize schedulers 629 | self.dis_scheduler = get_scheduler(self.dis_opt, hyperparameters, iterations) 630 | self.gen_scheduler = get_scheduler(self.gen_opt, hyperparameters, iterations) 631 | print('Resume from iteration %d' % iterations) 632 | return iterations 633 | 634 | def save(self, snapshot_dir, iterations, num_gpu=1): 635 | # Save generators, discriminators, and optimizers 636 | gen_name = os.path.join(snapshot_dir, 'gen_%08d.pt' % (iterations + 1)) 637 | dis_name = os.path.join(snapshot_dir, 'dis_%08d.pt' % (iterations + 1)) 638 | id_name = os.path.join(snapshot_dir, 'id_%08d.pt' % (iterations + 1)) 639 | opt_name = os.path.join(snapshot_dir, 'optimizer.pt') 640 | torch.save({'a': self.gen_a.state_dict()}, gen_name) 641 | if num_gpu>1: 642 | torch.save({'a': self.dis_a.module.state_dict()}, dis_name) 643 | else: 644 | torch.save({'a': self.dis_a.state_dict()}, dis_name) 645 | torch.save({'a': self.id_a.state_dict()}, id_name) 646 | torch.save({'gen': self.gen_opt.state_dict(), 'id': self.id_opt.state_dict(), 'dis': self.dis_opt.state_dict()}, opt_name) 647 | 648 | 649 | 650 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | 6 | from torch.utils.data import DataLoader 7 | from torch.autograd import Variable 8 | from torch.optim import lr_scheduler 9 | from torchvision import transforms 10 | from data import ImageFilelist 11 | from reIDfolder import ReIDFolder 12 | import torch 13 | import os 14 | import math 15 | import torchvision.utils as vutils 16 | import yaml 17 | import numpy as np 18 | import torch.nn.init as init 19 | import time 20 | # Methods 21 | # get_all_data_loaders : primary data loader interface (load trainA, testA, trainB, testB) 22 | # get_data_loader_list : list-based data loader 23 | # get_data_loader_folder : folder-based data loader 24 | # get_config : load yaml file 25 | # eformat : 26 | # write_2images : save output image 27 | # prepare_sub_folder : create checkpoints and images folders for saving outputs 28 | # write_one_row_html : write one row of the html file for output images 29 | # write_html : create the html file. 30 | # write_loss 31 | # slerp 32 | # get_slerp_interp 33 | # get_model_list 34 | # load_vgg16 35 | # vgg_preprocess 36 | # get_scheduler 37 | # weights_init 38 | 39 | def get_all_data_loaders(conf): 40 | batch_size = conf['batch_size'] 41 | num_workers = conf['num_workers'] 42 | if 'new_size' in conf: 43 | new_size_a= conf['new_size'] 44 | new_size_b = conf['new_size'] 45 | else: 46 | new_size_a = conf['new_size_a'] 47 | new_size_b = conf['new_size_b'] 48 | height = conf['crop_image_height'] 49 | width = conf['crop_image_width'] 50 | 51 | if 'data_root' in conf: 52 | train_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'train_all'), batch_size, True, 53 | new_size_a, height, width, num_workers, True) 54 | test_loader_a = get_data_loader_folder(os.path.join(conf['data_root'], 'query'), batch_size, False, 55 | new_size_a, height, width, num_workers, False) 56 | train_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'train_all'), batch_size, True, 57 | new_size_b, height, width, num_workers, True) 58 | test_loader_b = get_data_loader_folder(os.path.join(conf['data_root'], 'query'), batch_size, False, 59 | new_size_b, height, width, num_workers, False) 60 | else: 61 | train_loader_a = get_data_loader_list(conf['data_folder_train_a'], conf['data_list_train_a'], batch_size, True, 62 | new_size_a, height, width, num_workers, True) 63 | test_loader_a = get_data_loader_list(conf['data_folder_test_a'], conf['data_list_test_a'], batch_size, False, 64 | new_size_a, height, width, num_workers, False) 65 | train_loader_b = get_data_loader_list(conf['data_folder_train_b'], conf['data_list_train_b'], batch_size, True, 66 | new_size_b, height, width, num_workers, True) 67 | test_loader_b = get_data_loader_list(conf['data_folder_test_b'], conf['data_list_test_b'], batch_size, False, 68 | new_size_b, height, width, num_workers, False) 69 | return train_loader_a, train_loader_b, test_loader_a, test_loader_b 70 | 71 | 72 | def get_data_loader_list(root, file_list, batch_size, train, new_size=None, 73 | height=256, width=128, num_workers=4, crop=True): 74 | transform_list = [transforms.ToTensor(), 75 | transforms.Normalize((0.485, 0.456, 0.406), 76 | (0.229, 0.224, 0.225))] 77 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list 78 | transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list 79 | transform_list = [transforms.Resize((height, width), interpolation=3)] + transform_list if new_size is not None else transform_list 80 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list 81 | transform = transforms.Compose(transform_list) 82 | dataset = ImageFilelist(root, file_list, transform=transform) 83 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) 84 | return loader 85 | 86 | def get_data_loader_folder(input_folder, batch_size, train, new_size=None, 87 | height=256, width=128, num_workers=4, crop=True): 88 | transform_list = [transforms.ToTensor(), 89 | transforms.Normalize((0.485, 0.456, 0.406), 90 | (0.229, 0.224, 0.225))] 91 | transform_list = [transforms.RandomCrop((height, width))] + transform_list if crop else transform_list 92 | transform_list = [transforms.Pad(10, padding_mode='edge')] + transform_list if train else transform_list 93 | transform_list = [transforms.Resize((height,width), interpolation=3)] + transform_list if new_size is not None else transform_list 94 | transform_list = [transforms.RandomHorizontalFlip()] + transform_list if train else transform_list 95 | transform = transforms.Compose(transform_list) 96 | dataset = ReIDFolder(input_folder, transform=transform) 97 | loader = DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True, drop_last=True, num_workers=num_workers) 98 | return loader 99 | 100 | 101 | def get_config(config): 102 | with open(config, 'r') as stream: 103 | return yaml.safe_load(stream) 104 | 105 | 106 | def eformat(f, prec): 107 | s = "%.*e"%(prec, f) 108 | mantissa, exp = s.split('e') 109 | # add 1 to digits as 1 is taken by sign +/- 110 | return "%se%d"%(mantissa, int(exp)) 111 | 112 | 113 | def __write_images(image_outputs, display_image_num, file_name): 114 | image_outputs = [images.expand(-1, 3, -1, -1) for images in image_outputs] # expand gray-scale images to 3 channels 115 | image_tensor = torch.cat([images[:display_image_num] for images in image_outputs], 0) 116 | image_grid = vutils.make_grid(image_tensor.data, nrow=display_image_num, padding=0, normalize=True, scale_each=True) 117 | vutils.save_image(image_grid, file_name, nrow=1) 118 | 119 | 120 | def write_2images(image_outputs, display_image_num, image_directory, postfix): 121 | n = len(image_outputs) 122 | __write_images(image_outputs[0:n//2], display_image_num, '%s/gen_a2b_%s.jpg' % (image_directory, postfix)) 123 | __write_images(image_outputs[n//2:n], display_image_num, '%s/gen_b2a_%s.jpg' % (image_directory, postfix)) 124 | 125 | 126 | def prepare_sub_folder(output_directory): 127 | image_directory = os.path.join(output_directory, 'images') 128 | if not os.path.exists(image_directory): 129 | print("Creating directory: {}".format(image_directory)) 130 | os.makedirs(image_directory) 131 | checkpoint_directory = os.path.join(output_directory, 'checkpoints') 132 | if not os.path.exists(checkpoint_directory): 133 | print("Creating directory: {}".format(checkpoint_directory)) 134 | os.makedirs(checkpoint_directory) 135 | return checkpoint_directory, image_directory 136 | 137 | 138 | def write_one_row_html(html_file, iterations, img_filename, all_size): 139 | html_file.write("

iteration [%d] (%s)

" % (iterations,img_filename.split('/')[-1])) 140 | html_file.write(""" 141 |

142 | 143 |
144 |

145 | """ % (img_filename, img_filename, all_size)) 146 | return 147 | 148 | 149 | def write_html(filename, iterations, image_save_iterations, image_directory, all_size=1536): 150 | html_file = open(filename, "w") 151 | html_file.write(''' 152 | 153 | 154 | 155 | Experiment name = %s 156 | 157 | 158 | 159 | ''' % os.path.basename(filename)) 160 | html_file.write("

current

") 161 | write_one_row_html(html_file, iterations, '%s/gen_a2b_train_current.jpg' % (image_directory), all_size) 162 | write_one_row_html(html_file, iterations, '%s/gen_b2a_train_current.jpg' % (image_directory), all_size) 163 | for j in range(iterations, image_save_iterations-1, -1): 164 | if j % image_save_iterations == 0: 165 | write_one_row_html(html_file, j, '%s/gen_a2b_test_%08d.jpg' % (image_directory, j), all_size) 166 | write_one_row_html(html_file, j, '%s/gen_b2a_test_%08d.jpg' % (image_directory, j), all_size) 167 | write_one_row_html(html_file, j, '%s/gen_a2b_train_%08d.jpg' % (image_directory, j), all_size) 168 | write_one_row_html(html_file, j, '%s/gen_b2a_train_%08d.jpg' % (image_directory, j), all_size) 169 | html_file.write("") 170 | html_file.close() 171 | 172 | 173 | def write_loss(iterations, trainer, train_writer): 174 | members = [attr for attr in dir(trainer) \ 175 | if not callable(getattr(trainer, attr)) and not attr.startswith("__") and ('loss' in attr or 'grad' in attr or 'nwd' in attr)] 176 | for m in members: 177 | train_writer.add_scalar(m, getattr(trainer, m), iterations + 1) 178 | 179 | 180 | def slerp(val, low, high): 181 | """ 182 | original: Animating Rotation with Quaternion Curves, Ken Shoemake 183 | https://arxiv.org/abs/1609.04468 184 | Code: https://github.com/soumith/dcgan.torch/issues/14, Tom White 185 | """ 186 | omega = np.arccos(np.dot(low / np.linalg.norm(low), high / np.linalg.norm(high))) 187 | so = np.sin(omega) 188 | return np.sin((1.0 - val) * omega) / so * low + np.sin(val * omega) / so * high 189 | 190 | 191 | def get_slerp_interp(nb_latents, nb_interp, z_dim): 192 | """ 193 | modified from: PyTorch inference for "Progressive Growing of GANs" with CelebA snapshot 194 | https://github.com/ptrblck/prog_gans_pytorch_inference 195 | """ 196 | 197 | latent_interps = np.empty(shape=(0, z_dim), dtype=np.float32) 198 | for _ in range(nb_latents): 199 | low = np.random.randn(z_dim) 200 | high = np.random.randn(z_dim) # low + np.random.randn(512) * 0.7 201 | interp_vals = np.linspace(0, 1, num=nb_interp) 202 | latent_interp = np.array([slerp(v, low, high) for v in interp_vals], 203 | dtype=np.float32) 204 | latent_interps = np.vstack((latent_interps, latent_interp)) 205 | 206 | return latent_interps[:, :, np.newaxis, np.newaxis] 207 | 208 | 209 | # Get model list for resume 210 | def get_model_list(dirname, key): 211 | if os.path.exists(dirname) is False: 212 | return None 213 | gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if 214 | os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f] 215 | if gen_models is None: 216 | return None 217 | gen_models.sort() 218 | last_model_name = gen_models[-1] 219 | return last_model_name 220 | 221 | 222 | def load_vgg16(model_dir): 223 | """ Use the model from https://github.com/abhiskk/fast-neural-style/blob/master/neural_style/utils.py """ 224 | if not os.path.exists(model_dir): 225 | os.mkdir(model_dir) 226 | if not os.path.exists(os.path.join(model_dir, 'vgg16.weight')): 227 | if not os.path.exists(os.path.join(model_dir, 'vgg16.t7')): 228 | os.system('wget https://www.dropbox.com/s/76l3rt4kyi3s8x7/vgg16.t7?dl=1 -O ' + os.path.join(model_dir, 'vgg16.t7')) 229 | vgglua = load_lua(os.path.join(model_dir, 'vgg16.t7')) 230 | vgg = Vgg16() 231 | for (src, dst) in zip(vgglua.parameters()[0], vgg.parameters()): 232 | dst.data[:] = src 233 | torch.save(vgg.state_dict(), os.path.join(model_dir, 'vgg16.weight')) 234 | vgg = Vgg16() 235 | vgg.load_state_dict(torch.load(os.path.join(model_dir, 'vgg16.weight'))) 236 | return vgg 237 | 238 | 239 | def vgg_preprocess(batch): 240 | tensortype = type(batch.data) 241 | (r, g, b) = torch.chunk(batch, 3, dim = 1) 242 | batch = torch.cat((b, g, r), dim = 1) # convert RGB to BGR 243 | batch = (batch + 1) * 255 * 0.5 # [-1, 1] -> [0, 255] 244 | mean = tensortype(batch.data.size()) 245 | mean[:, 0, :, :] = 103.939 246 | mean[:, 1, :, :] = 116.779 247 | mean[:, 2, :, :] = 123.680 248 | batch = batch.sub(Variable(mean)) # subtract mean 249 | return batch 250 | 251 | 252 | def get_scheduler(optimizer, hyperparameters, iterations=-1): 253 | if 'lr_policy' not in hyperparameters or hyperparameters['lr_policy'] == 'constant': 254 | scheduler = None # constant scheduler 255 | elif hyperparameters['lr_policy'] == 'step': 256 | scheduler = lr_scheduler.StepLR(optimizer, step_size=hyperparameters['step_size'], 257 | gamma=hyperparameters['gamma'], last_epoch=iterations) 258 | elif hyperparameters['lr_policy'] == 'multistep': 259 | #50000 -- 75000 -- 260 | step = hyperparameters['step_size'] 261 | scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[step, step+step//2, step+step//2+step//4], 262 | gamma=hyperparameters['gamma'], last_epoch=iterations) 263 | else: 264 | return NotImplementedError('learning rate policy [%s] is not implemented', hyperparameters['lr_policy']) 265 | return scheduler 266 | 267 | 268 | def weights_init(init_type='gaussian'): 269 | def init_fun(m): 270 | classname = m.__class__.__name__ 271 | if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'): 272 | # print m.__class__.__name__ 273 | if init_type == 'gaussian': 274 | init.normal_(m.weight.data, 0.0, 0.02) 275 | elif init_type == 'xavier': 276 | init.xavier_normal_(m.weight.data, gain=math.sqrt(2)) 277 | elif init_type == 'kaiming': 278 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 279 | elif init_type == 'orthogonal': 280 | init.orthogonal_(m.weight.data, gain=math.sqrt(2)) 281 | elif init_type == 'default': 282 | pass 283 | else: 284 | assert 0, "Unsupported initialization: {}".format(init_type) 285 | if hasattr(m, 'bias') and m.bias is not None: 286 | init.constant_(m.bias.data, 0.0) 287 | 288 | return init_fun 289 | 290 | 291 | class Timer: 292 | def __init__(self, msg): 293 | self.msg = msg 294 | self.start_time = None 295 | 296 | def __enter__(self): 297 | self.start_time = time.time() 298 | 299 | def __exit__(self, exc_type, exc_value, exc_tb): 300 | print(self.msg % (time.time() - self.start_time)) 301 | 302 | 303 | -------------------------------------------------------------------------------- /visual_data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/.DS_Store -------------------------------------------------------------------------------- /visual_data/demo/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/demo/.DS_Store -------------------------------------------------------------------------------- /visual_data/demo/train/0010_c6s4_002427_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/demo/train/0010_c6s4_002427_02.jpg -------------------------------------------------------------------------------- /visual_data/demo/train/0042_c3s3_064169_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/demo/train/0042_c3s3_064169_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/0000_c3s1_107067_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/0000_c3s1_107067_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/0200_c3s3_056178_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/0200_c3s3_056178_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/0223_c1s6_014296_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/0223_c1s6_014296_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/0387_c3s1_091167_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/0387_c3s1_091167_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/0400_c2s1_046426_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/0400_c2s1_046426_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/0531_c2s1_149316_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/0531_c2s1_149316_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/1196_c3s1_038576_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/1196_c3s1_038576_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/b0280_c3s3_065619_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/b0280_c3s3_065619_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/b0387_c2s1_090996_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/b0387_c2s1_090996_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/b0878_c2s2_106457_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/b0878_c2s2_106457_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/b1324_c1s6_007741_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/b1324_c1s6_007741_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/back0100_c3s2_115894_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/back0100_c3s2_115894_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/l0155_c3s1_025176_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/l0155_c3s1_025176_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/l0345_c1s2_006466_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/l0345_c1s2_006466_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/0/l0736_c6s1_022526_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/0/l0736_c6s1_022526_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0051_c4s1_005526_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0051_c4s1_005526_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0055_c3s3_076469_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0055_c3s3_076469_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0089_c6s1_013501_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0089_c6s1_013501_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0280_c5s1_078148_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0280_c5s1_078148_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0715_c3s2_059153_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0715_c3s2_059153_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0825_c6s2_097943_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0825_c6s2_097943_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/0829_c4s4_043560_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/0829_c4s4_043560_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/1147_c2s2_158702_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/1147_c2s2_158702_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/1226_c6s3_038167_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/1226_c6s3_038167_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/1251_c3s3_020428_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/1251_c3s3_020428_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/1255_c3s3_021678_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/1255_c3s3_021678_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/1459_c2s3_052107_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/1459_c2s3_052107_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/bb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/bb.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test/1/bbb.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test/1/bbb.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/00000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/00000.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/0001_c7_f0106449.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/0001_c7_f0106449.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/0100_c1_f0100590.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/0100_c1_f0100590.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/0120_c7_f0112912.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/0120_c7_f0112912.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/0180_c6_f0074316.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/0180_c6_f0074316.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/0304_c4_f0069917.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/0304_c4_f0069917.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/5842_c7_f0053326.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/5842_c7_f0053326.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/0163_c5_f0084927.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/0163_c5_f0084927.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/0279_c2_f0098340.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/0279_c2_f0098340.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/0377_c1_f0110464.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/0377_c1_f0110464.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/0459_c1_f0124423.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/0459_c1_f0124423.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/0654_c1_f0165081.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/0654_c1_f0165081.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/0749_c1_f0188863.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/0749_c1_f0188863.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/1/data/4141_c8_f0022224.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/1/data/4141_c8_f0022224.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/0201_c2_f0089204.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/0201_c2_f0089204.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/0243_c1_f0095564.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/0243_c1_f0095564.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/0371_c1_f0110190.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/0371_c1_f0110190.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/0690_c4_f0142152.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/0690_c4_f0142152.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/0712_c6_f0165928.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/0712_c6_f0165928.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/1108_c6_f0176009.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/1108_c6_f0176009.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_duke/2/data/1486_c4_f0069129.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_duke/2/data/1486_c4_f0069129.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/0051_c3s1_005551_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/0051_c3s1_005551_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/0227_c2s1_046426_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/0227_c2s1_046426_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/0304_c5s1_068523_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/0304_c5s1_068523_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/0378_c2s1_089071_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/0378_c2s1_089071_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/0447_c3s1_111408_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/0447_c3s1_111408_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/1285_c1s5_053991_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/1285_c1s5_053991_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/1/data/l0387_c6s2_006743_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/1/data/l0387_c6s2_006743_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/0146_c2s1_023126_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/0146_c2s1_023126_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/0291_c3s3_079969_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/0291_c3s3_079969_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/1154_c5s3_001368_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/1154_c5s3_001368_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/1285_c2s3_021432_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/1285_c2s3_021432_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/1394_c6s3_071892_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/1394_c6s3_071892_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/1429_c2s3_048607_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/1429_c2s3_048607_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_market/2/data/l1255_c3s3_021678_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_market/2/data/l1255_c3s3_021678_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/0367_010_10_0303noon_0606_2.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/0367_010_10_0303noon_0606_2.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/0434_020_07_0303noon_1785_5_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/0434_020_07_0303noon_1785_5_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/0494_029_01_0303afternoon_0435_2_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/0494_029_01_0303afternoon_0435_2_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/0514_012_05_0303afternoon_0585_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/0514_012_05_0303afternoon_0585_1.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/0720_027_01_0303afternoon_0771_4.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/0720_027_01_0303afternoon_0771_4.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/2744_001_12_0114noon_0960_1_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/2744_001_12_0114noon_0960_1_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/1/data/2944_020_01_0114afternoon_1429_2_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/1/data/2944_020_01_0114afternoon_1429_2_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/0168_018_14_0303morning_1118_1_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/0168_018_14_0303morning_1118_1_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/0318_001_01_0303noon_0111_4_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/0318_001_01_0303noon_0111_4_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/0771_019_01_0303afternoon_1000_9.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/0771_019_01_0303afternoon_1000_9.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/0917_021_15_0113morning_0174_0.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/0917_021_15_0113morning_0174_0.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/2744_037_06_0114noon_0915_0_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/2744_037_06_0114noon_0915_0_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/2790_015_05_0114afternoon_0244_1.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/2790_015_05_0114afternoon_0244_1.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/2869_011_05_0114afternoon_0836_0_ex.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/2869_011_05_0114afternoon_0836_0_ex.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_test_msmt/2/data/3032_043_07_0114afternoon_1088_3.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_test_msmt/2/data/3032_043_07_0114afternoon_1088_3.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0002_c2s1_000351_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0002_c2s1_000351_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0028_c2s1_001751_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0028_c2s1_001751_02.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0068_c6s1_012351_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0068_c6s1_012351_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0070_c4s1_010576_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0070_c4s1_010576_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0127_c6s1_021426_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0127_c6s1_021426_06.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0142_c3s3_020203_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0142_c3s3_020203_02.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0197_c6s1_045076_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0197_c6s1_045076_05.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0259_c6s1_055126_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0259_c6s1_055126_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0301_c6s1_076101_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0301_c6s1_076101_04.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0348_c2s1_080671_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0348_c2s1_080671_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0696_c6s2_054993_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0696_c6s2_054993_02.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0711_c2s2_056057_03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0711_c2s2_056057_03.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0754_c2s2_068182_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0754_c2s2_068182_01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0820_c3s2_104428_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0820_c3s2_104428_05.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0843_c2s2_104907_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0843_c2s2_104907_04.jpg -------------------------------------------------------------------------------- /visual_data/inputs_many_train/1/0901_c6s2_119343_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_many_train/1/0901_c6s2_119343_02.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/1/00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/1/00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/1/000.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/1/000.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/1/0006_c2s3_069327_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/1/0006_c2s3_069327_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/1/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/1/01.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/1/0289_c3s1_063567_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/1/0289_c3s1_063567_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/1/back0100_c3s2_115894_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/1/back0100_c3s2_115894_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/2/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/2/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/2/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/2/data/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/2/data/0051_c3s1_005551_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/2/data/0051_c3s1_005551_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/2/data/0227_c2s1_046426_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/2/data/0227_c2s1_046426_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/3/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/3/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/3/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/3/data/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/3/data/0000_c3s1_107067_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/3/data/0000_c3s1_107067_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/3/data/0531_c2s1_149316_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/3/data/0531_c2s1_149316_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/4/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/4/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/4/data/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/4/data/.DS_Store -------------------------------------------------------------------------------- /visual_data/inputs_two/4/data/b1324_c1s6_007741_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/4/data/b1324_c1s6_007741_00.jpg -------------------------------------------------------------------------------- /visual_data/inputs_two/4/data/back0100_c3s2_115894_00.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/inputs_two/4/data/back0100_c3s2_115894_00.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/045_388_gan0734_c3s2_064678_06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/045_388_gan0734_c3s2_064678_06.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/104_188_gan0371_c5s1_088523_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/104_188_gan0371_c5s1_088523_02.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/115_000_gan0002_c1s1_069056_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/115_000_gan0002_c1s1_069056_02.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/258_750_gan1500_c5s3_063312_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/258_750_gan1500_c5s3_063312_02.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/344_283_gan0549_c1s3_009921_05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/344_283_gan0549_c1s3_009921_05.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/400_344_gan0656_c6s2_045393_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/400_344_gan0656_c6s2_045393_01.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/406_090_gan0179_c3s3_078044_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/406_090_gan0179_c3s3_078044_01.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/496_722_gan1437_c1s6_007216_02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/496_722_gan1437_c1s6_007216_02.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/620_357_gan0673_c5s2_045155_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/620_357_gan0673_c5s2_045155_01.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/643_269_gan0519_c6s2_087418_01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/643_269_gan0519_c6s2_087418_01.jpg -------------------------------------------------------------------------------- /visual_data/train_sample/643_375_gan0706_c6s2_062118_04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/NVlabs/DG-Net/9855f08711df1d7ebdf976885b0fddec8e7d4a37/visual_data/train_sample/643_375_gan0706_c6s2_062118_04.jpg -------------------------------------------------------------------------------- /visual_tools/README.md: -------------------------------------------------------------------------------- 1 | ## Image Generation 2 | - `test_folder.py` generates samples and calculates the FID score. (You need to download the [TTUR](https://github.com/layumi/TTUR) for the FID evaluation.) 3 | ```bash 4 | python visual_tools/test_folder.py --name E0.5new_reid0.5_w30000 --which_epoch 100000 5 | ``` 6 | 7 | - `show_rainbow.py` generates Figure 1 of the paper. 8 | ```bash 9 | python visual_tools/show_rainbow.py 10 | ``` 11 | 12 | - `show_swap.py` swaps the codes on Market-1501. (Figure 6) 13 | 14 | - `show_smooth.py` generates Figure 5 of the paper. 15 | 16 | - `show_smooth_structure.py` generates Figure 9 of the paper. 17 | 18 | - `show1by1.py` swaps the codes of all images in one folder (many to many) and save the generated image one by one. 19 | -------------------------------------------------------------------------------- /visual_tools/show1by1.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | import sys 7 | sys.path.append('.') 8 | from utils import get_config 9 | from trainer import DGNet_Trainer, to_gray 10 | import argparse 11 | from torch.autograd import Variable 12 | import sys 13 | import torch 14 | import os 15 | import numpy as np 16 | from torchvision import datasets, transforms 17 | from PIL import Image 18 | 19 | name = 'E0.5new_reid0.5_w30000' 20 | 21 | if not os.path.isdir('./outputs/%s'%name): 22 | assert 0, "please change the name to your model name" 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--output_folder', type=str, default="./", help="output image path") 26 | parser.add_argument('--input_folder', type=str, default="./visual_data/inputs_many_train", help="input image path") 27 | 28 | parser.add_argument('--config', type=str, default='./outputs/%s/config.yaml'%name, help="net configuration") 29 | parser.add_argument('--checkpoint_gen', type=str, default="./outputs/%s/checkpoints/gen_00100000.pt"%name, help="checkpoint of autoencoders") 30 | parser.add_argument('--checkpoint_id', type=str, default="./outputs/%s/checkpoints/id_00100000.pt"%name, help="checkpoint of autoencoders") 31 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 32 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 33 | parser.add_argument('--seed', type=int, default=10, help="random seed") 34 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 35 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 36 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 37 | 38 | 39 | opts = parser.parse_args() 40 | 41 | torch.manual_seed(opts.seed) 42 | torch.cuda.manual_seed(opts.seed) 43 | if not os.path.exists(opts.output_folder): 44 | os.makedirs(opts.output_folder) 45 | 46 | # Load experiment setting 47 | config = get_config(opts.config) 48 | opts.num_style = 1 49 | 50 | # Setup model and data loader 51 | if opts.trainer == 'DGNet': 52 | trainer = DGNet_Trainer(config) 53 | else: 54 | sys.exit("Only support DGNet") 55 | 56 | state_dict_gen = torch.load(opts.checkpoint_gen) 57 | trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False) 58 | trainer.gen_b = trainer.gen_a 59 | 60 | state_dict_id = torch.load(opts.checkpoint_id) 61 | trainer.id_a.load_state_dict(state_dict_id['a']) 62 | trainer.id_b = trainer.id_a 63 | 64 | trainer.cuda() 65 | trainer.eval() 66 | encode = trainer.gen_a.encode # encode function 67 | style_encode = trainer.gen_a.encode # encode function 68 | id_encode = trainer.id_a # encode function 69 | decode = trainer.gen_a.decode # decode function 70 | 71 | data_transforms = transforms.Compose([ 72 | transforms.Resize((256,128), interpolation=3), 73 | transforms.ToTensor(), 74 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 75 | ]) 76 | 77 | image_datasets = datasets.ImageFolder(opts.input_folder, data_transforms) 78 | dataloader_content = torch.utils.data.DataLoader(image_datasets, batch_size=1, shuffle=False, num_workers=1) 79 | dataloader_structure = torch.utils.data.DataLoader(image_datasets, batch_size=16, shuffle=False, num_workers=1) 80 | image_paths = image_datasets.imgs 81 | 82 | ###################################################################### 83 | # recover image 84 | # ----------------- 85 | def recover(inp): 86 | """Imshow for Tensor.""" 87 | inp = inp.numpy().transpose((1, 2, 0)) 88 | mean = np.array([0.485, 0.456, 0.406]) 89 | std = np.array([0.229, 0.224, 0.225]) 90 | inp = std * inp + mean 91 | inp = inp * 255.0 92 | inp = np.clip(inp, 0, 255) 93 | return inp 94 | 95 | save_path = './visual_data/rainbow' 96 | if not os.path.isdir(save_path): 97 | os.mkdir(save_path) 98 | 99 | im = {} 100 | count = 0 101 | data2 = next(iter(dataloader_structure)) 102 | bg_img, _ = data2 103 | gray = to_gray(False) 104 | bg_img = gray(bg_img) 105 | bg_img = Variable(bg_img.cuda()) 106 | with torch.no_grad(): 107 | for data in dataloader_content: 108 | id_img, _ = data 109 | id_img = Variable(id_img.cuda()) 110 | n, c, h, w = id_img.size() 111 | # Start testing 112 | s = encode(bg_img) 113 | f, _ = id_encode(id_img) 114 | input1 = recover(data[0].squeeze()) 115 | im[count] = input1 116 | for i in range(s.size(0)): 117 | s_tmp = s[i,:,:,:] 118 | outputs = decode(s_tmp.unsqueeze(0), f) 119 | tmp = recover(outputs[0].data.cpu()) 120 | pic = Image.fromarray(tmp.astype('uint8')) 121 | pic.save('%s/rainbow_%d_%d.jpg'%(save_path,i,count)) 122 | count +=1 123 | 124 | 125 | -------------------------------------------------------------------------------- /visual_tools/show_rainbow.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | import sys 7 | sys.path.append('.') 8 | from utils import get_config 9 | from trainer import DGNet_Trainer, to_gray 10 | import argparse 11 | from torch.autograd import Variable 12 | import sys 13 | import torch 14 | import os 15 | import numpy as np 16 | from torchvision import datasets, transforms 17 | from PIL import Image 18 | 19 | name = 'E0.5new_reid0.5_w30000' 20 | 21 | if not os.path.isdir('./outputs/%s'%name): 22 | assert 0, "please change the name to your model name" 23 | 24 | parser = argparse.ArgumentParser() 25 | parser.add_argument('--output_folder', type=str, default="./", help="output image path") 26 | parser.add_argument('--input_folder', type=str, default="./visual_data/inputs_many_test", help="input image path") 27 | 28 | parser.add_argument('--config', type=str, default='./outputs/%s/config.yaml'%name, help="net configuration") 29 | parser.add_argument('--checkpoint_gen', type=str, default="./outputs/%s/checkpoints/gen_00100000.pt"%name, help="checkpoint of autoencoders") 30 | parser.add_argument('--checkpoint_id', type=str, default="./outputs/%s/checkpoints/id_00100000.pt"%name, help="checkpoint of autoencoders") 31 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 32 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 33 | parser.add_argument('--seed', type=int, default=10, help="random seed") 34 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 35 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 36 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 37 | 38 | 39 | opts = parser.parse_args() 40 | 41 | torch.manual_seed(opts.seed) 42 | torch.cuda.manual_seed(opts.seed) 43 | if not os.path.exists(opts.output_folder): 44 | os.makedirs(opts.output_folder) 45 | 46 | # Load experiment setting 47 | config = get_config(opts.config) 48 | opts.num_style = 1 49 | 50 | # Setup model and data loader 51 | if opts.trainer == 'DGNet': 52 | trainer = DGNet_Trainer(config) 53 | else: 54 | sys.exit("Only support DGNet") 55 | 56 | state_dict_gen = torch.load(opts.checkpoint_gen) 57 | trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False) 58 | trainer.gen_b = trainer.gen_a 59 | 60 | state_dict_id = torch.load(opts.checkpoint_id) 61 | trainer.id_a.load_state_dict(state_dict_id['a']) 62 | trainer.id_b = trainer.id_a 63 | 64 | trainer.cuda() 65 | trainer.eval() 66 | encode = trainer.gen_a.encode # encode function 67 | style_encode = trainer.gen_a.encode # encode function 68 | id_encode = trainer.id_a # encode function 69 | decode = trainer.gen_a.decode # decode function 70 | 71 | data_transforms = transforms.Compose([ 72 | transforms.Resize((256,128), interpolation=3), 73 | transforms.ToTensor(), 74 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 75 | ]) 76 | 77 | 78 | num = 7 79 | image_datasets = datasets.ImageFolder(opts.input_folder, data_transforms) 80 | dataloader_content = torch.utils.data.DataLoader(image_datasets, batch_size=1, shuffle=False, num_workers=1) 81 | dataloader_structure = torch.utils.data.DataLoader(image_datasets, batch_size=num, shuffle=False, num_workers=1) 82 | image_paths = image_datasets.imgs 83 | 84 | ###################################################################### 85 | # recover image 86 | # ----------------- 87 | def recover(inp): 88 | """Imshow for Tensor.""" 89 | inp = inp.numpy().transpose((1, 2, 0)) 90 | mean = np.array([0.485, 0.456, 0.406]) 91 | std = np.array([0.229, 0.224, 0.225]) 92 | inp = std * inp + mean 93 | inp = inp * 255.0 94 | inp = np.clip(inp, 0, 255) 95 | return inp 96 | 97 | def pad(inp, pad = 3): 98 | h = inp.shape[0] 99 | w = inp.shape[1] 100 | bg = np.zeros((h+2*pad, w+2*pad, inp.shape[2])) 101 | bg[pad:pad+h, pad:pad+w, :] = inp 102 | return bg 103 | 104 | im = {} 105 | npad = 3 106 | count = 0 107 | data2 = next(iter(dataloader_structure)) 108 | bg_img, _ = data2 109 | 110 | gray = to_gray(False) 111 | bg_img = gray(bg_img) 112 | bg_img = Variable(bg_img.cuda()) 113 | white_col = np.ones( (256+2*npad,24,3))*255 114 | with torch.no_grad(): 115 | for data in dataloader_content: 116 | id_img, _ = data 117 | id_img = Variable(id_img.cuda()) 118 | n, c, h, w = id_img.size() 119 | # Start testing 120 | s = encode(bg_img) 121 | f, _ = id_encode(id_img) 122 | input1 = recover(data[0].squeeze()) 123 | im[count] = pad(input1, pad= npad) 124 | for i in range( s.size(0)): 125 | s_tmp = s[i,:,:,:] 126 | outputs = decode(s_tmp.unsqueeze(0), f) 127 | tmp = recover(outputs[0].data.cpu()) 128 | tmp = pad(tmp, pad=npad) 129 | im[count] = np.concatenate((im[count], white_col, tmp), axis=1) 130 | count +=1 131 | 132 | first_row = np.ones((256+2*npad,128+2*npad,3))*255 133 | white_row = np.ones( (12,im[0].shape[1],3))*255 134 | for i in range(num): 135 | if i == 0: 136 | pic = im[0] 137 | else: 138 | pic = np.concatenate((pic, im[i]), axis=0) 139 | pic = np.concatenate((pic, white_row), axis=0) 140 | first_row = np.concatenate((first_row, white_col, im[i][0:256+2*npad, 0:128+2*npad, 0:3]), axis=1) 141 | 142 | pic = np.concatenate((first_row, white_row, pic), axis=0) 143 | pic = Image.fromarray(pic.astype('uint8')) 144 | pic.save('rainbow_%d.jpg'%num) 145 | 146 | -------------------------------------------------------------------------------- /visual_tools/show_smooth.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | import sys 7 | sys.path.append('.') 8 | from utils import get_config 9 | from trainer import DGNet_Trainer, to_gray 10 | import argparse 11 | from torch.autograd import Variable 12 | import sys 13 | import torch 14 | import os 15 | import numpy as np 16 | import imageio 17 | from torchvision import datasets, transforms 18 | from PIL import Image 19 | 20 | name = 'E0.5new_reid0.5_w30000' 21 | 22 | if not os.path.isdir('./outputs/%s'%name): 23 | assert 0, "please change the name to your model name" 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--output_folder', type=str, default="./", help="output image path") 27 | parser.add_argument('--input_folder', type=str, default="./visual_data/inputs_two", help="input image path") 28 | 29 | parser.add_argument('--config', type=str, default='./outputs/%s/config.yaml'%name, help="net configuration") 30 | parser.add_argument('--checkpoint_gen', type=str, default="./outputs/%s/checkpoints/gen_00100000.pt"%name, help="checkpoint of autoencoders") 31 | parser.add_argument('--checkpoint_id', type=str, default="./outputs/%s/checkpoints/id_00100000.pt"%name, help="checkpoint of autoencoders") 32 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 33 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 34 | parser.add_argument('--seed', type=int, default=10, help="random seed") 35 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 36 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 37 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 38 | 39 | 40 | opts = parser.parse_args() 41 | 42 | torch.manual_seed(opts.seed) 43 | torch.cuda.manual_seed(opts.seed) 44 | if not os.path.exists(opts.output_folder): 45 | os.makedirs(opts.output_folder) 46 | 47 | # Load experiment setting 48 | config = get_config(opts.config) 49 | opts.num_style = 1 50 | 51 | # Setup model and data loader 52 | if opts.trainer == 'DGNet': 53 | trainer = DGNet_Trainer(config) 54 | else: 55 | sys.exit("Only support DGNet") 56 | 57 | state_dict_gen = torch.load(opts.checkpoint_gen) 58 | trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False) 59 | trainer.gen_b = trainer.gen_a 60 | 61 | state_dict_id = torch.load(opts.checkpoint_id) 62 | trainer.id_a.load_state_dict(state_dict_id['a']) 63 | trainer.id_b = trainer.id_a 64 | 65 | trainer.cuda() 66 | trainer.eval() 67 | encode = trainer.gen_a.encode # encode function 68 | style_encode = trainer.gen_a.encode # encode function 69 | id_encode = trainer.id_a # encode function 70 | decode = trainer.gen_a.decode # decode function 71 | 72 | data_transforms = transforms.Compose([ 73 | transforms.Resize((256,128), interpolation=3), 74 | transforms.ToTensor(), 75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 76 | ]) 77 | 78 | image_datasets = datasets.ImageFolder(opts.input_folder, data_transforms) 79 | dataloader_content = torch.utils.data.DataLoader(image_datasets, batch_size=2, shuffle=False, num_workers=1) 80 | dataloader_structure = torch.utils.data.DataLoader(image_datasets, batch_size=2, shuffle=False, num_workers=1) 81 | image_paths = image_datasets.imgs 82 | 83 | ###################################################################### 84 | # recover image 85 | # ----------------- 86 | def recover(inp): 87 | """Imshow for Tensor.""" 88 | inp = inp.numpy().transpose((1, 2, 0)) 89 | mean = np.array([0.485, 0.456, 0.406]) 90 | std = np.array([0.229, 0.224, 0.225]) 91 | inp = std * inp + mean 92 | inp = inp * 255.0 93 | inp = np.clip(inp, 0, 255) 94 | return inp 95 | 96 | im = {} 97 | data2 = next(iter(dataloader_structure)) 98 | bg_img, _ = data2 99 | gray = to_gray(False) 100 | bg_img = gray(bg_img) 101 | bg_img = Variable(bg_img.cuda()) 102 | ff = [] 103 | gif = [] 104 | with torch.no_grad(): 105 | for data in dataloader_content: 106 | id_img, _ = data 107 | id_img = Variable(id_img.cuda()) 108 | n, c, h, w = id_img.size() 109 | # Start testing 110 | s = encode(bg_img) 111 | f, _ = id_encode(id_img) 112 | for count in range(2): 113 | input1 = recover(id_img[count].squeeze().data.cpu()) 114 | im[count] = input1 115 | gif.append(input1) 116 | for i in range(11): 117 | s_tmp = s[count,:,:,:] 118 | tmp_f = 0.1*i*f[0] + (1-0.1*i)*f[1] 119 | tmp_f = tmp_f.view(1, -1) 120 | outputs = decode(s_tmp.unsqueeze(0), tmp_f) 121 | tmp = recover(outputs[0].data.cpu()) 122 | im[count] = np.concatenate((im[count], tmp), axis=1) 123 | gif.append(tmp) 124 | break 125 | 126 | # save long image 127 | pic = np.concatenate( (im[0], im[1]) , axis=0) 128 | pic = Image.fromarray(pic.astype('uint8')) 129 | pic.save('smooth.jpg') 130 | 131 | # save gif 132 | imageio.mimsave('./smooth.gif', gif) 133 | 134 | -------------------------------------------------------------------------------- /visual_tools/show_smooth_structure.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | import sys 7 | sys.path.append('.') 8 | from utils import get_config 9 | from trainer import DGNet_Trainer, to_gray 10 | import argparse 11 | from torch.autograd import Variable 12 | import sys 13 | import torch 14 | import os 15 | import numpy as np 16 | import imageio 17 | from torchvision import datasets, transforms 18 | from PIL import Image 19 | 20 | name = 'E0.5new_reid0.5_w30000' 21 | 22 | if not os.path.isdir('./outputs/%s'%name): 23 | assert 0, "please change the name to your model name" 24 | 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument('--output_folder', type=str, default="./", help="output image path") 27 | parser.add_argument('--input_folder', type=str, default="./visual_data/inputs_two", help="input image path") 28 | 29 | parser.add_argument('--config', type=str, default='./outputs/%s/config.yaml'%name, help="net configuration") 30 | parser.add_argument('--checkpoint_gen', type=str, default="./outputs/%s/checkpoints/gen_00100000.pt"%name, help="checkpoint of autoencoders") 31 | parser.add_argument('--checkpoint_id', type=str, default="./outputs/%s/checkpoints/id_00100000.pt"%name, help="checkpoint of autoencoders") 32 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 33 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 34 | parser.add_argument('--seed', type=int, default=10, help="random seed") 35 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 36 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 37 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 38 | 39 | 40 | opts = parser.parse_args() 41 | 42 | torch.manual_seed(opts.seed) 43 | torch.cuda.manual_seed(opts.seed) 44 | if not os.path.exists(opts.output_folder): 45 | os.makedirs(opts.output_folder) 46 | 47 | # Load experiment setting 48 | config = get_config(opts.config) 49 | opts.num_style = 1 50 | 51 | # Setup model and data loader 52 | if opts.trainer == 'DGNet': 53 | trainer = DGNet_Trainer(config) 54 | else: 55 | sys.exit("Only support DGNet") 56 | 57 | state_dict_gen = torch.load(opts.checkpoint_gen) 58 | trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False) 59 | trainer.gen_b = trainer.gen_a 60 | 61 | state_dict_id = torch.load(opts.checkpoint_id) 62 | trainer.id_a.load_state_dict(state_dict_id['a']) 63 | trainer.id_b = trainer.id_a 64 | 65 | trainer.cuda() 66 | trainer.eval() 67 | encode = trainer.gen_a.encode # encode function 68 | style_encode = trainer.gen_a.encode # encode function 69 | id_encode = trainer.id_a # encode function 70 | decode = trainer.gen_a.decode # decode function 71 | 72 | data_transforms = transforms.Compose([ 73 | transforms.Resize((256,128), interpolation=3), 74 | transforms.ToTensor(), 75 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 76 | ]) 77 | 78 | image_datasets = datasets.ImageFolder(opts.input_folder, data_transforms) 79 | dataloader_content = torch.utils.data.DataLoader(image_datasets, batch_size=2, shuffle=False, num_workers=1) 80 | dataloader_structure = torch.utils.data.DataLoader(image_datasets, batch_size=2, shuffle=False, num_workers=1) 81 | image_paths = image_datasets.imgs 82 | 83 | ###################################################################### 84 | # recover image 85 | # ----------------- 86 | def recover(inp): 87 | """Imshow for Tensor.""" 88 | inp = inp.numpy().transpose((1, 2, 0)) 89 | mean = np.array([0.485, 0.456, 0.406]) 90 | std = np.array([0.229, 0.224, 0.225]) 91 | inp = std * inp + mean 92 | inp = inp * 255.0 93 | inp = np.clip(inp, 0, 255) 94 | return inp 95 | 96 | im = {} 97 | data2 = next(iter(dataloader_structure)) 98 | bg_img, _ = data2 99 | gray = to_gray(False) 100 | bg_img = gray(bg_img) 101 | bg_img = Variable(bg_img.cuda()) 102 | ff = [] 103 | gif = [] 104 | with torch.no_grad(): 105 | for data in dataloader_content: 106 | id_img, _ = data 107 | id_img = Variable(id_img.cuda()) 108 | n, c, h, w = id_img.size() 109 | # Start testing 110 | s = encode(bg_img) 111 | f, _ = id_encode(id_img) 112 | for count in range(2): 113 | input1 = recover(id_img[count].squeeze().data.cpu()) 114 | im[count] = input1 115 | gif.append(input1) 116 | for i in range(11): 117 | f_tmp = f[count,:] 118 | f_tmp = f_tmp.view(1,-1) 119 | tmp_s = 0.1*i*s[0,:,:,:] + (1-0.1*i)*s[1,:,:,:] 120 | tmp_s = tmp_s.unsqueeze(0) 121 | outputs = decode(tmp_s, f_tmp) 122 | tmp = recover(outputs[0].data.cpu()) 123 | im[count] = np.concatenate((im[count], tmp), axis=1) 124 | gif.append(tmp) 125 | break 126 | 127 | # save long image 128 | pic = np.concatenate( (im[0], im[1]) , axis=0) 129 | pic = Image.fromarray(pic.astype('uint8')) 130 | pic.save('smooth-s.jpg') 131 | 132 | # save gif 133 | imageio.mimsave('./smooth-s.gif', gif) 134 | -------------------------------------------------------------------------------- /visual_tools/show_swap.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | import sys 7 | sys.path.append('.') 8 | from utils import get_config 9 | from trainer import DGNet_Trainer, to_gray 10 | import argparse 11 | from torch.autograd import Variable 12 | import sys 13 | import torch 14 | import os 15 | import numpy as np 16 | from torchvision import datasets, transforms 17 | from PIL import Image 18 | try: 19 | from itertools import izip as zip 20 | except ImportError: # will be 3.x series 21 | pass 22 | 23 | name = 'E0.5new_reid0.5_w30000' 24 | 25 | if not os.path.isdir('./outputs/%s'%name): 26 | assert 0, "please change the name to your model name" 27 | 28 | parser = argparse.ArgumentParser() 29 | parser.add_argument('--output_folder', type=str, default="./", help="output image path") 30 | parser.add_argument('--input_folder', type=str, default="./visual_data/inputs_many_test_market", help="input image path") 31 | 32 | parser.add_argument('--config', type=str, default='./outputs/%s/config.yaml'%name, help="net configuration") 33 | parser.add_argument('--checkpoint_gen', type=str, default="./outputs/%s/checkpoints/gen_00100000.pt"%name, help="checkpoint of autoencoders") 34 | parser.add_argument('--checkpoint_id', type=str, default="./outputs/%s/checkpoints/id_00100000.pt"%name, help="checkpoint of autoencoders") 35 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 36 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 37 | parser.add_argument('--seed', type=int, default=10, help="random seed") 38 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 39 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 40 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 41 | 42 | 43 | opts = parser.parse_args() 44 | 45 | torch.manual_seed(opts.seed) 46 | torch.cuda.manual_seed(opts.seed) 47 | if not os.path.exists(opts.output_folder): 48 | os.makedirs(opts.output_folder) 49 | 50 | # Load experiment setting 51 | config = get_config(opts.config) 52 | opts.num_style = 1 53 | 54 | # Setup model and data loader 55 | if opts.trainer == 'DGNet': 56 | trainer = DGNet_Trainer(config) 57 | else: 58 | sys.exit("Only support DGNet") 59 | 60 | state_dict_gen = torch.load(opts.checkpoint_gen) 61 | trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False) 62 | trainer.gen_b = trainer.gen_a 63 | 64 | state_dict_id = torch.load(opts.checkpoint_id) 65 | trainer.id_a.load_state_dict(state_dict_id['a'], strict=False) 66 | trainer.id_b = trainer.id_a 67 | 68 | trainer.cuda() 69 | trainer.eval() 70 | encode = trainer.gen_a.encode # encode function 71 | style_encode = trainer.gen_a.encode # encode function 72 | id_encode = trainer.id_a # encode function 73 | decode = trainer.gen_a.decode # decode function 74 | 75 | data_transforms = transforms.Compose([ 76 | transforms.Resize((256,128), interpolation=3), 77 | transforms.ToTensor(), 78 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 79 | ]) 80 | 81 | image_datasets1 = datasets.ImageFolder(opts.input_folder+'/1', data_transforms) 82 | image_datasets2 = datasets.ImageFolder(opts.input_folder+'/2', data_transforms) 83 | dataloader_content = torch.utils.data.DataLoader(image_datasets1, batch_size=1, shuffle=False, num_workers=1) 84 | dataloader_structure = torch.utils.data.DataLoader(image_datasets2, batch_size=1, shuffle=False, num_workers=1) 85 | 86 | ###################################################################### 87 | # recover image 88 | # ----------------- 89 | def recover(inp): 90 | """Imshow for Tensor.""" 91 | inp = inp.numpy().transpose((1, 2, 0)) 92 | mean = np.array([0.485, 0.456, 0.406]) 93 | std = np.array([0.229, 0.224, 0.225]) 94 | inp = std * inp + mean 95 | inp = inp * 255.0 96 | inp = np.clip(inp, 0, 255) 97 | return inp 98 | 99 | def pad(inp, pad = 3): 100 | h = inp.shape[0] 101 | w = inp.shape[1] 102 | bg = np.zeros((h+2*pad, w+2*pad, inp.shape[2])) 103 | bg[pad:pad+h, pad:pad+w, :] = inp 104 | return bg 105 | 106 | im = {} 107 | npad = 3 108 | count = 0 109 | gray = to_gray(False) 110 | 111 | def generate(data, data2): 112 | bg_img, _ = data2 113 | bg_img = gray(bg_img) 114 | bg_img = Variable(bg_img.cuda()) 115 | id_img, _ = data 116 | id_img = Variable(id_img.cuda()) 117 | # Start testing 118 | s = encode(bg_img) 119 | f, _ = id_encode(id_img) 120 | output = decode(s, f) 121 | return output.squeeze().data.cpu() 122 | 123 | w = np.ones( (16, 128+2*npad,3))*255 #white row 124 | w2 = np.ones( (32, 128+2*npad,3))*255 #white row 125 | with torch.no_grad(): 126 | for data, data2 in zip(dataloader_content, dataloader_structure): 127 | im1 = pad(recover(data[0].squeeze()), pad= npad) 128 | im2 = pad(recover(data2[0].squeeze()), pad= npad) 129 | output1 = pad(recover(generate(data, data2)), pad= npad) 130 | output2 = pad(recover(generate(data2, data)), pad= npad) 131 | im[count] = np.concatenate((im1, w, im2, w2, output1, w, output2), axis=0) 132 | count +=1 133 | print(count) 134 | 135 | white_col = np.ones( (im[0].shape[0], 16, 3))*255 136 | 137 | for i in range(count): 138 | if i == 0: 139 | pic = im[0] 140 | else: 141 | pic = np.concatenate((pic, im[i]), axis=1) 142 | pic = np.concatenate((pic, white_col), axis=1) 143 | 144 | pic = Image.fromarray(pic.astype('uint8')) 145 | pic.save('swap.jpg') 146 | 147 | -------------------------------------------------------------------------------- /visual_tools/test_folder.py: -------------------------------------------------------------------------------- 1 | """ 2 | Copyright (C) 2018 NVIDIA Corporation. All rights reserved. 3 | Licensed under the CC BY-NC-SA 4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 4 | """ 5 | from __future__ import print_function 6 | import sys 7 | sys.path.append('.') 8 | from trainer import DGNet_Trainer, to_gray 9 | from utils import get_config 10 | import argparse 11 | from torch.autograd import Variable 12 | import torchvision.utils as vutils 13 | import sys 14 | import torch 15 | import random 16 | import os 17 | import numpy as np 18 | from torchvision import datasets, models, transforms 19 | from PIL import Image 20 | from shutil import copyfile 21 | 22 | parser = argparse.ArgumentParser() 23 | parser.add_argument('--output_folder', type=str, default="../Market/pytorch/off-gan_id1/", help="output image path") 24 | parser.add_argument('--output_folder2', type=str, default="../Market/pytorch/off-gan_bg1/", help="output image path") 25 | parser.add_argument('--input_folder', type=str, default="../Market/pytorch/train_all/", help="input image path") 26 | 27 | parser.add_argument('--name', type=str, default="E0.5new_reid0.5_w30000", help="model name") 28 | parser.add_argument('--which_epoch', default=100000, type=int, help='iteration') 29 | 30 | parser.add_argument('--batchsize', default=1, type=int, help='batchsize') 31 | parser.add_argument('--a2b', type=int, default=1, help="1 for a2b and others for b2a") 32 | parser.add_argument('--seed', type=int, default=10, help="random seed") 33 | parser.add_argument('--synchronized', action='store_true', help="whether use synchronized style code or not") 34 | parser.add_argument('--output_only', action='store_true', help="whether use synchronized style code or not") 35 | parser.add_argument('--trainer', type=str, default='DGNet', help="DGNet") 36 | 37 | 38 | opts = parser.parse_args() 39 | opts.checkpoint_gen = "./outputs/%s/checkpoints/gen_00%06d.pt"%(opts.name, opts.which_epoch) 40 | opts.checkpoint_id = "./outputs/%s/checkpoints/id_00%06d.pt"%(opts.name, opts.which_epoch) 41 | opts.config = './outputs/%s/config.yaml'%opts.name 42 | 43 | if not os.path.exists(opts.output_folder): 44 | os.makedirs(opts.output_folder) 45 | else: 46 | os.system('rm -rf %s/*'%opts.output_folder) 47 | 48 | if not os.path.exists(opts.output_folder2): 49 | os.makedirs(opts.output_folder2) 50 | else: 51 | os.system('rm -rf %s/*'%opts.output_folder2) 52 | 53 | # Load experiment setting 54 | config = get_config(opts.config) 55 | # we use config 56 | config['apex'] = False 57 | opts.num_style = 1 58 | 59 | # Setup model and data loader 60 | if opts.trainer == 'DGNet': 61 | trainer = DGNet_Trainer(config) 62 | else: 63 | sys.exit("Only support DGNet") 64 | 65 | state_dict_gen = torch.load(opts.checkpoint_gen) 66 | trainer.gen_a.load_state_dict(state_dict_gen['a'], strict=False) 67 | trainer.gen_b = trainer.gen_a 68 | 69 | state_dict_id = torch.load(opts.checkpoint_id) 70 | trainer.id_a.load_state_dict(state_dict_id['a']) 71 | trainer.id_b = trainer.id_a 72 | 73 | trainer.cuda() 74 | trainer.eval() 75 | encode = trainer.gen_a.encode # encode function 76 | style_encode = trainer.gen_a.encode # encode function 77 | id_encode = trainer.id_a # encode function 78 | decode = trainer.gen_a.decode # decode function 79 | 80 | data_transforms = transforms.Compose([ 81 | transforms.Resize(( config['crop_image_height'], config['crop_image_width']), interpolation=3), 82 | transforms.ToTensor(), 83 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) 84 | ]) 85 | 86 | image_datasets = datasets.ImageFolder(opts.input_folder, data_transforms) 87 | dataloader_content = torch.utils.data.DataLoader(image_datasets, batch_size=opts.batchsize, shuffle=False, pin_memory=True, num_workers=1) 88 | dataloader_structure = torch.utils.data.DataLoader(image_datasets, batch_size=opts.batchsize, shuffle=True, pin_memory=True, num_workers=1) 89 | image_paths = image_datasets.imgs 90 | 91 | ###################################################################### 92 | # recover image 93 | # ----------------- 94 | def recover(inp): 95 | """Imshow for Tensor.""" 96 | inp = inp.numpy().transpose((1, 2, 0)) 97 | mean = np.array([0.485, 0.456, 0.406]) 98 | std = np.array([0.229, 0.224, 0.225]) 99 | inp = std * inp + mean 100 | inp = inp * 255.0 101 | inp = np.clip(inp, 0, 255) 102 | return inp 103 | 104 | def fliplr(img): 105 | '''flip horizontal''' 106 | inv_idx = torch.arange(img.size(3)-1,-1,-1).long() # N x C x H x W 107 | img_flip = img.index_select(3,inv_idx) 108 | return img_flip 109 | 110 | ########################################### 111 | # ID with different background (run 10 times) 112 | #------------------------------------------ 113 | 114 | gray = to_gray(False) 115 | 116 | torch.manual_seed(opts.seed) 117 | 118 | with torch.no_grad(): 119 | for i in range(1): 120 | for data, data2, path in zip(dataloader_content, dataloader_structure, image_paths): 121 | name = os.path.basename(path[0]) 122 | id_img, label = data 123 | id_img_flip = Variable(fliplr(id_img).cuda()) 124 | id_img = Variable(id_img.cuda()) 125 | bg_img, label2 = data2 126 | if config['single'] == 'gray': 127 | bg_img = gray(bg_img) 128 | bg_img = Variable(bg_img.cuda()) 129 | 130 | n, c, h, w = id_img.size() 131 | # Start testing 132 | c = encode(bg_img) 133 | f, _ = id_encode(id_img) 134 | 135 | if opts.trainer == 'DGNet': 136 | outputs = decode(c, f) 137 | im = recover(outputs[0].data.cpu()) 138 | im = Image.fromarray(im.astype('uint8')) 139 | ID = name.split('_') 140 | dst_path = opts.output_folder + '/%03d'%label 141 | dst_path2 = opts.output_folder2 + '/%03d'%label2 142 | if not os.path.isdir(dst_path): 143 | os.mkdir(dst_path) 144 | if not os.path.isdir(dst_path2): 145 | os.mkdir(dst_path2) 146 | im.save(dst_path + '/%03d_%03d_gan%s.jpg'%(label2, label, name[:-4])) 147 | im.save(dst_path2 + '/%03d_%03d_gan%s.jpg'%(label2, label, name[:-4])) 148 | else: 149 | pass 150 | print('---- start fid evaluation ------') 151 | os.system('cd ../TTUR; python fid.py ../Market/pytorch/train_all ../Market/pytorch/off-gan_id1 --gpu 0') 152 | 153 | --------------------------------------------------------------------------------