├── LICENSE ├── README.md ├── classification └── pruner.py ├── examples ├── cifar100_prune.py ├── cifar100_resnet.py ├── cifar_resnet.py └── docs │ ├── concat.png │ ├── conv-conv.png │ ├── conv-fc.png │ ├── dep1.png │ ├── dep2.png │ ├── dep3.png │ ├── residual.png │ └── split.png ├── prune_demo.py ├── quant_demo.py ├── requirements.txt ├── setup.py ├── torch_pruning ├── __init__.py ├── autoslim.py ├── autoslim_test.py ├── dependency.py ├── flops_counter.py ├── prune │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-36.pyc │ │ ├── __init__.cpython-38.pyc │ │ ├── structured.cpython-36.pyc │ │ ├── structured.cpython-38.pyc │ │ ├── unstructured.cpython-36.pyc │ │ └── unstructured.cpython-38.pyc │ ├── structured.py │ └── unstructured.py ├── resnet_small.py ├── sensitivity_analysis.py └── utils.py └── torch_quanting ├── __init__.py ├── autoquant.py └── quantizer.py /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # pytorch-Autoslim2.0 2 | 3 | A pytorch toolkit for structured neural network pruning automatically 4 | 5 | 完全自动化的模型剪枝工具 6 | ## 1 Introduction 项目介绍 7 | 8 | ### ① Architecture 系统架构 9 | 10 | **用户层**:人人都会用的剪枝工具,仅需二行代码即可完成全自动化剪枝 11 | 12 | **中间层**:提供统一接口,让开发者可以自己封装SOTA剪枝算法,不断更新工具 13 | 14 | **系统底层**:自动分析网络结构并构建剪枝关系 15 | 16 | 17 | 18 | ## 2 Support 支持度 19 | 20 | ### ① Supported Models 支持的模型 21 | 22 | |模型类型|
支持
|
已测试
| 23 | | --- | --- | --- | 24 | | 分类模型 |√ |AlexNet,VGG,ResNet系列等 | 25 | | 检测模型 |√ |CenterNet,YOLO系列等 | 26 | | 分割模型 |√ | 正在测试 | 27 | 28 | ### ② Pruning Algorithm 剪枝算法 29 | 30 | |函数名|
剪枝算法
| 31 | | --- | --- | 32 | | l1_norm_pruning |[Pruning Filters for Efficient ConvNets](https://arxiv.org/abs/1608.08710)| 33 | | l2_norm_pruning |[Pruning Filters for Efficient ConvNets](https://arxiv.org/abs/1608.08710)| 34 | | fpgm_pruning |[Filter Pruning via Geometric Median for Deep Convolutional Neural Networks Acceleration](https://arxiv.org/abs/1811.00250)| 35 | 36 | 37 | 在原始剪枝算法上,做了部分调整。此外,后续会支持更多的剪枝算法。 38 | ## 3 Installation 安装 39 | 40 | ```bash 41 | pip install -e ./ 42 | ``` 43 | 44 | ## 4 Instructions 使用介绍 45 | 46 | **model可以来源于torchvision,也可以是自己在Pytorch中构建的model** 47 | 48 | ### Automatic Pruning 自动化剪枝 49 | 50 | ```python 51 | import torch_pruning as pruning 52 | from torchvision.models import resnet18 53 | import torch 54 | 55 | # 模型建立 56 | model = resnet18() 57 | flops_raw, params_raw = pruning.get_model_complexity_info( 58 | model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) 59 | print('-[INFO] before pruning flops: ' + flops_raw) 60 | print('-[INFO] before pruning params: ' + params_raw) 61 | # 选择裁剪方式 62 | mod = 'fpgm' 63 | 64 | # 剪枝引擎建立 65 | slim = pruning.Autoslim(model, inputs=torch.randn( 66 | 1, 3, 224, 224), compression_ratio=0.5) 67 | 68 | if mod == 'fpgm': 69 | config = { 70 | 'layer_compression_ratio': None, 71 | 'norm_rate': 1.0, 'prune_shortcut': 1, 72 | 'dist_type': 'l1', 'pruning_func': 'fpgm' 73 | } 74 | elif mod == 'l1': 75 | config = { 76 | 'layer_compression_ratio': None, 77 | 'norm_rate': 1.0, 'prune_shortcut': 1, 78 | 'global_pruning': False, 'pruning_func': 'l1' 79 | } 80 | slim.base_prunging(config) 81 | flops_new, params_new = pruning.get_model_complexity_info( 82 | model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) 83 | print('\n-[INFO] after pruning flops: ' + flops_new) 84 | print('-[INFO] after pruning params: ' + params_new) 85 | 86 | ``` 87 | 88 | ## 5 Examples 使用案例 89 | 90 | ### ①Resnet-cifar10 91 | 92 | #### Train 训练 93 | 94 | ```bash 95 | python prune_resnet18_cifar10.py --mode train --round 0 96 | ``` 97 | #### Pruning 剪枝 98 | 99 | ```bash 100 | python prune_resnet18_cifar10.py --mode prune --round 1 --total_epochs 60 101 | ``` 102 | 103 | #### Train 微调 104 | 105 | ```bash 106 | python cifar100_prune.py --mode train --round 2 --total_epochs 10 --batch_size 512 107 | ``` 108 | 109 | ## 6 致谢 110 | 111 | 感谢以下仓库: 112 | 113 | [https://github.com/TD-wzw/Autoslim](https://github.com/TD-wzw/Autoslim) 114 | 115 | [https://github.com/microsoft/nni](https://github.com/microsoft/nni) -------------------------------------------------------------------------------- /classification/pruner.py: -------------------------------------------------------------------------------- 1 | from torchvision import models 2 | import torch 3 | 4 | # 模型建立 5 | print(hasattr(models, 'resnet18')) -------------------------------------------------------------------------------- /examples/cifar100_prune.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch.nn as nn 3 | from resnet import resnext50_32x4d, resnext101_32x8d 4 | import torch.nn.functional as F 5 | from torchvision import transforms 6 | from torchvision.datasets import CIFAR100 7 | import torch 8 | import argparse 9 | import torch_pruning as pruning 10 | import cifar100_resnet as resnet 11 | from cifar_resnet import ResNet18 12 | import sys 13 | import os 14 | sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) 15 | 16 | 17 | parser = argparse.ArgumentParser() 18 | parser.add_argument('--mode', type=str, required=True, 19 | choices=['train', 'prune', 'test']) 20 | parser.add_argument('--batch_size', type=int, default=256) 21 | parser.add_argument('--verbose', action='store_true', default=False) 22 | parser.add_argument('--total_epochs', type=int, default=200) 23 | parser.add_argument('--step_size', type=int, default=70) 24 | parser.add_argument('--round', type=int, default=1) 25 | 26 | args = parser.parse_args() 27 | 28 | 29 | def get_dataloader(): 30 | 31 | train_loader = torch.utils.data.DataLoader( 32 | CIFAR100('./data', train=True, transform=transforms.Compose([ 33 | transforms.RandomCrop(32, padding=4), 34 | transforms.RandomHorizontalFlip(), 35 | transforms.ToTensor(), 36 | ]), download=True), batch_size=args.batch_size, num_workers=2) 37 | test_loader = torch.utils.data.DataLoader( 38 | CIFAR100('./data', train=False, transform=transforms.Compose([ 39 | transforms.ToTensor(), 40 | ]), download=True), batch_size=args.batch_size, num_workers=2) 41 | return train_loader, test_loader 42 | 43 | 44 | def eval(model, test_loader): 45 | correct = 0 46 | total = 0 47 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 48 | model.to(device) 49 | model.eval() 50 | with torch.no_grad(): 51 | for i, (img, target) in enumerate(test_loader): 52 | img = img.to(device) 53 | out = model(img) 54 | pred = out.max(1)[1].detach().cpu().numpy() 55 | target = target.cpu().numpy() 56 | correct += (pred == target).sum() 57 | total += len(target) 58 | return correct / total 59 | 60 | 61 | def train_model(model, train_loader, test_loader): 62 | 63 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 64 | optimizer = torch.optim.SGD( 65 | model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4) 66 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.step_size, 0.1) 67 | model.to(device) 68 | 69 | best_acc = -1 70 | for epoch in range(args.total_epochs): 71 | model.train() 72 | for i, (img, target) in enumerate(train_loader): 73 | img, target = img.to(device), target.to(device) 74 | optimizer.zero_grad() 75 | out = model(img) 76 | loss = F.cross_entropy(out, target) 77 | loss.backward() 78 | optimizer.step() 79 | if i % 10 == 0 and args.verbose: 80 | print("Epoch %d/%d, iter %d/%d, loss=%.4f" % 81 | (epoch, args.total_epochs, i, len(train_loader), loss.item())) 82 | model.eval() 83 | acc = eval(model, test_loader) 84 | print("Epoch %d/%d, Acc=%.4f" % (epoch, args.total_epochs, acc)) 85 | if best_acc < acc: 86 | torch.save(model, 'resnet18-round%d.pth' % (args.round)) 87 | best_acc = acc 88 | scheduler.step() 89 | print("Best Acc=%.4f" % (best_acc)) 90 | 91 | 92 | def prune_model_with_shortcut(model): 93 | model.cpu() 94 | config = { 95 | 'layer_compression_ratio': None, 96 | 'norm_rate': 1.0, 'prune_shortcut': 1, 97 | 'global_pruning': False, 'pruning_func': 'l1' 98 | } 99 | slim = pruning.Autoslim(model, inputs=torch.randn( 100 | 1, 3, 32, 32), compression_ratio=0.5) 101 | slim.base_prunging(config) 102 | return model 103 | 104 | 105 | def prune_model_without_shortcut(model): 106 | model.cpu() 107 | slim = pruning.Autoslim(model, inputs=torch.randn( 108 | 1, 3, 32, 32), compression_ratio=0) 109 | 110 | # print(model) 111 | layer_compression_rate = {5: 0.2, 11: 0.5, 18: 0.5, 112 | 26: 0.5, 33: 0.75, 41: 0.75, 48: 0.875, 56: 0.875} 113 | slim.l1_norm_pruning(layer_compression_ratio=layer_compression_rate) 114 | # print(model) 115 | return model 116 | 117 | 118 | def prune_model_mixed(model): 119 | 120 | model.cpu() 121 | slim = pruning.Autoslim(model, inputs=torch.randn( 122 | 1, 3, 32, 32), compression_ratio=0.5) 123 | config = { 124 | 'layer_compression_ratio': None, 125 | 'norm_rate': 1.0, 'prune_shortcut': 1, 126 | 'global_pruning': False, 'pruning_func': 'l1' 127 | } 128 | slim = pruning.Autoslim(model, inputs=torch.randn( 129 | 1, 3, 32, 32), compression_ratio=0.5) 130 | slim.base_prunging(config) 131 | return model 132 | 133 | 134 | def main(): 135 | train_loader, test_loader = get_dataloader() 136 | if args.mode == 'train': 137 | args.round = 0 138 | model = resnext50_32x4d(classnum=100) 139 | train_model(model, train_loader, test_loader) 140 | elif args.mode == 'prune': 141 | previous_ckpt = 'resnet18-round%d.pth' % (args.round-1) 142 | print("Pruning round %d, load model from %s" % 143 | (args.round, previous_ckpt)) 144 | model = torch.load(previous_ckpt) 145 | params_ori = sum([np.prod(p.size()) for p in model.parameters()]) 146 | print("Number of ori_Parameters: %.1fM" % (params_ori/1e6)) 147 | # prune_model_with_shortcut(model) 148 | # prune_model_without_shortcut(model) 149 | prune_model_mixed(model) 150 | print(model) 151 | params = sum([np.prod(p.size()) for p in model.parameters()]) 152 | print("Number of Parameters: %.1fM" % (params/1e6)) 153 | train_model(model, train_loader, test_loader) 154 | 155 | elif args.mode == 'test': 156 | ckpt = 'resnet18-round%d.pth' % (args.round) 157 | print("Load model from %s" % (ckpt)) 158 | model = torch.load(ckpt) 159 | params = sum([np.prod(p.size()) for p in model.parameters()]) 160 | print("Number of Parameters: %.1fM" % (params/1e6)) 161 | acc = eval(model, test_loader) 162 | print("Acc=%.4f\n" % (acc)) 163 | 164 | 165 | if __name__ == '__main__': 166 | main() 167 | -------------------------------------------------------------------------------- /examples/cifar100_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | 16 | self.shortcut = nn.Sequential() 17 | if stride != 1 or in_planes != self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(self.expansion*planes) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(self.conv1(x))) 25 | out = self.bn2(self.conv2(out)) 26 | out += self.shortcut(x) 27 | out = F.relu(out) 28 | return out 29 | 30 | 31 | class Bottleneck(nn.Module): 32 | expansion = 4 33 | 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(Bottleneck, self).__init__() 36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 42 | 43 | self.shortcut = nn.Sequential() 44 | if stride != 1 or in_planes != self.expansion*planes: 45 | self.shortcut = nn.Sequential( 46 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 47 | nn.BatchNorm2d(self.expansion*planes) 48 | ) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(self.conv1(x))) 52 | out = F.relu(self.bn2(self.conv2(out))) 53 | out = self.bn3(self.conv3(out)) 54 | out += self.shortcut(x) 55 | out = F.relu(out) 56 | return out 57 | 58 | 59 | class ResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=100): 61 | super(ResNet, self).__init__() 62 | self.in_planes = 64 63 | 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 67 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1) 68 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 69 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 70 | self.linear = nn.Linear(512*block.expansion, num_classes) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 75 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 76 | nn.init.constant_(m.weight, 1) 77 | nn.init.constant_(m.bias, 0) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x, out_feature=False): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | feature = out.view(out.size(0), -1) 95 | out = self.linear(feature) 96 | if out_feature == False: 97 | return out 98 | else: 99 | return out,feature 100 | 101 | 102 | def ResNet18(num_classes=100): 103 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 104 | 105 | def ResNet34(num_classes=100): 106 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 107 | 108 | def ResNet50(num_classes=100): 109 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 110 | 111 | def ResNet101(num_classes=100): 112 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 113 | 114 | def ResNet152(num_classes=100): 115 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 116 | -------------------------------------------------------------------------------- /examples/cifar_resnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class BasicBlock(nn.Module): 7 | expansion = 1 8 | 9 | def __init__(self, in_planes, planes, stride=1): 10 | super(BasicBlock, self).__init__() 11 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 12 | self.bn1 = nn.BatchNorm2d(planes) 13 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False) 14 | self.bn2 = nn.BatchNorm2d(planes) 15 | 16 | self.shortcut = nn.Sequential() 17 | if stride != 1 or in_planes != self.expansion*planes: 18 | self.shortcut = nn.Sequential( 19 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 20 | nn.BatchNorm2d(self.expansion*planes) 21 | ) 22 | 23 | def forward(self, x): 24 | out = F.relu(self.bn1(self.conv1(x))) 25 | out = self.bn2(self.conv2(out)) 26 | out += self.shortcut(x) 27 | out = F.relu(out) 28 | return out 29 | 30 | 31 | class Bottleneck(nn.Module): 32 | expansion = 4 33 | 34 | def __init__(self, in_planes, planes, stride=1): 35 | super(Bottleneck, self).__init__() 36 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False) 37 | self.bn1 = nn.BatchNorm2d(planes) 38 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False) 39 | self.bn2 = nn.BatchNorm2d(planes) 40 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False) 41 | self.bn3 = nn.BatchNorm2d(self.expansion*planes) 42 | 43 | self.shortcut = nn.Sequential() 44 | if stride != 1 or in_planes != self.expansion*planes: 45 | self.shortcut = nn.Sequential( 46 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False), 47 | nn.BatchNorm2d(self.expansion*planes) 48 | ) 49 | 50 | def forward(self, x): 51 | out = F.relu(self.bn1(self.conv1(x))) 52 | out = F.relu(self.bn2(self.conv2(out))) 53 | out = self.bn3(self.conv3(out)) 54 | out += self.shortcut(x) 55 | out = F.relu(out) 56 | return out 57 | 58 | 59 | class ResNet(nn.Module): 60 | def __init__(self, block, num_blocks, num_classes=10): 61 | super(ResNet, self).__init__() 62 | self.in_planes = 64 63 | 64 | self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False) 65 | self.bn1 = nn.BatchNorm2d(64) 66 | self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1) 67 | self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=1) 68 | self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2) 69 | self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2) 70 | self.linear = nn.Linear(512*block.expansion, num_classes) 71 | 72 | for m in self.modules(): 73 | if isinstance(m, nn.Conv2d): 74 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 75 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 76 | nn.init.constant_(m.weight, 1) 77 | nn.init.constant_(m.bias, 0) 78 | 79 | def _make_layer(self, block, planes, num_blocks, stride): 80 | strides = [stride] + [1]*(num_blocks-1) 81 | layers = [] 82 | for stride in strides: 83 | layers.append(block(self.in_planes, planes, stride)) 84 | self.in_planes = planes * block.expansion 85 | return nn.Sequential(*layers) 86 | 87 | def forward(self, x, out_feature=False): 88 | out = F.relu(self.bn1(self.conv1(x))) 89 | out = self.layer1(out) 90 | out = self.layer2(out) 91 | out = self.layer3(out) 92 | out = self.layer4(out) 93 | out = F.avg_pool2d(out, 4) 94 | feature = out.view(out.size(0), -1) 95 | out = self.linear(feature) 96 | if out_feature == False: 97 | return out 98 | else: 99 | return out,feature 100 | 101 | 102 | def ResNet18(num_classes=10): 103 | return ResNet(BasicBlock, [2,2,2,2], num_classes) 104 | 105 | def ResNet34(num_classes=10): 106 | return ResNet(BasicBlock, [3,4,6,3], num_classes) 107 | 108 | def ResNet50(num_classes=10): 109 | return ResNet(Bottleneck, [3,4,6,3], num_classes) 110 | 111 | def ResNet101(num_classes=10): 112 | return ResNet(Bottleneck, [3,4,23,3], num_classes) 113 | 114 | def ResNet152(num_classes=10): 115 | return ResNet(Bottleneck, [3,8,36,3], num_classes) 116 | -------------------------------------------------------------------------------- /examples/docs/concat.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/concat.png -------------------------------------------------------------------------------- /examples/docs/conv-conv.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/conv-conv.png -------------------------------------------------------------------------------- /examples/docs/conv-fc.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/conv-fc.png -------------------------------------------------------------------------------- /examples/docs/dep1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/dep1.png -------------------------------------------------------------------------------- /examples/docs/dep2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/dep2.png -------------------------------------------------------------------------------- /examples/docs/dep3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/dep3.png -------------------------------------------------------------------------------- /examples/docs/residual.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/residual.png -------------------------------------------------------------------------------- /examples/docs/split.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/examples/docs/split.png -------------------------------------------------------------------------------- /prune_demo.py: -------------------------------------------------------------------------------- 1 | import torch_pruning as pruning 2 | from torchvision.models import resnet18 3 | import torch 4 | 5 | # 模型建立 6 | model = resnet18() 7 | flops_raw, params_raw = pruning.get_model_complexity_info( 8 | model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) 9 | print('-[INFO] before pruning flops: ' + flops_raw) 10 | print('-[INFO] before pruning params: ' + params_raw) 11 | # 选择裁剪方式 12 | mod = 'fpgm' 13 | 14 | # 剪枝引擎建立 15 | slim = pruning.Autoslim(model, inputs=torch.randn( 16 | 1, 3, 224, 224), compression_ratio=0.5) 17 | 18 | if mod == 'fpgm': 19 | config = { 20 | 'layer_compression_ratio': None, 21 | 'norm_rate': 1.0, 'prune_shortcut': 1, 22 | 'dist_type': 'l1', 'pruning_func': 'fpgm' 23 | } 24 | elif mod == 'l1': 25 | config = { 26 | 'layer_compression_ratio': None, 27 | 'norm_rate': 1.0, 'prune_shortcut': 1, 28 | 'global_pruning': False, 'pruning_func': 'l1' 29 | } 30 | slim.base_prunging(config) 31 | flops_new, params_new = pruning.get_model_complexity_info( 32 | model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) 33 | print('\n-[INFO] after pruning flops: ' + flops_new) 34 | print('-[INFO] after pruning params: ' + params_new) 35 | -------------------------------------------------------------------------------- /quant_demo.py: -------------------------------------------------------------------------------- 1 | import torch_pruning as pruning 2 | import torch 3 | from torch_quanting import AutoQuant 4 | from torchvision.models import resnet18 5 | model = resnet18() 6 | flops_raw, params_raw = pruning.get_model_complexity_info( 7 | model, (3, 224, 224), as_strings=True, print_per_layer_stat=False) 8 | print('\n-[INFO] before pruning flops: ' + flops_raw) 9 | print('-[INFO] before pruning params: ' + params_raw) 10 | torch.save(model, 'a.pth') 11 | config_list = [{ 12 | 'quant_types': ['weight'], 13 | 'quant_bits': { 14 | 'weight': 8, 15 | }, # 这里可以仅使用 `int`,因为所有 `quan_types` 使用了一样的位长,参考下方 `ReLu6` 配置。 16 | 'op_types':['Conv2d', 'Linear'] 17 | }] 18 | quantizer = AutoQuant(model, config_list) 19 | model = quantizer.compress() 20 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch>=1.0 -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="torch-pruning", 8 | version="1.0", 9 | author="liangyingping", 10 | author_email="1691608003@qq.com", 11 | description="A pytorch toolkit for pruning automatically", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com", 15 | packages=setuptools.find_packages(), 16 | classifiers=[ 17 | "Programming Language :: Python :: 3", 18 | "License :: OSI Approved :: MIT License", 19 | "Operating System :: OS Independent", 20 | ], 21 | install_requires=['torch'], 22 | python_requires='>=3.6', 23 | ) -------------------------------------------------------------------------------- /torch_pruning/__init__.py: -------------------------------------------------------------------------------- 1 | from .dependency import * 2 | from .prune import * 3 | from .autoslim import * 4 | from . import utils 5 | from .flops_counter import get_model_complexity_info 6 | import warnings 7 | 8 | warnings.filterwarnings('ignore') -------------------------------------------------------------------------------- /torch_pruning/autoslim.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from itertools import chain 5 | from .dependency import * 6 | from . import prune 7 | import math 8 | from scipy.spatial import distance 9 | 10 | __all__ = ['Autoslim'] 11 | 12 | 13 | class Autoslim(object): 14 | def __init__(self, model, inputs, compression_ratio): 15 | self.model = model # torchvision.models模型 16 | self.inputs = inputs # 输入大小,torch.randn(1,3,224,224) 17 | self.compression_ratio = compression_ratio # 期望压缩率 18 | self.DG = DependencyGraph() 19 | # 构建节点依赖关系 20 | self.DG.build_dependency(model, example_inputs=inputs) 21 | self.model_modules = list(model.modules()) 22 | self.pruning_func = { 23 | 'l1': self._base_l1_pruning, 24 | 'fpgm': self._base_fpgm_pruning 25 | } 26 | 27 | def index_of_layer(self): 28 | dicts = {} 29 | for i, m in enumerate(self.model_modules): 30 | if isinstance(m, nn.modules.conv._ConvNd): 31 | dicts[i] = m 32 | return dicts 33 | 34 | def base_prunging(self, config): 35 | if not config['pruning_func'] in self.pruning_func: 36 | raise KeyError( 37 | "-[ERROR] {} pruning not supported.".format(config['pruning_func'])) 38 | 39 | ori_output = {} 40 | for i, m in enumerate(self.model_modules): 41 | if isinstance(m, nn.modules.conv._ConvNd): 42 | ori_output[i] = m.out_channels 43 | 44 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 45 | config['layer_compression_ratio'] = self._compute_auto_ratios() 46 | 47 | prune_indexes = self.pruning_func[config['pruning_func']](config) 48 | 49 | for i, m in enumerate(self.model_modules): 50 | if i in prune_indexes and m.out_channels == ori_output[i]: 51 | pruning_plan = self.DG.get_pruning_plan( 52 | m, prune.prune_conv, idxs=prune_indexes[i]) 53 | if pruning_plan and config['prune_shortcut'] == 1: 54 | pruning_plan.exec() 55 | elif not pruning_plan.is_in_shortcut: 56 | pruning_plan.exec() 57 | 58 | def _base_fpgm_pruning(self, config): 59 | prune_indexes = {} 60 | for i, m in enumerate(self.model_modules): 61 | # _ConvNd包含卷积和反卷积 62 | if isinstance(m, nn.modules.conv._ConvNd): 63 | weight_torch = m.weight.detach().cuda() 64 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 65 | weight_vec = weight_torch.view(weight_torch.size()[1], -1) 66 | out_channels = weight_torch.size()[1] 67 | else: 68 | weight_vec = weight_torch.view( 69 | weight_torch.size()[0], -1) # 权重[512,64,3,3] -> [512, 64*3*3] 70 | out_channels = weight_torch.size()[0] 71 | 72 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 73 | similar_pruned_num = int( 74 | out_channels * config['layer_compression_ratio'][i]) 75 | # 全自动化压缩时,不剪跳连层 76 | else: 77 | similar_pruned_num = int( 78 | out_channels * self.compression_ratio) 79 | 80 | filter_pruned_num = int( 81 | out_channels * (1 - config['norm_rate'])) 82 | 83 | if config['dist_type'] == "l2" or "cos": 84 | norm = torch.norm(weight_vec, 2, 1) 85 | norm_np = norm.cpu().numpy() 86 | elif config['dist_type'] == "l1": 87 | norm = torch.norm(weight_vec, 1, 1) 88 | norm_np = norm.cpu().numpy() 89 | 90 | filter_large_index = [] 91 | filter_large_index = norm_np.argsort()[filter_pruned_num:] 92 | 93 | indices = torch.LongTensor(filter_large_index).cuda() 94 | # weight_vec_after_norm.size=15 95 | weight_vec_after_norm = torch.index_select( 96 | weight_vec, 0, indices).cpu().numpy() 97 | 98 | # for euclidean distance 99 | if config['dist_type'] == "l2" or "l1": 100 | similar_matrix = distance.cdist( 101 | weight_vec_after_norm, weight_vec_after_norm, 'euclidean') 102 | elif config['dist_type'] == "cos": # for cos similarity 103 | similar_matrix = 1 - \ 104 | distance.cdist(weight_vec_after_norm, 105 | weight_vec_after_norm, 'cosine') 106 | 107 | # 将任意一个点与其他点的距离算出来,最后将距离相加,一共得到15组数据 108 | similar_sum = np.sum(np.abs(similar_matrix), axis=0) 109 | 110 | # for distance similar: get the filter index with largest similarity == small distance 111 | similar_large_index = similar_sum.argsort()[ 112 | similar_pruned_num:] 113 | similar_small_index = similar_sum.argsort()[ 114 | :similar_pruned_num] 115 | prune_index = [filter_large_index[i] 116 | for i in similar_small_index] 117 | prune_indexes[i] = prune_index 118 | return prune_indexes 119 | 120 | def _base_l1_pruning(self, config): 121 | return self.__base_lx_norm_pruning(config, norm='l1') 122 | 123 | def _base_l1_pruning(self, config): 124 | return self.__base_lx_norm_pruning(config, norm='l2') 125 | 126 | def __base_lx_norm_pruning(self, config, norm='l1'): 127 | prune_indexes = {} 128 | 129 | def _compute_lx_norm(m, norm): 130 | weight = m.weight.detach().cpu().numpy() 131 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 132 | if norm == 'l1': 133 | Lx_norm = np.sum(np.abs(weight), axis=(0, 2, 3)) 134 | else: 135 | Lx_norm = np.sum( 136 | np.sqrt(weight ** 2), axis=(0, 2, 3)) 137 | else: 138 | if norm == 'l1': 139 | Lx_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 140 | else: 141 | # 注:反卷积维数1对应输出维度 142 | Lx_norm = np.sum( 143 | np.sqrt(weight ** 2), axis=(1, 2, 3)) 144 | return Lx_norm 145 | 146 | # 全局阈值剪枝法(最好别用,效果不佳) 147 | if config['global_pruning']: 148 | filter_record = [] 149 | for i, m in enumerate(self.model_modules): 150 | if isinstance(m, nn.modules.conv._ConvNd): 151 | Lx_norm = _compute_lx_norm(m, norm) 152 | filter_record.append(Lx_norm.tolist()) # 记录每层卷积的lx_norm参数 153 | 154 | filter_record = list(chain.from_iterable(filter_record)) 155 | total = len(filter_record) 156 | filter_record.sort() # 全局排序 157 | thre_index = int(total * self.compression_ratio) 158 | thre = filter_record[thre_index] # 根据裁剪率确定阈值 159 | for i, m in enumerate(self.model_modules): 160 | if isinstance(m, nn.modules.conv._ConvNd): 161 | weight = m.weight.detach().cpu().numpy() 162 | # _ConvTransposeMixin只包含反卷积 163 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 164 | Lx_norm = np.sum(np.abs(weight), axis=( 165 | 0, 2, 3)) # 注:反卷积维数1对应输出维度 166 | else: 167 | Lx_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 168 | num_pruned = min(int(max_ratio*len(Lx_norm)), 169 | len(Lx_norm[Lx_norm < thre])) # 不能全部减去 170 | # 删除低于阈值的卷积核 171 | prune_index = np.argsort(Lx_norm)[:num_pruned].tolist() 172 | prune_indexes[i] = prune_index 173 | 174 | # 局部阈值加指定层 175 | else: 176 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 177 | # 需要剪跳连层,并且未指定每一层的裁剪率 178 | config['layer_compression_ratio'] = self._compute_auto_ratios() 179 | 180 | for i, m in enumerate(self.model_modules): 181 | # 逐层裁剪 182 | # _ConvNd包含卷积和反卷积 183 | if isinstance(m, nn.modules.conv._ConvNd): 184 | Lx_norm = _compute_lx_norm(m, norm) 185 | 186 | # 自定义压缩或全自动化压缩时剪跳连层 187 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 188 | num_pruned = int( 189 | out_channels * config['layer_compression_ratio'][i]) 190 | # 全自动化压缩时,不剪跳连层 191 | else: 192 | num_pruned = int(out_channels * self.compression_ratio) 193 | 194 | # remove filters with small L1-Norm 195 | prune_index = np.argsort(Lx_norm)[:num_pruned].tolist() 196 | prune_indexes[i] = prune_index 197 | return prune_indexes 198 | 199 | def _compute_auto_ratios(self): 200 | # 如果未指定每层裁剪率,则自动生成 201 | layer_compression_ratio = {} 202 | mid_value = self.compression_ratio 203 | 204 | one_value = (1-mid_value)/4 if mid_value >= 0.43 else mid_value/4 205 | values = [mid_value-one_value*3, mid_value-one_value*2, mid_value-one_value, 206 | mid_value, mid_value+one_value, mid_value+one_value*2, mid_value+one_value*3] 207 | # 分为七级裁剪率,从浅到深,从小到大 208 | # 均值为期望裁剪率 209 | layer_cnt = 0 210 | for i, m in enumerate(self.model_modules): 211 | if isinstance(m, nn.modules.conv._ConvNd): 212 | layer_compression_ratio[i] = 0 213 | layer_cnt += 1 214 | layers_of_class = layer_cnt/7 215 | conv_cnt = 0 216 | for i, m in enumerate(self.model_modules): 217 | if isinstance(m, nn.modules.conv._ConvNd): 218 | layer_compression_ratio[i] = values[math.floor( 219 | conv_cnt/layers_of_class)] 220 | conv_cnt += 1 221 | return layer_compression_ratio 222 | 223 | 224 | if __name__ == "__main__": 225 | from resnet_small import resnet_small 226 | model = resnet_small() 227 | slim = Autoslim(model, inputs=torch.randn( 228 | 1, 3, 224, 224), compression_ratio=0.5) 229 | slim.l1_norm_pruning() 230 | print(model) 231 | -------------------------------------------------------------------------------- /torch_pruning/autoslim_test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from itertools import chain 5 | from .dependency import * 6 | from . import prune 7 | import math 8 | from scipy.spatial import distance 9 | 10 | __all__ = ['Autoslim'] 11 | 12 | 13 | class Autoslim(object): 14 | def __init__(self, model, inputs, compression_ratio): 15 | self.model = model # torchvision.models模型 16 | self.inputs = inputs # 输入大小,torch.randn(1,3,224,224) 17 | self.compression_ratio = compression_ratio # 期望压缩率 18 | self.DG = DependencyGraph() 19 | # 构建节点依赖关系 20 | self.DG.build_dependency(model, example_inputs=inputs) 21 | self.model_modules = list(model.modules()) 22 | self.pruning_func = { 23 | 'l1': self._base_l1_pruning, 24 | 'fpgm': self._base_fpgm_pruning 25 | } 26 | 27 | def index_of_layer(self): 28 | dicts = {} 29 | for i, m in enumerate(self.model_modules): 30 | if isinstance(m, nn.modules.conv._ConvNd): 31 | dicts[i] = m 32 | return dicts 33 | 34 | def base_prunging(self, config): 35 | if not config['pruning_func'] in self.pruning_func: 36 | raise KeyError( 37 | "-[ERROR] {} not supported.".format((config['pruning_func']))) 38 | 39 | ori_output = {} 40 | for i, m in enumerate(self.model_modules): 41 | if isinstance(m, nn.modules.conv._ConvNd): 42 | ori_output[i] = m.out_channels 43 | 44 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 45 | config['layer_compression_ratio'] = self._compute_auto_ratios() 46 | 47 | prune_indexes = self.pruning_func[config['pruning_func']](config) 48 | 49 | for i, m in enumerate(self.model_modules): 50 | if i in prune_indexes and m.out_channels == ori_output[i]: 51 | pruning_plan = self.DG.get_pruning_plan( 52 | m, prune.prune_conv, idxs=prune_indexes[i]) 53 | if pruning_plan and config['prune_shortcut'] == 1: 54 | pruning_plan.exec() 55 | elif not pruning_plan.is_in_shortcut: 56 | pruning_plan.exec() 57 | 58 | def _base_fpgm_pruning(self, config): 59 | prune_indexes = {} 60 | for i, m in enumerate(self.model_modules): 61 | # _ConvNd包含卷积和反卷积 62 | if isinstance(m, nn.modules.conv._ConvNd): 63 | weight_torch = m.weight.detach().cuda() 64 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 65 | weight_vec = weight_torch.view(weight_torch.size()[1], -1) 66 | out_channels = weight_torch.size()[1] 67 | else: 68 | weight_vec = weight_torch.view( 69 | weight_torch.size()[0], -1) # 权重[512,64,3,3] -> [512, 64*3*3] 70 | out_channels = weight_torch.size()[0] 71 | 72 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 73 | similar_pruned_num = int( 74 | out_channels * config['layer_compression_ratio'][i]) 75 | # 全自动化压缩时,不剪跳连层 76 | else: 77 | similar_pruned_num = int( 78 | out_channels * self.compression_ratio) 79 | 80 | filter_pruned_num = int( 81 | out_channels * (1 - config['norm_rate'])) 82 | 83 | if config['dist_type'] == "l2" or "cos": 84 | norm = torch.norm(weight_vec, 2, 1) 85 | norm_np = norm.cpu().numpy() 86 | elif config['dist_type'] == "l1": 87 | norm = torch.norm(weight_vec, 1, 1) 88 | norm_np = norm.cpu().numpy() 89 | 90 | filter_large_index = [] 91 | filter_large_index = norm_np.argsort()[filter_pruned_num:] 92 | 93 | indices = torch.LongTensor(filter_large_index).cuda() 94 | # weight_vec_after_norm.size=15 95 | weight_vec_after_norm = torch.index_select( 96 | weight_vec, 0, indices).cpu().numpy() 97 | 98 | # for euclidean distance 99 | if config['dist_type'] == "l2" or "l1": 100 | similar_matrix = distance.cdist( 101 | weight_vec_after_norm, weight_vec_after_norm, 'euclidean') 102 | elif config['dist_type'] == "cos": # for cos similarity 103 | similar_matrix = 1 - \ 104 | distance.cdist(weight_vec_after_norm, 105 | weight_vec_after_norm, 'cosine') 106 | 107 | # 将任意一个点与其他点的距离算出来,最后将距离相加,一共得到15组数据 108 | similar_sum = np.sum(np.abs(similar_matrix), axis=0) 109 | 110 | # for distance similar: get the filter index with largest similarity == small distance 111 | similar_large_index = similar_sum.argsort()[ 112 | similar_pruned_num:] 113 | similar_small_index = similar_sum.argsort()[ 114 | :similar_pruned_num] 115 | prune_index = [filter_large_index[i] 116 | for i in similar_small_index] 117 | prune_indexes[i] = prune_index 118 | return prune_indexes 119 | 120 | def _base_l1_pruning(self, config): 121 | prune_indexes = {} 122 | # 全局阈值剪枝法(最好别用,效果不佳) 123 | if config['global_pruning']: 124 | filter_record = [] 125 | for i, m in enumerate(self.model_modules): 126 | if isinstance(m, nn.modules.conv._ConvNd): 127 | weight = m.weight.detach().cpu().numpy() 128 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 129 | L1_norm = np.sum(np.abs(weight), axis=( 130 | 0, 2, 3)) # 注:反卷积维数1对应输出维度 131 | else: 132 | L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 133 | filter_record.append(L1_norm.tolist()) # 记录每层卷积的l1_norm参数 134 | 135 | filter_record = list(chain.from_iterable(filter_record)) 136 | total = len(filter_record) 137 | filter_record.sort() # 全局排序 138 | thre_index = int(total * self.compression_ratio) 139 | thre = filter_record[thre_index] # 根据裁剪率确定阈值 140 | for i, m in enumerate(self.model_modules): 141 | if isinstance(m, nn.modules.conv._ConvNd): 142 | weight = m.weight.detach().cpu().numpy() 143 | # _ConvTransposeMixin只包含反卷积 144 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 145 | L1_norm = np.sum(np.abs(weight), axis=( 146 | 0, 2, 3)) # 注:反卷积维数1对应输出维度 147 | else: 148 | L1_norm = np.sum(np.abs(weight), axis=(1, 2, 3)) 149 | num_pruned = min(int(max_ratio*len(L1_norm)), 150 | len(L1_norm[L1_norm < thre])) # 不能全部减去 151 | # 删除低于阈值的卷积核 152 | prune_index = np.argsort(L1_norm)[:num_pruned].tolist() 153 | prune_indexes.append(prune_index) 154 | 155 | # 局部阈值加指定层 156 | else: 157 | if config['layer_compression_ratio'] is None and config['prune_shortcut'] == 1: 158 | # 需要剪跳连层,并且未指定每一层的裁剪率 159 | config['layer_compression_ratio'] = self._compute_auto_ratios() 160 | 161 | for i, m in enumerate(self.model_modules): 162 | # 逐层裁剪 163 | # _ConvNd包含卷积和反卷积 164 | if isinstance(m, nn.modules.conv._ConvNd): 165 | weight = m.weight.detach().cpu().numpy() 166 | # _ConvTransposeMixin只包含反卷积 167 | if isinstance(m, nn.modules.conv._ConvTransposeMixin): 168 | out_channels = weight.shape[1] 169 | L1_norm = np.sum(np.abs(weight), axis=(0, 2, 3)) 170 | else: 171 | out_channels = weight.shape[0] 172 | L1_norm = np.sum( 173 | np.abs(weight), axis=(1, 2, 3)) # 计算卷积核的L1范式 174 | 175 | # 自定义压缩或全自动化压缩时剪跳连层 176 | if config['layer_compression_ratio'] and i in config['layer_compression_ratio']: 177 | num_pruned = int( 178 | out_channels * config['layer_compression_ratio'][i]) 179 | # 全自动化压缩时,不剪跳连层 180 | else: 181 | num_pruned = int(out_channels * self.compression_ratio) 182 | 183 | # remove filters with small L1-Norm 184 | prune_index = np.argsort(L1_norm)[:num_pruned].tolist() 185 | prune_indexes.append(prune_index) 186 | return prune_indexes 187 | 188 | def _compute_auto_ratios(self): 189 | layer_compression_ratio = {} 190 | mid_value = self.compression_ratio 191 | 192 | one_value = (1-mid_value)/4 if mid_value >= 0.43 else mid_value/4 193 | values = [mid_value-one_value*3, mid_value-one_value*2, mid_value-one_value, 194 | mid_value, mid_value+one_value, mid_value+one_value*2, mid_value+one_value*3] 195 | layer_cnt = 0 196 | for i, m in enumerate(self.model_modules): 197 | if isinstance(m, nn.modules.conv._ConvNd): 198 | layer_compression_ratio[i] = 0 199 | layer_cnt += 1 200 | layers_of_class = layer_cnt/7 201 | conv_cnt = 0 202 | for i, m in enumerate(self.model_modules): 203 | if isinstance(m, nn.modules.conv._ConvNd): 204 | layer_compression_ratio[i] = values[math.floor( 205 | conv_cnt/layers_of_class)] 206 | conv_cnt += 1 207 | return layer_compression_ratio 208 | 209 | 210 | if __name__ == "__main__": 211 | from resnet_small import resnet_small 212 | model = resnet_small() 213 | slim = Autoslim(model, inputs=torch.randn( 214 | 1, 3, 224, 224), compression_ratio=0.5) 215 | slim.l1_norm_pruning() 216 | print(model) 217 | -------------------------------------------------------------------------------- /torch_pruning/dependency.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import typing 4 | from functools import reduce 5 | from operator import mul 6 | from . import prune 7 | from enum import IntEnum 8 | 9 | __all__ = ['PruningPlan', 'Dependency', 'DependencyGraph'] 10 | 11 | TORCH_CONV = nn.modules.conv._ConvNd 12 | TORCH_BATCHNORM = nn.modules.batchnorm._BatchNorm 13 | TORCH_PRELU = nn.PReLU 14 | TORCH_LINEAR = nn.Linear 15 | 16 | 17 | class OPTYPE(IntEnum): 18 | # 枚举类 19 | CONV = 0 20 | BN = 1 21 | LINEAR = 2 22 | PRELU = 3 23 | GROUP_CONV = 4 24 | 25 | CONCAT = 5 26 | SPLIT = 6 27 | ELEMENTWISE = 7 28 | 29 | 30 | def _get_module_type(module): 31 | if isinstance(module, TORCH_CONV): 32 | if module.groups > 1: 33 | return OPTYPE.GROUP_CONV 34 | else: 35 | return OPTYPE.CONV 36 | elif isinstance(module, TORCH_BATCHNORM): 37 | return OPTYPE.BN 38 | elif isinstance(module, TORCH_PRELU): 39 | return OPTYPE.PRELU 40 | elif isinstance(module, TORCH_LINEAR): 41 | return OPTYPE.LINEAR 42 | elif isinstance(module, _ConcatOp): 43 | return OPTYPE.CONCAT 44 | elif isinstance(module, _SplitOP): 45 | return OPTYPE.SPLIT 46 | else: 47 | return OPTYPE.ELEMENTWISE 48 | 49 | 50 | def _get_node_out_channel(node): 51 | if node.type == OPTYPE.CONV or node.type == OPTYPE.GROUP_CONV: 52 | return node.module.out_channels 53 | elif node.type == OPTYPE.BN: 54 | return node.module.num_features 55 | elif node.type == OPTYPE.LINEAR: 56 | return node.module.out_features 57 | elif node.type == OPTYPE.PRELU: 58 | if node.module.num_parameters == 1: 59 | return None 60 | else: 61 | return node.module.num_parameters 62 | else: 63 | return None 64 | 65 | 66 | def _get_node_in_channel(node): 67 | if node.type == OPTYPE.CONV or node.type == OPTYPE.GROUP_CONV: 68 | return node.module.in_channels 69 | elif node.type == OPTYPE.BN: 70 | return node.module.num_features 71 | elif node.type == OPTYPE.LINEAR: 72 | return node.module.in_features 73 | elif node.type == OPTYPE.PRELU: 74 | if node.module.num_parameters == 1: 75 | return None 76 | else: 77 | return node.module.num_parameters 78 | else: 79 | return None 80 | 81 | # Dummy Pruning fn 82 | 83 | 84 | def _prune_concat(layer, *args, **kargs): 85 | return layer, 0 86 | 87 | 88 | def _prune_split(layer, *args, **kargs): 89 | return layer, 0 90 | 91 | 92 | def _prune_elementwise_op(layer, *args, **kargs): 93 | return layer, 0 94 | 95 | # Dummy module 96 | 97 | 98 | class _ConcatOp(nn.Module): 99 | def __init__(self): 100 | super(_ConcatOp, self).__init__() 101 | self.offsets = None 102 | 103 | def __repr__(self): 104 | return "_ConcatOp(%s)" % (self.offsets) 105 | 106 | 107 | class _SplitOP(nn.Module): 108 | def __init__(self): 109 | super(_SplitOP, self).__init__() 110 | self.offsets = None 111 | 112 | def __repr__(self): 113 | return "_SplitOP(%s)" % (self.offsets) 114 | 115 | 116 | class _ElementWiseOp(nn.Module): 117 | def __init__(self): 118 | super(_ElementWiseOp, self).__init__() 119 | 120 | def __repr__(self): 121 | return "_ElementWiseOp()" 122 | 123 | 124 | class _FlattenIndexTransform(object): 125 | def __init__(self, stride=1, reverse=False): 126 | self._stride = stride 127 | self.reverse = reverse 128 | 129 | def __call__(self, idxs): 130 | new_idxs = [] 131 | if self.reverse == True: 132 | for i in idxs: 133 | new_idxs.append(i//self._stride) 134 | new_idxs = list(set(new_idxs)) 135 | else: 136 | for i in idxs: 137 | new_idxs.extend( 138 | list(range(i*self._stride, (i+1)*self._stride))) 139 | return new_idxs 140 | 141 | 142 | class _ConcatIndexTransform(object): 143 | def __init__(self, offset, reverse=False): 144 | self.offset = offset 145 | self.reverse = reverse 146 | 147 | def __call__(self, idxs): 148 | if self.reverse == True: 149 | new_idxs = [i-self.offset[0] 150 | for i in idxs if (i >= self.offset[0] and i < self.offset[1])] 151 | else: 152 | new_idxs = [i+self.offset[0] for i in idxs] 153 | return new_idxs 154 | 155 | 156 | class _SplitIndexTransform(object): 157 | def __init__(self, offset, reverse=False): 158 | self.offset = offset 159 | self.reverse = reverse 160 | 161 | def __call__(self, idxs): 162 | if self.reverse == True: 163 | new_idxs = [i+self.offset[0] for i in idxs] 164 | else: 165 | new_idxs = [i-self.offset[0] 166 | for i in idxs if (i >= self.offset[0] and i < self.offset[1])] 167 | return new_idxs 168 | 169 | 170 | class Node(object): 171 | def __init__(self, module, grad_fn, node_name=None): 172 | self.module = module 173 | self.grad_fn = grad_fn 174 | self.inputs = [] 175 | self.outputs = [] 176 | self.dependencies = [] 177 | self._node_name = node_name 178 | self.type = _get_module_type(module) 179 | 180 | @property 181 | def node_name(self): 182 | return "%s (%s)" % (self._node_name, str(self.module)) if self._node_name is not None else str(self.module) 183 | 184 | def add_input(self, node): 185 | if node not in self.inputs: 186 | self.inputs.append(node) 187 | 188 | def add_output(self, node): 189 | if node not in self.outputs: 190 | self.outputs.append(node) 191 | 192 | def __repr__(self): 193 | return "" % (self.node_name, self.grad_fn) 194 | 195 | def __str__(self): 196 | return "" % (self.node_name, self.grad_fn) 197 | 198 | def details(self): 199 | fmt = "\n" % (self.node_name, self.grad_fn) 200 | fmt += ' '*4+'IN:\n' 201 | for in_node in self.inputs: 202 | fmt += ' '*8+'%s\n' % (in_node) 203 | fmt += ' '*4+'OUT:\n' 204 | for out_node in self.outputs: 205 | fmt += ' '*8+'%s\n' % (out_node) 206 | 207 | fmt += ' '*4+'DEP:\n' 208 | for dep in self.dependencies: 209 | fmt += ' '*8+"%s\n" % (dep) 210 | return fmt 211 | 212 | 213 | class Dependency(object): 214 | def __init__(self, trigger, handler, broken_node: Node, index_transform: typing.Callable = None): 215 | """ Layer dependency in structed neural network pruning. 216 | 217 | Parameters: 218 | trigger (Callable or None): a pruning function which will break the dependency 219 | handler (Callable): a pruning function to fix the broken dependency 220 | broken_node (nn.Module): the broken layer 221 | """ 222 | self.trigger = trigger 223 | self.handler = handler 224 | self.broken_node = broken_node 225 | self.index_transform = index_transform 226 | 227 | def __call__(self, idxs: list, dry_run: bool = False): 228 | result = self.handler(self.broken_node.module, idxs, dry_run=dry_run) 229 | return result 230 | 231 | def __repr__(self): 232 | return str(self) 233 | 234 | def __str__(self): 235 | return " %s on %s>" % ("None" if self.trigger is None else self.trigger.__name__, self.handler.__name__, self.broken_node.node_name) 236 | 237 | def is_triggered_by(self, pruning_fn): 238 | return pruning_fn == self.trigger 239 | 240 | def __eq__(self, other): 241 | return ((self.trigger == other.trigger) and 242 | self.handler == other.handler and 243 | self.broken_node == other.broken_node) 244 | 245 | 246 | class PruningPlan(object): 247 | """ Pruning plan. 248 | 249 | Args: 250 | dry_run (Callable or None): only return the info about pruning. 251 | module_to_name (dict): mapping nn.module to a readable name. It will be filled by DependencyGraph. 252 | """ 253 | 254 | def __init__(self): 255 | self._plans = list() 256 | 257 | def add_plan(self, dep, idxs): 258 | self._plans.append((dep, idxs)) 259 | 260 | @property 261 | def plan(self): 262 | return self._plans 263 | 264 | def exec(self, dry_run=False): 265 | num_pruned = 0 266 | for dep, idxs in self._plans: 267 | _, n = dep(idxs, dry_run=dry_run) 268 | num_pruned += n 269 | return num_pruned 270 | 271 | def has_dep(self, dep): 272 | for _dep, _ in self._plans: 273 | if dep == _dep: 274 | return True 275 | return False 276 | 277 | def has_pruning_op(self, dep, idxs): 278 | for _dep, _idxs in self._plans: 279 | if _dep.broken_node == dep.broken_node and _dep.handler == dep.handler and _idxs == idxs: 280 | return True 281 | return False 282 | 283 | @property 284 | def is_in_shortcut(self): 285 | prune_conv_cnt = 0 286 | for _dep, _idxs in self._plans: 287 | if _dep.handler.__name__ == 'prune_conv': 288 | prune_conv_cnt += 1 289 | if prune_conv_cnt > 1: 290 | return True 291 | else: 292 | return False 293 | 294 | def add_plan_and_merge(self, dep, idxs): 295 | for i, (_dep, _idxs) in enumerate(self._plans): 296 | if _dep.broken_node == dep.broken_node and _dep.handler == dep.handler: 297 | self._plans[i] = (_dep, list(set(_idxs+idxs))) 298 | return 299 | self.add_plan(dep, idxs) 300 | 301 | def __str__(self): 302 | fmt = "" 303 | fmt += "\n-------------\n" 304 | totally_pruned = 0 305 | for dep, idxs in self._plans: 306 | _, n_pruned = dep(idxs, dry_run=True) 307 | totally_pruned += n_pruned 308 | fmt += "[ %s, Index=%s, NumPruned=%d]\n" % (dep, idxs, n_pruned) 309 | fmt += "%d parameters will be pruned\n" % (totally_pruned) 310 | fmt += "-------------\n" 311 | return fmt 312 | 313 | 314 | class DependencyGraph(object): 315 | 316 | PRUNABLE_MODULES = (nn.modules.conv._ConvNd, 317 | nn.modules.batchnorm._BatchNorm, nn.Linear, nn.PReLU) # 可裁剪的层 318 | 319 | HANDLER = { # prune in_channel # prune out_channel 320 | OPTYPE.CONV: (prune.prune_related_conv, prune.prune_conv), 321 | OPTYPE.BN: (prune.prune_batchnorm, prune.prune_batchnorm), 322 | OPTYPE.PRELU: (prune.prune_prelu, prune.prune_prelu), 323 | OPTYPE.LINEAR: (prune.prune_related_linear, prune.prune_linear), 324 | OPTYPE.GROUP_CONV: (prune.prune_group_conv, prune.prune_group_conv), 325 | OPTYPE.CONCAT: (_prune_concat, _prune_concat), 326 | OPTYPE.SPLIT: (_prune_split, _prune_split), 327 | OPTYPE.ELEMENTWISE: (_prune_elementwise_op, _prune_elementwise_op), 328 | } 329 | OUTPUT_NODE_RULES = {} 330 | INPUT_NODE_RULES = {} 331 | for t1 in HANDLER.keys(): 332 | for t2 in HANDLER.keys(): 333 | # change in_channels of output layer 334 | OUTPUT_NODE_RULES[(t1, t2)] = (HANDLER[t1][1], HANDLER[t2][0]) 335 | # change out_channels of input layer 336 | INPUT_NODE_RULES[(t1, t2)] = (HANDLER[t1][0], HANDLER[t2][1]) 337 | 338 | def build_dependency(self, model: torch.nn.Module, example_inputs: torch.Tensor, output_transform: callable = None, verbose: bool = True): 339 | self.verbose = verbose # 显示细节 340 | 341 | self._module_to_name = {module: name for ( 342 | name, module) in model.named_modules()} 343 | # 获取每层的名称: 344 | # conv1.weight 345 | # bn1.weight 346 | # bn1.bias 347 | # layer1.0.conv1.weight 348 | # layer1.0.bn1.weight 349 | # layer1.0.bn1.bias ... 350 | 351 | # build dependency graph 352 | self.module_to_node, self.output_grad_fn = self._obtain_forward_graph( 353 | model, example_inputs, output_transform=output_transform) 354 | self._build_dependency(self.module_to_node) 355 | self.update_index() 356 | return self 357 | 358 | def update_index(self): 359 | for module, node in self.module_to_node.items(): 360 | if node.type == OPTYPE.LINEAR: 361 | self._set_fc_index_transform(node) 362 | if node.type == OPTYPE.CONCAT: 363 | self._set_concat_index_transform(node) 364 | if node.type == OPTYPE.SPLIT: 365 | self._set_split_index_transform(node) 366 | 367 | def get_pruning_plan(self, module, pruning_fn, idxs): 368 | 369 | cur_plan_is_group_conv = False 370 | if isinstance(module, TORCH_CONV) and module.groups > 1: 371 | # 只剪枝深度卷积,不剪枝分组卷积 372 | if module.groups == module.in_channels and module.groups == module.out_channels: 373 | pruning_fn = prune.prune_group_conv 374 | cur_plan_is_group_conv = True 375 | else: 376 | return None 377 | 378 | self.update_index() 379 | plan = PruningPlan() 380 | # the user pruning operation 381 | # oot_node = self.module_to_node[module] 382 | root_node = self.module_to_node.get(module, None) 383 | if not root_node: 384 | return None 385 | # 如果是神经网络的输出层,那么不剪枝 386 | if root_node.grad_fn in self.output_grad_fn: 387 | return None 388 | 389 | plan.add_plan(Dependency(pruning_fn, pruning_fn, root_node), idxs) 390 | 391 | visited = set() 392 | 393 | def _fix_denpendency_graph(node, fn, indices): 394 | visited.add(node) 395 | for dep in node.dependencies: 396 | # and dep.broken_node not in visited: 397 | if dep.is_triggered_by(fn): 398 | if dep.index_transform is not None: 399 | new_indices = dep.index_transform(indices) 400 | else: 401 | new_indices = indices 402 | 403 | if len(new_indices) == 0: 404 | continue 405 | if dep.broken_node in visited and plan.has_pruning_op(dep, new_indices): 406 | continue 407 | else: 408 | plan.add_plan(dep, new_indices) 409 | _fix_denpendency_graph( 410 | dep.broken_node, dep.handler, new_indices) 411 | 412 | _fix_denpendency_graph(root_node, pruning_fn, idxs) 413 | 414 | # merge pruning ops 415 | merged_plan = PruningPlan() 416 | for dep, idxs in plan.plan: 417 | merged_plan.add_plan_and_merge(dep, idxs) 418 | 419 | # 如果剪枝计划中有prune_group_conv,但当前节点不是group_conv,则不剪枝,取消计划。 420 | prune_group_conv_cnt = 0 421 | for _dep, _idxs in merged_plan._plans: 422 | if _dep.handler.__name__ == 'prune_group_conv': 423 | prune_group_conv_cnt += 1 424 | if prune_group_conv_cnt > 0: 425 | if not cur_plan_is_group_conv: 426 | return None 427 | 428 | return merged_plan 429 | 430 | def _build_dependency(self, module_to_node): 431 | for module, node in module_to_node.items(): 432 | for in_node in node.inputs: 433 | in_node_rule = self.INPUT_NODE_RULES.get( 434 | (node.type, in_node.type), None) 435 | if in_node_rule is not None: 436 | dep = Dependency( 437 | trigger=in_node_rule[0], handler=in_node_rule[1], broken_node=in_node) 438 | node.dependencies.append(dep) 439 | 440 | for out_node in node.outputs: 441 | out_node_rule = self.OUTPUT_NODE_RULES.get( 442 | (node.type, out_node.type), None) 443 | if out_node_rule is not None: 444 | dep = Dependency( 445 | trigger=out_node_rule[0], handler=out_node_rule[1], broken_node=out_node) 446 | node.dependencies.append(dep) 447 | 448 | def _obtain_forward_graph(self, model, example_inputs, output_transform): 449 | # module_to_node = { m: Node( m ) for m in model.modules() if isinstance( m, self.PRUNABLE_MODULES ) } 450 | model.eval().cpu() 451 | # Get grad_fn from prunable modules 452 | grad_fn_to_module = {} 453 | visited = {} 454 | 455 | def _record_module_grad_fn(module, inputs, outputs): 456 | # 记录中间层是否有重复使用的 457 | # 有重复使用的往往和其他层存在依赖关系 458 | if module not in visited: 459 | visited[module] = 1 460 | else: 461 | visited[module] += 1 462 | grad_fn_to_module[outputs.grad_fn] = module # 463 | 464 | hooks = [m.register_forward_hook(_record_module_grad_fn) for m in model.modules( 465 | ) if isinstance(m, self.PRUNABLE_MODULES)] 466 | # 获取模型中可裁剪层的输入和输出 467 | # hook作用: 468 | # 用来获取某些变量的中间结果的。 469 | # Pytorch会自动舍弃图计算的中间结果,所以想要获取这些数值就需要使用hook函数。 470 | # hook函数在使用后应及时删除,以避免每次都运行钩子增加运行负载。 471 | out = model(example_inputs) # 示例输入的输出(包括注册hook的中间层输出) 472 | for hook in hooks: 473 | hook.remove() # 删除hook增加运行负载 474 | reused = [m for (m, count) in visited.items() if count > 1] 475 | # 创建节点和虚拟模块 476 | module_to_node = {} 477 | # 记录神经网络的最后一层,因为这些层不剪枝 478 | output_grad_fn = [] 479 | 480 | def _build_graph(grad_fn, search_final_conv=0): 481 | # print('grad_fn',grad_fn) grad_fn指向Function对象,用于反向传播的梯度计算之用 482 | 483 | search_final_conv = search_final_conv 484 | 485 | module = grad_fn_to_module.get(grad_fn, None) 486 | if module is not None and module in module_to_node and module not in reused: 487 | return module_to_node[module] 488 | 489 | if module is None: 490 | if not hasattr(grad_fn, 'name'): 491 | module = _ElementWiseOp() # skip customized modules 492 | if self.verbose: 493 | print( 494 | "[Warning] Unrecognized operation: %s. It will be treated as element-wise op" % (str(grad_fn))) 495 | elif 'catbackward' in grad_fn.name().lower(): # concat op 496 | module = _ConcatOp() 497 | elif 'splitbackward' in grad_fn.name().lower(): 498 | module = _SplitOP() 499 | else: 500 | module = _ElementWiseOp() # All other ops are treated as element-wise ops 501 | grad_fn_to_module[grad_fn] = module # record grad_fn 502 | 503 | if module not in module_to_node: 504 | node = Node(module, grad_fn, 505 | self._module_to_name.get(module, None)) 506 | module_to_node[module] = node 507 | else: 508 | node = module_to_node[module] 509 | 510 | if search_final_conv and grad_fn is not None and hasattr(grad_fn, 'name') and ('MkldnnConvolutionBackward' in grad_fn.name() or 'AddmmBackward' in grad_fn.name()): 511 | search_final_conv = 0 512 | output_grad_fn.append(grad_fn) 513 | 514 | if hasattr(grad_fn, 'next_functions'): 515 | for f in grad_fn.next_functions: 516 | # print(f) 517 | if f[0] is not None: 518 | # skip leaf variables 519 | if hasattr(f[0], 'name') and 'accumulategrad' in f[0].name().lower(): 520 | continue 521 | input_node = _build_graph(f[0], search_final_conv) 522 | node.add_input(input_node) 523 | input_node.add_output(node) 524 | return node 525 | 526 | if output_transform is not None: 527 | out = output_transform(out) 528 | 529 | if isinstance(out, (list, tuple)): 530 | 531 | for o in out: 532 | # print('start1---------------------------------------') 533 | if isinstance(o, dict): 534 | # print('if1---------------------------------------') 535 | for key in o: 536 | # print('if1---------------------------------------') 537 | # print(o[key]) 538 | # if o[key].grad_fn is not None: 539 | if o[key].grad_fn is not None and hasattr(o[key].grad_fn, 'name') and ('MkldnnConvolutionBackward' in o[key].grad_fn.name() or 'AddmmBackward' in o[key].grad_fn.name()): 540 | output_grad_fn.append(o[key].grad_fn) 541 | _build_graph(o[key].grad_fn, search_final_conv=0) 542 | else: 543 | _build_graph(o[key].grad_fn, search_final_conv=1) 544 | 545 | elif isinstance(o, (list, tuple)): 546 | 547 | for new_value in o: 548 | # print('if2---------------------------------------') 549 | # print(new_value) 550 | # if new_value.grad_fn is not None: 551 | if new_value.grad_fn is not None and hasattr(new_value.grad_fn, 'name') and ('MkldnnConvolutionBackward' in new_value.grad_fn.name() or 'AddmmBackward' in new_value.grad_fn.name()): 552 | output_grad_fn.append(new_value.grad_fn) 553 | _build_graph(new_value.grad_fn, 554 | search_final_conv=0) 555 | else: 556 | _build_graph(new_value.grad_fn, 557 | search_final_conv=1) 558 | else: 559 | # print('if3---------------------------------------') 560 | # print(o) 561 | # if o.grad_fn is not None: 562 | if o.grad_fn is not None and hasattr(o.grad_fn, 'name') and ('MkldnnConvolutionBackward' in o.grad_fn.name() or 'AddmmBackward' in o.grad_fn.name()): 563 | output_grad_fn.append(o.grad_fn) 564 | _build_graph(o.grad_fn, search_final_conv=0) 565 | else: 566 | _build_graph(o.grad_fn, search_final_conv=1) 567 | 568 | else: 569 | 570 | if out.grad_fn is not None and hasattr(out.grad_fn, 'name') and ('MkldnnConvolutionBackward' in out.grad_fn.name() or 'AddmmBackward' in out.grad_fn.name()): 571 | output_grad_fn.append(out.grad_fn) 572 | _build_graph(out.grad_fn, search_final_conv=0) 573 | else: 574 | _build_graph(out.grad_fn, search_final_conv=1) 575 | return module_to_node, output_grad_fn 576 | 577 | def _set_fc_index_transform(self, fc_node: Node): 578 | if fc_node.type != OPTYPE.LINEAR: 579 | return 580 | visited = set() 581 | fc_in_features = fc_node.module.in_features 582 | feature_channels = _get_in_node_out_channels(fc_node.inputs[0]) 583 | stride = fc_in_features // feature_channels 584 | if stride > 1: 585 | for in_node in fc_node.inputs: 586 | for dep in fc_node.dependencies: 587 | if dep.broken_node == in_node: 588 | dep.index_transform = _FlattenIndexTransform( 589 | stride=stride, reverse=True) 590 | 591 | for dep in in_node.dependencies: 592 | if dep.broken_node == fc_node: 593 | dep.index_transform = _FlattenIndexTransform( 594 | stride=stride, reverse=False) 595 | 596 | def _set_concat_index_transform(self, cat_node: Node): 597 | if cat_node.type != OPTYPE.CONCAT: 598 | return 599 | 600 | chs = [] 601 | for n in cat_node.inputs: 602 | chs.append(_get_in_node_out_channels(n)) 603 | 604 | offsets = [0] 605 | for ch in chs: 606 | offsets.append(offsets[-1]+ch) 607 | cat_node.module.offsets = offsets 608 | 609 | for i, in_node in enumerate(cat_node.inputs): 610 | for dep in cat_node.dependencies: 611 | if dep.broken_node == in_node: 612 | dep.index_transform = _ConcatIndexTransform( 613 | offset=offsets[i:i+2], reverse=True) 614 | 615 | for dep in in_node.dependencies: 616 | if dep.broken_node == cat_node: 617 | dep.index_transform = _ConcatIndexTransform( 618 | offset=offsets[i:i+2], reverse=False) 619 | 620 | def _set_split_index_transform(self, split_node: Node): 621 | if split_node.type != OPTYPE.SPLIT: 622 | return 623 | 624 | chs = [] 625 | for n in split_node.outputs: 626 | chs.append(_get_out_node_in_channels(n)) 627 | 628 | offsets = [0] 629 | for ch in chs: 630 | offsets.append(offsets[-1]+ch) 631 | split_node.module.offsets = offsets 632 | for i, out_node in enumerate(split_node.outputs): 633 | for dep in split_node.dependencies: 634 | if dep.broken_node == out_node: 635 | dep.index_transform = _SplitIndexTransform( 636 | offset=offsets[i:i+2], reverse=False) 637 | 638 | for dep in out_node.dependencies: 639 | if dep.broken_node == split_node: 640 | dep.index_transform = _SplitIndexTransform( 641 | offset=offsets[i:i+2], reverse=True) 642 | 643 | 644 | def _get_in_node_out_channels(node): 645 | ch = _get_node_out_channel(node) 646 | if ch is None: 647 | ch = 0 648 | for in_node in node.inputs: 649 | if node.type == OPTYPE.CONCAT: 650 | ch += _get_in_node_out_channels(in_node) 651 | else: 652 | ch = _get_in_node_out_channels(in_node) 653 | return ch 654 | 655 | 656 | def _get_out_node_in_channels(node): 657 | ch = _get_node_in_channel(node) 658 | if ch is None: 659 | ch = 0 660 | for out_node in node.outputs: 661 | if node.type == OPTYPE.SPLIT: 662 | ch += _get_out_node_in_channels(out_node) 663 | else: 664 | ch = _get_out_node_in_channels(out_node) 665 | return ch 666 | -------------------------------------------------------------------------------- /torch_pruning/flops_counter.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Copyright (C) 2019 Sovrasov V. - All Rights Reserved 3 | * You may use, distribute and modify this code under the 4 | * terms of the MIT license. 5 | * You should have received a copy of the MIT license with 6 | * this file. If not visit https://opensource.org/licenses/MIT 7 | ''' 8 | 9 | import sys 10 | from functools import partial 11 | 12 | import numpy as np 13 | import torch 14 | import torch.nn as nn 15 | 16 | 17 | def get_model_complexity_info(model, input_res, 18 | print_per_layer_stat=True, 19 | as_strings=True, 20 | input_constructor=None, ost=sys.stdout, 21 | verbose=False, ignore_modules=[], 22 | custom_modules_hooks={}): 23 | assert type(input_res) is tuple 24 | assert len(input_res) >= 1 25 | assert isinstance(model, nn.Module) 26 | global CUSTOM_MODULES_MAPPING 27 | CUSTOM_MODULES_MAPPING = custom_modules_hooks 28 | flops_model = add_flops_counting_methods(model) 29 | flops_model.eval() 30 | flops_model.start_flops_count(ost=ost, verbose=verbose, 31 | ignore_list=ignore_modules) 32 | if input_constructor: 33 | input = input_constructor(input_res) 34 | _ = flops_model(**input) 35 | else: 36 | try: 37 | batch = torch.ones(()).new_empty((1, *input_res), 38 | dtype=next(flops_model.parameters()).dtype, 39 | device=next(flops_model.parameters()).device) 40 | except StopIteration: 41 | batch = torch.ones(()).new_empty((1, *input_res)) 42 | 43 | _ = flops_model(batch) 44 | 45 | flops_count, params_count = flops_model.compute_average_flops_cost() 46 | if print_per_layer_stat: 47 | print_model_with_flops(flops_model, flops_count, params_count, ost=ost) 48 | flops_model.stop_flops_count() 49 | CUSTOM_MODULES_MAPPING = {} 50 | 51 | if as_strings: 52 | return flops_to_string(flops_count), params_to_string(params_count) 53 | 54 | return flops_count, params_count 55 | 56 | 57 | def flops_to_string(flops, units='GMac', precision=2): 58 | if units is None: 59 | if flops // 10**9 > 0: 60 | return str(round(flops / 10.**9, precision)) + ' GMac' 61 | elif flops // 10**6 > 0: 62 | return str(round(flops / 10.**6, precision)) + ' MMac' 63 | elif flops // 10**3 > 0: 64 | return str(round(flops / 10.**3, precision)) + ' KMac' 65 | else: 66 | return str(flops) + ' Mac' 67 | else: 68 | if units == 'GMac': 69 | return str(round(flops / 10.**9, precision)) + ' ' + units 70 | elif units == 'MMac': 71 | return str(round(flops / 10.**6, precision)) + ' ' + units 72 | elif units == 'KMac': 73 | return str(round(flops / 10.**3, precision)) + ' ' + units 74 | else: 75 | return str(flops) + ' Mac' 76 | 77 | 78 | def params_to_string(params_num, units=None, precision=2): 79 | if units is None: 80 | if params_num // 10 ** 6 > 0: 81 | return str(round(params_num / 10 ** 6, 2)) + ' M' 82 | elif params_num // 10 ** 3: 83 | return str(round(params_num / 10 ** 3, 2)) + ' k' 84 | else: 85 | return str(params_num) 86 | else: 87 | if units == 'M': 88 | return str(round(params_num / 10.**6, precision)) + ' ' + units 89 | elif units == 'K': 90 | return str(round(params_num / 10.**3, precision)) + ' ' + units 91 | else: 92 | return str(params_num) 93 | 94 | 95 | def print_model_with_flops(model, total_flops, total_params, units='GMac', 96 | precision=3, ost=sys.stdout): 97 | 98 | def accumulate_params(self): 99 | if is_supported_instance(self): 100 | return self.__params__ 101 | else: 102 | sum = 0 103 | for m in self.children(): 104 | sum += m.accumulate_params() 105 | return sum 106 | 107 | def accumulate_flops(self): 108 | if is_supported_instance(self): 109 | return self.__flops__ / model.__batch_counter__ 110 | else: 111 | sum = 0 112 | for m in self.children(): 113 | sum += m.accumulate_flops() 114 | return sum 115 | 116 | def flops_repr(self): 117 | accumulated_params_num = self.accumulate_params() 118 | accumulated_flops_cost = self.accumulate_flops() 119 | return ', '.join([params_to_string(accumulated_params_num, 120 | units='M', precision=precision), 121 | '{:.3%} Params'.format(accumulated_params_num / total_params), 122 | flops_to_string(accumulated_flops_cost, 123 | units=units, precision=precision), 124 | '{:.3%} MACs'.format(accumulated_flops_cost / total_flops), 125 | self.original_extra_repr()]) 126 | 127 | def add_extra_repr(m): 128 | m.accumulate_flops = accumulate_flops.__get__(m) 129 | m.accumulate_params = accumulate_params.__get__(m) 130 | flops_extra_repr = flops_repr.__get__(m) 131 | if m.extra_repr != flops_extra_repr: 132 | m.original_extra_repr = m.extra_repr 133 | m.extra_repr = flops_extra_repr 134 | assert m.extra_repr != m.original_extra_repr 135 | 136 | def del_extra_repr(m): 137 | if hasattr(m, 'original_extra_repr'): 138 | m.extra_repr = m.original_extra_repr 139 | del m.original_extra_repr 140 | if hasattr(m, 'accumulate_flops'): 141 | del m.accumulate_flops 142 | 143 | model.apply(add_extra_repr) 144 | print(repr(model), file=ost) 145 | model.apply(del_extra_repr) 146 | 147 | 148 | def get_model_parameters_number(model): 149 | params_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 150 | return params_num 151 | 152 | 153 | def add_flops_counting_methods(net_main_module): 154 | # adding additional methods to the existing module object, 155 | # this is done this way so that each function has access to self object 156 | net_main_module.start_flops_count = start_flops_count.__get__(net_main_module) 157 | net_main_module.stop_flops_count = stop_flops_count.__get__(net_main_module) 158 | net_main_module.reset_flops_count = reset_flops_count.__get__(net_main_module) 159 | net_main_module.compute_average_flops_cost = compute_average_flops_cost.__get__( 160 | net_main_module) 161 | 162 | net_main_module.reset_flops_count() 163 | 164 | return net_main_module 165 | 166 | 167 | def compute_average_flops_cost(self): 168 | """ 169 | A method that will be available after add_flops_counting_methods() is called 170 | on a desired net object. 171 | 172 | Returns current mean flops consumption per image. 173 | 174 | """ 175 | 176 | batches_count = self.__batch_counter__ 177 | flops_sum = 0 178 | params_sum = 0 179 | for module in self.modules(): 180 | if is_supported_instance(module): 181 | flops_sum += module.__flops__ 182 | params_sum = get_model_parameters_number(self) 183 | return flops_sum / batches_count, params_sum 184 | 185 | 186 | def start_flops_count(self, **kwargs): 187 | """ 188 | A method that will be available after add_flops_counting_methods() is called 189 | on a desired net object. 190 | 191 | Activates the computation of mean flops consumption per image. 192 | Call it before you run the network. 193 | 194 | """ 195 | add_batch_counter_hook_function(self) 196 | 197 | seen_types = set() 198 | 199 | def add_flops_counter_hook_function(module, ost, verbose, ignore_list): 200 | if type(module) in ignore_list: 201 | seen_types.add(type(module)) 202 | if is_supported_instance(module): 203 | module.__params__ = 0 204 | elif is_supported_instance(module): 205 | if hasattr(module, '__flops_handle__'): 206 | return 207 | if type(module) in CUSTOM_MODULES_MAPPING: 208 | handle = module.register_forward_hook( 209 | CUSTOM_MODULES_MAPPING[type(module)]) 210 | else: 211 | handle = module.register_forward_hook(MODULES_MAPPING[type(module)]) 212 | module.__flops_handle__ = handle 213 | seen_types.add(type(module)) 214 | else: 215 | if verbose and not type(module) in (nn.Sequential, nn.ModuleList) and \ 216 | not type(module) in seen_types: 217 | print('Warning: module ' + type(module).__name__ + 218 | ' is treated as a zero-op.', file=ost) 219 | seen_types.add(type(module)) 220 | 221 | self.apply(partial(add_flops_counter_hook_function, **kwargs)) 222 | 223 | 224 | def stop_flops_count(self): 225 | """ 226 | A method that will be available after add_flops_counting_methods() is called 227 | on a desired net object. 228 | 229 | Stops computing the mean flops consumption per image. 230 | Call whenever you want to pause the computation. 231 | 232 | """ 233 | remove_batch_counter_hook_function(self) 234 | self.apply(remove_flops_counter_hook_function) 235 | 236 | 237 | def reset_flops_count(self): 238 | """ 239 | A method that will be available after add_flops_counting_methods() is called 240 | on a desired net object. 241 | 242 | Resets statistics computed so far. 243 | 244 | """ 245 | add_batch_counter_variables_or_reset(self) 246 | self.apply(add_flops_counter_variable_or_reset) 247 | 248 | 249 | # ---- Internal functions 250 | def empty_flops_counter_hook(module, input, output): 251 | module.__flops__ += 0 252 | 253 | 254 | def upsample_flops_counter_hook(module, input, output): 255 | output_size = output[0] 256 | batch_size = output_size.shape[0] 257 | output_elements_count = batch_size 258 | for val in output_size.shape[1:]: 259 | output_elements_count *= val 260 | module.__flops__ += int(output_elements_count) 261 | 262 | 263 | def relu_flops_counter_hook(module, input, output): 264 | active_elements_count = output.numel() 265 | module.__flops__ += int(active_elements_count) 266 | 267 | 268 | def linear_flops_counter_hook(module, input, output): 269 | input = input[0] 270 | # pytorch checks dimensions, so here we don't care much 271 | output_last_dim = output.shape[-1] 272 | bias_flops = output_last_dim if module.bias is not None else 0 273 | module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops) 274 | 275 | 276 | def pool_flops_counter_hook(module, input, output): 277 | input = input[0] 278 | module.__flops__ += int(np.prod(input.shape)) 279 | 280 | 281 | def bn_flops_counter_hook(module, input, output): 282 | input = input[0] 283 | 284 | batch_flops = np.prod(input.shape) 285 | if module.affine: 286 | batch_flops *= 2 287 | module.__flops__ += int(batch_flops) 288 | 289 | 290 | def conv_flops_counter_hook(conv_module, input, output): 291 | # Can have multiple inputs, getting the first one 292 | input = input[0] 293 | 294 | batch_size = input.shape[0] 295 | output_dims = list(output.shape[2:]) 296 | 297 | kernel_dims = list(conv_module.kernel_size) 298 | in_channels = conv_module.in_channels 299 | out_channels = conv_module.out_channels 300 | groups = conv_module.groups 301 | 302 | filters_per_channel = out_channels // groups 303 | conv_per_position_flops = int(np.prod(kernel_dims)) * \ 304 | in_channels * filters_per_channel 305 | 306 | active_elements_count = batch_size * int(np.prod(output_dims)) 307 | 308 | overall_conv_flops = conv_per_position_flops * active_elements_count 309 | 310 | bias_flops = 0 311 | 312 | if conv_module.bias is not None: 313 | 314 | bias_flops = out_channels * active_elements_count 315 | 316 | overall_flops = overall_conv_flops + bias_flops 317 | 318 | conv_module.__flops__ += int(overall_flops) 319 | 320 | 321 | def batch_counter_hook(module, input, output): 322 | batch_size = 1 323 | if len(input) > 0: 324 | # Can have multiple inputs, getting the first one 325 | input = input[0] 326 | batch_size = len(input) 327 | else: 328 | pass 329 | print('Warning! No positional inputs found for a module,' 330 | ' assuming batch size is 1.') 331 | module.__batch_counter__ += batch_size 332 | 333 | 334 | def rnn_flops(flops, rnn_module, w_ih, w_hh, input_size): 335 | # matrix matrix mult ih state and internal state 336 | flops += w_ih.shape[0]*w_ih.shape[1] 337 | # matrix matrix mult hh state and internal state 338 | flops += w_hh.shape[0]*w_hh.shape[1] 339 | if isinstance(rnn_module, (nn.RNN, nn.RNNCell)): 340 | # add both operations 341 | flops += rnn_module.hidden_size 342 | elif isinstance(rnn_module, (nn.GRU, nn.GRUCell)): 343 | # hadamard of r 344 | flops += rnn_module.hidden_size 345 | # adding operations from both states 346 | flops += rnn_module.hidden_size*3 347 | # last two hadamard product and add 348 | flops += rnn_module.hidden_size*3 349 | elif isinstance(rnn_module, (nn.LSTM, nn.LSTMCell)): 350 | # adding operations from both states 351 | flops += rnn_module.hidden_size*4 352 | # two hadamard product and add for C state 353 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 354 | # final hadamard 355 | flops += rnn_module.hidden_size + rnn_module.hidden_size + rnn_module.hidden_size 356 | return flops 357 | 358 | 359 | def rnn_flops_counter_hook(rnn_module, input, output): 360 | """ 361 | Takes into account batch goes at first position, contrary 362 | to pytorch common rule (but actually it doesn't matter). 363 | IF sigmoid and tanh are made hard, only a comparison FLOPS should be accurate 364 | """ 365 | flops = 0 366 | # input is a tuple containing a sequence to process and (optionally) hidden state 367 | inp = input[0] 368 | batch_size = inp.shape[0] 369 | seq_length = inp.shape[1] 370 | num_layers = rnn_module.num_layers 371 | 372 | for i in range(num_layers): 373 | w_ih = rnn_module.__getattr__('weight_ih_l' + str(i)) 374 | w_hh = rnn_module.__getattr__('weight_hh_l' + str(i)) 375 | if i == 0: 376 | input_size = rnn_module.input_size 377 | else: 378 | input_size = rnn_module.hidden_size 379 | flops = rnn_flops(flops, rnn_module, w_ih, w_hh, input_size) 380 | if rnn_module.bias: 381 | b_ih = rnn_module.__getattr__('bias_ih_l' + str(i)) 382 | b_hh = rnn_module.__getattr__('bias_hh_l' + str(i)) 383 | flops += b_ih.shape[0] + b_hh.shape[0] 384 | 385 | flops *= batch_size 386 | flops *= seq_length 387 | if rnn_module.bidirectional: 388 | flops *= 2 389 | rnn_module.__flops__ += int(flops) 390 | 391 | 392 | def rnn_cell_flops_counter_hook(rnn_cell_module, input, output): 393 | flops = 0 394 | inp = input[0] 395 | batch_size = inp.shape[0] 396 | w_ih = rnn_cell_module.__getattr__('weight_ih') 397 | w_hh = rnn_cell_module.__getattr__('weight_hh') 398 | input_size = inp.shape[1] 399 | flops = rnn_flops(flops, rnn_cell_module, w_ih, w_hh, input_size) 400 | if rnn_cell_module.bias: 401 | b_ih = rnn_cell_module.__getattr__('bias_ih') 402 | b_hh = rnn_cell_module.__getattr__('bias_hh') 403 | flops += b_ih.shape[0] + b_hh.shape[0] 404 | 405 | flops *= batch_size 406 | rnn_cell_module.__flops__ += int(flops) 407 | 408 | 409 | def add_batch_counter_variables_or_reset(module): 410 | 411 | module.__batch_counter__ = 0 412 | 413 | 414 | def add_batch_counter_hook_function(module): 415 | if hasattr(module, '__batch_counter_handle__'): 416 | return 417 | 418 | handle = module.register_forward_hook(batch_counter_hook) 419 | module.__batch_counter_handle__ = handle 420 | 421 | 422 | def remove_batch_counter_hook_function(module): 423 | if hasattr(module, '__batch_counter_handle__'): 424 | module.__batch_counter_handle__.remove() 425 | del module.__batch_counter_handle__ 426 | 427 | 428 | def add_flops_counter_variable_or_reset(module): 429 | if is_supported_instance(module): 430 | # if hasattr(module, '__flops__') or hasattr(module, '__params__'): 431 | # print('Warning: variables __flops__ or __params__ are already ' 432 | # 'defined for the module' + type(module).__name__ + 433 | # ' ptflops can affect your code!') 434 | module.__flops__ = 0 435 | module.__params__ = get_model_parameters_number(module) 436 | 437 | 438 | CUSTOM_MODULES_MAPPING = {} 439 | 440 | MODULES_MAPPING = { 441 | # convolutions 442 | nn.Conv1d: conv_flops_counter_hook, 443 | nn.Conv2d: conv_flops_counter_hook, 444 | nn.Conv3d: conv_flops_counter_hook, 445 | # activations 446 | nn.ReLU: relu_flops_counter_hook, 447 | nn.PReLU: relu_flops_counter_hook, 448 | nn.ELU: relu_flops_counter_hook, 449 | nn.LeakyReLU: relu_flops_counter_hook, 450 | nn.ReLU6: relu_flops_counter_hook, 451 | # poolings 452 | nn.MaxPool1d: pool_flops_counter_hook, 453 | nn.AvgPool1d: pool_flops_counter_hook, 454 | nn.AvgPool2d: pool_flops_counter_hook, 455 | nn.MaxPool2d: pool_flops_counter_hook, 456 | nn.MaxPool3d: pool_flops_counter_hook, 457 | nn.AvgPool3d: pool_flops_counter_hook, 458 | nn.AdaptiveMaxPool1d: pool_flops_counter_hook, 459 | nn.AdaptiveAvgPool1d: pool_flops_counter_hook, 460 | nn.AdaptiveMaxPool2d: pool_flops_counter_hook, 461 | nn.AdaptiveAvgPool2d: pool_flops_counter_hook, 462 | nn.AdaptiveMaxPool3d: pool_flops_counter_hook, 463 | nn.AdaptiveAvgPool3d: pool_flops_counter_hook, 464 | # BNs 465 | nn.BatchNorm1d: bn_flops_counter_hook, 466 | nn.BatchNorm2d: bn_flops_counter_hook, 467 | nn.BatchNorm3d: bn_flops_counter_hook, 468 | # FC 469 | nn.Linear: linear_flops_counter_hook, 470 | # Upscale 471 | nn.Upsample: upsample_flops_counter_hook, 472 | # Deconvolution 473 | nn.ConvTranspose1d: conv_flops_counter_hook, 474 | nn.ConvTranspose2d: conv_flops_counter_hook, 475 | nn.ConvTranspose3d: conv_flops_counter_hook, 476 | # RNN 477 | nn.RNN: rnn_flops_counter_hook, 478 | nn.GRU: rnn_flops_counter_hook, 479 | nn.LSTM: rnn_flops_counter_hook, 480 | nn.RNNCell: rnn_cell_flops_counter_hook, 481 | nn.LSTMCell: rnn_cell_flops_counter_hook, 482 | nn.GRUCell: rnn_cell_flops_counter_hook 483 | } 484 | 485 | 486 | def is_supported_instance(module): 487 | if type(module) in MODULES_MAPPING or type(module) in CUSTOM_MODULES_MAPPING: 488 | return True 489 | return False 490 | 491 | 492 | def remove_flops_counter_hook_function(module): 493 | if is_supported_instance(module): 494 | if hasattr(module, '__flops_handle__'): 495 | module.__flops_handle__.remove() 496 | del module.__flops_handle__ 497 | -------------------------------------------------------------------------------- /torch_pruning/prune/__init__.py: -------------------------------------------------------------------------------- 1 | from .structured import * 2 | from .unstructured import * -------------------------------------------------------------------------------- /torch_pruning/prune/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/torch_pruning/prune/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /torch_pruning/prune/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/torch_pruning/prune/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /torch_pruning/prune/__pycache__/structured.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/torch_pruning/prune/__pycache__/structured.cpython-36.pyc -------------------------------------------------------------------------------- /torch_pruning/prune/__pycache__/structured.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/torch_pruning/prune/__pycache__/structured.cpython-38.pyc -------------------------------------------------------------------------------- /torch_pruning/prune/__pycache__/unstructured.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/torch_pruning/prune/__pycache__/unstructured.cpython-36.pyc -------------------------------------------------------------------------------- /torch_pruning/prune/__pycache__/unstructured.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yanggui19891007/Pytorch-Auto-Slim-Tools/67f02ea0e04f23bb9c39586aebbfa19aa4ac007b/torch_pruning/prune/__pycache__/unstructured.cpython-38.pyc -------------------------------------------------------------------------------- /torch_pruning/prune/structured.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | from functools import reduce 5 | from operator import mul 6 | 7 | __all__=['prune_conv', 'prune_related_conv', 'prune_linear', 'prune_related_linear', 'prune_batchnorm', 'prune_prelu', 'prune_group_conv'] 8 | 9 | def prune_group_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False): 10 | """Prune `filters` for the convolutional layer, e.g. [256 x 128 x 3 x 3] => [192 x 128 x 3 x 3] 11 | 12 | Args: 13 | - layer: a convolution layer. 14 | - idxs: pruning index. 15 | """ 16 | 17 | if layer.groups>1: 18 | assert layer.groups==layer.in_channels and layer.groups==layer.out_channels, "only group conv with in_channel==groups==out_channels is supported" 19 | 20 | 21 | idxs = list(set(idxs)) 22 | num_pruned = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0) 23 | if dry_run: 24 | return layer, num_pruned 25 | if not inplace: 26 | layer = deepcopy(layer) 27 | keep_idxs = [idx for idx in range(layer.out_channels) if idx not in idxs] 28 | layer.out_channels = layer.out_channels-len(idxs) 29 | layer.in_channels = layer.in_channels-len(idxs) 30 | 31 | layer.groups = layer.groups-len(idxs) 32 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 33 | if layer.bias is not None: 34 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 35 | 36 | return layer, num_pruned 37 | 38 | def prune_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False): 39 | """Prune `filters` for the convolutional layer, e.g. [256 x 128 x 3 x 3] => [192 x 128 x 3 x 3] 40 | 41 | Args: 42 | - layer: a convolution layer. 43 | - idxs: pruning index. 44 | """ 45 | idxs = list(set(idxs)) 46 | num_pruned = len(idxs) * reduce(mul, layer.weight.shape[1:]) + (len(idxs) if layer.bias is not None else 0) 47 | if dry_run: 48 | return layer, num_pruned 49 | 50 | if not inplace: 51 | layer = deepcopy(layer) 52 | 53 | keep_idxs = [idx for idx in range(layer.out_channels) if idx not in idxs] 54 | layer.out_channels = layer.out_channels-len(idxs) 55 | if isinstance(layer,(nn.ConvTranspose2d,nn.ConvTranspose3d)): 56 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs]) 57 | else: 58 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 59 | if layer.bias is not None: 60 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 61 | return layer, num_pruned 62 | 63 | def prune_related_conv(layer: nn.modules.conv._ConvNd, idxs: list, inplace: bool=True, dry_run: bool=False): 64 | """Prune `kernels` for the related (affected) convolutional layer, e.g. [256 x 128 x 3 x 3] => [256 x 96 x 3 x 3] 65 | 66 | Args: 67 | layer: a convolutional layer. 68 | idxs: pruning index. 69 | """ 70 | idxs = list(set(idxs)) 71 | num_pruned = len(idxs) * layer.weight.shape[0] * reduce(mul ,layer.weight.shape[2:]) 72 | if dry_run: 73 | return layer, num_pruned 74 | if not inplace: 75 | layer = deepcopy(layer) 76 | 77 | 78 | keep_idxs = [i for i in range(layer.in_channels) if i not in idxs] 79 | 80 | layer.in_channels = layer.in_channels - len(idxs) 81 | 82 | if isinstance(layer,(nn.ConvTranspose2d,nn.ConvTranspose3d)): 83 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs,:]) 84 | else: 85 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs]) 86 | # no bias pruning because it does not change the output size 87 | return layer, num_pruned 88 | 89 | def prune_linear(layer: nn.modules.linear.Linear, idxs: list, inplace: list=True, dry_run: list=False): 90 | """Prune neurons for the fully-connected layer, e.g. [256 x 128] => [192 x 128] 91 | 92 | Args: 93 | layer: a fully-connected layer. 94 | idxs: pruning index. 95 | """ 96 | num_pruned = len(idxs)*layer.weight.shape[1] + (len(idxs) if layer.bias is not None else 0) 97 | if dry_run: 98 | return layer, num_pruned 99 | 100 | if not inplace: 101 | layer = deepcopy(layer) 102 | keep_idxs = [i for i in range(layer.out_features) if i not in idxs] 103 | layer.out_features = layer.out_features-len(idxs) 104 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 105 | if layer.bias is not None: 106 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 107 | return layer, num_pruned 108 | 109 | def prune_related_linear(layer: nn.modules.linear.Linear, idxs: list, inplace: list=True, dry_run: list=False): 110 | """Prune weights for the related (affected) fully-connected layer, e.g. [256 x 128] => [256 x 96] 111 | 112 | Args: 113 | layer: a fully-connected layer. 114 | idxs: pruning index. 115 | """ 116 | num_pruned = len(idxs) * layer.weight.shape[0] 117 | if dry_run: 118 | return layer, num_pruned 119 | 120 | if not inplace: 121 | layer = deepcopy(layer) 122 | keep_idxs = [i for i in range(layer.in_features) if i not in idxs] 123 | layer.in_features = layer.in_features-len(idxs) 124 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[:, keep_idxs]) 125 | return layer, num_pruned 126 | 127 | def prune_batchnorm(layer: nn.modules.batchnorm._BatchNorm, idxs: list, inplace: bool=True, dry_run: bool=False ): 128 | """Prune batch normalization layers, e.g. [128] => [64] 129 | 130 | Args: 131 | layer: a batch normalization layer. 132 | idxs: pruning index. 133 | """ 134 | 135 | num_pruned = len(idxs)* ( 2 if layer.affine else 1) 136 | if dry_run: 137 | return layer, num_pruned 138 | 139 | if not inplace: 140 | layer = deepcopy(layer) 141 | 142 | keep_idxs = [i for i in range(layer.num_features) if i not in idxs] 143 | layer.num_features = layer.num_features-len(idxs) 144 | layer.running_mean = layer.running_mean.data.clone()[keep_idxs] 145 | layer.running_var = layer.running_var.data.clone()[keep_idxs] 146 | if layer.affine: 147 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 148 | layer.bias = torch.nn.Parameter(layer.bias.data.clone()[keep_idxs]) 149 | return layer, num_pruned 150 | 151 | def prune_prelu(layer: nn.PReLU, idxs: list, inplace: bool=True, dry_run: bool=False): 152 | """Prune PReLU layers, e.g. [128] => [64] or [1] => [1] (no pruning if prelu has only 1 parameter) 153 | 154 | Args: 155 | layer: a PReLU layer. 156 | idxs: pruning index. 157 | """ 158 | num_pruned = 0 if layer.num_parameters==1 else len(idxs) 159 | if dry_run: 160 | return layer, num_pruned 161 | if not inplace: 162 | layer = deepcopy(layer) 163 | if layer.num_parameters==1: return layer, num_pruned 164 | keep_idxs = [i for i in range(layer.num_parameters) if i not in idxs] 165 | layer.num_parameters = layer.num_parameters-len(idxs) 166 | layer.weight = torch.nn.Parameter(layer.weight.data.clone()[keep_idxs]) 167 | return layer, num_pruned 168 | 169 | 170 | -------------------------------------------------------------------------------- /torch_pruning/prune/unstructured.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from copy import deepcopy 4 | 5 | __all__=['mask_weight', 'mask_bias'] 6 | 7 | def _mask_weight_hook(module, input): 8 | if hasattr(module, 'weight_mask'): 9 | module.weight.data *= module.weight_mask 10 | 11 | def _mask_bias_hook(module, input): 12 | if module.bias is not None and hasattr(module, 'bias_mask'): 13 | module.bias.data *= module.bias_mask 14 | 15 | def mask_weight(layer, mask, inplace=True): 16 | """Unstructed pruning for convolution layer 17 | 18 | Args: 19 | layer: a convolution layer. 20 | mask: 0-1 mask. 21 | """ 22 | if not inplace: 23 | layer = deepcopy(layer) 24 | if mask.shape != layer.weight.shape: 25 | return layer 26 | mask = torch.tensor( mask, dtype=layer.weight.dtype, device=layer.weight.device, requires_grad=False ) 27 | if hasattr(layer, 'weight_mask'): 28 | mask = mask + layer.weight_mask 29 | mask[mask>0]=1 30 | layer.weight_mask = mask 31 | else: 32 | layer.register_buffer( 'weight_mask', mask ) 33 | 34 | layer.register_forward_pre_hook( _mask_weight_hook ) 35 | return layer 36 | 37 | def mask_bias(layer, mask, inplace=True): 38 | """Unstructed pruning for convolution layer 39 | 40 | Args: 41 | layer: a convolution layer. 42 | mask: 0-1 mask. 43 | """ 44 | if not inplace: 45 | layer = deepcopy(layer) 46 | if layer.bias is None or mask.shape != layer.bias.shape: 47 | return layer 48 | 49 | mask = torch.tensor( mask, dtype=layer.weight.dtype, device=layer.weight.device, requires_grad=False ) 50 | if hasattr(layer, 'bias_mask'): 51 | mask = mask + layer.bias_mask 52 | mask[mask>0]=1 53 | layer.bias_mask = mask 54 | else: 55 | layer.register_buffer( 'bias_mask', mask ) 56 | layer.register_forward_pre_hook( _mask_bias_hook ) 57 | return layer 58 | -------------------------------------------------------------------------------- /torch_pruning/resnet_small.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=dilation, groups=groups, bias=False, dilation=dilation) 9 | 10 | 11 | def conv1x1(in_planes, out_planes, stride=1): 12 | """1x1 convolution""" 13 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 14 | 15 | 16 | class BasicBlock(nn.Module): 17 | expansion = 1 18 | 19 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 20 | base_width=64, dilation=1, norm_layer=None): 21 | super(BasicBlock, self).__init__() 22 | if norm_layer is None: 23 | norm_layer = nn.BatchNorm2d 24 | if groups != 1 or base_width != 64: 25 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 26 | if dilation > 1: 27 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 28 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 29 | self.conv1 = conv3x3(inplanes, planes, stride) 30 | self.bn1 = norm_layer(planes) 31 | self.relu = nn.ReLU(inplace=True) 32 | self.conv2 = conv3x3(planes, planes) 33 | self.bn2 = norm_layer(planes) 34 | self.downsample = downsample 35 | self.stride = stride 36 | 37 | def forward(self, x): 38 | identity = x 39 | 40 | out = self.conv1(x) 41 | out = self.bn1(out) 42 | out = self.relu(out) 43 | 44 | out = self.conv2(out) 45 | out = self.bn2(out) 46 | 47 | if self.downsample is not None: 48 | identity = self.downsample(x) 49 | 50 | out += identity 51 | out = self.relu(out) 52 | 53 | return out 54 | 55 | 56 | class Bottleneck(nn.Module): 57 | # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2) 58 | # while original implementation places the stride at the first 1x1 convolution(self.conv1) 59 | # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385. 60 | # This variant is also known as ResNet V1.5 and improves accuracy according to 61 | # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch. 62 | 63 | expansion = 4 64 | 65 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 66 | base_width=64, dilation=1, norm_layer=None): 67 | super(Bottleneck, self).__init__() 68 | if norm_layer is None: 69 | norm_layer = nn.BatchNorm2d 70 | width = int(planes * (base_width / 64.)) * groups 71 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 72 | self.conv1 = conv1x1(inplanes, width) 73 | self.bn1 = norm_layer(width) 74 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 75 | self.bn2 = norm_layer(width) 76 | self.conv3 = conv1x1(width, planes * self.expansion) 77 | self.bn3 = norm_layer(planes * self.expansion) 78 | self.relu = nn.ReLU(inplace=True) 79 | self.downsample = downsample 80 | self.stride = stride 81 | 82 | def forward(self, x): 83 | identity = x 84 | 85 | out = self.conv1(x) 86 | out = self.bn1(out) 87 | out = self.relu(out) 88 | 89 | out = self.conv2(out) 90 | out = self.bn2(out) 91 | out = self.relu(out) 92 | 93 | out = self.conv3(out) 94 | out = self.bn3(out) 95 | 96 | if self.downsample is not None: 97 | identity = self.downsample(x) 98 | 99 | out += identity 100 | out = self.relu(out) 101 | 102 | return out 103 | 104 | 105 | class ResNet(nn.Module): 106 | 107 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 108 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 109 | norm_layer=None): 110 | super(ResNet, self).__init__() 111 | if norm_layer is None: 112 | norm_layer = nn.BatchNorm2d 113 | self._norm_layer = norm_layer 114 | 115 | self.inplanes = 64 116 | self.dilation = 1 117 | if replace_stride_with_dilation is None: 118 | # each element in the tuple indicates if we should replace 119 | # the 2x2 stride with a dilated convolution instead 120 | replace_stride_with_dilation = [False, False, False] 121 | if len(replace_stride_with_dilation) != 3: 122 | raise ValueError("replace_stride_with_dilation should be None " 123 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 124 | self.groups = groups 125 | self.base_width = width_per_group 126 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 127 | bias=False) 128 | self.bn1 = norm_layer(self.inplanes) 129 | self.relu = nn.ReLU(inplace=True) 130 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 131 | self.layer1 = self._make_layer(block, 64, layers[0]) 132 | # self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 133 | # self.fc = nn.Linear(64 * block.expansion, num_classes) 134 | self.final_conv=nn.Conv2d(64, 1024, kernel_size=3, stride=1, padding=1, 135 | bias=False) 136 | 137 | for m in self.modules(): 138 | if isinstance(m, nn.Conv2d): 139 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 140 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 141 | nn.init.constant_(m.weight, 1) 142 | nn.init.constant_(m.bias, 0) 143 | 144 | # Zero-initialize the last BN in each residual branch, 145 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 146 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 147 | if zero_init_residual: 148 | for m in self.modules(): 149 | if isinstance(m, Bottleneck): 150 | nn.init.constant_(m.bn3.weight, 0) 151 | elif isinstance(m, BasicBlock): 152 | nn.init.constant_(m.bn2.weight, 0) 153 | 154 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 155 | norm_layer = self._norm_layer 156 | downsample = None 157 | previous_dilation = self.dilation 158 | if dilate: 159 | self.dilation *= stride 160 | stride = 1 161 | if stride != 1 or self.inplanes != planes * block.expansion: 162 | downsample = nn.Sequential( 163 | conv1x1(self.inplanes, planes * block.expansion, stride), 164 | norm_layer(planes * block.expansion), 165 | ) 166 | 167 | layers = [] 168 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 169 | self.base_width, previous_dilation, norm_layer)) 170 | self.inplanes = planes * block.expansion 171 | for _ in range(1, blocks): 172 | layers.append(block(self.inplanes, planes, groups=self.groups, 173 | base_width=self.base_width, dilation=self.dilation, 174 | norm_layer=norm_layer)) 175 | 176 | return nn.Sequential(*layers) 177 | 178 | def _forward_impl(self, x): 179 | # See note [TorchScript super()] 180 | x = self.conv1(x) 181 | x = self.bn1(x) 182 | x = self.relu(x) 183 | x = self.maxpool(x) 184 | 185 | x = self.layer1(x) 186 | # x = self.avgpool(x) 187 | # x = torch.flatten(x, 1) 188 | # x = self.fc(x) 189 | x=self.final_conv(x) 190 | return x 191 | 192 | def forward(self, x): 193 | return self._forward_impl(x) 194 | 195 | def _resnet(arch, block, layers, pretrained, progress, **kwargs): 196 | model = ResNet(block, layers, **kwargs) 197 | return model 198 | 199 | def resnet_small(pretrained=False, progress=True, **kwargs): 200 | r"""ResNet-18 model from 201 | `"Deep Residual Learning for Image Recognition" `_ 202 | 203 | Args: 204 | pretrained (bool): If True, returns a model pre-trained on ImageNet 205 | progress (bool): If True, displays a progress bar of the download to stderr 206 | """ 207 | return _resnet('resnet_small', BasicBlock, [1], pretrained, progress, 208 | **kwargs) 209 | -------------------------------------------------------------------------------- /torch_pruning/sensitivity_analysis.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Microsoft Corporation. 2 | # Licensed under the MIT license. 3 | 4 | import copy 5 | import csv 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import numpy as np 10 | import torch.nn as nn 11 | 12 | # FIXME: I don't know where "utils" should be 13 | SUPPORTED_OP_NAME = ['Conv2d', 'Conv1d'] 14 | SUPPORTED_OP_TYPE = [getattr(nn, name) for name in SUPPORTED_OP_NAME] 15 | 16 | logger = logging.getLogger('Sensitivity_Analysis') 17 | logger.setLevel(logging.INFO) 18 | 19 | 20 | class SensitivityAnalysis: 21 | def __init__(self, model, val_func, sparsities=None, prune_type='l1', early_stop_mode=None, early_stop_value=None): 22 | """ 23 | Perform sensitivity analysis for this model. 24 | Parameters 25 | ---------- 26 | model : torch.nn.Module 27 | the model to perform sensitivity analysis 28 | val_func : function 29 | validation function for the model. Due to 30 | different models may need different dataset/criterion 31 | , therefore the user need to cover this part by themselves. 32 | In the val_func, the model should be tested on the validation dateset, 33 | and the validation accuracy/loss should be returned as the output of val_func. 34 | There are no restrictions on the input parameters of the val_function. 35 | User can use the val_args, val_kwargs parameters in analysis 36 | to pass all the parameters that val_func needed. 37 | sparsities : list 38 | The sparsity list provided by users. This parameter is set when the user 39 | only wants to test some specific sparsities. In the sparsity list, each element 40 | is a sparsity value which means how much weight the pruner should prune. Take 41 | [0.25, 0.5, 0.75] for an example, the SensitivityAnalysis will prune 25% 50% 75% 42 | weights gradually for each layer. 43 | prune_type : str 44 | The pruner type used to prune the conv layers, default is 'l1', 45 | and 'l2', 'fine-grained' is also supported. 46 | early_stop_mode : str 47 | If this flag is set, the sensitivity analysis 48 | for a conv layer will early stop when the validation metric( 49 | for example, accurracy/loss) has alreay meet the threshold. We 50 | support four different early stop modes: minimize, maximize, dropped, 51 | raised. The default value is None, which means the analysis won't stop 52 | until all given sparsities are tested. This option should be used with 53 | early_stop_value together. 54 | 55 | minimize: The analysis stops when the validation metric return by the val_func 56 | lower than early_stop_value. 57 | maximize: The analysis stops when the validation metric return by the val_func 58 | larger than early_stop_value. 59 | dropped: The analysis stops when the validation metric has dropped by early_stop_value. 60 | raised: The analysis stops when the validation metric has raised by early_stop_value. 61 | early_stop_value : float 62 | This value is used as the threshold for different earlystop modes. 63 | This value is effective only when the early_stop_mode is set. 64 | 65 | """ 66 | from nni.algorithms.compression.pytorch.pruning.constants_pruner import PRUNER_DICT 67 | 68 | self.model = model 69 | self.val_func = val_func 70 | self.target_layer = OrderedDict() 71 | self.ori_state_dict = copy.deepcopy(self.model.state_dict()) 72 | self.target_layer = {} 73 | self.sensitivities = {} 74 | if sparsities is not None: 75 | self.sparsities = sorted(sparsities) 76 | else: 77 | self.sparsities = np.arange(0.1, 1.0, 0.1) 78 | self.sparsities = [np.round(x, 2) for x in self.sparsities] 79 | self.Pruner = PRUNER_DICT[prune_type] 80 | self.early_stop_mode = early_stop_mode 81 | self.early_stop_value = early_stop_value 82 | self.ori_metric = None # original validation metric for the model 83 | # already_pruned is for the iterative sensitivity analysis 84 | # For example, sensitivity_pruner iteratively prune the target 85 | # model according to the sensitivity. After each round of 86 | # pruning, the sensitivity_pruner will test the new sensitivity 87 | # for each layer 88 | self.already_pruned = {} 89 | self.model_parse() 90 | 91 | @property 92 | def layers_count(self): 93 | return len(self.target_layer) 94 | 95 | def model_parse(self): 96 | for name, submodel in self.model.named_modules(): 97 | for op_type in SUPPORTED_OP_TYPE: 98 | if isinstance(submodel, op_type): 99 | self.target_layer[name] = submodel 100 | self.already_pruned[name] = 0 101 | 102 | def _need_to_stop(self, ori_metric, cur_metric): 103 | """ 104 | Judge if meet the stop conditon(early_stop, min_threshold, 105 | max_threshold). 106 | Parameters 107 | ---------- 108 | ori_metric : float 109 | original validation metric 110 | cur_metric : float 111 | current validation metric 112 | 113 | Returns 114 | ------- 115 | stop : bool 116 | if stop the sensitivity analysis 117 | """ 118 | if self.early_stop_mode is None: 119 | # early stop mode is not enable 120 | return False 121 | assert self.early_stop_value is not None 122 | if self.early_stop_mode == 'minimize': 123 | if cur_metric < self.early_stop_value: 124 | return True 125 | elif self.early_stop_mode == 'maximize': 126 | if cur_metric > self.early_stop_value: 127 | return True 128 | elif self.early_stop_mode == 'dropped': 129 | if cur_metric < ori_metric - self.early_stop_value: 130 | return True 131 | elif self.early_stop_mode == 'raised': 132 | if cur_metric > ori_metric + self.early_stop_value: 133 | return True 134 | return False 135 | 136 | def analysis(self, val_args=None, val_kwargs=None, specified_layers=None): 137 | """ 138 | This function analyze the sensitivity to pruning for 139 | each conv layer in the target model. 140 | If start and end are not set, we analyze all the conv 141 | layers by default. Users can specify several layers to 142 | analyze or parallelize the analysis process easily through 143 | the start and end parameter. 144 | 145 | Parameters 146 | ---------- 147 | val_args : list 148 | args for the val_function 149 | val_kwargs : dict 150 | kwargs for the val_funtion 151 | specified_layers : list 152 | list of layer names to analyze sensitivity. 153 | If this variable is set, then only analyze 154 | the conv layers that specified in the list. 155 | User can also use this option to parallelize 156 | the sensitivity analysis easily. 157 | Returns 158 | ------- 159 | sensitivities : dict 160 | dict object that stores the trajectory of the 161 | accuracy/loss when the prune ratio changes 162 | """ 163 | if val_args is None: 164 | val_args = [] 165 | if val_kwargs is None: 166 | val_kwargs = {} 167 | # Get the original validation metric(accuracy/loss) before pruning 168 | # Get the accuracy baseline before starting the analysis. 169 | self.ori_metric = self.val_func(*val_args, **val_kwargs) 170 | namelist = list(self.target_layer.keys()) 171 | if specified_layers is not None: 172 | # only analyze several specified conv layers 173 | namelist = list(filter(lambda x: x in specified_layers, namelist)) 174 | for name in namelist: 175 | self.sensitivities[name] = {} 176 | for sparsity in self.sparsities: 177 | # here the sparsity is the relative sparsity of the 178 | # the remained weights 179 | # Calculate the actual prune ratio based on the already pruned ratio 180 | real_sparsity = ( 181 | 1.0 - self.already_pruned[name]) * sparsity + self.already_pruned[name] 182 | # TODO In current L1/L2 Filter Pruner, the 'op_types' is still necessary 183 | # I think the L1/L2 Pruner should specify the op_types automaticlly 184 | # according to the op_names 185 | cfg = [{'sparsity': real_sparsity, 'op_names': [ 186 | name], 'op_types': ['Conv2d']}] 187 | pruner = self.Pruner(self.model, cfg) 188 | pruner.compress() 189 | val_metric = self.val_func(*val_args, **val_kwargs) 190 | logger.info('Layer: %s Sparsity: %.2f Validation Metric: %.4f', 191 | name, real_sparsity, val_metric) 192 | 193 | self.sensitivities[name][sparsity] = val_metric 194 | pruner._unwrap_model() 195 | del pruner 196 | # check if the current metric meet the stop condition 197 | if self._need_to_stop(self.ori_metric, val_metric): 198 | break 199 | 200 | # reset the weights pruned by the pruner, because the 201 | # input sparsities is sorted, so we donnot need to reset 202 | # weight of the layer when the sparsity changes, instead, 203 | # we only need reset the weight when the pruning layer changes. 204 | self.model.load_state_dict(self.ori_state_dict) 205 | 206 | return self.sensitivities 207 | 208 | def export(self, filepath): 209 | """ 210 | Export the results of the sensitivity analysis 211 | to a csv file. The firstline of the csv file describe the content 212 | structure. The first line is constructed by 'layername' and sparsity 213 | list. Each line below records the validation metric returned by val_func 214 | when this layer is under different sparsities. Note that, due to the early_stop 215 | option, some layers may not have the metrics under all sparsities. 216 | 217 | layername, 0.25, 0.5, 0.75 218 | conv1, 0.6, 0.55 219 | conv2, 0.61, 0.57, 0.56 220 | 221 | Parameters 222 | ---------- 223 | filepath : str 224 | Path of the output file 225 | """ 226 | str_sparsities = [str(x) for x in self.sparsities] 227 | header = ['layername'] + str_sparsities 228 | with open(filepath, 'w') as csvf: 229 | csv_w = csv.writer(csvf) 230 | csv_w.writerow(header) 231 | for layername in self.sensitivities: 232 | row = [] 233 | row.append(layername) 234 | for sparsity in sorted(self.sensitivities[layername].keys()): 235 | row.append(self.sensitivities[layername][sparsity]) 236 | csv_w.writerow(row) 237 | 238 | def update_already_pruned(self, layername, ratio): 239 | """ 240 | Set the already pruned ratio for the target layer. 241 | """ 242 | self.already_pruned[layername] = ratio 243 | 244 | def load_state_dict(self, state_dict): 245 | """ 246 | Update the weight of the model 247 | """ 248 | self.ori_state_dict = copy.deepcopy(state_dict) 249 | self.model.load_state_dict(self.ori_state_dict) 250 | -------------------------------------------------------------------------------- /torch_pruning/utils.py: -------------------------------------------------------------------------------- 1 | from .dependency import TORCH_CONV, TORCH_BATCHNORM, TORCH_PRELU, TORCH_LINEAR 2 | 3 | def count_prunable_params(module): 4 | if isinstance( module, ( TORCH_CONV, TORCH_LINEAR) ): 5 | num_params = module.weight.numel() 6 | if module.bias is not None: 7 | num_params += module.bias.numel() 8 | return num_params 9 | elif isinstance( module, TORCH_BATCHNORM ): 10 | num_params = module.running_mean.numel() + module.running_var.numel() 11 | if module.affine: 12 | num_params+= module.weight.numel() + module.bias.numel() 13 | return num_params 14 | elif isinstance( module, TORCH_PRELU ): 15 | if len( module.weight )==1: 16 | return 0 17 | else: 18 | return module.weight.numel 19 | else: 20 | return 0 21 | 22 | def count_prunable_channels(module): 23 | if isinstance( module, TORCH_CONV ): 24 | return module.weight.shape[0] 25 | elif isinstance( module, TORCH_LINEAR ): 26 | return module.out_features 27 | elif isinstance( module, TORCH_BATCHNORM ): 28 | return module.num_features 29 | elif isinstance( module, TORCH_PRELU ): 30 | if len( module.weight )==1: 31 | return 0 32 | else: 33 | return len(module.weight) 34 | else: 35 | return 0 36 | 37 | def count_params(module): 38 | return sum([ p.numel() for p in module.parameters() ]) 39 | -------------------------------------------------------------------------------- /torch_quanting/__init__.py: -------------------------------------------------------------------------------- 1 | from .autoquant import AutoQuant -------------------------------------------------------------------------------- /torch_quanting/autoquant.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import copy 4 | from .quantizer import Quantizer 5 | 6 | __all__ = ['AutoQuant'] 7 | 8 | class AutoQuant(Quantizer): 9 | """quantize weight to 8 bits 10 | """ 11 | 12 | def __init__(self, model, config_list): 13 | super().__init__(model, config_list) 14 | self.layer_scale = {} 15 | 16 | def quantize_weight(self, wrapper, **kwargs): 17 | weight = copy.deepcopy(wrapper.module.old_weight.data) 18 | new_scale = weight.abs().max() / 127 19 | scale = max(self.layer_scale.get(wrapper.name, 0), new_scale) 20 | self.layer_scale[wrapper.name] = scale 21 | orig_type = weight.type() # TODO: user layer 22 | weight = weight.div(scale).type(torch.int8).type(orig_type).mul(scale) 23 | wrapper.module.weight = weight 24 | return weight 25 | 26 | -------------------------------------------------------------------------------- /torch_quanting/quantizer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class LayerInfo: 5 | def __init__(self, name, module): 6 | self.module = module 7 | self.name = name 8 | self.type = type(module).__name__ 9 | 10 | 11 | def _setattr(model, name, module): 12 | name_list = name.split(".") 13 | for name in name_list[:-1]: 14 | model = getattr(model, name) 15 | setattr(model, name_list[-1], module) 16 | 17 | 18 | def _check_weight(module): 19 | try: 20 | return isinstance(module.weight.data, torch.Tensor) 21 | except AttributeError: 22 | return False 23 | 24 | 25 | class QuantizerModuleWrapper(torch.nn.Module): 26 | def __init__(self, module, module_name, module_type, config, quantizer): 27 | 28 | super().__init__() 29 | # origin layer information 30 | self.module = module 31 | self.name = module_name 32 | self.type = module_type 33 | # config and pruner 34 | self.config = config 35 | self.quantizer = quantizer 36 | 37 | if 'weight' in config['quant_types']: 38 | if not _check_weight(self.module): 39 | _logger.warning( 40 | 'Module %s does not have parameter "weight"', self.name) 41 | else: 42 | self.module.register_parameter( 43 | 'old_weight', torch.nn.Parameter(self.module.weight)) 44 | delattr(self.module, 'weight') 45 | self.module.register_buffer('weight', self.module.old_weight) 46 | 47 | def forward(self, *inputs): 48 | if 'input' in self.config['quant_types']: 49 | inputs = self.quantizer.quant_grad.apply( 50 | inputs, 51 | QuantType.QUANT_INPUT, 52 | self) 53 | 54 | if 'weight' in self.config['quant_types'] and _check_weight(self.module): 55 | self.quantizer.quant_grad.apply( 56 | self.module.old_weight, 57 | QuantType.QUANT_WEIGHT, 58 | self, inputs[0]) 59 | result = self.module(*inputs) 60 | else: 61 | result = self.module(*inputs) 62 | 63 | if 'output' in self.config['quant_types']: 64 | result = self.quantizer.quant_grad.apply( 65 | result, 66 | QuantType.QUANT_OUTPUT, 67 | self) 68 | return result 69 | 70 | 71 | class Compressor(object): 72 | """ 73 | Abstract base PyTorch compressor 74 | """ 75 | 76 | def __init__(self, model, config_list): 77 | 78 | assert isinstance(model, torch.nn.Module) 79 | 80 | self.bound_model = model 81 | self.config_list = config_list 82 | 83 | self.modules_to_compress = None 84 | self.modules_wrapper = [] 85 | self.is_wrapped = False 86 | 87 | self._fwd_hook_handles = {} 88 | self._fwd_hook_id = 0 89 | 90 | self.reset() 91 | 92 | def reset(self, checkpoint=None): 93 | """ 94 | reset model state dict and model wrapper 95 | """ 96 | self._unwrap_model() 97 | if checkpoint is not None: 98 | self.bound_model.load_state_dict(checkpoint) 99 | 100 | self.modules_to_compress = None 101 | self.modules_wrapper = [] 102 | 103 | for layer, config in self._detect_modules_to_compress(): 104 | wrapper = self._wrap_modules(layer, config) 105 | self.modules_wrapper.append(wrapper) 106 | 107 | self._wrap_model() 108 | 109 | def _detect_modules_to_compress(self): 110 | """ 111 | detect all modules should be compressed, and save the result in `self.modules_to_compress`. 112 | The model will be instrumented and user should never edit it after calling this method. 113 | """ 114 | if self.modules_to_compress is None: 115 | self.modules_to_compress = [] 116 | for name, module in self.bound_model.named_modules(): 117 | if module == self.bound_model: 118 | continue 119 | layer = LayerInfo(name, module) 120 | config = self.select_config(layer) 121 | if config is not None: 122 | self.modules_to_compress.append((layer, config)) 123 | return self.modules_to_compress 124 | 125 | def _wrap_model(self): 126 | """ 127 | wrap all modules that needed to be compressed 128 | 129 | """ 130 | for wrapper in reversed(self.get_modules_wrapper()): 131 | _setattr(self.bound_model, wrapper.name, wrapper) 132 | self.is_wrapped = True 133 | 134 | def _unwrap_model(self): 135 | """ 136 | unwrap all modules that needed to be compressed 137 | 138 | """ 139 | for wrapper in self.get_modules_wrapper(): 140 | _setattr(self.bound_model, wrapper.name, wrapper.module) 141 | self.is_wrapped = False 142 | 143 | def compress(self): 144 | """ 145 | Compress the model with algorithm implemented by subclass. 146 | 147 | The model will be instrumented and user should never edit it after calling this method. 148 | `self.modules_to_compress` records all the to-be-compressed layers 149 | 150 | Returns 151 | ------- 152 | torch.nn.Module 153 | model with specified modules compressed. 154 | """ 155 | return self.bound_model 156 | 157 | def set_wrappers_attribute(self, name, value): 158 | """ 159 | To register attributes used in wrapped module's forward method. 160 | If the type of the value is Torch.tensor, then this value is registered as a buffer in wrapper, 161 | which will be saved by model.state_dict. Otherwise, this value is just a regular variable in wrapper. 162 | 163 | Parameters 164 | ---------- 165 | name : str 166 | name of the variable 167 | value: any 168 | value of the variable 169 | """ 170 | for wrapper in self.get_modules_wrapper(): 171 | if isinstance(value, torch.Tensor): 172 | wrapper.register_buffer(name, value.clone()) 173 | else: 174 | setattr(wrapper, name, value) 175 | 176 | def get_modules_to_compress(self): 177 | """ 178 | To obtain all the to-be-compressed modules. 179 | 180 | Returns 181 | ------- 182 | list 183 | a list of the layers, each of which is a tuple (`layer`, `config`), 184 | `layer` is `LayerInfo`, `config` is a `dict` 185 | """ 186 | return self.modules_to_compress 187 | 188 | def get_modules_wrapper(self): 189 | """ 190 | To obtain all the wrapped modules. 191 | 192 | Returns 193 | ------- 194 | list 195 | a list of the wrapped modules 196 | """ 197 | return self.modules_wrapper 198 | 199 | def select_config(self, layer): 200 | """ 201 | Find the configuration for `layer` by parsing `self.config_list` 202 | 203 | Parameters 204 | ---------- 205 | layer : LayerInfo 206 | one layer 207 | 208 | Returns 209 | ------- 210 | config or None 211 | the retrieved configuration for this layer, if None, this layer should 212 | not be compressed 213 | """ 214 | ret = None 215 | for config in self.config_list: 216 | config = config.copy() 217 | # expand config if key `default` is in config['op_types'] 218 | if 'op_types' in config and 'default' in config['op_types']: 219 | expanded_op_types = [] 220 | for op_type in config['op_types']: 221 | if op_type == 'default': 222 | expanded_op_types.extend( 223 | default_layers.weighted_modules) 224 | else: 225 | expanded_op_types.append(op_type) 226 | config['op_types'] = expanded_op_types 227 | 228 | # check if condition is satisified 229 | if 'op_types' in config and layer.type not in config['op_types']: 230 | continue 231 | if 'op_names' in config and layer.name not in config['op_names']: 232 | continue 233 | 234 | ret = config 235 | if ret is None or 'exclude' in ret: 236 | return None 237 | return ret 238 | 239 | def update_epoch(self, epoch): 240 | """ 241 | If user want to update model every epoch, user can override this method. 242 | This method should be called at the beginning of each epoch 243 | 244 | Parameters 245 | ---------- 246 | epoch : num 247 | the current epoch number 248 | """ 249 | pass 250 | 251 | def _wrap_modules(self, layer, config): 252 | """ 253 | This method is implemented in the subclasses, i.e., `Pruner` and `Quantizer` 254 | 255 | Parameters 256 | ---------- 257 | layer : LayerInfo 258 | the layer to instrument the compression operation 259 | config : dict 260 | the configuration for compressing this layer 261 | """ 262 | raise NotImplementedError() 263 | 264 | def add_activation_collector(self, collector): 265 | self._fwd_hook_id += 1 266 | self._fwd_hook_handles[self._fwd_hook_id] = [] 267 | for wrapper in self.get_modules_wrapper(): 268 | handle = wrapper.register_forward_hook(collector) 269 | self._fwd_hook_handles[self._fwd_hook_id].append(handle) 270 | return self._fwd_hook_id 271 | 272 | def remove_activation_collector(self, fwd_hook_id): 273 | if fwd_hook_id not in self._fwd_hook_handles: 274 | raise ValueError("%s is not a valid collector id" % 275 | str(fwd_hook_id)) 276 | for handle in self._fwd_hook_handles[fwd_hook_id]: 277 | handle.remove() 278 | del self._fwd_hook_handles[fwd_hook_id] 279 | 280 | 281 | class Quantizer(Compressor): 282 | """ 283 | Base quantizer for pytorch quantizer 284 | """ 285 | 286 | def __init__(self, model, config_list): 287 | super().__init__(model, config_list) 288 | self.quant_grad = QuantGrad 289 | 290 | def quantize_weight(self, wrapper, **kwargs): 291 | """ 292 | quantize should overload this method to quantize weight. 293 | This method is effectively hooked to :meth:`forward` of the model. 294 | Parameters 295 | ---------- 296 | wrapper : QuantizerModuleWrapper 297 | the wrapper for origin module 298 | """ 299 | raise NotImplementedError('Quantizer must overload quantize_weight()') 300 | 301 | def quantize_output(self, output, wrapper, **kwargs): 302 | """ 303 | quantize should overload this method to quantize output. 304 | This method is effectively hooked to :meth:`forward` of the model. 305 | Parameters 306 | ---------- 307 | output : Tensor 308 | output that needs to be quantized 309 | wrapper : QuantizerModuleWrapper 310 | the wrapper for origin module 311 | """ 312 | raise NotImplementedError('Quantizer must overload quantize_output()') 313 | 314 | def quantize_input(self, *inputs, wrapper, **kwargs): 315 | """ 316 | quantize should overload this method to quantize input. 317 | This method is effectively hooked to :meth:`forward` of the model. 318 | Parameters 319 | ---------- 320 | inputs : Tensor 321 | inputs that needs to be quantized 322 | wrapper : QuantizerModuleWrapper 323 | the wrapper for origin module 324 | """ 325 | raise NotImplementedError('Quantizer must overload quantize_input()') 326 | 327 | def _wrap_modules(self, layer, config): 328 | """ 329 | Create a wrapper forward function to replace the original one. 330 | Parameters 331 | ---------- 332 | layer : LayerInfo 333 | the layer to instrument the mask 334 | config : dict 335 | the configuration for quantization 336 | """ 337 | assert 'quant_types' in config, 'must provide quant_types in config' 338 | assert isinstance(config['quant_types'], 339 | list), 'quant_types must be list type' 340 | assert 'quant_bits' in config, 'must provide quant_bits in config' 341 | assert isinstance(config['quant_bits'], int) or isinstance( 342 | config['quant_bits'], dict), 'quant_bits must be dict type or int type' 343 | 344 | if isinstance(config['quant_bits'], dict): 345 | for quant_type in config['quant_types']: 346 | assert quant_type in config['quant_bits'], 'bits length for %s must be specified in quant_bits dict' % quant_type 347 | 348 | return QuantizerModuleWrapper(layer.module, layer.name, layer.type, config, self) 349 | 350 | 351 | class QuantType: 352 | """ 353 | Enum class for quantization type. 354 | """ 355 | QUANT_INPUT = 0 356 | QUANT_WEIGHT = 1 357 | QUANT_OUTPUT = 2 358 | 359 | QType_Dict = { 360 | 0: "input", 361 | 1: "weight", 362 | 2: "output" 363 | } 364 | 365 | class QuantGrad(torch.autograd.Function): 366 | """ 367 | Base class for overriding backward function of quantization operation. 368 | """ 369 | @classmethod 370 | def _quantize(cls, x, scale, zero_point): 371 | """ 372 | Reference function for quantizing x -- non-clamped. 373 | Parameters 374 | ---------- 375 | x : Tensor 376 | tensor to be quantized 377 | scale : Tensor 378 | scale for quantizing x 379 | zero_point : Tensor 380 | zero_point for quantizing x 381 | Returns 382 | ------- 383 | tensor 384 | quantized x without clamped 385 | """ 386 | return ((x / scale) + zero_point).round() 387 | @classmethod 388 | def get_bits_length(cls, config, quant_type): 389 | """ 390 | Get bit for quantize config 391 | Parameters 392 | ---------- 393 | config : Dict 394 | the configuration for quantization 395 | quant_type : str 396 | quant type 397 | Returns 398 | ------- 399 | int 400 | n-bits for quantization configuration 401 | """ 402 | if isinstance(config["quant_bits"], int): 403 | return config["quant_bits"] 404 | else: 405 | return config["quant_bits"].get(quant_type) 406 | 407 | @staticmethod 408 | def quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax): 409 | """ 410 | This method should be overrided by subclass to provide customized backward function, 411 | default implementation is Straight-Through Estimator 412 | Parameters 413 | ---------- 414 | tensor : Tensor 415 | input of quantization operation 416 | grad_output : Tensor 417 | gradient of the output of quantization operation 418 | scale : Tensor 419 | the type of quantization, it can be `QuantType.QUANT_INPUT`, `QuantType.QUANT_WEIGHT`, `QuantType.QUANT_OUTPUT`, 420 | you can define different behavior for different types. 421 | zero_point : Tensor 422 | zero_point for quantizing tensor 423 | qmin : Tensor 424 | quant_min for quantizing tensor 425 | qmax : Tensor 426 | quant_max for quantizng tensor 427 | Returns 428 | ------- 429 | tensor 430 | gradient of the input of quantization operation 431 | """ 432 | return grad_output 433 | 434 | @staticmethod 435 | def forward(ctx, tensor, quant_type, wrapper, input_tensor=None, **kwargs): 436 | if quant_type == QuantType.QUANT_INPUT: 437 | output = wrapper.quantizer.quantize_input(tensor, wrapper, **kwargs) 438 | elif quant_type == QuantType.QUANT_WEIGHT: 439 | output = wrapper.quantizer.quantize_weight(wrapper, input_tensor=input_tensor, **kwargs) 440 | elif quant_type == QuantType.QUANT_OUTPUT: 441 | output = wrapper.quantizer.quantize_output(tensor, wrapper, **kwargs) 442 | else: 443 | raise ValueError("unrecognized QuantType.") 444 | 445 | 446 | bits = QuantGrad.get_bits_length(wrapper.config, QType_Dict[quant_type]) 447 | qmin, qmax = torch.Tensor([0]).to(tensor.device), torch.Tensor([(1 << bits) - 1]).to(tensor.device) 448 | if hasattr(wrapper.module, 'scale') and hasattr(wrapper.module, 'zero_point'): 449 | scale = wrapper.module.scale 450 | zero_point = wrapper.module.zero_point 451 | else: 452 | scale, zero_point = None, None 453 | ctx.save_for_backward(tensor, torch.Tensor([quant_type]), scale, zero_point, qmin, qmax) 454 | return output 455 | 456 | @classmethod 457 | def backward(cls, ctx, grad_output): 458 | tensor, quant_type, scale, zero_point, qmin, qmax = ctx.saved_variables 459 | output = cls.quant_backward(tensor, grad_output, quant_type, scale, zero_point, qmin, qmax) 460 | return output, None, None, None 461 | --------------------------------------------------------------------------------