├── .gitignore ├── 4in1 ├── 01_PSPNet.png ├── 01_crfasrnn.png ├── 01_deeplabv3plus.png ├── 01_original.jpg ├── 02_PSPNet.png ├── 02_crfasrnn.png ├── 02_deeplabv3plus.png ├── 02_original.jpg ├── 03_PSPNet.png ├── 03_crfasrnn.png ├── 03_deeplabv3plus.png ├── 03_original.jpg ├── 04_PSPNet.png ├── 04_crfasrnn.png ├── 04_deeplabv3plus.png ├── 04_original.jpg ├── 05_PSPNet.png ├── 05_crfasrnn.png ├── 05_deeplabv3plus.png ├── 05_original.jpg ├── 06_PSPNet.png ├── 06_crfasrnn.png ├── 06_deeplabv3plus.png └── 06_original.jpg ├── LICENSE ├── README.md ├── deeplabv3plus_xception.py ├── deepllabv3plus.py ├── demo.ipynb ├── function.py ├── google_images_download.py ├── hair_colab.ipynb ├── images ├── 10.png ├── 11.png ├── 12.png ├── 5.png ├── 6.png ├── 7.png ├── 8.png └── 9.png ├── matting_result ├── 01_matting_01_01.png ├── 01_matting_25_05.png ├── 01_matting_30_10.png ├── 01_matting_30_20.png ├── 01_matting_40_20.png ├── 01_trimap_01_01.png ├── 01_trimap_25_05.png ├── 01_trimap_30_10.png ├── 01_trimap_30_20.png ├── 01_trimap_40_20.png ├── 02_matting_01_01.png ├── 02_matting_25_05.png ├── 02_matting_30_10.png ├── 02_matting_30_20.png ├── 02_matting_40_20.png ├── 02_trimap_01_01.png ├── 02_trimap_25_05.png ├── 02_trimap_30_10.png ├── 02_trimap_30_20.png ├── 02_trimap_40_20.png ├── 03_matting_01_01.png ├── 03_matting_25_05.png ├── 03_matting_30_10.png ├── 03_matting_30_20.png ├── 03_matting_40_20.png ├── 03_trimap_01_01.png ├── 03_trimap_25_05.png ├── 03_trimap_30_10.png ├── 03_trimap_30_20.png ├── 03_trimap_40_20.png ├── 04_matting_01_01.png ├── 04_matting_25_05.png ├── 04_matting_30_10.png ├── 04_matting_30_20.png ├── 04_matting_40_20.png ├── 04_trimap_01_01.png ├── 04_trimap_25_05.png ├── 04_trimap_30_10.png ├── 04_trimap_30_20.png ├── 04_trimap_40_20.png ├── 05_matting_01_01.png ├── 05_matting_25_05.png ├── 05_matting_30_10.png ├── 05_matting_30_20.png ├── 05_matting_40_20.png ├── 05_trimap_01_01.png ├── 05_trimap_25_05.png ├── 05_trimap_30_10.png ├── 05_trimap_30_20.png ├── 05_trimap_40_20.png ├── 06_matting_01_01.png ├── 06_matting_25_05.png ├── 06_matting_30_10.png ├── 06_matting_30_20.png ├── 06_matting_40_20.png ├── 06_trimap_01_01.png ├── 06_trimap_25_05.png ├── 06_trimap_30_10.png ├── 06_trimap_30_20.png └── 06_trimap_40_20.png ├── net.py ├── pytorch_deep_image_matting ├── .ipynb_checkpoints │ ├── Untitled-checkpoint.ipynb │ └── segmentation_trimap-checkpoint.ipynb ├── README.md └── deep_image_matting.py ├── sample ├── 01.jpg ├── 02.jpg ├── 03.jpg ├── 04.jpg ├── 05.jpg └── 06.jpg ├── segmentation_result ├── 01_mask.png ├── 01_seq.png ├── 02_mask.png ├── 02_seq.png ├── 03_mask.png ├── 03_seq.png ├── 04_mask.png ├── 04_seq.png ├── 05_mask.png ├── 05_seq.png ├── 06_mask.png ├── 06_seq.png └── 1. 34423890523_f8c5b3741c.png └── trimap.py /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | *.swp 3 | *.sqlite3 4 | *.DS_Store 5 | model_xception65_coco_voc_trainval.tar.gz 6 | stage1_sad_54.4.pth 7 | -------------------------------------------------------------------------------- /4in1/01_PSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/01_PSPNet.png -------------------------------------------------------------------------------- /4in1/01_crfasrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/01_crfasrnn.png -------------------------------------------------------------------------------- /4in1/01_deeplabv3plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/01_deeplabv3plus.png -------------------------------------------------------------------------------- /4in1/01_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/01_original.jpg -------------------------------------------------------------------------------- /4in1/02_PSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/02_PSPNet.png -------------------------------------------------------------------------------- /4in1/02_crfasrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/02_crfasrnn.png -------------------------------------------------------------------------------- /4in1/02_deeplabv3plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/02_deeplabv3plus.png -------------------------------------------------------------------------------- /4in1/02_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/02_original.jpg -------------------------------------------------------------------------------- /4in1/03_PSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/03_PSPNet.png -------------------------------------------------------------------------------- /4in1/03_crfasrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/03_crfasrnn.png -------------------------------------------------------------------------------- /4in1/03_deeplabv3plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/03_deeplabv3plus.png -------------------------------------------------------------------------------- /4in1/03_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/03_original.jpg -------------------------------------------------------------------------------- /4in1/04_PSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/04_PSPNet.png -------------------------------------------------------------------------------- /4in1/04_crfasrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/04_crfasrnn.png -------------------------------------------------------------------------------- /4in1/04_deeplabv3plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/04_deeplabv3plus.png -------------------------------------------------------------------------------- /4in1/04_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/04_original.jpg -------------------------------------------------------------------------------- /4in1/05_PSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/05_PSPNet.png -------------------------------------------------------------------------------- /4in1/05_crfasrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/05_crfasrnn.png -------------------------------------------------------------------------------- /4in1/05_deeplabv3plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/05_deeplabv3plus.png -------------------------------------------------------------------------------- /4in1/05_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/05_original.jpg -------------------------------------------------------------------------------- /4in1/06_PSPNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/06_PSPNet.png -------------------------------------------------------------------------------- /4in1/06_crfasrnn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/06_crfasrnn.png -------------------------------------------------------------------------------- /4in1/06_deeplabv3plus.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/06_deeplabv3plus.png -------------------------------------------------------------------------------- /4in1/06_original.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/4in1/06_original.jpg -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 2, June 1991 3 | 4 | Copyright (C) 1989, 1991 Free Software Foundation, Inc., 5 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA 6 | Everyone is permitted to copy and distribute verbatim copies 7 | of this license document, but changing it is not allowed. 8 | 9 | Preamble 10 | 11 | The licenses for most software are designed to take away your 12 | freedom to share and change it. By contrast, the GNU General Public 13 | License is intended to guarantee your freedom to share and change free 14 | software--to make sure the software is free for all its users. This 15 | General Public License applies to most of the Free Software 16 | Foundation's software and to any other program whose authors commit to 17 | using it. (Some other Free Software Foundation software is covered by 18 | the GNU Lesser General Public License instead.) You can apply it to 19 | your programs, too. 20 | 21 | When we speak of free software, we are referring to freedom, not 22 | price. Our General Public Licenses are designed to make sure that you 23 | have the freedom to distribute copies of free software (and charge for 24 | this service if you wish), that you receive source code or can get it 25 | if you want it, that you can change the software or use pieces of it 26 | in new free programs; and that you know you can do these things. 27 | 28 | To protect your rights, we need to make restrictions that forbid 29 | anyone to deny you these rights or to ask you to surrender the rights. 30 | These restrictions translate to certain responsibilities for you if you 31 | distribute copies of the software, or if you modify it. 32 | 33 | For example, if you distribute copies of such a program, whether 34 | gratis or for a fee, you must give the recipients all the rights that 35 | you have. You must make sure that they, too, receive or can get the 36 | source code. And you must show them these terms so they know their 37 | rights. 38 | 39 | We protect your rights with two steps: (1) copyright the software, and 40 | (2) offer you this license which gives you legal permission to copy, 41 | distribute and/or modify the software. 42 | 43 | Also, for each author's protection and ours, we want to make certain 44 | that everyone understands that there is no warranty for this free 45 | software. If the software is modified by someone else and passed on, we 46 | want its recipients to know that what they have is not the original, so 47 | that any problems introduced by others will not reflect on the original 48 | authors' reputations. 49 | 50 | Finally, any free program is threatened constantly by software 51 | patents. We wish to avoid the danger that redistributors of a free 52 | program will individually obtain patent licenses, in effect making the 53 | program proprietary. To prevent this, we have made it clear that any 54 | patent must be licensed for everyone's free use or not licensed at all. 55 | 56 | The precise terms and conditions for copying, distribution and 57 | modification follow. 58 | 59 | GNU GENERAL PUBLIC LICENSE 60 | TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION 61 | 62 | 0. This License applies to any program or other work which contains 63 | a notice placed by the copyright holder saying it may be distributed 64 | under the terms of this General Public License. The "Program", below, 65 | refers to any such program or work, and a "work based on the Program" 66 | means either the Program or any derivative work under copyright law: 67 | that is to say, a work containing the Program or a portion of it, 68 | either verbatim or with modifications and/or translated into another 69 | language. (Hereinafter, translation is included without limitation in 70 | the term "modification".) Each licensee is addressed as "you". 71 | 72 | Activities other than copying, distribution and modification are not 73 | covered by this License; they are outside its scope. The act of 74 | running the Program is not restricted, and the output from the Program 75 | is covered only if its contents constitute a work based on the 76 | Program (independent of having been made by running the Program). 77 | Whether that is true depends on what the Program does. 78 | 79 | 1. You may copy and distribute verbatim copies of the Program's 80 | source code as you receive it, in any medium, provided that you 81 | conspicuously and appropriately publish on each copy an appropriate 82 | copyright notice and disclaimer of warranty; keep intact all the 83 | notices that refer to this License and to the absence of any warranty; 84 | and give any other recipients of the Program a copy of this License 85 | along with the Program. 86 | 87 | You may charge a fee for the physical act of transferring a copy, and 88 | you may at your option offer warranty protection in exchange for a fee. 89 | 90 | 2. You may modify your copy or copies of the Program or any portion 91 | of it, thus forming a work based on the Program, and copy and 92 | distribute such modifications or work under the terms of Section 1 93 | above, provided that you also meet all of these conditions: 94 | 95 | a) You must cause the modified files to carry prominent notices 96 | stating that you changed the files and the date of any change. 97 | 98 | b) You must cause any work that you distribute or publish, that in 99 | whole or in part contains or is derived from the Program or any 100 | part thereof, to be licensed as a whole at no charge to all third 101 | parties under the terms of this License. 102 | 103 | c) If the modified program normally reads commands interactively 104 | when run, you must cause it, when started running for such 105 | interactive use in the most ordinary way, to print or display an 106 | announcement including an appropriate copyright notice and a 107 | notice that there is no warranty (or else, saying that you provide 108 | a warranty) and that users may redistribute the program under 109 | these conditions, and telling the user how to view a copy of this 110 | License. (Exception: if the Program itself is interactive but 111 | does not normally print such an announcement, your work based on 112 | the Program is not required to print an announcement.) 113 | 114 | These requirements apply to the modified work as a whole. If 115 | identifiable sections of that work are not derived from the Program, 116 | and can be reasonably considered independent and separate works in 117 | themselves, then this License, and its terms, do not apply to those 118 | sections when you distribute them as separate works. But when you 119 | distribute the same sections as part of a whole which is a work based 120 | on the Program, the distribution of the whole must be on the terms of 121 | this License, whose permissions for other licensees extend to the 122 | entire whole, and thus to each and every part regardless of who wrote it. 123 | 124 | Thus, it is not the intent of this section to claim rights or contest 125 | your rights to work written entirely by you; rather, the intent is to 126 | exercise the right to control the distribution of derivative or 127 | collective works based on the Program. 128 | 129 | In addition, mere aggregation of another work not based on the Program 130 | with the Program (or with a work based on the Program) on a volume of 131 | a storage or distribution medium does not bring the other work under 132 | the scope of this License. 133 | 134 | 3. You may copy and distribute the Program (or a work based on it, 135 | under Section 2) in object code or executable form under the terms of 136 | Sections 1 and 2 above provided that you also do one of the following: 137 | 138 | a) Accompany it with the complete corresponding machine-readable 139 | source code, which must be distributed under the terms of Sections 140 | 1 and 2 above on a medium customarily used for software interchange; or, 141 | 142 | b) Accompany it with a written offer, valid for at least three 143 | years, to give any third party, for a charge no more than your 144 | cost of physically performing source distribution, a complete 145 | machine-readable copy of the corresponding source code, to be 146 | distributed under the terms of Sections 1 and 2 above on a medium 147 | customarily used for software interchange; or, 148 | 149 | c) Accompany it with the information you received as to the offer 150 | to distribute corresponding source code. (This alternative is 151 | allowed only for noncommercial distribution and only if you 152 | received the program in object code or executable form with such 153 | an offer, in accord with Subsection b above.) 154 | 155 | The source code for a work means the preferred form of the work for 156 | making modifications to it. For an executable work, complete source 157 | code means all the source code for all modules it contains, plus any 158 | associated interface definition files, plus the scripts used to 159 | control compilation and installation of the executable. However, as a 160 | special exception, the source code distributed need not include 161 | anything that is normally distributed (in either source or binary 162 | form) with the major components (compiler, kernel, and so on) of the 163 | operating system on which the executable runs, unless that component 164 | itself accompanies the executable. 165 | 166 | If distribution of executable or object code is made by offering 167 | access to copy from a designated place, then offering equivalent 168 | access to copy the source code from the same place counts as 169 | distribution of the source code, even though third parties are not 170 | compelled to copy the source along with the object code. 171 | 172 | 4. You may not copy, modify, sublicense, or distribute the Program 173 | except as expressly provided under this License. Any attempt 174 | otherwise to copy, modify, sublicense or distribute the Program is 175 | void, and will automatically terminate your rights under this License. 176 | However, parties who have received copies, or rights, from you under 177 | this License will not have their licenses terminated so long as such 178 | parties remain in full compliance. 179 | 180 | 5. You are not required to accept this License, since you have not 181 | signed it. However, nothing else grants you permission to modify or 182 | distribute the Program or its derivative works. These actions are 183 | prohibited by law if you do not accept this License. Therefore, by 184 | modifying or distributing the Program (or any work based on the 185 | Program), you indicate your acceptance of this License to do so, and 186 | all its terms and conditions for copying, distributing or modifying 187 | the Program or works based on it. 188 | 189 | 6. Each time you redistribute the Program (or any work based on the 190 | Program), the recipient automatically receives a license from the 191 | original licensor to copy, distribute or modify the Program subject to 192 | these terms and conditions. You may not impose any further 193 | restrictions on the recipients' exercise of the rights granted herein. 194 | You are not responsible for enforcing compliance by third parties to 195 | this License. 196 | 197 | 7. If, as a consequence of a court judgment or allegation of patent 198 | infringement or for any other reason (not limited to patent issues), 199 | conditions are imposed on you (whether by court order, agreement or 200 | otherwise) that contradict the conditions of this License, they do not 201 | excuse you from the conditions of this License. If you cannot 202 | distribute so as to satisfy simultaneously your obligations under this 203 | License and any other pertinent obligations, then as a consequence you 204 | may not distribute the Program at all. For example, if a patent 205 | license would not permit royalty-free redistribution of the Program by 206 | all those who receive copies directly or indirectly through you, then 207 | the only way you could satisfy both it and this License would be to 208 | refrain entirely from distribution of the Program. 209 | 210 | If any portion of this section is held invalid or unenforceable under 211 | any particular circumstance, the balance of the section is intended to 212 | apply and the section as a whole is intended to apply in other 213 | circumstances. 214 | 215 | It is not the purpose of this section to induce you to infringe any 216 | patents or other property right claims or to contest validity of any 217 | such claims; this section has the sole purpose of protecting the 218 | integrity of the free software distribution system, which is 219 | implemented by public license practices. Many people have made 220 | generous contributions to the wide range of software distributed 221 | through that system in reliance on consistent application of that 222 | system; it is up to the author/donor to decide if he or she is willing 223 | to distribute software through any other system and a licensee cannot 224 | impose that choice. 225 | 226 | This section is intended to make thoroughly clear what is believed to 227 | be a consequence of the rest of this License. 228 | 229 | 8. If the distribution and/or use of the Program is restricted in 230 | certain countries either by patents or by copyrighted interfaces, the 231 | original copyright holder who places the Program under this License 232 | may add an explicit geographical distribution limitation excluding 233 | those countries, so that distribution is permitted only in or among 234 | countries not thus excluded. In such case, this License incorporates 235 | the limitation as if written in the body of this License. 236 | 237 | 9. The Free Software Foundation may publish revised and/or new versions 238 | of the General Public License from time to time. Such new versions will 239 | be similar in spirit to the present version, but may differ in detail to 240 | address new problems or concerns. 241 | 242 | Each version is given a distinguishing version number. If the Program 243 | specifies a version number of this License which applies to it and "any 244 | later version", you have the option of following the terms and conditions 245 | either of that version or of any later version published by the Free 246 | Software Foundation. If the Program does not specify a version number of 247 | this License, you may choose any version ever published by the Free Software 248 | Foundation. 249 | 250 | 10. If you wish to incorporate parts of the Program into other free 251 | programs whose distribution conditions are different, write to the author 252 | to ask for permission. For software which is copyrighted by the Free 253 | Software Foundation, write to the Free Software Foundation; we sometimes 254 | make exceptions for this. Our decision will be guided by the two goals 255 | of preserving the free status of all derivatives of our free software and 256 | of promoting the sharing and reuse of software generally. 257 | 258 | NO WARRANTY 259 | 260 | 11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY 261 | FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN 262 | OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES 263 | PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED 264 | OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF 265 | MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS 266 | TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE 267 | PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, 268 | REPAIR OR CORRECTION. 269 | 270 | 12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 271 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR 272 | REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, 273 | INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING 274 | OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED 275 | TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY 276 | YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER 277 | PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE 278 | POSSIBILITY OF SUCH DAMAGES. 279 | 280 | END OF TERMS AND CONDITIONS 281 | 282 | How to Apply These Terms to Your New Programs 283 | 284 | If you develop a new program, and you want it to be of the greatest 285 | possible use to the public, the best way to achieve this is to make it 286 | free software which everyone can redistribute and change under these terms. 287 | 288 | To do so, attach the following notices to the program. It is safest 289 | to attach them to the start of each source file to most effectively 290 | convey the exclusion of warranty; and each file should have at least 291 | the "copyright" line and a pointer to where the full notice is found. 292 | 293 | 294 | Copyright (C) 295 | 296 | This program is free software; you can redistribute it and/or modify 297 | it under the terms of the GNU General Public License as published by 298 | the Free Software Foundation; either version 2 of the License, or 299 | (at your option) any later version. 300 | 301 | This program is distributed in the hope that it will be useful, 302 | but WITHOUT ANY WARRANTY; without even the implied warranty of 303 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 304 | GNU General Public License for more details. 305 | 306 | You should have received a copy of the GNU General Public License along 307 | with this program; if not, write to the Free Software Foundation, Inc., 308 | 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. 309 | 310 | Also add information on how to contact you by electronic and paper mail. 311 | 312 | If the program is interactive, make it output a short notice like this 313 | when it starts in an interactive mode: 314 | 315 | Gnomovision version 69, Copyright (C) year name of author 316 | Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 317 | This is free software, and you are welcome to redistribute it 318 | under certain conditions; type `show c' for details. 319 | 320 | The hypothetical commands `show w' and `show c' should show the appropriate 321 | parts of the General Public License. Of course, the commands you use may 322 | be called something other than `show w' and `show c'; they could even be 323 | mouse-clicks or menu items--whatever suits your program. 324 | 325 | You should also get your employer (if you work as a programmer) or your 326 | school, if any, to sign a "copyright disclaimer" for the program, if 327 | necessary. Here is a sample; alter the names: 328 | 329 | Yoyodyne, Inc., hereby disclaims all copyright interest in the program 330 | `Gnomovision' (which makes passes at compilers) written by James Hacker. 331 | 332 | , 1 April 1989 333 | Ty Coon, President of Vice 334 | 335 | This General Public License does not permit incorporating your program into 336 | proprietary programs. If your program is a subroutine library, you may 337 | consider it more useful to permit linking proprietary applications with the 338 | library. If this is what you want to do, use the GNU Lesser General 339 | Public License instead of this License. 340 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## 簡介 2 | 開源智造使用Google最新版的DeepLab v3+,搭配Automatic Trimap Generator,以及使用目前開源代碼中最高名次的matting方法(下表) Deep Image Matting,成功整合Image Segmentation, Trimap自動生成器與Image Matting,創造出一套全自動髮絲去背AI模型,我們將之取名為Auto Hair。 3 | 4 | | method | github | python file | 5 | |-------|:-----:|------:| 6 | | Deep-Image-Matting | | data_generator.py | 7 | | Automatic Trimap Generator | | trimap_module.py | 8 | | unet-gan-matting | | combine.py | 9 | | Semantic Human Matting | | gen_trimap.py | 10 | 11 | ## 髮絲圖片去背技術架構及執行方法 12 | 圖片去背通常有兩種類型,分別是Image Segmentation (IS) 與Image Mat-ting (IM)。IS主要著重在對每個像素的語義理解,並得到該像素分類結果,但缺點是很難滿足標的物邊緣高精度的切割效果。Image Matting主要著重在找出前景與背景的顏色,以及它們之間的融合程度,在像素邊緣的分割效果更加自然,但其缺點是需要人工耗時繪製Trimap圖,來確定圖像中前景、背景與不確定區域。經過研究和實驗,開源智造提出了一個全新的模式,以達成「髮絲」的精準去背。我們 (1) 使用Google最新版的DeepLab v3+ (2) 搭配 Automatic Trimap Generator [1],以及(3) 使用目前開源代碼中最高名次的matting方法Deep Image Matting [2]。最終我們成功整合Image Segmentation, Trimap自動生成器與Image Matting,創造出一套全自動髮絲去背演算法,我們將之取名為Auto Hair。 13 | 在我們研究的文獻中,尚未有人嘗試將此三種方法進行結合。以下我們將針對(A) Mask, (B) Trimap, (C) Image Matting的執行方法做細部說明。 14 | 15 | #### (A) Mask 16 | 目前已超過百個Image Segmentation的架構,針對其中有Open Source的30多種演算法 我們做了全面的測試及實作上的比較。我們在此列舉三個演算法 (CRF-RNN、PSPNET和Deeplab v3+) 的測試效果[3]。如下圖,CRF-RNN、PSPNET和DeepLab v3+,在6種範例圖檔中去背效果均不錯。 17 | ![image1](images/5.png) 18 | 19 | 仔細觀察可發現DeepLab v3+的結果又比另外兩種效果更好。仔細觀察可以發現,在沙發的範例中,椅腳在CRF-RNN的結果中完全消失,而在PSPNET則消失了一部分。再來,在車子的範例中,車底陰影的去背效果也是以DeepLab v3+為最佳。針對有髮絲的女性圖片來做比較,將三種去背結果的放大後 (如下圖),可發現DeepLab v3+最貼近標的物邊緣,背景成分也更少。 20 | ![image2](images/6.png) 21 | 22 | 即使DeepLab v3+的去背成效已經不錯,可以看到頭髮之間仍然存在背景圖案。為了達成「有髮絲效果」的精準去背,我們利用DeepLab v3+所產生的遮罩圖(mask,如下圖),進行第二次去背的優化。 23 | ![image3](images/7.png) 24 | 25 | #### (B) Trimap 26 | Image Matting是人工手動描繪出需要被提取的物件邊緣,產生Trimap,再利用公式求得圖片的相關參數。如下圖,一旦有了Trimap之後,Image Matting要做的分析就是針對灰色區域中的未知像素,判斷這些未知像素哪些是前景跟背景。 27 | ![image4](images/8.png) 28 | 29 | 通常Trimap是人為繪製黑灰白的三向圖。黑色的地方是確定的背景(Definite Background),白色的是確定的前景(Definite Foreground),灰色則是不確定的部分(Unknown)。要繪製出精確的Trimap並不容易,需要設計師花費相當多的時間。經過大量的搜尋,我們找到4種可產生Trimap的開源代碼。我們團隊做了深入的比較,最後決定採用Automatic Trimap Generator的方法,其特點是遮罩圖(mask)邊緣的內縮與外擴單位向量可自行決定,詳下圖。 30 | ![image5](images/9.png) 31 | 32 | #### (C) Image Matting 33 | Image Matting已經有不少學術論文和開源演算法,我們團隊將開源的python程式碼做了整理,詳見下表。 34 | 35 | | method | link | 36 | |-------|:-----:| 37 | | Deep Image Matting | | 38 | | Indexnet matting | | 39 | | Fusion Matting | | 40 | | Learning-based Sampling | | 41 | | KNN-matting | | 42 | | bayesian-matting | | 43 | | learning-based-matting | | 44 | | poisson-matting | | 45 | | mishima-matting | | 46 | | closed-form-matting | | 47 | | Lkm matting | | 48 | | LFM matting | | 49 | | AlphaGAN-Matting | | 50 | | unet-gan-matting | | 51 | 52 | 53 | 由於不同形式的Trimap配上不同的matting方法效果差異極大,我們投入了大量的時間,並善用TWCC的運算資源,針對各種組合進行分析, 下圖為比較結果的範例之一。透過仔細的觀察,最終我們找出當前效果最佳的組合,即用DeepLab v3+萃取出的遮罩,自動產生的Trimap,並搭配Deep Image Matting,這樣的架構可以達成自動化髮絲去背的細緻效果,而不再需要人工介入。 54 | ![image6](images/10.png) 55 | 56 | 如下圖所示, Auto Hair的髮絲去背成果比DeepLab v3+細緻和自然很多。而且,Auto Hair不只是邊緣的效果增強了,甚至還原了原本Image Segmentation (即DeepLab v3+)步驟中丟失的捲髮區塊 (如圖的綠色圓圈部位)。 57 | ![image7](images/11.png) 58 | ![image8](images/12.png) 59 | 60 | [1] https://github.com/lnugraha/trimap_generator 61 | [2] https://github.com/huochaitiantang/pytorch-deep-image-matting 62 | [3] http://host.robots.ox.ac.uk/leaderboard/displaylb.php?challengeid=11 63 | 64 | ## download large size model from here: 65 | https://celloai-my.sharepoint.com/:u:/g/personal/jeremylai_openaifab_com/ETP6v1ct3vNAnouZvBEYWLQBevXuGVmO5Yf1tSrzvA5Hfg?e=yTyfFB 66 | https://celloai-my.sharepoint.com/:u:/g/personal/jeremylai_openaifab_com/EQtY9hQtOl1On3GSRBDw-H4B3D7oxIi5lilIklCZRBku3Q?e=HMfJJ6 67 | -------------------------------------------------------------------------------- /deeplabv3plus_xception.py: -------------------------------------------------------------------------------- 1 | import os 2 | from io import BytesIO 3 | import tarfile 4 | import tempfile 5 | from six.moves import urllib 6 | from matplotlib import gridspec 7 | from matplotlib import pyplot as plt 8 | import numpy as np 9 | from PIL import Image 10 | import tensorflow as tf 11 | #from scipy.misc import imread 12 | #import imageio 13 | 14 | 15 | class DeepLabModel(object): 16 | INPUT_TENSOR_NAME = 'ImageTensor:0' 17 | OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' 18 | INPUT_SIZE = 513 #圖片長寬 19 | FROZEN_GRAPH_NAME = 'frozen' #_inference_graph 20 | def __init__(self, tarball_path): 21 | self.graph = tf.Graph() 22 | graph_def = None 23 | tar_file = tarfile.open(tarball_path) 24 | for tar_info in tar_file.getmembers(): 25 | if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): 26 | file_handle = tar_file.extractfile(tar_info) 27 | graph_def = tf.GraphDef.FromString(file_handle.read()) 28 | break 29 | tar_file.close() 30 | if graph_def is None: 31 | raise RuntimeError('Cannot find inference graph in tar archive.') 32 | with self.graph.as_default(): 33 | tf.import_graph_def(graph_def, name='') 34 | self.sess = tf.Session(graph=self.graph) 35 | def run(self, image): 36 | width, height = image.size 37 | resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) 38 | target_size = (int(resize_ratio * width), int(resize_ratio * height)) 39 | resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) 40 | batch_seg_map = self.sess.run( 41 | self.OUTPUT_TENSOR_NAME, 42 | feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) 43 | seg_map = batch_seg_map[0] 44 | return resized_image, seg_map 45 | def create_pascal_label_colormap(): 46 | colormap = np.zeros((256, 3), dtype=int) 47 | ind = np.arange(256, dtype=int) 48 | for shift in reversed(range(8)): 49 | for channel in range(3): 50 | colormap[:, channel] |= ((ind >> channel) & 1) << shift 51 | ind >>= 3 52 | return colormap 53 | def label_to_color_image(label): 54 | if label.ndim != 2: 55 | raise ValueError('Expect 2-D input label') 56 | colormap = create_pascal_label_colormap() 57 | if np.max(label) >= len(colormap): 58 | raise ValueError('label value too large.') 59 | return colormap[label] 60 | 61 | def run_deeplabv3plus_xception(model_input,photo_input,output_file,show=False): 62 | MODEL = model_input 63 | original_im = Image.open(photo_input) 64 | width, height = original_im.size 65 | resized_im, seg_map = MODEL.run(original_im) 66 | cm = seg_map 67 | img = np.array(resized_im) 68 | rows = cm.shape[0] 69 | cols = cm.shape[1] 70 | for x in range(0, rows): 71 | for y in range(0, cols): 72 | if cm[x][y] == 0: 73 | img[x][y] = np.array([255, 255, 255], dtype='uint8') 74 | img = Image.fromarray(img) 75 | img_convert = img.resize((width, height),Image.ANTIALIAS) 76 | img_convert.save(output_file) 77 | if(width>513 or height>513): 78 | print("deeplabv3+ xception/mobilenet 成功:但圖片被壓縮") 79 | else: 80 | print("deeplabv3+ xception/mobilenet 成功") 81 | if show: 82 | return(img_convert) 83 | 84 | -------------------------------------------------------------------------------- /deepllabv3plus.py: -------------------------------------------------------------------------------- 1 | import os 2 | import tarfile 3 | import numpy as np 4 | from PIL import Image 5 | import tensorflow as tf 6 | 7 | class DeepLabModel(object): 8 | INPUT_TENSOR_NAME = 'ImageTensor:0' 9 | OUTPUT_TENSOR_NAME = 'SemanticPredictions:0' 10 | INPUT_SIZE = 513 #圖片長寬 11 | FROZEN_GRAPH_NAME = 'frozen' #_inference_graph 12 | def __init__(self, tarball_path): 13 | self.graph = tf.Graph() 14 | graph_def = None 15 | tar_file = tarfile.open(tarball_path) 16 | for tar_info in tar_file.getmembers(): 17 | if self.FROZEN_GRAPH_NAME in os.path.basename(tar_info.name): 18 | file_handle = tar_file.extractfile(tar_info) 19 | graph_def = tf.GraphDef.FromString(file_handle.read()) 20 | break 21 | tar_file.close() 22 | if graph_def is None: 23 | raise RuntimeError('Cannot find inference graph in tar archive.') 24 | with self.graph.as_default(): 25 | tf.import_graph_def(graph_def, name='') 26 | self.sess = tf.Session(graph=self.graph) 27 | def run(self, image): 28 | width, height = image.size 29 | resize_ratio = 1.0 * self.INPUT_SIZE / max(width, height) 30 | target_size = (int(resize_ratio * width), int(resize_ratio * height)) 31 | resized_image = image.convert('RGB').resize(target_size, Image.ANTIALIAS) 32 | batch_seg_map = self.sess.run( 33 | self.OUTPUT_TENSOR_NAME, 34 | feed_dict={self.INPUT_TENSOR_NAME: [np.asarray(resized_image)]}) 35 | seg_map = batch_seg_map[0] 36 | return resized_image, seg_map 37 | def create_pascal_label_colormap(): 38 | colormap = np.zeros((256, 3), dtype=int) 39 | ind = np.arange(256, dtype=int) 40 | for shift in reversed(range(8)): 41 | for channel in range(3): 42 | colormap[:, channel] |= ((ind >> channel) & 1) << shift 43 | ind >>= 3 44 | return colormap 45 | def label_to_color_image(label): 46 | if label.ndim != 2: 47 | raise ValueError('Expect 2-D input label') 48 | colormap = create_pascal_label_colormap() 49 | if np.max(label) >= len(colormap): 50 | raise ValueError('label value too large.') 51 | return colormap[label] 52 | 53 | MODEL_xception65_trainval = DeepLabModel("model_xception65_coco_voc_trainval.tar.gz") 54 | print("deeplabv3+ model loading") 55 | 56 | def seg_result(rows, cols, cm, img): 57 | for x in range(0, rows): 58 | for y in range(0, cols): 59 | if cm[x][y] == 0: 60 | img[x][y] = np.array([255, 255, 255], dtype='uint8') 61 | return Image.fromarray(img) 62 | 63 | def mask_result(rows, cols, cm, img): 64 | for x in range(0, rows): 65 | for y in range(0, cols): 66 | if cm[x][y] == 0: 67 | img[x][y] = np.array([0, 0, 0], dtype='uint8') 68 | else: 69 | img[x][y] = np.array([255, 255, 255], dtype='uint8') 70 | return Image.fromarray(img) 71 | 72 | def run_deeplabv3plus(photo_input, seq_file, mask_file, ouput_folder, show, save): 73 | MODEL = MODEL_xception65_trainval 74 | original_im = Image.open(photo_input) 75 | width, height = original_im.size 76 | resized_im, seg_map = MODEL.run(original_im) 77 | cm = seg_map 78 | img = np.array(resized_im) 79 | rows = cm.shape[0] 80 | cols = cm.shape[1] 81 | 82 | img_seq = seg_result(rows, cols, cm, img) 83 | img_seq = img_seq.resize((width, height),Image.ANTIALIAS) 84 | img_mask = mask_result(rows, cols, cm, img) 85 | img_mask = img_mask.resize((width, height),Image.ANTIALIAS) 86 | 87 | if save : 88 | img_seq.save(ouput_folder + "/" + seq_file) 89 | img_mask.save(ouput_folder + "/" + mask_file) 90 | 91 | if show: 92 | result = [img_seq, img_mask] 93 | return result 94 | 95 | #for x in range(0, rows): 96 | # for y in range(0, cols): 97 | # if cm[x][y] == 0: 98 | # img[x][y] = np.array([255, 255, 255], dtype='uint8') 99 | #img = Image.fromarray(img) 100 | #img_convert = img.resize((width, height),Image.ANTIALIAS) 101 | #img_convert.save(seq_file + "/" + output_file) 102 | -------------------------------------------------------------------------------- /function.py: -------------------------------------------------------------------------------- 1 | import deepllabv3plus, google_images_download 2 | import requests, cv2, torch, urllib.request 3 | import pandas as pd 4 | from PIL import Image 5 | from trimap import trimap 6 | from pytorch_deep_image_matting.deep_image_matting import model_dim_fn, matting_result 7 | cuda = torch.cuda.is_available() 8 | print("cuda: " + str(cuda)) 9 | deep_image_matting_model = model_dim_fn(cuda) 10 | print("matting model loading") 11 | 12 | 13 | def google_img(search_name, output_folder, num=10, download=False): 14 | response = google_images_download.googleimagesdownload() 15 | if download: 16 | no_download = False 17 | else: 18 | no_download = True 19 | data = response.download({ 20 | "keywords": search_name, 21 | "color_type": "full-color", 22 | "format": "jpg", 23 | "size": "medium", 24 | "limit": num, 25 | 'no_directory': True, 26 | 'no_download': no_download, 27 | 'silent_mode': True, 28 | "output_directory": output_folder}) 29 | data = list(map(lambda x: x['image_link'], data[2])) 30 | return {search_name:data} 31 | 32 | def flickr_img(search_name, output_folder, num=10, download=False): 33 | url_flickr = "https://www.flickr.com/" 34 | data_flickr = requests.get(url_flickr) 35 | api_key = str(data_flickr.content).split("api.site_key")[1] 36 | api_key = api_key.split('"')[1] 37 | url = "https://api.flickr.com/services/rest?sort=relevance&parse_tags=1&content_type=7&extras=can_comment%2Ccount_comments%2Ccount_faves%2Cdescription%2Cisfavorite%2Clicense%2Cmedia%2Cneeds_interstitial%2Cowner_name%2Cpath_alias%2Crealname%2Crotation%2Curl_m&per_page=" + str(num) +"&page=1&lang=zh-Hant-HK&text=" + search_name + "&viewerNSID=&method=flickr.photos.search&csrf=&api_key=" + api_key + "&format=json&hermes=1&hermesClient=1&reqId=3405dc98&nojsoncallback=1" 38 | data = requests.get(url) 39 | data = data.json() 40 | data = pd.DataFrame.from_dict(data['photos']['photo']) 41 | data = data.filter(["height_m", "width_m", "url_m"]) 42 | data = data.fillna(0) 43 | data['height_m'] = data['height_m'].astype(int) 44 | data['width_m'] = data['width_m'].astype(int) 45 | data = data[data['height_m'] <= 500] 46 | data = data[data['height_m'] >0] 47 | data = data[data['width_m'] <= 500] 48 | data = data[data['width_m'] >0] 49 | data = data.reset_index(drop=True) 50 | if len(data)>0: 51 | if download: 52 | for i in range(len(data)): 53 | urllib.request.urlretrieve(data.url_m[i], output_folder + "/" + str(i+1) + ". " +data.url_m[i].split("/")[-1]) 54 | return {search_name:list(data.url_m)} 55 | else: 56 | return {search_name:"no data"} 57 | 58 | def seg_img(photo_input, seq_file, mask_file, ouput_folder, show, save): 59 | result = deepllabv3plus.run_deeplabv3plus(photo_input, seq_file, mask_file, ouput_folder, show, save) 60 | if show: 61 | return result 62 | 63 | def trimap_output(original_input, mask_input, output_folder, output_name, 64 | trimap_save = True, seg_show = True, seg_save = True): 65 | mask = cv2.imread(mask_input, cv2.IMREAD_GRAYSCALE) 66 | trimap_01_01 = trimap(mask, size=1, erosion=1) 67 | trimap_25_05 = trimap(mask, size=25, erosion=5) 68 | trimap_30_10 = trimap(mask, size=30, erosion=10) 69 | trimap_30_20 = trimap(mask, size=30, erosion=20) 70 | trimap_40_20 = trimap(mask, size=40, erosion=20) 71 | 72 | matting_01_01 = matting_result(original_input, trimap_01_01, deep_image_matting_model, cuda) 73 | print("generate matting(size: 01, erosion: 01)") 74 | matting_25_05 = matting_result(original_input, trimap_25_05, deep_image_matting_model, cuda) 75 | print("generate matting(size: 25, erosion: 05)") 76 | matting_30_10 = matting_result(original_input, trimap_30_10, deep_image_matting_model, cuda) 77 | print("generate matting(size: 30, erosion: 10)") 78 | matting_30_20 = matting_result(original_input, trimap_30_20, deep_image_matting_model, cuda) 79 | print("generate matting(size: 30, erosion: 20)") 80 | matting_40_20 = matting_result(original_input, trimap_40_20, deep_image_matting_model, cuda) 81 | print("generate matting(size: 40, erosion: 20)") 82 | 83 | if trimap_save : 84 | Image.fromarray(trimap_01_01.astype('uint8')).save(output_folder + "/" + output_name + "_trimap_01_01.png") 85 | Image.fromarray(trimap_25_05.astype('uint8')).save(output_folder + "/" + output_name + "_trimap_25_05.png") 86 | Image.fromarray(trimap_30_10.astype('uint8')).save(output_folder + "/" + output_name + "_trimap_30_10.png") 87 | Image.fromarray(trimap_30_20.astype('uint8')).save(output_folder + "/" + output_name + "_trimap_30_20.png") 88 | Image.fromarray(trimap_40_20.astype('uint8')).save(output_folder + "/" + output_name + "_trimap_40_20.png") 89 | 90 | if seg_save: 91 | matting_01_01.save(output_folder + "/" + output_name + "_matting_01_01.png") 92 | matting_25_05.save(output_folder + "/" + output_name + "_matting_25_05.png") 93 | matting_30_10.save(output_folder + "/" + output_name + "_matting_30_10.png") 94 | matting_30_20.save(output_folder + "/" + output_name + "_matting_30_20.png") 95 | matting_40_20.save(output_folder + "/" + output_name + "_matting_40_20.png") 96 | 97 | if seg_show: 98 | return {"01_01":matting_01_01, 99 | "25_05":matting_25_05, 100 | "30_10":matting_30_10, 101 | "30_20":matting_30_20, 102 | "40_20":matting_40_20} -------------------------------------------------------------------------------- /google_images_download.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # In[ ]: 3 | # coding: utf-8 4 | 5 | ###### Searching and Downloading Google Images to the local disk ###### 6 | 7 | # Import Libraries 8 | import sys 9 | version = (3, 0) 10 | cur_version = sys.version_info 11 | if cur_version >= version: # If the Current Version of Python is 3.0 or above 12 | import urllib.request 13 | from urllib.request import Request, urlopen 14 | from urllib.request import URLError, HTTPError 15 | from urllib.parse import quote 16 | import http.client 17 | from http.client import IncompleteRead, BadStatusLine 18 | http.client._MAXHEADERS = 1000 19 | else: # If the Current Version of Python is 2.x 20 | import urllib2 21 | from urllib2 import Request, urlopen 22 | from urllib2 import URLError, HTTPError 23 | from urllib import quote 24 | import httplib 25 | from httplib import IncompleteRead, BadStatusLine 26 | httplib._MAXHEADERS = 1000 27 | import time # Importing the time library to check the time of code execution 28 | import os 29 | import argparse 30 | import ssl 31 | import datetime 32 | import json 33 | import re 34 | import codecs 35 | import socket 36 | 37 | args_list = ["keywords", "keywords_from_file", "prefix_keywords", "suffix_keywords", 38 | "limit", "format", "color", "color_type", "usage_rights", "size", 39 | "exact_size", "aspect_ratio", "type", "time", "time_range", "delay", "url", "single_image", 40 | "output_directory", "image_directory", "no_directory", "proxy", "similar_images", "specific_site", 41 | "print_urls", "print_size", "print_paths", "metadata", "extract_metadata", "socket_timeout", 42 | "thumbnail", "thumbnail_only", "language", "prefix", "chromedriver", "related_images", "safe_search", "no_numbering", 43 | "offset", "no_download","save_source","silent_mode","ignore_urls"] 44 | 45 | 46 | def user_input(): 47 | config = argparse.ArgumentParser() 48 | config.add_argument('-cf', '--config_file', help='config file name', default='', type=str, required=False) 49 | config_file_check = config.parse_known_args() 50 | object_check = vars(config_file_check[0]) 51 | 52 | if object_check['config_file'] != '': 53 | records = [] 54 | json_file = json.load(open(config_file_check[0].config_file)) 55 | for record in range(0,len(json_file['Records'])): 56 | arguments = {} 57 | for i in args_list: 58 | arguments[i] = None 59 | for key, value in json_file['Records'][record].items(): 60 | arguments[key] = value 61 | records.append(arguments) 62 | records_count = len(records) 63 | else: 64 | # Taking command line arguments from users 65 | parser = argparse.ArgumentParser() 66 | parser.add_argument('-k', '--keywords', help='delimited list input', type=str, required=False) 67 | parser.add_argument('-kf', '--keywords_from_file', help='extract list of keywords from a text file', type=str, required=False) 68 | parser.add_argument('-sk', '--suffix_keywords', help='comma separated additional words added after to main keyword', type=str, required=False) 69 | parser.add_argument('-pk', '--prefix_keywords', help='comma separated additional words added before main keyword', type=str, required=False) 70 | parser.add_argument('-l', '--limit', help='delimited list input', type=str, required=False) 71 | parser.add_argument('-f', '--format', help='download images with specific format', type=str, required=False, 72 | choices=['jpg', 'gif', 'png', 'bmp', 'svg', 'webp', 'ico']) 73 | parser.add_argument('-u', '--url', help='search with google image URL', type=str, required=False) 74 | parser.add_argument('-x', '--single_image', help='downloading a single image from URL', type=str, required=False) 75 | parser.add_argument('-o', '--output_directory', help='download images in a specific main directory', type=str, required=False) 76 | parser.add_argument('-i', '--image_directory', help='download images in a specific sub-directory', type=str, required=False) 77 | parser.add_argument('-n', '--no_directory', default=False, help='download images in the main directory but no sub-directory', action="store_true") 78 | parser.add_argument('-d', '--delay', help='delay in seconds to wait between downloading two images', type=int, required=False) 79 | parser.add_argument('-co', '--color', help='filter on color', type=str, required=False, 80 | choices=['red', 'orange', 'yellow', 'green', 'teal', 'blue', 'purple', 'pink', 'white', 'gray', 'black', 'brown']) 81 | parser.add_argument('-ct', '--color_type', help='filter on color', type=str, required=False, 82 | choices=['full-color', 'black-and-white', 'transparent']) 83 | parser.add_argument('-r', '--usage_rights', help='usage rights', type=str, required=False, 84 | choices=['labeled-for-reuse-with-modifications','labeled-for-reuse','labeled-for-noncommercial-reuse-with-modification','labeled-for-nocommercial-reuse']) 85 | parser.add_argument('-s', '--size', help='image size', type=str, required=False, 86 | choices=['large','medium','icon','>400*300','>640*480','>800*600','>1024*768','>2MP','>4MP','>6MP','>8MP','>10MP','>12MP','>15MP','>20MP','>40MP','>70MP']) 87 | parser.add_argument('-es', '--exact_size', help='exact image resolution "WIDTH,HEIGHT"', type=str, required=False) 88 | parser.add_argument('-t', '--type', help='image type', type=str, required=False, 89 | choices=['face','photo','clipart','line-drawing','animated']) 90 | parser.add_argument('-w', '--time', help='image age', type=str, required=False, 91 | choices=['past-24-hours','past-7-days','past-month','past-year']) 92 | parser.add_argument('-wr', '--time_range', help='time range for the age of the image. should be in the format {"time_min":"MM/DD/YYYY","time_max":"MM/DD/YYYY"}', type=str, required=False) 93 | parser.add_argument('-a', '--aspect_ratio', help='comma separated additional words added to keywords', type=str, required=False, 94 | choices=['tall', 'square', 'wide', 'panoramic']) 95 | parser.add_argument('-si', '--similar_images', help='downloads images very similar to the image URL you provide', type=str, required=False) 96 | parser.add_argument('-ss', '--specific_site', help='downloads images that are indexed from a specific website', type=str, required=False) 97 | parser.add_argument('-p', '--print_urls', default=False, help="Print the URLs of the images", action="store_true") 98 | parser.add_argument('-ps', '--print_size', default=False, help="Print the size of the images on disk", action="store_true") 99 | parser.add_argument('-pp', '--print_paths', default=False, help="Prints the list of absolute paths of the images",action="store_true") 100 | parser.add_argument('-m', '--metadata', default=False, help="Print the metadata of the image", action="store_true") 101 | parser.add_argument('-e', '--extract_metadata', default=False, help="Dumps all the logs into a text file", action="store_true") 102 | parser.add_argument('-st', '--socket_timeout', default=False, help="Connection timeout waiting for the image to download", type=float) 103 | parser.add_argument('-th', '--thumbnail', default=False, help="Downloads image thumbnail along with the actual image", action="store_true") 104 | parser.add_argument('-tho', '--thumbnail_only', default=False, help="Downloads only thumbnail without downloading actual images", action="store_true") 105 | parser.add_argument('-la', '--language', default=False, help="Defines the language filter. The search results are authomatically returned in that language", type=str, required=False, 106 | choices=['Arabic','Chinese (Simplified)','Chinese (Traditional)','Czech','Danish','Dutch','English','Estonian','Finnish','French','German','Greek','Hebrew','Hungarian','Icelandic','Italian','Japanese','Korean','Latvian','Lithuanian','Norwegian','Portuguese','Polish','Romanian','Russian','Spanish','Swedish','Turkish']) 107 | parser.add_argument('-pr', '--prefix', default=False, help="A word that you would want to prefix in front of each image name", type=str, required=False) 108 | parser.add_argument('-px', '--proxy', help='specify a proxy address and port', type=str, required=False) 109 | parser.add_argument('-cd', '--chromedriver', help='specify the path to chromedriver executable in your local machine', type=str, required=False) 110 | parser.add_argument('-ri', '--related_images', default=False, help="Downloads images that are similar to the keyword provided", action="store_true") 111 | parser.add_argument('-sa', '--safe_search', default=False, help="Turns on the safe search filter while searching for images", action="store_true") 112 | parser.add_argument('-nn', '--no_numbering', default=False, help="Allows you to exclude the default numbering of images", action="store_true") 113 | parser.add_argument('-of', '--offset', help="Where to start in the fetched links", type=str, required=False) 114 | parser.add_argument('-nd', '--no_download', default=False, help="Prints the URLs of the images and/or thumbnails without downloading them", action="store_true") 115 | parser.add_argument('-iu', '--ignore_urls', default=False, help="delimited list input of image urls/keywords to ignore", type=str) 116 | parser.add_argument('-sil', '--silent_mode', default=False, help="Remains silent. Does not print notification messages on the terminal", action="store_true") 117 | parser.add_argument('-is', '--save_source', help="creates a text file containing a list of downloaded images along with source page url", type=str, required=False) 118 | 119 | args = parser.parse_args() 120 | arguments = vars(args) 121 | records = [] 122 | records.append(arguments) 123 | return records 124 | 125 | 126 | class googleimagesdownload: 127 | def __init__(self): 128 | pass 129 | 130 | # Downloading entire Web Document (Raw Page Content) 131 | def download_page(self,url): 132 | version = (3, 0) 133 | cur_version = sys.version_info 134 | if cur_version >= version: # If the Current Version of Python is 3.0 or above 135 | try: 136 | headers = {} 137 | headers['User-Agent'] = "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36" 138 | req = urllib.request.Request(url, headers=headers) 139 | resp = urllib.request.urlopen(req) 140 | respData = str(resp.read()) 141 | return respData 142 | except Exception as e: 143 | print("Could not open URL. Please check your internet connection and/or ssl settings \n" 144 | "If you are using proxy, make sure your proxy settings is configured correctly") 145 | sys.exit() 146 | else: # If the Current Version of Python is 2.x 147 | try: 148 | headers = {} 149 | headers['User-Agent'] = "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17" 150 | req = urllib2.Request(url, headers=headers) 151 | try: 152 | response = urllib2.urlopen(req) 153 | except URLError: # Handling SSL certificate failed 154 | context = ssl._create_unverified_context() 155 | response = urlopen(req, context=context) 156 | page = response.read() 157 | return page 158 | except: 159 | print("Could not open URL. Please check your internet connection and/or ssl settings \n" 160 | "If you are using proxy, make sure your proxy settings is configured correctly") 161 | sys.exit() 162 | return "Page Not found" 163 | 164 | 165 | # Download Page for more than 100 images 166 | def download_extended_page(self,url,chromedriver): 167 | from selenium import webdriver 168 | from selenium.webdriver.common.keys import Keys 169 | if sys.version_info[0] < 3: 170 | reload(sys) 171 | sys.setdefaultencoding('utf8') 172 | options = webdriver.ChromeOptions() 173 | options.add_argument('--no-sandbox') 174 | options.add_argument("--headless") 175 | 176 | try: 177 | browser = webdriver.Chrome(chromedriver, chrome_options=options) 178 | except Exception as e: 179 | print("Looks like we cannot locate the path the 'chromedriver' (use the '--chromedriver' " 180 | "argument to specify the path to the executable.) or google chrome browser is not " 181 | "installed on your machine (exception: %s)" % e) 182 | sys.exit() 183 | browser.set_window_size(1024, 768) 184 | 185 | # Open the link 186 | browser.get(url) 187 | time.sleep(1) 188 | print("Getting you a lot of images. This may take a few moments...") 189 | 190 | element = browser.find_element_by_tag_name("body") 191 | # Scroll down 192 | for i in range(30): 193 | element.send_keys(Keys.PAGE_DOWN) 194 | time.sleep(0.3) 195 | 196 | try: 197 | browser.find_element_by_id("smb").click() 198 | for i in range(50): 199 | element.send_keys(Keys.PAGE_DOWN) 200 | time.sleep(0.3) # bot id protection 201 | except: 202 | for i in range(10): 203 | element.send_keys(Keys.PAGE_DOWN) 204 | time.sleep(0.3) # bot id protection 205 | 206 | print("Reached end of Page.") 207 | time.sleep(0.5) 208 | 209 | source = browser.page_source #page source 210 | #close the browser 211 | browser.close() 212 | 213 | return source 214 | 215 | 216 | #Correcting the escape characters for python2 217 | def replace_with_byte(self,match): 218 | return chr(int(match.group(0)[1:], 8)) 219 | 220 | def repair(self,brokenjson): 221 | invalid_escape = re.compile(r'\\[0-7]{1,3}') # up to 3 digits for byte values up to FF 222 | return invalid_escape.sub(self.replace_with_byte, brokenjson) 223 | 224 | 225 | # Finding 'Next Image' from the given raw page 226 | def get_next_tab(self,s): 227 | start_line = s.find('class="dtviD"') 228 | if start_line == -1: # If no links are found then give an error! 229 | end_quote = 0 230 | link = "no_tabs" 231 | return link,'',end_quote 232 | else: 233 | start_line = s.find('class="dtviD"') 234 | start_content = s.find('href="', start_line + 1) 235 | end_content = s.find('">', start_content + 1) 236 | url_item = "https://www.google.com" + str(s[start_content + 6:end_content]) 237 | url_item = url_item.replace('&', '&') 238 | 239 | start_line_2 = s.find('class="dtviD"') 240 | s = s.replace('&', '&') 241 | start_content_2 = s.find(':', start_line_2 + 1) 242 | end_content_2 = s.find('&usg=', start_content_2 + 1) 243 | url_item_name = str(s[start_content_2 + 1:end_content_2]) 244 | 245 | chars = url_item_name.find(',g_1:') 246 | chars_end = url_item_name.find(":", chars + 6) 247 | if chars_end == -1: 248 | updated_item_name = (url_item_name[chars + 5:]).replace("+", " ") 249 | else: 250 | updated_item_name = (url_item_name[chars+5:chars_end]).replace("+", " ") 251 | 252 | return url_item, updated_item_name, end_content 253 | 254 | 255 | # Getting all links with the help of '_images_get_next_image' 256 | def get_all_tabs(self,page): 257 | tabs = {} 258 | while True: 259 | item,item_name,end_content = self.get_next_tab(page) 260 | if item == "no_tabs": 261 | break 262 | else: 263 | if len(item_name) > 100 or item_name == "background-color": 264 | break 265 | else: 266 | tabs[item_name] = item # Append all the links in the list named 'Links' 267 | time.sleep(0.1) # Timer could be used to slow down the request for image downloads 268 | page = page[end_content:] 269 | return tabs 270 | 271 | 272 | #Format the object in readable format 273 | def format_object(self,object): 274 | formatted_object = {} 275 | formatted_object['image_format'] = object['ity'] 276 | formatted_object['image_height'] = object['oh'] 277 | formatted_object['image_width'] = object['ow'] 278 | formatted_object['image_link'] = object['ou'] 279 | formatted_object['image_description'] = object['pt'] 280 | formatted_object['image_host'] = object['rh'] 281 | formatted_object['image_source'] = object['ru'] 282 | formatted_object['image_thumbnail_url'] = object['tu'] 283 | return formatted_object 284 | 285 | 286 | #function to download single image 287 | def single_image(self,image_url): 288 | main_directory = "downloads" 289 | extensions = (".jpg", ".gif", ".png", ".bmp", ".svg", ".webp", ".ico") 290 | url = image_url 291 | try: 292 | os.makedirs(main_directory) 293 | except OSError as e: 294 | if e.errno != 17: 295 | raise 296 | pass 297 | req = Request(url, headers={ 298 | "User-Agent": "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17"}) 299 | 300 | response = urlopen(req, None, 10) 301 | data = response.read() 302 | response.close() 303 | 304 | image_name = str(url[(url.rfind('/')) + 1:]) 305 | if '?' in image_name: 306 | image_name = image_name[:image_name.find('?')] 307 | # if ".jpg" in image_name or ".gif" in image_name or ".png" in image_name or ".bmp" in image_name or ".svg" in image_name or ".webp" in image_name or ".ico" in image_name: 308 | if any(map(lambda extension: extension in image_name, extensions)): 309 | file_name = main_directory + "/" + image_name 310 | else: 311 | file_name = main_directory + "/" + image_name + ".jpg" 312 | image_name = image_name + ".jpg" 313 | 314 | try: 315 | output_file = open(file_name, 'wb') 316 | output_file.write(data) 317 | output_file.close() 318 | except IOError as e: 319 | raise e 320 | except OSError as e: 321 | raise e 322 | print("completed ====> " + image_name.encode('raw_unicode_escape').decode('utf-8')) 323 | return 324 | 325 | def similar_images(self,similar_images): 326 | version = (3, 0) 327 | cur_version = sys.version_info 328 | if cur_version >= version: # If the Current Version of Python is 3.0 or above 329 | try: 330 | searchUrl = 'https://www.google.com/searchbyimage?site=search&sa=X&image_url=' + similar_images 331 | headers = {} 332 | headers['User-Agent'] = "Mozilla/5.0 (Windows NT 6.1) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/41.0.2228.0 Safari/537.36" 333 | 334 | req1 = urllib.request.Request(searchUrl, headers=headers) 335 | resp1 = urllib.request.urlopen(req1) 336 | content = str(resp1.read()) 337 | l1 = content.find('AMhZZ') 338 | l2 = content.find('&', l1) 339 | urll = content[l1:l2] 340 | 341 | newurl = "https://www.google.com/search?tbs=sbi:" + urll + "&site=search&sa=X" 342 | req2 = urllib.request.Request(newurl, headers=headers) 343 | resp2 = urllib.request.urlopen(req2) 344 | l3 = content.find('/search?sa=X&q=') 345 | l4 = content.find(';', l3 + 19) 346 | urll2 = content[l3 + 19:l4] 347 | return urll2 348 | except: 349 | return "Cloud not connect to Google Images endpoint" 350 | else: # If the Current Version of Python is 2.x 351 | try: 352 | searchUrl = 'https://www.google.com/searchbyimage?site=search&sa=X&image_url=' + similar_images 353 | headers = {} 354 | headers['User-Agent'] = "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17" 355 | 356 | req1 = urllib2.Request(searchUrl, headers=headers) 357 | resp1 = urllib2.urlopen(req1) 358 | content = str(resp1.read()) 359 | l1 = content.find('AMhZZ') 360 | l2 = content.find('&', l1) 361 | urll = content[l1:l2] 362 | 363 | newurl = "https://www.google.com/search?tbs=sbi:" + urll + "&site=search&sa=X" 364 | req2 = urllib2.Request(newurl, headers=headers) 365 | resp2 = urllib2.urlopen(req2) 366 | l3 = content.find('/search?sa=X&q=') 367 | l4 = content.find(';', l3 + 19) 368 | urll2 = content[l3 + 19:l4] 369 | return(urll2) 370 | except: 371 | return "Cloud not connect to Google Images endpoint" 372 | 373 | #Building URL parameters 374 | def build_url_parameters(self,arguments): 375 | if arguments['language']: 376 | lang = "&lr=" 377 | lang_param = {"Arabic":"lang_ar","Chinese (Simplified)":"lang_zh-CN","Chinese (Traditional)":"lang_zh-TW","Czech":"lang_cs","Danish":"lang_da","Dutch":"lang_nl","English":"lang_en","Estonian":"lang_et","Finnish":"lang_fi","French":"lang_fr","German":"lang_de","Greek":"lang_el","Hebrew":"lang_iw ","Hungarian":"lang_hu","Icelandic":"lang_is","Italian":"lang_it","Japanese":"lang_ja","Korean":"lang_ko","Latvian":"lang_lv","Lithuanian":"lang_lt","Norwegian":"lang_no","Portuguese":"lang_pt","Polish":"lang_pl","Romanian":"lang_ro","Russian":"lang_ru","Spanish":"lang_es","Swedish":"lang_sv","Turkish":"lang_tr"} 378 | lang_url = lang+lang_param[arguments['language']] 379 | else: 380 | lang_url = '' 381 | 382 | if arguments['time_range']: 383 | json_acceptable_string = arguments['time_range'].replace("'", "\"") 384 | d = json.loads(json_acceptable_string) 385 | time_range = ',cdr:1,cd_min:' + d['time_min'] + ',cd_max:' + d['time_max'] 386 | else: 387 | time_range = '' 388 | 389 | if arguments['exact_size']: 390 | size_array = [x.strip() for x in arguments['exact_size'].split(',')] 391 | exact_size = ",isz:ex,iszw:" + str(size_array[0]) + ",iszh:" + str(size_array[1]) 392 | else: 393 | exact_size = '' 394 | 395 | built_url = "&tbs=" 396 | counter = 0 397 | params = {'color':[arguments['color'],{'red':'ic:specific,isc:red', 'orange':'ic:specific,isc:orange', 'yellow':'ic:specific,isc:yellow', 'green':'ic:specific,isc:green', 'teal':'ic:specific,isc:teel', 'blue':'ic:specific,isc:blue', 'purple':'ic:specific,isc:purple', 'pink':'ic:specific,isc:pink', 'white':'ic:specific,isc:white', 'gray':'ic:specific,isc:gray', 'black':'ic:specific,isc:black', 'brown':'ic:specific,isc:brown'}], 398 | 'color_type':[arguments['color_type'],{'full-color':'ic:color', 'black-and-white':'ic:gray','transparent':'ic:trans'}], 399 | 'usage_rights':[arguments['usage_rights'],{'labeled-for-reuse-with-modifications':'sur:fmc','labeled-for-reuse':'sur:fc','labeled-for-noncommercial-reuse-with-modification':'sur:fm','labeled-for-nocommercial-reuse':'sur:f'}], 400 | 'size':[arguments['size'],{'large':'isz:l','medium':'isz:m','icon':'isz:i','>400*300':'isz:lt,islt:qsvga','>640*480':'isz:lt,islt:vga','>800*600':'isz:lt,islt:svga','>1024*768':'visz:lt,islt:xga','>2MP':'isz:lt,islt:2mp','>4MP':'isz:lt,islt:4mp','>6MP':'isz:lt,islt:6mp','>8MP':'isz:lt,islt:8mp','>10MP':'isz:lt,islt:10mp','>12MP':'isz:lt,islt:12mp','>15MP':'isz:lt,islt:15mp','>20MP':'isz:lt,islt:20mp','>40MP':'isz:lt,islt:40mp','>70MP':'isz:lt,islt:70mp'}], 401 | 'type':[arguments['type'],{'face':'itp:face','photo':'itp:photo','clipart':'itp:clipart','line-drawing':'itp:lineart','animated':'itp:animated'}], 402 | 'time':[arguments['time'],{'past-24-hours':'qdr:d','past-7-days':'qdr:w','past-month':'qdr:m','past-year':'qdr:y'}], 403 | 'aspect_ratio':[arguments['aspect_ratio'],{'tall':'iar:t','square':'iar:s','wide':'iar:w','panoramic':'iar:xw'}], 404 | 'format':[arguments['format'],{'jpg':'ift:jpg','gif':'ift:gif','png':'ift:png','bmp':'ift:bmp','svg':'ift:svg','webp':'webp','ico':'ift:ico','raw':'ift:craw'}]} 405 | for key, value in params.items(): 406 | if value[0] is not None: 407 | ext_param = value[1][value[0]] 408 | # counter will tell if it is first param added or not 409 | if counter == 0: 410 | # add it to the built url 411 | built_url = built_url + ext_param 412 | counter += 1 413 | else: 414 | built_url = built_url + ',' + ext_param 415 | counter += 1 416 | built_url = lang_url+built_url+exact_size+time_range 417 | return built_url 418 | 419 | 420 | #building main search URL 421 | def build_search_url(self,search_term,params,url,similar_images,specific_site,safe_search): 422 | #check safe_search 423 | safe_search_string = "&safe=active" 424 | # check the args and choose the URL 425 | if url: 426 | url = url 427 | elif similar_images: 428 | print(similar_images) 429 | keywordem = self.similar_images(similar_images) 430 | url = 'https://www.google.com/search?q=' + keywordem + '&espv=2&biw=1366&bih=667&site=webhp&source=lnms&tbm=isch&sa=X&ei=XosDVaCXD8TasATItgE&ved=0CAcQ_AUoAg' 431 | elif specific_site: 432 | url = 'https://www.google.com/search?q=' + quote( 433 | search_term.encode('utf-8')) + '&as_sitesearch=' + specific_site + '&espv=2&biw=1366&bih=667&site=webhp&source=lnms&tbm=isch' + params + '&sa=X&ei=XosDVaCXD8TasATItgE&ved=0CAcQ_AUoAg' 434 | else: 435 | url = 'https://www.google.com/search?q=' + quote( 436 | search_term.encode('utf-8')) + '&espv=2&biw=1366&bih=667&site=webhp&source=lnms&tbm=isch' + params + '&sa=X&ei=XosDVaCXD8TasATItgE&ved=0CAcQ_AUoAg' 437 | 438 | #safe search check 439 | if safe_search: 440 | url = url + safe_search_string 441 | 442 | return url 443 | 444 | 445 | #measures the file size 446 | def file_size(self,file_path): 447 | if os.path.isfile(file_path): 448 | file_info = os.stat(file_path) 449 | size = file_info.st_size 450 | for x in ['bytes', 'KB', 'MB', 'GB', 'TB']: 451 | if size < 1024.0: 452 | return "%3.1f %s" % (size, x) 453 | size /= 1024.0 454 | return size 455 | 456 | #keywords from file 457 | def keywords_from_file(self,file_name): 458 | search_keyword = [] 459 | with codecs.open(file_name, 'r', encoding='utf-8-sig') as f: 460 | if '.csv' in file_name: 461 | for line in f: 462 | if line in ['\n', '\r\n']: 463 | pass 464 | else: 465 | search_keyword.append(line.replace('\n', '').replace('\r', '')) 466 | elif '.txt' in file_name: 467 | for line in f: 468 | if line in ['\n', '\r\n']: 469 | pass 470 | else: 471 | search_keyword.append(line.replace('\n', '').replace('\r', '')) 472 | else: 473 | print("Invalid file type: Valid file types are either .txt or .csv \n" 474 | "exiting...") 475 | sys.exit() 476 | return search_keyword 477 | 478 | # make directories 479 | def create_directories(self,main_directory, dir_name,thumbnail,thumbnail_only): 480 | dir_name_thumbnail = dir_name + " - thumbnail" 481 | # make a search keyword directory 482 | try: 483 | if not os.path.exists(main_directory): 484 | os.makedirs(main_directory) 485 | time.sleep(0.2) 486 | path = (dir_name) 487 | sub_directory = os.path.join(main_directory, path) 488 | if not os.path.exists(sub_directory): 489 | os.makedirs(sub_directory) 490 | if thumbnail or thumbnail_only: 491 | sub_directory_thumbnail = os.path.join(main_directory, dir_name_thumbnail) 492 | if not os.path.exists(sub_directory_thumbnail): 493 | os.makedirs(sub_directory_thumbnail) 494 | else: 495 | path = (dir_name) 496 | sub_directory = os.path.join(main_directory, path) 497 | if not os.path.exists(sub_directory): 498 | os.makedirs(sub_directory) 499 | if thumbnail or thumbnail_only: 500 | sub_directory_thumbnail = os.path.join(main_directory, dir_name_thumbnail) 501 | if not os.path.exists(sub_directory_thumbnail): 502 | os.makedirs(sub_directory_thumbnail) 503 | except OSError as e: 504 | if e.errno != 17: 505 | raise 506 | pass 507 | return 508 | 509 | 510 | # Download Image thumbnails 511 | def download_image_thumbnail(self,image_url,main_directory,dir_name,return_image_name,print_urls,socket_timeout,print_size,no_download,save_source,img_src,ignore_urls): 512 | if print_urls or no_download: 513 | print("Image URL: " + image_url) 514 | if no_download: 515 | return "success","Printed url without downloading" 516 | try: 517 | req = Request(image_url, headers={ 518 | "User-Agent": "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17"}) 519 | try: 520 | # timeout time to download an image 521 | if socket_timeout: 522 | timeout = float(socket_timeout) 523 | else: 524 | timeout = 10 525 | 526 | response = urlopen(req, None, timeout) 527 | data = response.read() 528 | response.close() 529 | 530 | path = main_directory + "/" + dir_name + " - thumbnail" + "/" + return_image_name 531 | 532 | try: 533 | output_file = open(path, 'wb') 534 | output_file.write(data) 535 | output_file.close() 536 | if save_source: 537 | list_path = main_directory + "/" + save_source + ".txt" 538 | list_file = open(list_path,'a') 539 | list_file.write(path + '\t' + img_src + '\n') 540 | list_file.close() 541 | except OSError as e: 542 | download_status = 'fail' 543 | download_message = "OSError on an image...trying next one..." + " Error: " + str(e) 544 | except IOError as e: 545 | download_status = 'fail' 546 | download_message = "IOError on an image...trying next one..." + " Error: " + str(e) 547 | 548 | download_status = 'success' 549 | download_message = "Completed Image Thumbnail ====> " + return_image_name 550 | 551 | # image size parameter 552 | if print_size: 553 | print("Image Size: " + str(self.file_size(path))) 554 | 555 | except UnicodeEncodeError as e: 556 | download_status = 'fail' 557 | download_message = "UnicodeEncodeError on an image...trying next one..." + " Error: " + str(e) 558 | 559 | except HTTPError as e: # If there is any HTTPError 560 | download_status = 'fail' 561 | download_message = "HTTPError on an image...trying next one..." + " Error: " + str(e) 562 | 563 | except URLError as e: 564 | download_status = 'fail' 565 | download_message = "URLError on an image...trying next one..." + " Error: " + str(e) 566 | 567 | except ssl.CertificateError as e: 568 | download_status = 'fail' 569 | download_message = "CertificateError on an image...trying next one..." + " Error: " + str(e) 570 | 571 | except IOError as e: # If there is any IOError 572 | download_status = 'fail' 573 | download_message = "IOError on an image...trying next one..." + " Error: " + str(e) 574 | return download_status, download_message 575 | 576 | 577 | # Download Images 578 | def download_image(self,image_url,image_format,main_directory,dir_name,count,print_urls,socket_timeout,prefix,print_size,no_numbering,no_download,save_source,img_src,silent_mode,thumbnail_only,format,ignore_urls): 579 | if not silent_mode: 580 | if print_urls or no_download: 581 | print("Image URL: " + image_url) 582 | if ignore_urls: 583 | if any(url in image_url for url in ignore_urls.split(',')): 584 | return "fail", "Image ignored due to 'ignore url' parameter", None, image_url 585 | if thumbnail_only: 586 | return "success", "Skipping image download...", str(image_url[(image_url.rfind('/')) + 1:]), image_url 587 | if no_download: 588 | return "success","Printed url without downloading",None,image_url 589 | try: 590 | req = Request(image_url, headers={ 591 | "User-Agent": "Mozilla/5.0 (X11; Linux i686) AppleWebKit/537.17 (KHTML, like Gecko) Chrome/24.0.1312.27 Safari/537.17"}) 592 | try: 593 | # timeout time to download an image 594 | if socket_timeout: 595 | timeout = float(socket_timeout) 596 | else: 597 | timeout = 10 598 | 599 | response = urlopen(req, None, timeout) 600 | data = response.read() 601 | response.close() 602 | 603 | extensions = [".jpg", ".jpeg", ".gif", ".png", ".bmp", ".svg", ".webp", ".ico"] 604 | # keep everything after the last '/' 605 | image_name = str(image_url[(image_url.rfind('/')) + 1:]) 606 | if format: 607 | if not image_format or image_format != format: 608 | download_status = 'fail' 609 | download_message = "Wrong image format returned. Skipping..." 610 | return_image_name = '' 611 | absolute_path = '' 612 | return download_status, download_message, return_image_name, absolute_path 613 | 614 | if image_format == "" or not image_format or "." + image_format not in extensions: 615 | download_status = 'fail' 616 | download_message = "Invalid or missing image format. Skipping..." 617 | return_image_name = '' 618 | absolute_path = '' 619 | return download_status, download_message, return_image_name, absolute_path 620 | elif image_name.lower().find("." + image_format) < 0: 621 | image_name = image_name + "." + image_format 622 | else: 623 | image_name = image_name[:image_name.lower().find("." + image_format) + (len(image_format) + 1)] 624 | 625 | # prefix name in image 626 | if prefix: 627 | prefix = prefix + " " 628 | else: 629 | prefix = '' 630 | 631 | if no_numbering: 632 | path = main_directory + "/" + dir_name + "/" + prefix + image_name 633 | else: 634 | path = main_directory + "/" + dir_name + "/" + prefix + str(count) + "." + image_name 635 | 636 | try: 637 | output_file = open(path, 'wb') 638 | output_file.write(data) 639 | output_file.close() 640 | if save_source: 641 | list_path = main_directory + "/" + save_source + ".txt" 642 | list_file = open(list_path,'a') 643 | list_file.write(path + '\t' + img_src + '\n') 644 | list_file.close() 645 | absolute_path = os.path.abspath(path) 646 | except OSError as e: 647 | download_status = 'fail' 648 | download_message = "OSError on an image...trying next one..." + " Error: " + str(e) 649 | return_image_name = '' 650 | absolute_path = '' 651 | 652 | #return image name back to calling method to use it for thumbnail downloads 653 | download_status = 'success' 654 | download_message = "Completed Image ====> " + prefix + str(count) + "." + image_name 655 | return_image_name = prefix + str(count) + "." + image_name 656 | 657 | # image size parameter 658 | if not silent_mode: 659 | if print_size: 660 | print("Image Size: " + str(self.file_size(path))) 661 | 662 | except UnicodeEncodeError as e: 663 | download_status = 'fail' 664 | download_message = "UnicodeEncodeError on an image...trying next one..." + " Error: " + str(e) 665 | return_image_name = '' 666 | absolute_path = '' 667 | 668 | except URLError as e: 669 | download_status = 'fail' 670 | download_message = "URLError on an image...trying next one..." + " Error: " + str(e) 671 | return_image_name = '' 672 | absolute_path = '' 673 | 674 | except BadStatusLine as e: 675 | download_status = 'fail' 676 | download_message = "BadStatusLine on an image...trying next one..." + " Error: " + str(e) 677 | return_image_name = '' 678 | absolute_path = '' 679 | 680 | except HTTPError as e: # If there is any HTTPError 681 | download_status = 'fail' 682 | download_message = "HTTPError on an image...trying next one..." + " Error: " + str(e) 683 | return_image_name = '' 684 | absolute_path = '' 685 | 686 | except URLError as e: 687 | download_status = 'fail' 688 | download_message = "URLError on an image...trying next one..." + " Error: " + str(e) 689 | return_image_name = '' 690 | absolute_path = '' 691 | 692 | except ssl.CertificateError as e: 693 | download_status = 'fail' 694 | download_message = "CertificateError on an image...trying next one..." + " Error: " + str(e) 695 | return_image_name = '' 696 | absolute_path = '' 697 | 698 | except IOError as e: # If there is any IOError 699 | download_status = 'fail' 700 | download_message = "IOError on an image...trying next one..." + " Error: " + str(e) 701 | return_image_name = '' 702 | absolute_path = '' 703 | 704 | except IncompleteRead as e: 705 | download_status = 'fail' 706 | download_message = "IncompleteReadError on an image...trying next one..." + " Error: " + str(e) 707 | return_image_name = '' 708 | absolute_path = '' 709 | 710 | return download_status,download_message,return_image_name,absolute_path 711 | 712 | 713 | # Finding 'Next Image' from the given raw page 714 | def _get_next_item(self,s): 715 | start_line = s.find('rg_meta notranslate') 716 | if start_line == -1: # If no links are found then give an error! 717 | end_quote = 0 718 | link = "no_links" 719 | return link, end_quote 720 | else: 721 | start_line = s.find('class="rg_meta notranslate">') 722 | start_object = s.find('{', start_line + 1) 723 | end_object = s.find('', start_object + 1) 724 | object_raw = str(s[start_object:end_object]) 725 | #remove escape characters based on python version 726 | version = (3, 0) 727 | cur_version = sys.version_info 728 | if cur_version >= version: #python3 729 | try: 730 | object_decode = bytes(object_raw, "utf-8").decode("unicode_escape") 731 | final_object = json.loads(object_decode) 732 | except: 733 | final_object = "" 734 | else: #python2 735 | try: 736 | final_object = (json.loads(self.repair(object_raw))) 737 | except: 738 | final_object = "" 739 | return final_object, end_object 740 | 741 | 742 | # Getting all links with the help of '_images_get_next_image' 743 | def _get_all_items(self,page,main_directory,dir_name,limit,arguments): 744 | items = [] 745 | abs_path = [] 746 | errorCount = 0 747 | i = 0 748 | count = 1 749 | while count < limit+1: 750 | object, end_content = self._get_next_item(page) 751 | if object == "no_links": 752 | break 753 | elif object == "": 754 | page = page[end_content:] 755 | elif arguments['offset'] and count < int(arguments['offset']): 756 | count += 1 757 | page = page[end_content:] 758 | else: 759 | #format the item for readability 760 | object = self.format_object(object) 761 | if arguments['metadata']: 762 | if not arguments["silent_mode"]: 763 | print("\nImage Metadata: " + str(object)) 764 | 765 | #download the images 766 | download_status,download_message,return_image_name,absolute_path = self.download_image(object['image_link'],object['image_format'],main_directory,dir_name,count,arguments['print_urls'],arguments['socket_timeout'],arguments['prefix'],arguments['print_size'],arguments['no_numbering'],arguments['no_download'],arguments['save_source'],object['image_source'],arguments["silent_mode"],arguments["thumbnail_only"],arguments['format'],arguments['ignore_urls']) 767 | if not arguments["silent_mode"]: 768 | print(download_message) 769 | if download_status == "success": 770 | 771 | # download image_thumbnails 772 | if arguments['thumbnail'] or arguments["thumbnail_only"]: 773 | download_status, download_message_thumbnail = self.download_image_thumbnail(object['image_thumbnail_url'],main_directory,dir_name,return_image_name,arguments['print_urls'],arguments['socket_timeout'],arguments['print_size'],arguments['no_download'],arguments['save_source'],object['image_source'],arguments['ignore_urls']) 774 | if not arguments["silent_mode"]: 775 | print(download_message_thumbnail) 776 | 777 | count += 1 778 | object['image_filename'] = return_image_name 779 | items.append(object) # Append all the links in the list named 'Links' 780 | abs_path.append(absolute_path) 781 | else: 782 | errorCount += 1 783 | 784 | #delay param 785 | if arguments['delay']: 786 | time.sleep(int(arguments['delay'])) 787 | 788 | page = page[end_content:] 789 | i += 1 790 | if count < limit: 791 | print("\n\nUnfortunately all " + str( 792 | limit) + " could not be downloaded because some images were not downloadable. " + str( 793 | count-1) + " is all we got for this search filter!") 794 | return items,errorCount,abs_path 795 | 796 | 797 | # Bulk Download 798 | def download(self,arguments): 799 | paths_agg = {} 800 | # for input coming from other python files 801 | if __name__ != "__main__": 802 | # if the calling file contains config_file param 803 | if 'config_file' in arguments: 804 | records = [] 805 | json_file = json.load(open(arguments['config_file'])) 806 | for record in range(0, len(json_file['Records'])): 807 | arguments = {} 808 | for i in args_list: 809 | arguments[i] = None 810 | for key, value in json_file['Records'][record].items(): 811 | arguments[key] = value 812 | records.append(arguments) 813 | total_errors = 0 814 | for rec in records: 815 | paths, errors, items = self.download_executor(rec) 816 | for i in paths: 817 | paths_agg[i] = paths[i] 818 | if not arguments["silent_mode"]: 819 | if arguments['print_paths']: 820 | print(paths.encode('raw_unicode_escape').decode('utf-8')) 821 | total_errors = total_errors + errors 822 | return paths_agg,total_errors, items 823 | # if the calling file contains params directly 824 | else: 825 | paths, errors, items = self.download_executor(arguments) 826 | for i in paths: 827 | paths_agg[i] = paths[i] 828 | if not arguments["silent_mode"]: 829 | if arguments['print_paths']: 830 | print(paths.encode('raw_unicode_escape').decode('utf-8')) 831 | return paths_agg, errors, items 832 | # for input coming from CLI 833 | else: 834 | paths, errors, items = self.download_executor(arguments) 835 | for i in paths: 836 | paths_agg[i] = paths[i] 837 | if not arguments["silent_mode"]: 838 | if arguments['print_paths']: 839 | print(paths.encode('raw_unicode_escape').decode('utf-8')) 840 | return paths_agg, errors, items 841 | 842 | def download_executor(self,arguments): 843 | paths = {} 844 | errorCount = None 845 | for arg in args_list: 846 | if arg not in arguments: 847 | arguments[arg] = None 848 | ######Initialization and Validation of user arguments 849 | if arguments['keywords']: 850 | search_keyword = [str(item) for item in arguments['keywords'].split(',')] 851 | 852 | if arguments['keywords_from_file']: 853 | search_keyword = self.keywords_from_file(arguments['keywords_from_file']) 854 | 855 | # both time and time range should not be allowed in the same query 856 | if arguments['time'] and arguments['time_range']: 857 | raise ValueError('Either time or time range should be used in a query. Both cannot be used at the same time.') 858 | 859 | # both time and time range should not be allowed in the same query 860 | if arguments['size'] and arguments['exact_size']: 861 | raise ValueError('Either "size" or "exact_size" should be used in a query. Both cannot be used at the same time.') 862 | 863 | # both image directory and no image directory should not be allowed in the same query 864 | if arguments['image_directory'] and arguments['no_directory']: 865 | raise ValueError('You can either specify image directory or specify no image directory, not both!') 866 | 867 | # Additional words added to keywords 868 | if arguments['suffix_keywords']: 869 | suffix_keywords = [" " + str(sk) for sk in arguments['suffix_keywords'].split(',')] 870 | else: 871 | suffix_keywords = [''] 872 | 873 | # Additional words added to keywords 874 | if arguments['prefix_keywords']: 875 | prefix_keywords = [str(sk) + " " for sk in arguments['prefix_keywords'].split(',')] 876 | else: 877 | prefix_keywords = [''] 878 | 879 | # Setting limit on number of images to be downloaded 880 | if arguments['limit']: 881 | limit = int(arguments['limit']) 882 | else: 883 | limit = 100 884 | 885 | if arguments['url']: 886 | current_time = str(datetime.datetime.now()).split('.')[0] 887 | search_keyword = [current_time.replace(":", "_")] 888 | 889 | if arguments['similar_images']: 890 | current_time = str(datetime.datetime.now()).split('.')[0] 891 | search_keyword = [current_time.replace(":", "_")] 892 | 893 | # If single_image or url argument not present then keywords is mandatory argument 894 | if arguments['single_image'] is None and arguments['url'] is None and arguments['similar_images'] is None and \ 895 | arguments['keywords'] is None and arguments['keywords_from_file'] is None: 896 | print('-------------------------------\n' 897 | 'Uh oh! Keywords is a required argument \n\n' 898 | 'Please refer to the documentation on guide to writing queries \n' 899 | 'https://github.com/hardikvasa/google-images-download#examples' 900 | '\n\nexiting!\n' 901 | '-------------------------------') 902 | sys.exit() 903 | 904 | # If this argument is present, set the custom output directory 905 | if arguments['output_directory']: 906 | main_directory = arguments['output_directory'] 907 | else: 908 | main_directory = "downloads" 909 | 910 | # Proxy settings 911 | if arguments['proxy']: 912 | os.environ["http_proxy"] = arguments['proxy'] 913 | os.environ["https_proxy"] = arguments['proxy'] 914 | ######Initialization Complete 915 | total_errors = 0 916 | for pky in prefix_keywords: # 1.for every prefix keywords 917 | for sky in suffix_keywords: # 2.for every suffix keywords 918 | i = 0 919 | while i < len(search_keyword): # 3.for every main keyword 920 | iteration = "\n" + "Item no.: " + str(i + 1) + " -->" + " Item name = " + (pky) + (search_keyword[i]) + (sky) 921 | if not arguments["silent_mode"]: 922 | print(iteration.encode('raw_unicode_escape').decode('utf-8')) 923 | print("Evaluating...") 924 | else: 925 | print("Downloading images for: " + (pky) + (search_keyword[i]) + (sky) + " ...") 926 | search_term = pky + search_keyword[i] + sky 927 | 928 | if arguments['image_directory']: 929 | dir_name = arguments['image_directory'] 930 | elif arguments['no_directory']: 931 | dir_name = '' 932 | else: 933 | dir_name = search_term + ('-' + arguments['color'] if arguments['color'] else '') #sub-directory 934 | 935 | if not arguments["no_download"]: 936 | self.create_directories(main_directory,dir_name,arguments['thumbnail'],arguments['thumbnail_only']) #create directories in OS 937 | 938 | params = self.build_url_parameters(arguments) #building URL with params 939 | 940 | url = self.build_search_url(search_term,params,arguments['url'],arguments['similar_images'],arguments['specific_site'],arguments['safe_search']) #building main search url 941 | 942 | if limit < 101: 943 | raw_html = self.download_page(url) # download page 944 | else: 945 | raw_html = self.download_extended_page(url,arguments['chromedriver']) 946 | 947 | if not arguments["silent_mode"]: 948 | if arguments['no_download']: 949 | print("Getting URLs without downloading images...") 950 | else: 951 | print("Starting Download...") 952 | items,errorCount,abs_path = self._get_all_items(raw_html,main_directory,dir_name,limit,arguments) #get all image items and download images 953 | #print(items) 954 | paths[pky + search_keyword[i] + sky] = abs_path 955 | 956 | #dumps into a json file 957 | if arguments['extract_metadata']: 958 | try: 959 | if not os.path.exists("logs"): 960 | os.makedirs("logs") 961 | except OSError as e: 962 | print(e) 963 | json_file = open("logs/"+search_keyword[i]+".json", "w") 964 | json.dump(items, json_file, indent=4, sort_keys=True) 965 | json_file.close() 966 | 967 | #Related images 968 | if arguments['related_images']: 969 | print("\nGetting list of related keywords...this may take a few moments") 970 | tabs = self.get_all_tabs(raw_html) 971 | for key, value in tabs.items(): 972 | final_search_term = (search_term + " - " + key) 973 | print("\nNow Downloading - " + final_search_term) 974 | if limit < 101: 975 | new_raw_html = self.download_page(value) # download page 976 | else: 977 | new_raw_html = self.download_extended_page(value,arguments['chromedriver']) 978 | self.create_directories(main_directory, final_search_term,arguments['thumbnail'],arguments['thumbnail_only']) 979 | self._get_all_items(new_raw_html, main_directory, search_term + " - " + key, limit,arguments) 980 | 981 | i += 1 982 | total_errors = total_errors + errorCount 983 | if not arguments["silent_mode"]: 984 | print("\nErrors: " + str(errorCount) + "\n") 985 | return paths, total_errors, items 986 | 987 | #------------- Main Program -------------# 988 | def main(): 989 | records = user_input() 990 | total_errors = 0 991 | t0 = time.time() # start the timer 992 | for arguments in records: 993 | if arguments['single_image']: # Download Single Image using a URL 994 | response = googleimagesdownload() 995 | response.single_image(arguments['single_image']) 996 | else: # or download multiple images based on keywords/keyphrase search 997 | response = googleimagesdownload() 998 | paths,errors,items = response.download(arguments) #wrapping response in a variable just for consistency 999 | total_errors = total_errors + errors 1000 | 1001 | t1 = time.time() # stop the timer 1002 | total_time = t1 - t0 # Calculating the total time required to crawl, find and download all the links of 60,000 images 1003 | if not arguments["silent_mode"]: 1004 | print("\nEverything downloaded!") 1005 | print("Total errors: " + str(total_errors)) 1006 | print("Total time taken: " + str(total_time) + " Seconds") 1007 | 1008 | if __name__ == "__main__": 1009 | main() 1010 | 1011 | # In[ ]: -------------------------------------------------------------------------------- /images/10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/10.png -------------------------------------------------------------------------------- /images/11.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/11.png -------------------------------------------------------------------------------- /images/12.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/12.png -------------------------------------------------------------------------------- /images/5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/5.png -------------------------------------------------------------------------------- /images/6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/6.png -------------------------------------------------------------------------------- /images/7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/7.png -------------------------------------------------------------------------------- /images/8.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/8.png -------------------------------------------------------------------------------- /images/9.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/images/9.png -------------------------------------------------------------------------------- /matting_result/01_matting_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_matting_01_01.png -------------------------------------------------------------------------------- /matting_result/01_matting_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_matting_25_05.png -------------------------------------------------------------------------------- /matting_result/01_matting_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_matting_30_10.png -------------------------------------------------------------------------------- /matting_result/01_matting_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_matting_30_20.png -------------------------------------------------------------------------------- /matting_result/01_matting_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_matting_40_20.png -------------------------------------------------------------------------------- /matting_result/01_trimap_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_trimap_01_01.png -------------------------------------------------------------------------------- /matting_result/01_trimap_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_trimap_25_05.png -------------------------------------------------------------------------------- /matting_result/01_trimap_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_trimap_30_10.png -------------------------------------------------------------------------------- /matting_result/01_trimap_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_trimap_30_20.png -------------------------------------------------------------------------------- /matting_result/01_trimap_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/01_trimap_40_20.png -------------------------------------------------------------------------------- /matting_result/02_matting_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_matting_01_01.png -------------------------------------------------------------------------------- /matting_result/02_matting_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_matting_25_05.png -------------------------------------------------------------------------------- /matting_result/02_matting_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_matting_30_10.png -------------------------------------------------------------------------------- /matting_result/02_matting_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_matting_30_20.png -------------------------------------------------------------------------------- /matting_result/02_matting_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_matting_40_20.png -------------------------------------------------------------------------------- /matting_result/02_trimap_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_trimap_01_01.png -------------------------------------------------------------------------------- /matting_result/02_trimap_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_trimap_25_05.png -------------------------------------------------------------------------------- /matting_result/02_trimap_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_trimap_30_10.png -------------------------------------------------------------------------------- /matting_result/02_trimap_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_trimap_30_20.png -------------------------------------------------------------------------------- /matting_result/02_trimap_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/02_trimap_40_20.png -------------------------------------------------------------------------------- /matting_result/03_matting_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_matting_01_01.png -------------------------------------------------------------------------------- /matting_result/03_matting_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_matting_25_05.png -------------------------------------------------------------------------------- /matting_result/03_matting_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_matting_30_10.png -------------------------------------------------------------------------------- /matting_result/03_matting_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_matting_30_20.png -------------------------------------------------------------------------------- /matting_result/03_matting_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_matting_40_20.png -------------------------------------------------------------------------------- /matting_result/03_trimap_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_trimap_01_01.png -------------------------------------------------------------------------------- /matting_result/03_trimap_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_trimap_25_05.png -------------------------------------------------------------------------------- /matting_result/03_trimap_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_trimap_30_10.png -------------------------------------------------------------------------------- /matting_result/03_trimap_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_trimap_30_20.png -------------------------------------------------------------------------------- /matting_result/03_trimap_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/03_trimap_40_20.png -------------------------------------------------------------------------------- /matting_result/04_matting_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_matting_01_01.png -------------------------------------------------------------------------------- /matting_result/04_matting_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_matting_25_05.png -------------------------------------------------------------------------------- /matting_result/04_matting_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_matting_30_10.png -------------------------------------------------------------------------------- /matting_result/04_matting_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_matting_30_20.png -------------------------------------------------------------------------------- /matting_result/04_matting_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_matting_40_20.png -------------------------------------------------------------------------------- /matting_result/04_trimap_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_trimap_01_01.png -------------------------------------------------------------------------------- /matting_result/04_trimap_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_trimap_25_05.png -------------------------------------------------------------------------------- /matting_result/04_trimap_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_trimap_30_10.png -------------------------------------------------------------------------------- /matting_result/04_trimap_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_trimap_30_20.png -------------------------------------------------------------------------------- /matting_result/04_trimap_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/04_trimap_40_20.png -------------------------------------------------------------------------------- /matting_result/05_matting_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_matting_01_01.png -------------------------------------------------------------------------------- /matting_result/05_matting_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_matting_25_05.png -------------------------------------------------------------------------------- /matting_result/05_matting_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_matting_30_10.png -------------------------------------------------------------------------------- /matting_result/05_matting_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_matting_30_20.png -------------------------------------------------------------------------------- /matting_result/05_matting_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_matting_40_20.png -------------------------------------------------------------------------------- /matting_result/05_trimap_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_trimap_01_01.png -------------------------------------------------------------------------------- /matting_result/05_trimap_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_trimap_25_05.png -------------------------------------------------------------------------------- /matting_result/05_trimap_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_trimap_30_10.png -------------------------------------------------------------------------------- /matting_result/05_trimap_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_trimap_30_20.png -------------------------------------------------------------------------------- /matting_result/05_trimap_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/05_trimap_40_20.png -------------------------------------------------------------------------------- /matting_result/06_matting_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_matting_01_01.png -------------------------------------------------------------------------------- /matting_result/06_matting_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_matting_25_05.png -------------------------------------------------------------------------------- /matting_result/06_matting_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_matting_30_10.png -------------------------------------------------------------------------------- /matting_result/06_matting_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_matting_30_20.png -------------------------------------------------------------------------------- /matting_result/06_matting_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_matting_40_20.png -------------------------------------------------------------------------------- /matting_result/06_trimap_01_01.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_trimap_01_01.png -------------------------------------------------------------------------------- /matting_result/06_trimap_25_05.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_trimap_25_05.png -------------------------------------------------------------------------------- /matting_result/06_trimap_30_10.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_trimap_30_10.png -------------------------------------------------------------------------------- /matting_result/06_trimap_30_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_trimap_30_20.png -------------------------------------------------------------------------------- /matting_result/06_trimap_40_20.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/matting_result/06_trimap_40_20.png -------------------------------------------------------------------------------- /net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import math 4 | import cv2 5 | import torch.nn.functional as F 6 | 7 | class VGG16(nn.Module): 8 | def __init__(self, stage):#args 9 | super(VGG16, self).__init__() 10 | #self.stage = args.stage 11 | self.stage = stage 12 | 13 | self.conv1_1 = nn.Conv2d(4, 64, kernel_size=3,stride = 1, padding=1,bias=True) 14 | self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3,stride = 1, padding=1,bias=True) 15 | self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, padding=1,bias=True) 16 | self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, padding=1,bias=True) 17 | self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, padding=1,bias=True) 18 | self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True) 19 | self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, padding=1,bias=True) 20 | self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, padding=1,bias=True) 21 | self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True) 22 | self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True) 23 | self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True) 24 | self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True) 25 | self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True) 26 | 27 | # model released before 2019.09.09 should use kernel_size=1 & padding=0 28 | #self.conv6_1 = nn.Conv2d(512, 512, kernel_size=1, padding=0,bias=True) 29 | self.conv6_1 = nn.Conv2d(512, 512, kernel_size=3, padding=1,bias=True) 30 | 31 | self.deconv6_1 = nn.Conv2d(512, 512, kernel_size=1,bias=True) 32 | self.deconv5_1 = nn.Conv2d(512, 512, kernel_size=5, padding=2,bias=True) 33 | self.deconv4_1 = nn.Conv2d(512, 256, kernel_size=5, padding=2,bias=True) 34 | self.deconv3_1 = nn.Conv2d(256, 128, kernel_size=5, padding=2,bias=True) 35 | self.deconv2_1 = nn.Conv2d(128, 64, kernel_size=5, padding=2,bias=True) 36 | self.deconv1_1 = nn.Conv2d(64, 64, kernel_size=5, padding=2,bias=True) 37 | 38 | self.deconv1 = nn.Conv2d(64, 1, kernel_size=5, padding=2,bias=True) 39 | 40 | if stage == 2: 41 | #if args.stage == 2: 42 | # for stage2 training 43 | for p in self.parameters(): 44 | p.requires_grad=False 45 | 46 | if self.stage == 2 or self.stage == 3: 47 | self.refine_conv1 = nn.Conv2d(4, 64, kernel_size=3, padding=1, bias=True) 48 | self.refine_conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True) 49 | self.refine_conv3 = nn.Conv2d(64, 64, kernel_size=3, padding=1, bias=True) 50 | self.refine_pred = nn.Conv2d(64, 1, kernel_size=3, padding=1, bias=True) 51 | 52 | def forward(self, x): 53 | # Stage 1 54 | x11 = F.relu(self.conv1_1(x)) 55 | x12 = F.relu(self.conv1_2(x11)) 56 | x1p, id1 = F.max_pool2d(x12,kernel_size=(2,2), stride=(2,2),return_indices=True) 57 | 58 | # Stage 2 59 | x21 = F.relu(self.conv2_1(x1p)) 60 | x22 = F.relu(self.conv2_2(x21)) 61 | x2p, id2 = F.max_pool2d(x22,kernel_size=(2,2), stride=(2,2),return_indices=True) 62 | 63 | # Stage 3 64 | x31 = F.relu(self.conv3_1(x2p)) 65 | x32 = F.relu(self.conv3_2(x31)) 66 | x33 = F.relu(self.conv3_3(x32)) 67 | x3p, id3 = F.max_pool2d(x33,kernel_size=(2,2), stride=(2,2),return_indices=True) 68 | 69 | # Stage 4 70 | x41 = F.relu(self.conv4_1(x3p)) 71 | x42 = F.relu(self.conv4_2(x41)) 72 | x43 = F.relu(self.conv4_3(x42)) 73 | x4p, id4 = F.max_pool2d(x43,kernel_size=(2,2), stride=(2,2),return_indices=True) 74 | 75 | # Stage 5 76 | x51 = F.relu(self.conv5_1(x4p)) 77 | x52 = F.relu(self.conv5_2(x51)) 78 | x53 = F.relu(self.conv5_3(x52)) 79 | x5p, id5 = F.max_pool2d(x53,kernel_size=(2,2), stride=(2,2),return_indices=True) 80 | 81 | # Stage 6 82 | x61 = F.relu(self.conv6_1(x5p)) 83 | 84 | # Stage 6d 85 | x61d = F.relu(self.deconv6_1(x61)) 86 | 87 | # Stage 5d 88 | x5d = F.max_unpool2d(x61d,id5, kernel_size=2, stride=2) 89 | x51d = F.relu(self.deconv5_1(x5d)) 90 | 91 | # Stage 4d 92 | x4d = F.max_unpool2d(x51d, id4, kernel_size=2, stride=2) 93 | x41d = F.relu(self.deconv4_1(x4d)) 94 | 95 | # Stage 3d 96 | x3d = F.max_unpool2d(x41d, id3, kernel_size=2, stride=2) 97 | x31d = F.relu(self.deconv3_1(x3d)) 98 | 99 | # Stage 2d 100 | x2d = F.max_unpool2d(x31d, id2, kernel_size=2, stride=2) 101 | x21d = F.relu(self.deconv2_1(x2d)) 102 | 103 | # Stage 1d 104 | x1d = F.max_unpool2d(x21d, id1, kernel_size=2, stride=2) 105 | x12d = F.relu(self.deconv1_1(x1d)) 106 | 107 | # Should add sigmoid? github repo add so. 108 | raw_alpha = self.deconv1(x12d) 109 | pred_mattes = F.sigmoid(raw_alpha) 110 | 111 | if self.stage <= 1: 112 | return pred_mattes, 0 113 | 114 | # Stage2 refine conv1 115 | refine0 = torch.cat((x[:, :3, :, :], pred_mattes), 1) 116 | refine1 = F.relu(self.refine_conv1(refine0)) 117 | refine2 = F.relu(self.refine_conv2(refine1)) 118 | refine3 = F.relu(self.refine_conv3(refine2)) 119 | # Should add sigmoid? 120 | # sigmoid lead to refine result all converge to 0... 121 | #pred_refine = F.sigmoid(self.refine_pred(refine3)) 122 | pred_refine = self.refine_pred(refine3) 123 | 124 | pred_alpha = F.sigmoid(raw_alpha + pred_refine) 125 | 126 | #print(pred_mattes.mean(), pred_alpha.mean(), pred_refine.sum()) 127 | 128 | return pred_mattes, pred_alpha 129 | -------------------------------------------------------------------------------- /pytorch_deep_image_matting/README.md: -------------------------------------------------------------------------------- 1 | # pytorch-deep-image-matting 2 | This repository includes the non-official pytorch implementation of [deep image matting](http://openaccess.thecvf.com/content_cvpr_2017/papers/Xu_Deep_Image_Matting_CVPR_2017_paper.pdf). 3 | 4 | ## Performance 5 | |model |SAD |MSE |Grad |Conn | link | 6 | |------------|---------|---------|---------|---------| ---- | 7 | |paper-stage1| 54.6 |**0.017**| 36.7 | 55.3 | | 8 | |my-stage1 |**54.42**| 0.0175 |**35.01**|**54.85**|[download](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.4/stage1_sad_54.4.pth)| 9 | * Lower metrics show better performance. 10 | * Training batch=1, images=43100, epochs=25, it takes about 2 days. 11 | * Test maxSize=1600. 12 | 13 | 14 | ## Updates 15 | * 2019.09.09: conv6 kernel size from 1x1 to 3x3. Get [Stage1-SAD=54.4](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.4/stage1_sad_54.4.pth). **The performance of stage1 is as good as paper**. While using model released before this day, please change the kernel_size=1 and padding=0 of conv6 in file core/net.py. 16 | * 2019.08.24: Fix cv2.dilate and cv2.erode iterations is set default = 1 and set triamp dilate and erode as the test 1k tirmap (k_size:2-5, iterations:5-15). Get [Stage1-SAD=57.1](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.3/stage1_sad_57.1.pth). 17 | * 2019.07.05: Training with refine stage, fixed encoder-decoder. Get [Stage2-SAD=57.7](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.2/stage2_norm_balance_sad_57.9.pth). 18 | * 2019.06.23: Training with alpha loss and composite loss. Get [Stage1-SAD=58.7](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.2/stage1_norm_balance_sad_58.7.pth). 19 | * 2019.06.17: Training trimap generated by erode as well as dialte to balance the 0 and 1 value. Get [Stage0-SAD=62.0](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.2/stage0_norm_balance_sad_62.0.pth). 20 | * 2019.04.22: Input image is normalized by mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] and fix crop error. Get [Stage0-SAD=69.1](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.1/stage0_norm_e12_sad_69.1.pth). 21 | * 2018.12.14: Initial. Get [Stage0-SAD=72.9](https://github.com/huochaitiantang/pytorch-deep-image-matting/releases/download/v1.0/my_stage0_sad_72.9.pth). 22 | 23 | ## Installation 24 | * Python 2.7.12 or 3.6.5 25 | * Pytorch 0.4.0 or 1.0.0 26 | * OpenCV 3.4.3 27 | 28 | ## Demo 29 | Download our model to the `./model` and run the following command. Then the predict alpha mattes will locate in the folder `./result/example/pred`. 30 | 31 | python core/demo.py 32 | 33 | ## Training 34 | ### Adobe-Deep-Image-Matting-Dataset 35 | Please concat author for available. 36 | ### MSCOCO-2017-Train-Dataset 37 | [Download](http://images.cocodataset.org/zips/train2017.zip) 38 | ### PASCAL-VOC-2012 39 | [Download](http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar) 40 | ### Composite-Dataset 41 | Run the following command and the composite training and test dataset will locate in `Combined_Dataset/Training_set/comp` and `Combined_Dataset/Test_set/comp`, `Combined_Dataset` is the extracted folder of Adobe-Deep-Image-Matting-Dataset 42 | 43 | python tools/composite.py 44 | 45 | ### Pretrained-Model 46 | Run the following command and the pretrained model will locate in `./model/vgg_state_dict.pth` 47 | 48 | python tools/chg_model.py 49 | 50 | ### Start Training 51 | Run the following command and start the training 52 | 53 | bash train.sh 54 | 55 | ## Test 56 | Run the following command and start the test of Adobe-1k-Composite-Dataset 57 | 58 | bash deploy.sh 59 | 60 | ## Evaluation 61 | Please eval with [official Matlab Code](https://docs.google.com/uc?export=download&id=1euP9WmWve3c7EgOwRqgHfnp2H8NXH3OM). and get the SAD, MSE, Grad Conn. 62 | 63 | ### Visualization 64 | Running model is Stage1-SAD=57.1, please click to view whole images. 65 | 66 | | Image | Trimap | Pred-Alpha | GT-Alpha | 67 | |---|---|---|---| 68 | |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/image/boy-1518482_1920_12.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/trimap/boy-1518482_1920_12.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/pred/boy-1518482_1920_12.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/alpha/boy-1518482_1920_12.png) 69 | |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/image/dandelion-1335575_1920_1.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/trimap/dandelion-1335575_1920_1.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/pred/dandelion-1335575_1920_1.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/alpha/dandelion-1335575_1920_1.png) 70 | |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/image/light-bulb-376930_1920_11.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/trimap/light-bulb-376930_1920_11.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/pred/light-bulb-376930_1920_11.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/alpha/light-bulb-376930_1920_11.png) 71 | |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/image/sieve-641426_1920_1.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/trimap/sieve-641426_1920_1.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/pred/sieve-641426_1920_1.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/alpha/sieve-641426_1920_1.png) 72 | |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/image/spring-289527_1920_15.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/trimap/spring-289527_1920_15.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/pred/spring-289527_1920_15.png) |![image](https://github.com/huochaitiantang/pytorch-deep-image-matting/blob/master/result/example/alpha/spring-289527_1920_15.png) 73 | -------------------------------------------------------------------------------- /pytorch_deep_image_matting/deep_image_matting.py: -------------------------------------------------------------------------------- 1 | import torch, sys, cv2, os 2 | import numpy as np 3 | import argparse 4 | sys.path.insert(1, 'pytorch_deep_image_matting/core') 5 | import net 6 | from torchvision import transforms 7 | from PIL import Image 8 | import imageio 9 | 10 | def model_dim_fn(cuda): 11 | stage = 1 12 | resume = "pytorch_deep_image_matting/model/stage1_sad_54.4.pth" 13 | model = net.VGG16(stage) 14 | if cuda: 15 | ckpt = torch.load(resume) 16 | else : 17 | ckpt = torch.load(resume, map_location='cpu') 18 | model.load_state_dict(ckpt['state_dict'], strict=True) 19 | 20 | if cuda: 21 | model = model.cuda() 22 | 23 | return model 24 | 25 | def inference_once(model, scale_img, scale_trimap, cuda): 26 | size_h = 320 27 | size_w = 320 28 | stage = 1 29 | 30 | #if aligned: 31 | # assert(scale_img.shape[0] == size_h) 32 | # assert(scale_img.shape[1] == size_w) 33 | 34 | normalize = transforms.Compose([ 35 | transforms.ToTensor(), 36 | transforms.Normalize(mean = [0.485, 0.456, 0.406],std = [0.229, 0.224, 0.225]) 37 | ]) 38 | 39 | scale_img_rgb = cv2.cvtColor(scale_img, cv2.COLOR_BGR2RGB) 40 | # first, 0-255 to 0-1 41 | # second, x-mean/std and HWC to CHW 42 | tensor_img = normalize(scale_img_rgb).unsqueeze(0) 43 | 44 | scale_grad = compute_gradient(scale_img) 45 | #tensor_img = torch.from_numpy(scale_img.astype(np.float32)[np.newaxis, :, :, :]).permute(0, 3, 1, 2) 46 | tensor_trimap = torch.from_numpy(scale_trimap.astype(np.float32)[np.newaxis, np.newaxis, :, :]) 47 | tensor_grad = torch.from_numpy(scale_grad.astype(np.float32)[np.newaxis, np.newaxis, :, :]) 48 | 49 | if cuda: 50 | tensor_img = tensor_img.cuda() 51 | tensor_trimap = tensor_trimap.cuda() 52 | tensor_grad = tensor_grad.cuda() 53 | 54 | input_t = torch.cat((tensor_img, tensor_trimap / 255.), 1) 55 | 56 | # forward 57 | if stage <= 1: 58 | # stage 1 59 | pred_mattes, _ = model(input_t) 60 | else: 61 | # stage 2, 3 62 | _, pred_mattes = model(input_t) 63 | 64 | pred_mattes = pred_mattes.data 65 | if cuda: 66 | pred_mattes = pred_mattes.cpu() 67 | pred_mattes = pred_mattes.numpy()[0, 0, :, :] 68 | 69 | return pred_mattes 70 | 71 | def compute_gradient(img): 72 | x = cv2.Sobel(img, cv2.CV_16S, 1, 0) 73 | y = cv2.Sobel(img, cv2.CV_16S, 0, 1) 74 | absX = cv2.convertScaleAbs(x) 75 | absY = cv2.convertScaleAbs(y) 76 | grad = cv2.addWeighted(absX, 0.5, absY, 0.5, 0) 77 | grad=cv2.cvtColor(grad, cv2.COLOR_BGR2GRAY) 78 | return grad 79 | 80 | def inference_img_whole(max_size, model, img, trimap, cuda): #args 81 | h, w, c = img.shape 82 | #new_h = min(args.max_size, h - (h % 32)) 83 | #new_w = min(args.max_size, w - (w % 32)) 84 | new_h = min(max_size, h - (h % 32)) 85 | new_w = min(max_size, w - (w % 32)) 86 | 87 | # resize for network input, to Tensor 88 | scale_img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 89 | scale_trimap = cv2.resize(trimap, (new_w, new_h), interpolation=cv2.INTER_LINEAR) 90 | 91 | pred_mattes = inference_once(model, scale_img, scale_trimap, cuda)#args 92 | 93 | # resize to origin size 94 | origin_pred_mattes = cv2.resize(pred_mattes, (w, h), interpolation = cv2.INTER_LINEAR) 95 | assert(origin_pred_mattes.shape == trimap.shape) 96 | return origin_pred_mattes 97 | 98 | def deep_image_matting_final(model, image, trimap, cuda): 99 | torch.cuda.empty_cache() 100 | with torch.no_grad(): 101 | pred_mattes = inference_img_whole(1600, model, image, trimap, cuda) 102 | return pred_mattes 103 | 104 | def composite4(fg, bg, a, w, h): 105 | fg = np.array(fg, np.float32) 106 | bg_h, bg_w = bg.shape[:2] 107 | x = 0 108 | if bg_w > w: 109 | x = np.random.randint(0, bg_w - w) 110 | y = 0 111 | if bg_h > h: 112 | y = np.random.randint(0, bg_h - h) 113 | bg = np.array(bg[y:y + h, x:x + w], np.float32) 114 | alpha = np.zeros((h, w, 1), np.float32) 115 | alpha[:, :, 0] = a 116 | im = alpha * fg + (1 - alpha) * bg 117 | im = im.astype(np.uint8) 118 | return im 119 | 120 | def matting_result(pic_input, tri_input, model, cuda, website = False): 121 | if website: 122 | img = pic_input 123 | else : 124 | img = imageio.imread(pic_input)[:, :, :3] 125 | trimap = tri_input 126 | if len(trimap.shape)>2: 127 | trimap = trimap[:, :, 0] 128 | alpha = deep_image_matting_final(model, img, trimap, cuda) 129 | alpha[trimap == 0] = 0.0 130 | alpha[trimap == 255] = 1.0 131 | h, w = img.shape[:2] 132 | new_bg = np.array(np.full((h,w,3), 255), dtype='uint8') 133 | im = composite4(img, new_bg, alpha, w, h) 134 | return Image.fromarray(im) -------------------------------------------------------------------------------- /sample/01.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/sample/01.jpg -------------------------------------------------------------------------------- /sample/02.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/sample/02.jpg -------------------------------------------------------------------------------- /sample/03.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/sample/03.jpg -------------------------------------------------------------------------------- /sample/04.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/sample/04.jpg -------------------------------------------------------------------------------- /sample/05.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/sample/05.jpg -------------------------------------------------------------------------------- /sample/06.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/sample/06.jpg -------------------------------------------------------------------------------- /segmentation_result/01_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/01_mask.png -------------------------------------------------------------------------------- /segmentation_result/01_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/01_seq.png -------------------------------------------------------------------------------- /segmentation_result/02_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/02_mask.png -------------------------------------------------------------------------------- /segmentation_result/02_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/02_seq.png -------------------------------------------------------------------------------- /segmentation_result/03_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/03_mask.png -------------------------------------------------------------------------------- /segmentation_result/03_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/03_seq.png -------------------------------------------------------------------------------- /segmentation_result/04_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/04_mask.png -------------------------------------------------------------------------------- /segmentation_result/04_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/04_seq.png -------------------------------------------------------------------------------- /segmentation_result/05_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/05_mask.png -------------------------------------------------------------------------------- /segmentation_result/05_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/05_seq.png -------------------------------------------------------------------------------- /segmentation_result/06_mask.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/06_mask.png -------------------------------------------------------------------------------- /segmentation_result/06_seq.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/06_seq.png -------------------------------------------------------------------------------- /segmentation_result/1. 34423890523_f8c5b3741c.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/openaifab/hair/5dc587d3a70d5c8a656ea280758d6e05ebea9d76/segmentation_result/1. 34423890523_f8c5b3741c.png -------------------------------------------------------------------------------- /trimap.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | def trimap(image, size, erosion=False): 5 | 6 | row = image.shape[0]; 7 | col = image.shape[1]; 8 | 9 | pixels = 2*size + 1; ## Double and plus 1 to have an odd-sized kernel 10 | kernel = np.ones((pixels,pixels),np.uint8) ## How many pixel of extension do I get 11 | 12 | if erosion is not False: 13 | erosion = int(erosion) 14 | erosion_kernel = np.ones((3,3), np.uint8) ## Design an odd-sized erosion kernel 15 | image = cv2.erode(image, erosion_kernel, iterations=erosion) ## How many erosion do you expect 16 | image = np.where(image > 0, 255, image) ## Any gray-clored pixel becomes white (smoothing) 17 | # Error-handler to prevent entire foreground annihilation 18 | if cv2.countNonZero(image) == 0: 19 | print("ERROR: foreground has been entirely eroded"); 20 | sys.exit(); 21 | 22 | dilation = cv2.dilate(image, kernel, iterations = 1) 23 | 24 | dilation = np.where(dilation == 255, 128, dilation) ## WHITE to GRAY 25 | remake = np.where(dilation != 128, 0, dilation) ## Smoothing 26 | remake = np.where(image > 128, 200, dilation) ## mark the tumor inside GRAY 27 | 28 | remake = np.where(remake < 128, 0, remake) ## Embelishment 29 | remake = np.where(remake > 200, 0, remake) ## Embelishment 30 | remake = np.where(remake == 200, 255, remake) ## GRAY to WHITE 31 | 32 | ############################################# 33 | # Ensures only three pixel values available # 34 | # TODO: Optimization with Cython # 35 | ############################################# 36 | for i in range(0,row): 37 | for j in range (0,col): 38 | if (remake[i,j] != 0 and remake[i,j] != 255): 39 | remake[i,j] = 128; 40 | print("generate trimap(size: " + str(size) + ", erosion: " + str(erosion) + ")") 41 | return remake --------------------------------------------------------------------------------