├── LICENSE
├── README.md
├── checkpoints
└── log.txt
├── datasets
├── images.txt
├── voc2007.txt
├── voc2007test.txt
└── voc2012.txt
├── eval_voc.py
├── imgs
├── person.jpg
└── person_result.jpg
├── models
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ └── resnet_yolo.cpython-37.pyc
├── resnet_yolo.py
├── vgg_yolo.py
└── yoloLoss.py
├── predict.py
├── requirements.txt
├── train.py
└── utils
├── __init__.py
├── dataset.py
├── piplist2equal.py
└── xml2txt.py
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # yolo v1_pytorch
2 | ## 项目目录
3 | ```text
4 | |--checkpoints 日志和权重
5 | |--log.txt
6 | |--best.pth
7 | |--datasets 数据集
8 | |--images.txt
9 | |--voc2007.txt
10 | |--voc2012.txt
11 | |--voc2007test.txt
12 | |--images/
13 | |--XXXX.jpg
14 | |--...
15 | |--imgs 测试图片
16 | |--person.jpg
17 | |--models 模型
18 | |--resnet_yolo.py
19 | |--vgg_yolo.py
20 | |--yoloLoss.py
21 | |--utils 工具
22 | |--dataset.py
23 | |--piplist2equal.py
24 | |--xml2txt.py
25 | |--train.py
26 | |--predict.py
27 | |--eval_voc.py
28 | |--requirements.txt 环境
29 |
30 |
31 | ```
32 |
33 | ## 1. 环境准备
34 | 详细内容参考[requirements.txt](requirements.txt)
35 |
36 | ## 2. 数据集准备
37 | **下载数据集**
38 | 链接: https://pan.baidu.com/s/1hturxvztlt_ePnZt3TTzWQ 密码: 6qgn
39 |
40 | **解压数据集**
41 | 1、将voc2007和voc2012的所有图片放到`datasets/images`目录下。
42 |
43 | 2、然后使用`utils/xml2txt.pyt`将xml文件转成txt格式。此步骤结果已经保存,可省略。
44 |
45 | ## 3.训练
46 | ```shell
47 | python train.py
48 | ```
49 | ## 4.预测
50 | ```text
51 | python predict.py
52 | ```
53 | ## 5.计算mAP
54 | ```text
55 | python eval_voc.py
56 | ```
57 |
58 | 
59 |
60 | ## 附
61 | 对于初学者或硬件条件限制问题,提供训练后的权重.pth文件。
62 |
63 | 【获取方式】
64 | 关注公众号 【OAOA】,回复【0813】获取百度盘链接。
65 |
66 | ## 参考
67 | 原项目地址:https://github.com/FelixFu520/yolov1
--------------------------------------------------------------------------------
/checkpoints/log.txt:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isbrycee/yolov1_pytorch/ecb0ef944ec38fa5b887263d61654f714ca07108/checkpoints/log.txt
--------------------------------------------------------------------------------
/eval_voc.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : Bryce
5 | @File : eval_voc.py
6 | @Noice :
7 | @Modificattion :
8 | @Author :
9 | @Time :
10 | @Detail :
11 | """
12 | import os
13 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14 | import numpy as np
15 |
16 |
17 | VOC_CLASSES = ( # always index 0
18 | 'aeroplane', 'bicycle', 'bird', 'boat',
19 | 'bottle', 'bus', 'car', 'cat', 'chair',
20 | 'cow', 'diningtable', 'dog', 'horse',
21 | 'motorbike', 'person', 'pottedplant',
22 | 'sheep', 'sofa', 'train', 'tvmonitor')
23 | Color = [
24 | [0, 0, 0],
25 | [128, 0, 0],
26 | [0, 128, 0],
27 | [128, 128, 0],
28 | [0, 0, 128],
29 | [128, 0, 128],
30 | [0, 128, 128],
31 | [128, 128, 128],
32 | [64, 0, 0],
33 | [192, 0, 0],
34 | [64, 128, 0],
35 | [192, 128, 0],
36 | [64, 0, 128],
37 | [192, 0, 128],
38 | [64, 128, 128],
39 | [192, 128, 128],
40 | [0, 64, 0],
41 | [128, 64, 0],
42 | [0, 192, 0],
43 | [128, 192, 0],
44 | [0, 64, 128]
45 | ]
46 |
47 |
48 | def voc_ap(rec, prec, use_07_metric=False):
49 | if use_07_metric:
50 | # 11 point metric
51 | ap = 0.
52 | for t in np.arange(0., 1.1, 0.1):
53 | if np.sum(rec >= t) == 0:
54 | p = 0
55 | else:
56 | p = np.max(prec[rec >= t])
57 | ap = ap + p/11.
58 | else:
59 | # correct ap caculation
60 | mrec = np.concatenate(([0.], rec, [1.]))
61 | mpre = np.concatenate(([0.], prec, [0.]))
62 |
63 | for i in range(mpre.size - 1, 0, -1):
64 | mpre[i-1] = np.maximum(mpre[i-1], mpre[i])
65 |
66 | i = np.where(mrec[1:] != mrec[:-1])[0]
67 |
68 | ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
69 |
70 | return ap
71 |
72 |
73 | def voc_eval(preds, target, VOC_CLASSES=VOC_CLASSES, threshold=0.5, use_07_metric=False,):
74 | '''
75 | preds {'cat':[[image_id,confidence,x1,y1,x2,y2],...],'dog':[[],...]}
76 | target {(image_id,class):[[],]}
77 | '''
78 | aps = []
79 | for i,class_ in enumerate(VOC_CLASSES):
80 | pred = preds[class_] # [[image_id,confidence,x1,y1,x2,y2],...]
81 | if len(pred) == 0: # 如果这个类别一个都没有检测到的异常情况
82 | ap = -1
83 | print('---class {} ap {}---'.format(class_,ap))
84 | aps += [ap]
85 | break
86 | #print(pred)
87 | image_ids = [x[0] for x in pred]
88 | confidence = np.array([float(x[1]) for x in pred])
89 | BB = np.array([x[2:] for x in pred])
90 | # sort by confidence
91 | sorted_ind = np.argsort(-confidence)
92 | sorted_scores = np.sort(-confidence)
93 | BB = BB[sorted_ind, :]
94 | image_ids = [image_ids[x] for x in sorted_ind]
95 |
96 | # go down dets and mark TPs and FPs
97 | npos = 0.
98 | for (key1,key2) in target:
99 | if key2 == class_:
100 | npos += len(target[(key1,key2)]) #统计这个类别的正样本,在这里统计才不会遗漏
101 | nd = len(image_ids)
102 | tp = np.zeros(nd)
103 | fp = np.zeros(nd)
104 | for d,image_id in enumerate(image_ids):
105 | bb = BB[d] #预测框
106 | if (image_id,class_) in target:
107 | BBGT = target[(image_id,class_)] #[[],]
108 | for bbgt in BBGT:
109 | # compute overlaps
110 | # intersection
111 | ixmin = np.maximum(bbgt[0], bb[0])
112 | iymin = np.maximum(bbgt[1], bb[1])
113 | ixmax = np.minimum(bbgt[2], bb[2])
114 | iymax = np.minimum(bbgt[3], bb[3])
115 | iw = np.maximum(ixmax - ixmin + 1., 0.)
116 | ih = np.maximum(iymax - iymin + 1., 0.)
117 | inters = iw * ih
118 |
119 | union = (bb[2]-bb[0]+1.)*(bb[3]-bb[1]+1.) + (bbgt[2]-bbgt[0]+1.)*(bbgt[3]-bbgt[1]+1.) - inters
120 | if union == 0:
121 | print(bb,bbgt)
122 |
123 | overlaps = inters/union
124 | if overlaps > threshold:
125 | tp[d] = 1
126 | BBGT.remove(bbgt) #这个框已经匹配到了,不能再匹配
127 | if len(BBGT) == 0:
128 | del target[(image_id,class_)] #删除没有box的键值
129 | break
130 | fp[d] = 1-tp[d]
131 | else:
132 | fp[d] = 1
133 | fp = np.cumsum(fp)
134 | tp = np.cumsum(tp)
135 | rec = tp/float(npos)
136 | prec = tp/np.maximum(tp + fp, np.finfo(np.float64).eps)
137 | #print(rec,prec)
138 | ap = voc_ap(rec, prec, use_07_metric)
139 | print('---class {} ap {}---'.format(class_,ap))
140 | aps += [ap]
141 | print('---map {}---'.format(np.mean(aps)))
142 |
143 |
144 | def test_eval():
145 | preds = {'cat':[['image01',0.9,20,20,40,40],['image01',0.8,20,20,50,50],['image02',0.8,30,30,50,50]],'dog':[['image01',0.78,60,60,90,90]]}
146 | target = {('image01','cat'):[[20,20,41,41]],('image01','dog'):[[60,60,91,91]],('image02','cat'):[[30,30,51,51]]}
147 | voc_eval(preds,target,VOC_CLASSES=['cat','dog'])
148 |
149 |
150 | if __name__ == '__main__':
151 | #test_eval()
152 | from predict import *
153 | from collections import defaultdict
154 | from tqdm import tqdm
155 |
156 | target = defaultdict(list)
157 | preds = defaultdict(list)
158 | image_list = [] #image path list
159 |
160 | f = open('datasets/voc2007test.txt')
161 | lines = f.readlines()
162 | file_list = []
163 | for line in lines:
164 | splited = line.strip().split()
165 | file_list.append(splited)
166 | f.close()
167 | print('---prepare target---')
168 | for index,image_file in enumerate(file_list):
169 | image_id = image_file[0]
170 |
171 | image_list.append(image_id)
172 | num_obj = (len(image_file) - 1) // 5
173 | for i in range(num_obj):
174 | x1 = int(image_file[1+5*i])
175 | y1 = int(image_file[2+5*i])
176 | x2 = int(image_file[3+5*i])
177 | y2 = int(image_file[4+5*i])
178 | c = int(image_file[5+5*i])
179 | class_name = VOC_CLASSES[c]
180 | target[(image_id,class_name)].append([x1,y1,x2,y2])
181 | #
182 | #start test
183 | #
184 | print('---start test---')
185 | # model = vgg16_bn(pretrained=False)
186 | model = resnet50()
187 | # model.classifier = nn.Sequential(
188 | # nn.Linear(512 * 7 * 7, 4096),
189 | # nn.ReLU(True),
190 | # nn.Dropout(),
191 | # #nn.Linear(4096, 4096),
192 | # #nn.ReLU(True),
193 | # #nn.Dropout(),
194 | # nn.Linear(4096, 1470),
195 | # )
196 | model.load_state_dict(torch.load('checkpoints/best.pth'))
197 | model.eval()
198 | model.cuda()
199 | count = 0
200 | for image_path in tqdm(image_list):
201 | result = predict_gpu(model,image_path,root_path='datasets/images/') #result[[left_up,right_bottom,class_name,image_path],]
202 | for (x1,y1),(x2,y2),class_name,image_id,prob in result: #image_id is actually image_path
203 | preds[class_name].append([image_id,prob,x1,y1,x2,y2])
204 | # print(image_path)
205 | # image = cv2.imread('/home/xzh/data/VOCdevkit/VOC2012/allimgs/'+image_path)
206 | # for left_up,right_bottom,class_name,_,prob in result:
207 | # color = Color[VOC_CLASSES.index(class_name)]
208 | # cv2.rectangle(image,left_up,right_bottom,color,2)
209 | # label = class_name+str(round(prob,2))
210 | # text_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
211 | # p1 = (left_up[0], left_up[1]- text_size[1])
212 | # cv2.rectangle(image, (p1[0] - 2//2, p1[1] - 2 - baseline), (p1[0] + text_size[0], p1[1] + text_size[1]), color, -1)
213 | # cv2.putText(image, label, (p1[0], p1[1] + baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255,255,255), 1, 8)
214 |
215 | # cv2.imwrite('testimg/'+image_path,image)
216 | # count += 1
217 | # if count == 100:
218 | # break
219 |
220 | print('---start evaluate---')
221 | voc_eval(preds,target,VOC_CLASSES=VOC_CLASSES)
--------------------------------------------------------------------------------
/imgs/person.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isbrycee/yolov1_pytorch/ecb0ef944ec38fa5b887263d61654f714ca07108/imgs/person.jpg
--------------------------------------------------------------------------------
/imgs/person_result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isbrycee/yolov1_pytorch/ecb0ef944ec38fa5b887263d61654f714ca07108/imgs/person_result.jpg
--------------------------------------------------------------------------------
/models/__init__.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/python
2 | # -*- coding:utf-8 -*-
3 | # author: Felix Fu
--------------------------------------------------------------------------------
/models/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isbrycee/yolov1_pytorch/ecb0ef944ec38fa5b887263d61654f714ca07108/models/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/models/__pycache__/resnet_yolo.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/isbrycee/yolov1_pytorch/ecb0ef944ec38fa5b887263d61654f714ca07108/models/__pycache__/resnet_yolo.cpython-37.pyc
--------------------------------------------------------------------------------
/models/resnet_yolo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : Bryce
5 | @File : resnet_yolo.py
6 | @Noice :
7 | @Modificattion :
8 | @Author :
9 | @Time :
10 | @Detail :
11 | """
12 |
13 | import torch.nn as nn
14 | import math
15 | import torch.utils.model_zoo as model_zoo
16 | import torch.nn.functional as F
17 | import torch
18 |
19 |
20 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152']
21 |
22 |
23 | model_urls = {
24 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
25 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
26 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
27 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
28 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
29 | }
30 |
31 |
32 | def conv3x3(in_planes, out_planes, stride=1):
33 | """3x3 convolution with padding
34 | :param in_planes:
35 | :param out_planes:
36 | :param stride:
37 | :return:
38 | """
39 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
40 | padding=1, bias=False)
41 |
42 |
43 | class BasicBlock(nn.Module):
44 | expansion = 1
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None):
47 | super(BasicBlock, self).__init__()
48 | self.conv1 = conv3x3(inplanes, planes, stride)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.relu = nn.ReLU(inplace=True)
51 | self.conv2 = conv3x3(planes, planes)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.downsample = downsample
54 | self.stride = stride
55 |
56 | def forward(self, x):
57 | residual = x
58 |
59 | out = self.conv1(x)
60 | out = self.bn1(out)
61 | out = self.relu(out)
62 |
63 | out = self.conv2(out)
64 | out = self.bn2(out)
65 |
66 | if self.downsample is not None:
67 | residual = self.downsample(x)
68 |
69 | out += residual
70 | out = self.relu(out)
71 |
72 | return out
73 |
74 |
75 | class Bottleneck(nn.Module):
76 | expansion = 4
77 |
78 | def __init__(self, inplanes, planes, stride=1, downsample=None):
79 | super(Bottleneck, self).__init__()
80 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
81 | self.bn1 = nn.BatchNorm2d(planes)
82 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
83 | padding=1, bias=False)
84 | self.bn2 = nn.BatchNorm2d(planes)
85 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
86 | self.bn3 = nn.BatchNorm2d(planes * 4)
87 | self.relu = nn.ReLU(inplace=True)
88 | self.downsample = downsample
89 | self.stride = stride
90 |
91 | def forward(self, x):
92 | residual = x
93 |
94 | out = self.conv1(x)
95 | out = self.bn1(out)
96 | out = self.relu(out)
97 |
98 | out = self.conv2(out)
99 | out = self.bn2(out)
100 | out = self.relu(out)
101 |
102 | out = self.conv3(out)
103 | out = self.bn3(out)
104 |
105 | if self.downsample is not None:
106 | residual = self.downsample(x)
107 |
108 | out += residual
109 | out = self.relu(out)
110 |
111 | return out
112 |
113 |
114 | class detnet_bottleneck(nn.Module):
115 | # no expansion
116 | # dilation = 2
117 | # type B use 1x1 conv
118 | expansion = 1
119 |
120 | def __init__(self, in_planes, planes, stride=1, block_type='A'):
121 | super(detnet_bottleneck, self).__init__()
122 | self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
123 | self.bn1 = nn.BatchNorm2d(planes)
124 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=2, bias=False, dilation=2)
125 | self.bn2 = nn.BatchNorm2d(planes)
126 | self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)
127 | self.bn3 = nn.BatchNorm2d(self.expansion*planes)
128 |
129 | self.downsample = nn.Sequential()
130 | if stride != 1 or in_planes != self.expansion*planes or block_type == 'B':
131 | self.downsample = nn.Sequential(
132 | nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
133 | nn.BatchNorm2d(self.expansion*planes)
134 | )
135 |
136 | def forward(self, x):
137 | out = F.relu(self.bn1(self.conv1(x)))
138 | out = F.relu(self.bn2(self.conv2(out)))
139 | out = self.bn3(self.conv3(out))
140 | out += self.downsample(x)
141 | out = F.relu(out)
142 | return out
143 |
144 |
145 | class ResNet(nn.Module):
146 |
147 | def __init__(self, block, layers, num_classes=1470):
148 | self.inplanes = 64
149 | super(ResNet, self).__init__()
150 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
151 | bias=False)
152 | self.bn1 = nn.BatchNorm2d(64)
153 | self.relu = nn.ReLU(inplace=True)
154 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
155 | self.layer1 = self._make_layer(block, 64, layers[0])
156 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
157 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
158 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
159 | self.layer5 = self._make_detnet_layer(in_channels=2048)
160 | # self.avgpool = nn.AvgPool2d(14) #fit 448 input size
161 | # self.fc = nn.Linear(512 * block.expansion, num_classes)
162 | self.conv_end = nn.Conv2d(256, 30, kernel_size=3, stride=1, padding=1, bias=False)
163 | self.bn_end = nn.BatchNorm2d(30)
164 | for m in self.modules(): # 遍历模型
165 | if isinstance(m, nn.Conv2d): # isinstance:m类型判断 若当前组件为 conv
166 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
167 | m.weight.data.normal_(0, math.sqrt(2. / n)) # 正太分布初始化
168 | elif isinstance(m, nn.BatchNorm2d): # 若为batchnorm
169 | m.weight.data.fill_(1) # weight为1
170 | m.bias.data.zero_() # bias为0
171 |
172 | def _make_layer(self, block, planes, blocks, stride=1):
173 | downsample = None
174 | if stride != 1 or self.inplanes != planes * block.expansion:
175 | # 步长为2时,即第一次进入layer时,增加downsample层。
176 | # 或者inplans(输入通道数) 不等于 block.expansion倍的planes = (输出通道数)
177 | downsample = nn.Sequential(
178 | nn.Conv2d(self.inplanes, planes * block.expansion,
179 | kernel_size=1, stride=stride, bias=False),
180 | nn.BatchNorm2d(planes * block.expansion),
181 | )
182 |
183 | layers = []
184 | layers.append(block(self.inplanes, planes, stride, downsample))
185 | self.inplanes = planes * block.expansion
186 | for i in range(1, blocks):
187 | layers.append(block(self.inplanes, planes))
188 |
189 | return nn.Sequential(*layers)
190 |
191 | def _make_detnet_layer(self, in_channels):
192 | layers = []
193 | layers.append(detnet_bottleneck(in_planes=in_channels, planes=256, block_type='B'))
194 | layers.append(detnet_bottleneck(in_planes=256, planes=256, block_type='A'))
195 | layers.append(detnet_bottleneck(in_planes=256, planes=256, block_type='A'))
196 | return nn.Sequential(*layers)
197 |
198 | def forward(self, x):
199 | x = self.conv1(x)
200 | x = self.bn1(x)
201 | x = self.relu(x)
202 | x = self.maxpool(x)
203 |
204 | x = self.layer1(x)
205 | x = self.layer2(x)
206 | x = self.layer3(x)
207 | x = self.layer4(x)
208 | x = self.layer5(x)
209 | # x = self.avgpool(x)
210 | # x = x.view(x.size(0), -1)
211 | # x = self.fc(x)
212 | x = self.conv_end(x)
213 | x = self.bn_end(x)
214 | # x = F.sigmoid(x) # 归一化到0-1
215 | x = torch.sigmoid(x)
216 | # x = x.view(-1,7,7,30)
217 | x = x.permute(0, 2, 3, 1) # (-1,7,7,30)
218 |
219 | return x
220 |
221 |
222 | def resnet18(pretrained=False, **kwargs):
223 | """Constructs a ResNet-18 model.
224 |
225 | Args:
226 | pretrained (bool): If True, returns a model pre-trained on ImageNet
227 | """
228 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
229 | if pretrained:
230 | model.load_state_dict(model_zoo.load_url(model_urls['resnet18']))
231 | return model
232 |
233 |
234 | def resnet34(pretrained=False, **kwargs):
235 | """Constructs a ResNet-34 model.
236 |
237 | Args:
238 | pretrained (bool): If True, returns a model pre-trained on ImageNet
239 | """
240 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
241 | if pretrained:
242 | model.load_state_dict(model_zoo.load_url(model_urls['resnet34']))
243 | return model
244 |
245 |
246 | def resnet50(pretrained=False, **kwargs):
247 | """Constructs a ResNet-50 model.
248 |
249 | Args:
250 | pretrained (bool): If True, returns a model pre-trained on ImageNet
251 | """
252 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
253 | if pretrained:
254 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50']))
255 | return model
256 |
257 |
258 | def resnet101(pretrained=False, **kwargs):
259 | """Constructs a ResNet-101 model.
260 |
261 | Args:
262 | pretrained (bool): If True, returns a model pre-trained on ImageNet
263 | """
264 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
265 | if pretrained:
266 | model.load_state_dict(model_zoo.load_url(model_urls['resnet101']))
267 | return model
268 |
269 |
270 | def resnet152(pretrained=False, **kwargs):
271 | """Constructs a ResNet-152 model.
272 |
273 | Args:
274 | pretrained (bool): If True, returns a model pre-trained on ImageNet
275 | """
276 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
277 | if pretrained:
278 | model.load_state_dict(model_zoo.load_url(model_urls['resnet152']))
279 | return model
280 |
281 |
282 | if __name__ == "__main__":
283 | from torchsummary import summary
284 | net = resnet50().cuda()
285 | summary(net, (3, 418, 418))
286 | # print(net)
--------------------------------------------------------------------------------
/models/vgg_yolo.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : Bryce
5 | @File : vgg_yolo.py
6 | @Noice :
7 | @Modificattion :
8 | @Author :
9 | @Time :
10 | @Detail :
11 | """
12 | import torch.nn as nn
13 | import torch.utils.model_zoo as model_zoo
14 | import math
15 | import torch
16 | import torch.nn.functional as F
17 |
18 |
19 | __all__ = [
20 | 'VGG', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
21 | 'vgg19_bn', 'vgg19',
22 | ]
23 |
24 |
25 | model_urls = {
26 | 'vgg11': 'https://download.pytorch.org/models/vgg11-bbd30ac9.pth',
27 | 'vgg13': 'https://download.pytorch.org/models/vgg13-c768596a.pth',
28 | 'vgg16': 'https://download.pytorch.org/models/vgg16-397923af.pth',
29 | 'vgg19': 'https://download.pytorch.org/models/vgg19-dcbb9e9d.pth',
30 | 'vgg11_bn': 'https://download.pytorch.org/models/vgg11_bn-6002323d.pth',
31 | 'vgg13_bn': 'https://download.pytorch.org/models/vgg13_bn-abd245e5.pth',
32 | 'vgg16_bn': 'https://download.pytorch.org/models/vgg16_bn-6c64b313.pth',
33 | 'vgg19_bn': 'https://download.pytorch.org/models/vgg19_bn-c79401a0.pth',
34 | }
35 |
36 |
37 | class VGG(nn.Module):
38 |
39 | def __init__(self, features, num_classes=1000, image_size=448):
40 | super(VGG, self).__init__()
41 | self.features = features
42 | self.image_size = image_size
43 | # self.classifier = nn.Sequential(
44 | # nn.Linear(512 * 7 * 7, 4096),
45 | # nn.ReLU(True),
46 | # nn.Dropout(),
47 | # nn.Linear(4096, 4096),
48 | # nn.ReLU(True),
49 | # nn.Dropout(),
50 | # nn.Linear(4096, num_classes),
51 | # )
52 | # if self.image_size == 448:
53 | # self.extra_conv1 = conv_bn_relu(512,512)
54 | # self.extra_conv2 = conv_bn_relu(512,512)
55 | # self.downsample = nn.MaxPool2d(kernel_size=2, stride=2)
56 | self.classifier = nn.Sequential(
57 | nn.Linear(512 * 7 * 7, 4096),
58 | nn.ReLU(True),
59 | nn.Dropout(),
60 | nn.Linear(4096, 1470),
61 | )
62 | self._initialize_weights()
63 |
64 | def forward(self, x):
65 | x = self.features(x)
66 | # if self.image_size == 448:
67 | # x = self.extra_conv1(x)
68 | # x = self.extra_conv2(x)
69 | # x = self.downsample(x)
70 | x = x.view(x.size(0), -1)
71 | x = self.classifier(x)
72 | # x = F.sigmoid(x) #归一化到0-1
73 | x = torch.sigmoid(x)
74 | x = x.view(-1,7,7,30)
75 | return x
76 |
77 | def _initialize_weights(self):
78 | for m in self.modules():
79 | if isinstance(m, nn.Conv2d):
80 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
81 | m.weight.data.normal_(0, math.sqrt(2. / n))
82 | if m.bias is not None:
83 | m.bias.data.zero_()
84 | elif isinstance(m, nn.BatchNorm2d):
85 | m.weight.data.fill_(1)
86 | m.bias.data.zero_()
87 | elif isinstance(m, nn.Linear):
88 | m.weight.data.normal_(0, 0.01)
89 | m.bias.data.zero_()
90 |
91 |
92 | def make_layers(cfg, batch_norm=False):
93 | layers = []
94 | in_channels = 3
95 | s = 1
96 | first_flag=True
97 | for v in cfg:
98 | s=1
99 | if (v==64 and first_flag):
100 | s=2
101 | first_flag=False
102 | if v == 'M':
103 | layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
104 | else:
105 | conv2d = nn.Conv2d(in_channels, v, kernel_size=3, stride=s, padding=1)
106 | if batch_norm:
107 | layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
108 | else:
109 | layers += [conv2d, nn.ReLU(inplace=True)]
110 | in_channels = v
111 | return nn.Sequential(*layers)
112 |
113 | def conv_bn_relu(in_channels,out_channels,kernel_size=3,stride=2,padding=1):
114 | return nn.Sequential(
115 | nn.Conv2d(in_channels,out_channels,kernel_size=kernel_size,padding=padding,stride=stride),
116 | nn.BatchNorm2d(out_channels),
117 | nn.ReLU(True)
118 | )
119 |
120 |
121 | cfg = {
122 | 'A': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
123 | 'B': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
124 | 'D': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
125 | 'E': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
126 | }
127 |
128 |
129 | def vgg11(pretrained=False, **kwargs):
130 | """VGG 11-layer model (configuration "A")
131 |
132 | Args:
133 | pretrained (bool): If True, returns a model pre-trained on ImageNet
134 | """
135 | model = VGG(make_layers(cfg['A']), **kwargs)
136 | if pretrained:
137 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11']))
138 | return model
139 |
140 |
141 | def vgg11_bn(pretrained=False, **kwargs):
142 | """VGG 11-layer model (configuration "A") with batch normalization
143 |
144 | Args:
145 | pretrained (bool): If True, returns a model pre-trained on ImageNet
146 | """
147 | model = VGG(make_layers(cfg['A'], batch_norm=True), **kwargs)
148 | if pretrained:
149 | model.load_state_dict(model_zoo.load_url(model_urls['vgg11_bn']))
150 | return model
151 |
152 |
153 | def vgg13(pretrained=False, **kwargs):
154 | """VGG 13-layer model (configuration "B")
155 |
156 | Args:
157 | pretrained (bool): If True, returns a model pre-trained on ImageNet
158 | """
159 | model = VGG(make_layers(cfg['B']), **kwargs)
160 | if pretrained:
161 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13']))
162 | return model
163 |
164 |
165 | def vgg13_bn(pretrained=False, **kwargs):
166 | """VGG 13-layer model (configuration "B") with batch normalization
167 |
168 | Args:
169 | pretrained (bool): If True, returns a model pre-trained on ImageNet
170 | """
171 | model = VGG(make_layers(cfg['B'], batch_norm=True), **kwargs)
172 | if pretrained:
173 | model.load_state_dict(model_zoo.load_url(model_urls['vgg13_bn']))
174 | return model
175 |
176 |
177 | def vgg16(pretrained=False, **kwargs):
178 | """VGG 16-layer model (configuration "D")
179 |
180 | Args:
181 | pretrained (bool): If True, returns a model pre-trained on ImageNet
182 | """
183 | model = VGG(make_layers(cfg['D']), **kwargs)
184 | if pretrained:
185 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16']))
186 | return model
187 |
188 |
189 | def vgg16_bn(pretrained=False, **kwargs):
190 | """VGG 16-layer model (configuration "D") with batch normalization
191 |
192 | Args:
193 | pretrained (bool): If True, returns a model pre-trained on ImageNet
194 | """
195 | model = VGG(make_layers(cfg['D'], batch_norm=True), **kwargs)
196 | if pretrained:
197 | model.load_state_dict(model_zoo.load_url(model_urls['vgg16_bn']))
198 | return model
199 |
200 |
201 | def vgg19(pretrained=False, **kwargs):
202 | """VGG 19-layer model (configuration "E")
203 |
204 | Args:
205 | pretrained (bool): If True, returns a model pre-trained on ImageNet
206 | """
207 | model = VGG(make_layers(cfg['E']), **kwargs)
208 | if pretrained:
209 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
210 | return model
211 |
212 |
213 | def vgg19_bn(pretrained=False, **kwargs):
214 | """VGG 19-layer model (configuration 'E') with batch normalization
215 |
216 | Args:
217 | pretrained (bool): If True, returns a model pre-trained on ImageNet
218 | """
219 | model = VGG(make_layers(cfg['E'], batch_norm=True), **kwargs)
220 | if pretrained:
221 | model.load_state_dict(model_zoo.load_url(model_urls['vgg19_bn']))
222 | return model
223 |
224 |
225 | def test():
226 | import torch
227 | from torch.autograd import Variable
228 | model = vgg16()
229 | model.classifier = nn.Sequential(
230 | nn.Linear(512 * 7 * 7, 4096),
231 | nn.ReLU(True),
232 | nn.Dropout(),
233 | nn.Linear(4096, 4096),
234 | nn.ReLU(True),
235 | nn.Dropout(),
236 | nn.Linear(4096, 1470),
237 | )
238 | print(model.classifier[6])
239 | #print(model)
240 | img = torch.rand(2,3,224,224)
241 | img = Variable(img)
242 | output = model(img)
243 | print(output.size())
244 |
245 |
246 | if __name__ == '__main__':
247 | test()
--------------------------------------------------------------------------------
/models/yoloLoss.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : FelixFu / Bryce
5 | @File : yoloLoss.py
6 | @Noice :
7 | @Modificattion :
8 | @Detail : a little dufficult in builting yoloLoss funcion
9 | """
10 |
11 | import torch
12 | import torch.nn as nn
13 | import torch.nn.functional as F
14 |
15 |
16 | class yoloLoss(nn.Module):
17 | def __init__(self, S, B, l_coord, l_noobj):
18 | super(yoloLoss, self).__init__()
19 | self.S = S
20 | self.B = B
21 | self.l_coord = l_coord
22 | self.l_noobj = l_noobj
23 |
24 | def compute_iou(self, box1, box2):
25 | """Compute the intersection over union of two set of boxes, each box is [x1,y1,x2,y2].
26 | Args:
27 | box1: (tensor) bounding boxes, sized [N,4].
28 | box2: (tensor) bounding boxes, sized [M,4].
29 | Return:
30 | (tensor) iou, sized [N,M].
31 | """
32 | # 首先计算两个box左上角点坐标的最大值和右下角坐标的最小值,然后计算交集面积,最后把交集面积除以对应的并集面积
33 | N = box1.size(0)
34 | M = box2.size(0)
35 |
36 | lt = torch.max( # 左上角的点
37 | box1[:, :2].unsqueeze(1).expand(N, M, 2), # [N,2] -> [N,1,2] -> [N,M,2]
38 | box2[:, :2].unsqueeze(0).expand(N, M, 2), # [M,2] -> [1,M,2] -> [N,M,2]
39 | )
40 |
41 | rb = torch.min( # 右下角的点
42 | box1[:, 2:].unsqueeze(1).expand(N, M, 2), # [N,2] -> [N,1,2] -> [N,M,2]
43 | box2[:, 2:].unsqueeze(0).expand(N, M, 2), # [M,2] -> [1,M,2] -> [N,M,2]
44 | )
45 |
46 | wh = rb - lt # [N,M,2]
47 | wh[wh < 0] = 0 # clip at 指两个box没有重叠区域
48 | inter = wh[:, :, 0] * wh[:, :, 1] # [N,M]
49 |
50 | area1 = (box1[:, 2]-box1[:, 0]) * (box1[:, 3]-box1[:, 1]) # [N,]
51 | area2 = (box2[:, 2]-box2[:, 0]) * (box2[:, 3]-box2[:, 1]) # [M,]
52 | area1 = area1.unsqueeze(1).expand_as(inter) # [N,] -> [N,1] -> [N,M]
53 | area2 = area2.unsqueeze(0).expand_as(inter) # [M,] -> [1,M] -> [N,M]
54 |
55 | iou = inter / (area1 + area2 - inter)
56 | return iou
57 |
58 | def forward(self, pred_tensor, target_tensor):
59 | """
60 | pred_tensor: (tensor) size(batchsize,S,S,Bx5+20=30) [x,y,w,h,c]
61 | target_tensor: (tensor) size(batchsize,S,S,30)
62 | """
63 | N = pred_tensor.size()[0]
64 | # 具有目标标签的索引(bs, 7, 7, 30)中7*7方格中的哪个方格包含目标
65 | coo_mask = target_tensor[:, :, :, 4] > 0 # coo_mask.shape = (bs, 7, 7)
66 | noo_mask = target_tensor[:, :, :, 4] == 0 # 不具有目标的标签索引
67 | # 得到含物体的坐标等信息(coo_mask扩充到与target_tensor一样形状, 沿最后一维扩充)
68 | coo_mask = coo_mask.unsqueeze(-1).expand_as(target_tensor)
69 | noo_mask = noo_mask.unsqueeze(-1).expand_as(target_tensor)
70 |
71 | # coo_pred:tensor[, 30](所有batch数据都压缩在一起)
72 | coo_pred = pred_tensor[coo_mask].view(-1, 30)
73 | box_pred = coo_pred[:, :10].contiguous().view(-1, 5) # box[x1,y1,w1,h1,c1], [x2,y2,w2,h2,c2]
74 | class_pred = coo_pred[:, 10:]
75 |
76 | coo_target = target_tensor[coo_mask].view(-1, 30)
77 | box_target = coo_target[:, :10].contiguous().view(-1, 5)
78 | class_target = coo_target[:, 10:]
79 |
80 | # compute not contain obj loss
81 | noo_pred = pred_tensor[noo_mask].view(-1, 30)
82 | noo_target = target_tensor[noo_mask].view(-1, 30)
83 |
84 | noo_pred_mask = torch.cuda.ByteTensor(noo_pred.size()).bool()
85 | noo_pred_mask.zero_()
86 | noo_pred_mask[:, 4] = 1
87 | noo_pred_mask[:, 9] = 1
88 | noo_pred_c = noo_pred[noo_pred_mask] # noo pred只需要计算 c 的损失 size[-1,2]
89 | noo_target_c = noo_target[noo_pred_mask]
90 | nooobj_loss = F.mse_loss(noo_pred_c, noo_target_c, size_average=False)
91 |
92 | # compute contain obj loss
93 | coo_response_mask = torch.cuda.ByteTensor(box_target.size()).bool()
94 | coo_response_mask.zero_()
95 | coo_not_response_mask = torch.cuda.ByteTensor(box_target.size()).bool()
96 | coo_not_response_mask.zero_()
97 | box_target_iou = torch.zeros(box_target.size()).cuda()
98 | for i in range(0, box_target.size()[0], 2): # choose the best iou box
99 | box1 = box_pred[i:i+2] # 获取当前格点预测的b个box
100 | box1_xyxy = torch.FloatTensor(box1.size())
101 | # (x,y,w,h)
102 | box1_xyxy[:, :2] = box1[:, :2]/14. - 0.5 * box1[:, 2:4]
103 | box1_xyxy[:, 2:4] = box1[:, :2]/14. + 0.5 * box1[:, 2:4]
104 | box2 = box_target[i].view(-1, 5)
105 | box2_xyxy = torch.FloatTensor(box2.size())
106 | box2_xyxy[:, :2] = box2[:, :2]/14. - 0.5*box2[:, 2:4]
107 | box2_xyxy[:, 2:4] = box2[:, :2]/14. + 0.5*box2[:, 2:4]
108 | iou = self.compute_iou(box1_xyxy[:, :4], box2_xyxy[:, :4]) # [2,1]
109 | max_iou, max_index = iou.max(0)
110 | max_index = max_index.data.cuda()
111 |
112 | coo_response_mask[i+max_index] = 1
113 | coo_not_response_mask[i+1-max_index] = 1
114 |
115 | #####
116 | # we want the confidence score to equal the
117 | # intersection over union (IOU) between the predicted box
118 | # and the ground truth
119 | #####
120 | # iou value 作为box包含目标的confidence(赋值在向量的第五个位置)
121 | box_target_iou[i+max_index, torch.LongTensor([4]).cuda()] = (max_iou).data.cuda()
122 | box_target_iou = box_target_iou.cuda()
123 | # 1.response loss
124 | box_pred_response = box_pred[coo_response_mask].view(-1, 5)
125 | box_target_response_iou = box_target_iou[coo_response_mask].view(-1, 5)
126 | box_target_response = box_target[coo_response_mask].view(-1, 5)
127 | contain_loss = F.mse_loss(box_pred_response[:, 4], box_target_response_iou[:, 4], size_average=False)
128 | loc_loss = F.mse_loss(box_pred_response[:, :2], box_target_response[:, :2], size_average=False) + F.mse_loss(torch.sqrt(box_pred_response[:, 2:4]), torch.sqrt(box_target_response[:, 2:4]), size_average=False)
129 |
130 | # 2.not response loss
131 | box_pred_not_response = box_pred[coo_not_response_mask].view(-1, 5)
132 | box_target_not_response = box_target[coo_not_response_mask].view(-1, 5)
133 | box_target_not_response[:, 4] = 0
134 | # not_contain_loss = F.mse_loss(box_pred_response[:,4],box_target_response[:,4],size_average=False)
135 |
136 | # I believe this bug is simply a typo
137 | not_contain_loss = F.mse_loss(box_pred_not_response[:, 4], box_target_not_response[:, 4], size_average=False)
138 |
139 | # 3.class loss
140 | class_loss = F.mse_loss(class_pred, class_target, size_average=False)
141 |
142 | return (self.l_coord*loc_loss + self.B*contain_loss + not_contain_loss + self.l_noobj*nooobj_loss + class_loss)/N
143 |
144 |
145 |
146 |
147 |
--------------------------------------------------------------------------------
/predict.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : Bryce
5 | @File : train.py
6 | @Noice :
7 | @Modificattion :
8 | @Author :
9 | @Time :
10 | @Detail :
11 | """
12 | import torch
13 |
14 | from models.resnet_yolo import resnet50
15 | import torchvision.transforms as transforms
16 | import cv2
17 | import numpy as np
18 | import operator
19 |
20 | VOC_CLASSES = ( # always index 0
21 | 'aeroplane', 'bicycle', 'bird', 'boat',
22 | 'bottle', 'bus', 'car', 'cat', 'chair',
23 | 'cow', 'diningtable', 'dog', 'horse',
24 | 'motorbike', 'person', 'pottedplant',
25 | 'sheep', 'sofa', 'train', 'tvmonitor')
26 |
27 | # 数据集20个类别的所对应的20种颜色
28 | Color = [
29 | [0, 0, 0],
30 | [128, 0, 0],
31 | [0, 128, 0],
32 | [128, 128, 0],
33 | [0, 0, 128],
34 | [128, 0, 128],
35 | [0, 128, 128],
36 | [128, 128, 128],
37 | [64, 0, 0],
38 | [192, 0, 0],
39 | [64, 128, 0],
40 | [192, 128, 0],
41 | [64, 0, 128],
42 | [192, 0, 128],
43 | [64, 128, 128],
44 | [192, 128, 128],
45 | [0, 64, 0],
46 | [128, 64, 0],
47 | [0, 192, 0],
48 | [128, 192, 0],
49 | [0, 64, 128]
50 | ]
51 |
52 | # 对于网络输出预测 改为再图片上画出框及score
53 | def decoder(pred):
54 | """
55 | pred (tensor) torch.Size([1, 14, 14, 30])
56 | return (tensor) box[[x1,y1,x2,y2]] label[...]
57 | """
58 | grid_num = 14
59 | boxes = []
60 | cls_indexs = []
61 | probs = []
62 | cell_size = 1./grid_num
63 | pred = pred.data # torch.Size([1, 14, 14, 30])
64 | pred = pred.squeeze(0) # torch.Size([14, 14, 30])
65 | # 0 1 2 3 4 5 6 7 8 9
66 | # [中心坐标,长宽,置信度,中心坐标,长宽,置信度, 20个类别] x 7x7
67 | contain1 = pred[:, :, 4].unsqueeze(2) # torch.Size([14, 14, 1])
68 | contain2 = pred[:, :, 9].unsqueeze(2) # torch.Size([14, 14, 1])
69 | contain = torch.cat((contain1, contain2), 2) # torch.Size([14, 14, 2])
70 |
71 | mask1 = contain > 0.1 # 大于阈值, torch.Size([14, 14, 2]) content: tensor([False, False])
72 | mask2 = (contain == contain.max()) # we always select the best contain_prob what ever it>0.9
73 | mask = (mask1+mask2).gt(0)
74 |
75 | # min_score,min_index = torch.min(contain, 2) # 每个cell只选最大概率的那个预测框
76 | for i in range(grid_num):
77 | for j in range(grid_num):
78 | for b in range(2):
79 | # index = min_index[i,j]
80 | # mask[i,j,index] = 0
81 | if mask[i, j, b] == 1:
82 | box = pred[i, j, b*5:b*5+4]
83 | contain_prob = torch.FloatTensor([pred[i, j, b*5+4]])
84 | xy = torch.FloatTensor([j, i])*cell_size # cell左上角 up left of cell
85 | box[:2] = box[:2]*cell_size + xy # return cxcy relative to image
86 | box_xy = torch.FloatTensor(box.size()) # 转换成xy形式 convert[cx,cy,w,h] to [x1,y1,x2,y2]
87 | box_xy[:2] = box[:2] - 0.5*box[2:]
88 | box_xy[2:] = box[:2] + 0.5*box[2:]
89 | max_prob, cls_index = torch.max(pred[i, j, 10:], 0)
90 | if float((contain_prob*max_prob)[0]) > 0.1:
91 | boxes.append(box_xy.view(1, 4))
92 | cls_indexs.append(cls_index.item())
93 | probs.append(contain_prob*max_prob)
94 | if len(boxes) == 0:
95 | boxes = torch.zeros((1, 4))
96 | probs = torch.zeros(1)
97 | cls_indexs = torch.zeros(1)
98 | else:
99 | boxes = torch.cat(boxes, 0) # (n,4)
100 | # print(type(probs))
101 | # print(len(probs))
102 | # print(probs)
103 | probs = torch.cat(probs, 0) # (n,)
104 | # print(probs)
105 | # print(type(cls_indexs))
106 | # print(len(cls_indexs))
107 | # print(cls_indexs)
108 | cls_indexs = torch.IntTensor(cls_indexs) # (n,)
109 | keep = nms(boxes, probs)
110 | # print("keep:", keep)
111 |
112 | a = boxes[keep]
113 | b = cls_indexs[keep]
114 | c = probs[keep]
115 | return a, b, c
116 |
117 |
118 | def nms(bboxes, scores, threshold=0.5):
119 | '''
120 | bboxes(tensor) [N,4]
121 | scores(tensor) [N,]
122 | '''
123 | x1 = bboxes[:, 0]
124 | y1 = bboxes[:, 1]
125 | x2 = bboxes[:, 2]
126 | y2 = bboxes[:, 3]
127 | areas = (x2-x1) * (y2-y1)
128 | # print(scores) # tensor([0.1006, 0.2381, 0.1185, 0.5342, 0.2892, 0.3521, 0.6027])
129 | _, order = scores.sort(0, descending=True) # 降序排列score
130 | keep = []
131 | # print("order:", order) # order: tensor([6, 3, 5, 4, 1, 2, 0])
132 | # print("order.numel:", order.numel()) # 7
133 | while order.numel() > 0: # torch.numel()返回张量元素个数
134 | if order.numel() == 1: # 保留框只剩一个
135 | # print("end1")
136 | # print(type(order))
137 | # print(order)
138 | i = order
139 | keep.append(i)
140 | break
141 | # print("len:", order.size())
142 | # print(keep)
143 | i = order[0] # i=6,对应得分中最大的框 保留scores最大的那个框box[i]
144 | keep.append(i)
145 |
146 | # 计算box[i]与其余各框的IOU
147 | xx1 = x1[order[1:]].clamp(min=x1[i]) # [N-1,]
148 | yy1 = y1[order[1:]].clamp(min=y1[i])
149 | xx2 = x2[order[1:]].clamp(max=x2[i])
150 | yy2 = y2[order[1:]].clamp(max=y2[i])
151 |
152 | w = (xx2-xx1).clamp(min=0)
153 | h = (yy2-yy1).clamp(min=0)
154 | inter = w*h # [N-1,]
155 | ovr = inter / (areas[i] + areas[order[1:]] - inter)
156 | ids = (ovr <= threshold).nonzero(as_tuple=False).squeeze() # 注意此时idx为[N-1,] 而order为[N,]
157 | if ids.numel() == 0:
158 | # print("end2")
159 | break
160 | order = order[ids+1] # 修补索引之间的差值
161 | # print(keep)
162 | return torch.LongTensor(keep)
163 | # return keep
164 |
165 |
166 | # start predict one image
167 | def predict_gpu(model, image_name, root_path=''):
168 | result = []
169 | image = cv2.imread(root_path+image_name)
170 | # print(root_path , image_name)
171 | h, w, _ = image.shape
172 | img = cv2.resize(image, (448, 448))
173 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
174 | mean = (123, 117, 104) # RGB
175 | img = img - np.array(mean, dtype=np.float32)
176 |
177 | transform = transforms.Compose([transforms.ToTensor(), ])
178 | img = transform(img) # torch.Size([3, 448, 448])
179 | img = img[None, :, :, :] # img: torch.Size([1, 3, 448, 448])
180 | img = img.cuda()
181 |
182 | pred = model(img) # 1x14x14x30
183 | pred = pred.cpu()
184 | boxes, cls_indexs, probs = decoder(pred)
185 |
186 | for i, box in enumerate(boxes):
187 | x1 = int(box[0]*w)
188 | x2 = int(box[2]*w)
189 | y1 = int(box[1]*h)
190 | y2 = int(box[3]*h)
191 | cls_index = cls_indexs[i]
192 | cls_index = int(cls_index) # convert LongTensor to int
193 | prob = probs[i]
194 | prob = float(prob)
195 | result.append([(x1, y1), (x2, y2), VOC_CLASSES[cls_index], image_name, prob])
196 | return result
197 |
198 |
199 | if __name__ == '__main__':
200 | model = resnet50()
201 | print('load model...')
202 | model.load_state_dict(torch.load('checkpoints/best.pth'))
203 | model.eval()
204 | model.cuda()
205 | image_name = 'imgs/001526.jpg'
206 | image = cv2.imread(image_name)
207 | print('predicting...')
208 | result = predict_gpu(model, image_name)
209 |
210 | for left_up, right_bottom, class_name, _, prob in result:
211 | color = Color[VOC_CLASSES.index(class_name)]
212 | cv2.rectangle(image, left_up, right_bottom, color, 2)
213 | label = class_name+str(round(prob, 2))
214 | text_size, baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_SIMPLEX, 0.4, 1)
215 | p1 = (left_up[0], left_up[1] - text_size[1])
216 | cv2.rectangle(image, (p1[0] - 2//2, p1[1] - 2 - baseline), (p1[0] + text_size[0], p1[1] + text_size[1]), color, -1)
217 | cv2.putText(image, label, (p1[0], p1[1] + baseline), cv2.FONT_HERSHEY_SIMPLEX, 0.4, (255, 255, 255), 1, 8)
218 |
219 | cv2.imwrite('imgs/001526_result.jpg', image)
220 | img = cv2.imread('imgs/001526_result.jpg')
221 | cv2.imshow('img', img)
222 | cv2.waitKey(0)
223 |
224 |
225 |
226 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==0.9.0
2 | adabound==0.0.5
3 | appdirs==1.4.4
4 | asn1crypto==0.24.0
5 | attrs==17.4.0
6 | Automat==0.6.0
7 | backcall==0.2.0
8 | bleach==3.1.5
9 | blinker==1.4
10 | cachetools==4.1.1
11 | certifi==2018.1.18
12 | cffi==1.14.0
13 | chardet==3.0.4
14 | click==6.7
15 | cloud-init==19.4
16 | cloudpickle==1.4.1
17 | colorama==0.3.7
18 | command-not-found==0.3
19 | configobj==5.0.6
20 | constantly==15.1.0
21 | cryptography==2.1.4
22 | cycler==0.10.0
23 | Cython==0.29.17
24 | decorator==4.4.2
25 | defusedxml==0.6.0
26 | distro-info==0.18ubuntu0.18.04.1
27 | entrypoints==0.3
28 | fire==0.3.1
29 | Flask==1.1.2
30 | future==0.18.2
31 | google-auth==1.18.0
32 | google-auth-oauthlib==0.4.1
33 | grpcio==1.30.0
34 | grpcio-tools==1.30.0
35 | horovod==0.19.2
36 | httplib2==0.9.2
37 | hyperlink==17.3.1
38 | idna==2.6
39 | imageio==2.9.0
40 | importlib-metadata==1.7.0
41 | incremental==16.10.1
42 | ipykernel==5.3.0
43 | ipython==7.15.0
44 | ipython-genutils==0.2.0
45 | ipywidgets==7.5.1
46 | itsdangerous==1.1.0
47 | jedi==0.17.0
48 | Jinja2==2.11.2
49 | joblib==0.16.0
50 | jsonpatch==1.16
51 | jsonpointer==1.10
52 | jsonschema==2.6.0
53 | jupyter==1.0.0
54 | jupyter-client==6.1.3
55 | jupyter-console==6.1.0
56 | jupyter-core==4.6.3
57 | keras2onnx==1.7.0
58 | keyring==10.6.0
59 | keyrings.alt==3.0
60 | kiwisolver==1.2.0
61 | language-selector==0.1
62 | Mako==1.1.3
63 | Markdown==3.2.2
64 | MarkupSafe==1.0
65 | matplotlib==3.2.1
66 | mistune==0.8.4
67 | mlxtend==0.17.2
68 | nbconvert==5.6.1
69 | nbformat==5.0.7
70 | netifaces==0.10.4
71 | netron==4.3.5
72 | networkx==2.4
73 | notebook==6.0.3
74 | numpy==1.18.4
75 | oauthlib==3.1.0
76 | onnx==1.7.0
77 | onnxconverter-common==1.7.0
78 | onnxmltools==1.7.0
79 | onnxruntime==1.4.0
80 | onnxruntime-gpu==1.4.0
81 | opencv-python==4.3.0.36
82 | ort-gpu-nightly==1.4.0.dev202007171
83 | packaging==20.4
84 | PAM==0.4.2
85 | pandas==1.0.5
86 | pandocfilters==1.4.2
87 | parso==0.7.0
88 | pexpect==4.8.0
89 | pickleshare==0.7.5
90 | Pillow==7.1.2
91 | pip==20.1
92 | prometheus-client==0.8.0
93 | prompt-toolkit==3.0.5
94 | protobuf==3.12.2
95 | psutil==5.7.0
96 | ptyprocess==0.6.0
97 | pyasn1==0.4.2
98 | pyasn1-modules==0.2.1
99 | pycocotools==2.0.0
100 | pycparser==2.20
101 | pycrypto==2.6.1
102 | pycuda==2019.1.2
103 | Pygments==2.6.1
104 | pygobject==3.26.1
105 | PyJWT==1.5.3
106 | pyOpenSSL==17.5.0
107 | pyparsing==2.4.7
108 | pyserial==3.4
109 | python-apt==1.6.5+ubuntu0.2
110 | python-dateutil==2.8.1
111 | python-debian==0.1.32
112 | pytools==2020.2
113 | pytz==2020.1
114 | PyWavelets==1.1.1
115 | pyxdg==0.25
116 | PyYAML==3.12
117 | pyzmq==19.0.1
118 | qtconsole==4.7.4
119 | QtPy==1.9.0
120 | requests==2.24.0
121 | requests-oauthlib==1.3.0
122 | requests-unixsocket==0.1.5
123 | rsa==4.6
124 | scikit-image==0.17.2
125 | scikit-learn==0.23.1
126 | scipy==1.5.1
127 | SecretStorage==2.3.1
128 | Send2Trash==1.5.0
129 | service-identity==16.0.0
130 | setuptools==49.1.2
131 | six==1.11.0
132 | skl2onnx==1.7.0
133 | sklearn==0.0
134 | ssh-import-id==5.7
135 | systemd-python==234
136 | tensorboard==2.2.2
137 | tensorboard-plugin-wit==1.7.0
138 | tensorrt==7.0.0.11
139 | termcolor==1.1.0
140 | terminado==0.8.3
141 | terminaltables==3.1.0
142 | testpath==0.4.4
143 | threadpoolctl==2.1.0
144 | tifffile==2020.7.17
145 | torch==1.5.0
146 | torch2trt==0.1.0
147 | torchfile==0.1.0
148 | torchsummary==1.5.1
149 | torchvision==0.6.0
150 | tornado==6.0.4
151 | tqdm==4.46.0
152 | traitlets==4.3.3
153 | Twisted==17.9.0
154 | typing-extensions==3.7.4.2
155 | ufw==0.36
156 | unattended-upgrades==0.1
157 | urllib3==1.22
158 | visdom==0.1.8.9
159 | wcwidth==0.2.4
160 | webencodings==0.5.1
161 | websocket-client==0.57.0
162 | Werkzeug==1.0.1
163 | wget==3.2
164 | wheel==0.30.0
165 | widgetsnbextension==3.5.1
166 | zipp==3.1.0
167 | zope.interface==4.3.2
168 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : Bryce
5 | @File : train.py
6 | @Noice :
7 | @Modificattion :
8 | @Author :
9 | @Time :
10 | @Detail :
11 | """
12 | import warnings
13 | import os
14 | import numpy as np
15 |
16 | import torch
17 | from torch.utils.data import DataLoader
18 | import torchvision.transforms as transforms
19 | from torchvision import models
20 |
21 | from models.vgg_yolo import vgg16_bn
22 | from models.resnet_yolo import resnet50
23 | from models.yoloLoss import yoloLoss
24 | from utils.dataset import yoloDataset
25 |
26 | warnings.filterwarnings('ignore')
27 | # 设置GPU ID
28 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
29 |
30 | # 判断GPU是否可用
31 | use_gpu = torch.cuda.is_available()
32 |
33 | # 数据文件
34 | file_root = 'datasets'
35 |
36 | # 超参数
37 | learning_rate = 0.001
38 | num_epochs = 100
39 | batch_size = 24
40 |
41 | # checkpoints
42 | resume = True
43 |
44 | # ---------------------数据读取---------------------
45 | train_dataset = yoloDataset(root=file_root, list_file='images.txt', train=True,
46 | transform=[transforms.ToTensor()])
47 | # train_dataset = yoloDataset(root=file_root, list_file=['voc12_trainval.txt','voc07_trainval.txt'],
48 | # train=True,transform = [transforms.ToTensor()] )
49 | train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
50 |
51 | test_dataset = yoloDataset(root=file_root, list_file='voc2007test.txt', train=False,
52 | transform=[transforms.ToTensor()])
53 | test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
54 | print('the train dataset has %d images' % (len(train_dataset)))
55 | print('the test dataset has %d images' % (len(test_dataset)))
56 | print('the batch_size is %d' % batch_size)
57 |
58 |
59 | # ---------------------网络选择---------------------
60 | use_resnet = True
61 | if use_resnet:
62 | net = resnet50()
63 | else:
64 | net = vgg16_bn()
65 |
66 | if resume:
67 | print("loading weight from checkpoints/best.pth")
68 | net.load_state_dict(torch.load('checkpoints/best.pth'))
69 | else:
70 | print('loading pre-trined model ......')
71 | if use_resnet:
72 | resnet = models.resnet50(pretrained=True)
73 | new_state_dict = resnet.state_dict()
74 | dd = net.state_dict()
75 | for k in new_state_dict.keys():
76 | print(k)
77 | if k in dd.keys() and not k.startswith('fc'):
78 | # print('yes')
79 | dd[k] = new_state_dict[k]
80 | net.load_state_dict(dd)
81 | else:
82 | vgg = models.vgg16_bn(pretrained=True)
83 | new_state_dict = vgg.state_dict()
84 | dd = net.state_dict()
85 | for k in new_state_dict.keys():
86 | print(k)
87 | if k in dd.keys() and k.startswith('features'):
88 | print('yes')
89 | dd[k] = new_state_dict[k]
90 | net.load_state_dict(dd)
91 |
92 | if use_gpu:
93 | print('this computer has gpu %d and current is %s' % (torch.cuda.device_count(),
94 | torch.cuda.current_device()))
95 | net.cuda()
96 |
97 |
98 | # ---------------------损失函数---------------------
99 | criterion = yoloLoss(7, 2, 5, 0.5)
100 |
101 | # ---------------------优化器----------------------
102 |
103 | # different learning rate
104 | params = []
105 | params_dict = dict(net.named_parameters())
106 | for key, value in params_dict.items():
107 | if key.startswith('features'):
108 | params += [{'params': [value], 'lr':learning_rate*1}]
109 | else:
110 | params += [{'params': [value], 'lr':learning_rate}]
111 | optimizer = torch.optim.SGD(params, lr=learning_rate, momentum=0.9, weight_decay=5e-4)
112 | # optimizer = torch.optim.Adam(net.parameters(),lr=learning_rate,weight_decay=1e-4)
113 |
114 |
115 | # ---------------------训练---------------------
116 | logfile = open('checkpoints/log.txt', 'w')
117 | num_iter = 0
118 | best_test_loss = np.inf
119 |
120 | for epoch in range(num_epochs):
121 | # train
122 | net.train()
123 | if epoch == 30:
124 | learning_rate = 0.0001
125 | if epoch == 40:
126 | learning_rate = 0.00001
127 | for param_group in optimizer.param_groups:
128 | param_group['lr'] = learning_rate
129 |
130 | print('\n\nStarting epoch %d / %d' % (epoch + 1, num_epochs))
131 | print('Learning Rate for this epoch: {}'.format(learning_rate))
132 |
133 | total_loss = 0.
134 |
135 | for i, (images, target) in enumerate(train_loader):
136 | if use_gpu:
137 | images, target = images.cuda(), target.cuda()
138 |
139 | pred = net(images)
140 | loss = criterion(pred, target)
141 | total_loss += loss.data.item()
142 |
143 | optimizer.zero_grad()
144 | loss.backward()
145 | optimizer.step()
146 | if (i+1) % 5 == 0:
147 | print('Epoch [%d/%d], Iter [%d/%d] Loss: %.4f, average_loss: %.4f'
148 | % (epoch+1, num_epochs, i+1, len(train_loader), loss.item(), total_loss / (i+1)))
149 | num_iter += 1
150 |
151 | # validation
152 | validation_loss = 0.0
153 | net.eval()
154 | for i, (images, target) in enumerate(test_loader):
155 | if use_gpu:
156 | images, target = images.cuda(), target.cuda()
157 |
158 | pred = net(images)
159 | loss = criterion(pred, target)
160 | validation_loss += loss.item()
161 | validation_loss /= len(test_loader)
162 |
163 | if best_test_loss > validation_loss:
164 | best_test_loss = validation_loss
165 | print('get best test loss %.5f' % best_test_loss)
166 | torch.save(net.state_dict(), 'checkpoints/best.pth')
167 | logfile.writelines(str(epoch) + '\t' + str(validation_loss) + '\n')
168 | logfile.flush()
169 |
170 |
171 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/python
2 | # -*- coding:utf-8 -*-
3 | # author: Bryce
--------------------------------------------------------------------------------
/utils/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | """
3 | @Time : 2020/08/12 18:30
4 | @Author : Bryce
5 | @File : dataset.py
6 | @Noice :
7 | @Modificattion : txt描述文件 image_name.jpg x y w h c x y w h c 这样就是说一张图片中有两个目标
8 | @Author :
9 | @Time :
10 | @Detail :
11 | """
12 |
13 | import os
14 | import os.path
15 |
16 | import random
17 | import numpy as np
18 |
19 | import torch
20 | import torch.utils.data as data
21 | import torchvision.transforms as transforms
22 |
23 | import cv2
24 | import matplotlib.pyplot as plt
25 |
26 |
27 | class yoloDataset(data.Dataset):
28 | image_size = 448
29 | # train_dataset = yoloDataset(root=datasets/,
30 | # list_file=['voc2012.txt','voc2007.txt'],
31 | # train=True,transform = [transforms.ToTensor()] )
32 |
33 | def __init__(self, root, list_file, train, transform):
34 | self.root = root # 数据集根目录
35 | self.train = train # 是否为训练
36 | self.transform = transform # 转换
37 | self.fnames = [] # 文件名s [001.jpg, 002.jpg]
38 | self.boxes = [] # boxes [ [box], [[x1,y1,x2,y2], ...], ... ]
39 | self.labels = [] # labels [ [1], [2], ... ]
40 | self.mean = (123, 117, 104) # RGB
41 | self.num_samples = 0 # 样本总数
42 |
43 | if isinstance(list_file, list):
44 | # Cat multiple list files together.
45 | # This is especially useful for voc07/voc12 combination.
46 | tmp_file = os.path.join(root, 'images.txt')
47 | list_file = [os.path.join(root, list_file[0]), os.path.join(root, list_file[1])]
48 | os.system('cat %s > %s' % (' '.join(list_file), tmp_file))
49 | list_file = tmp_file
50 | else:
51 | list_file = os.path.join(root, list_file)
52 |
53 | # 处理标签
54 | with open(list_file) as f:
55 | lines = f.readlines()
56 | for line in lines:
57 | splited = line.strip().split() # ['005246.jpg', '84', '48', '493', '387', '2'] 坐标 + 类型(labels)
58 | self.fnames.append(splited[0])
59 | num_boxes = (len(splited) - 1) // 5
60 | box = []
61 | label = []
62 | for i in range(num_boxes):
63 | x = float(splited[1+5*i])
64 | y = float(splited[2+5*i])
65 | x2 = float(splited[3+5*i])
66 | y2 = float(splited[4+5*i])
67 | c = splited[5+5*i]
68 | box.append([x, y, x2, y2])
69 | label.append(int(c)+1)
70 | self.boxes.append(torch.Tensor(box))
71 | self.labels.append(torch.LongTensor(label))
72 | self.num_samples = len(self.boxes) # 数据集中包含所有Ground truth个数
73 |
74 | def __getitem__(self, idx):
75 | fname = self.fnames[idx]
76 | img = cv2.imread(os.path.join(self.root, "images", fname))
77 | boxes = self.boxes[idx].clone()
78 | labels = self.labels[idx].clone()
79 |
80 | # 数据增强
81 | # if self.train:
82 | # img = self.random_bright(img)
83 | # img, boxes = self.random_flip(img, boxes)
84 | # img, boxes = self.randomScale(img, boxes)
85 | # img = self.randomBlur(img)
86 | # img = self.RandomBrightness(img)
87 | # img = self.RandomHue(img)
88 | # img = self.RandomSaturation(img)
89 | # img, boxes, labels = self.randomShift(img, boxes, labels)
90 | # img, boxes, labels = self.randomCrop(img, boxes, labels)
91 |
92 | # # debug
93 | # box_show = boxes.numpy().reshape(-1)
94 | # print(box_show)
95 | # img_show = self.BGR2RGB(img)
96 | # pt1 = (int(box_show[0]), int(box_show[1]))
97 | # pt2 = (int(box_show[2]), int(box_show[3]))
98 | # cv2.rectangle(img_show, pt1=pt1, pt2=pt2, color=(0, 255, 0), thickness=1)
99 | # print(type(img_show))
100 | # plt.figure()
101 | # plt.imshow(img_show)
102 | # plt.show()
103 | # plt.savefig("a.png")
104 | # #debug
105 | h, w, _ = img.shape # 不管通道数 _
106 | boxes /= torch.Tensor([w, h, w, h]).expand_as(boxes) # 一张图片中框的坐标归一化,即转换为对于0,0点的(0,1)范围内的表述
107 | img = self.BGR2RGB(img) # because pytorch pretrained model use RGB
108 | img = self.subMean(img, self.mean) # 减去均值
109 | img = cv2.resize(img, (self.image_size, self.image_size))
110 | target = self.encoder(boxes, labels) # 7x7x30
111 | for t in self.transform:
112 | img = t(img)
113 |
114 | return img, target
115 |
116 | def __len__(self):
117 | return self.num_samples
118 |
119 | def encoder(self, boxes, labels):
120 | '''
121 | boxes (tensor) [[x1,y1,x2,y2],[]]
122 | labels (tensor) [...]
123 | return 7x7x30
124 | '''
125 | grid_num = 14 # 论文中设为7
126 | target = torch.zeros((grid_num, grid_num, 30))
127 | cell_size = 1./grid_num # 之前已经把目标框归一化,故这里用1. 作为除数
128 | wh = boxes[:, 2:]-boxes[:, :2] # 宽高
129 | cxcy = (boxes[:, 2:]+boxes[:, :2])/2 # 中心点
130 | for i in range(cxcy.size()[0]): # 对于数据集中的每个框 这里cxcy.size() == num_samples
131 | cxcy_sample = cxcy[i]
132 | ij = (cxcy_sample/cell_size).ceil()-1 # ij 是一个list, 表示目标中心点cxcy在归一化后的图片中所处的x y 方向的第几个网格
133 | # 0 1 2 3 4 5 6 7 8 9
134 | # [中心坐标,长宽,置信度,中心坐标,长宽,置信度, 20个类别] x 7x7
135 | target[int(ij[1]), int(ij[0]), 4] = 1 # 第一个框的置信度
136 | target[int(ij[1]), int(ij[0]), 9] = 1 # 第二个框的置信度
137 | target[int(ij[1]), int(ij[0]), int(labels[i])+9] = 1
138 | xy = ij*cell_size # 匹配到划分后的子网格的左上角相对坐标
139 | delta_xy = (cxcy_sample -xy)/cell_size # delta_xy对于目标中心点落入的子网格,目标中心坐标相对于子网格左上点的位置比例
140 | target[int(ij[1]), int(ij[0]), 2:4] = wh[i] # 坐标w,h代表了预测边界框的width、height相对于整幅图像width,height的比例,范围为(0,1)
141 | target[int(ij[1]), int(ij[0]), :2] = delta_xy
142 | target[int(ij[1]), int(ij[0]), 7:9] = wh[i]
143 | target[int(ij[1]), int(ij[0]), 5:7] = delta_xy
144 | return target
145 |
146 | def BGR2RGB(self, img):
147 | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
148 |
149 | def BGR2HSV(self,img):
150 | return cv2.cvtColor(img,cv2.COLOR_BGR2HSV)
151 |
152 | def HSV2BGR(self,img):
153 | return cv2.cvtColor(img,cv2.COLOR_HSV2BGR)
154 |
155 | def RandomBrightness(self,bgr):
156 | if random.random() < 0.5:
157 | hsv = self.BGR2HSV(bgr)
158 | h,s,v = cv2.split(hsv)
159 | adjust = random.choice([0.5,1.5])
160 | v = v*adjust
161 | v = np.clip(v, 0, 255).astype(hsv.dtype)
162 | hsv = cv2.merge((h,s,v))
163 | bgr = self.HSV2BGR(hsv)
164 | return bgr
165 |
166 | def RandomSaturation(self,bgr):
167 | if random.random() < 0.5:
168 | hsv = self.BGR2HSV(bgr)
169 | h,s,v = cv2.split(hsv)
170 | adjust = random.choice([0.5,1.5])
171 | s = s*adjust
172 | s = np.clip(s, 0, 255).astype(hsv.dtype)
173 | hsv = cv2.merge((h,s,v))
174 | bgr = self.HSV2BGR(hsv)
175 | return bgr
176 |
177 | def RandomHue(self,bgr):
178 | if random.random() < 0.5:
179 | hsv = self.BGR2HSV(bgr)
180 | h,s,v = cv2.split(hsv)
181 | adjust = random.choice([0.5,1.5])
182 | h = h*adjust
183 | h = np.clip(h, 0, 255).astype(hsv.dtype)
184 | hsv = cv2.merge((h,s,v))
185 | bgr = self.HSV2BGR(hsv)
186 | return bgr
187 |
188 | def randomBlur(self,bgr):
189 | if random.random()<0.5:
190 | bgr = cv2.blur(bgr,(5,5))
191 | return bgr
192 |
193 | def randomShift(self,bgr,boxes,labels):
194 | #平移变换
195 | center = (boxes[:,2:]+boxes[:,:2])/2
196 | if random.random() <0.5:
197 | height,width,c = bgr.shape
198 | after_shfit_image = np.zeros((height,width,c),dtype=bgr.dtype)
199 | after_shfit_image[:,:,:] = (104,117,123) #bgr
200 | shift_x = random.uniform(-width*0.2,width*0.2)
201 | shift_y = random.uniform(-height*0.2,height*0.2)
202 | #print(bgr.shape,shift_x,shift_y)
203 | #原图像的平移
204 | if shift_x>=0 and shift_y>=0:
205 | after_shfit_image[int(shift_y):,int(shift_x):,:] = bgr[:height-int(shift_y),:width-int(shift_x),:]
206 | elif shift_x>=0 and shift_y<0:
207 | after_shfit_image[:height+int(shift_y),int(shift_x):,:] = bgr[-int(shift_y):,:width-int(shift_x),:]
208 | elif shift_x <0 and shift_y >=0:
209 | after_shfit_image[int(shift_y):,:width+int(shift_x),:] = bgr[:height-int(shift_y),-int(shift_x):,:]
210 | elif shift_x<0 and shift_y<0:
211 | after_shfit_image[:height+int(shift_y),:width+int(shift_x),:] = bgr[-int(shift_y):,-int(shift_x):,:]
212 |
213 | shift_xy = torch.FloatTensor([[int(shift_x),int(shift_y)]]).expand_as(center)
214 | center = center + shift_xy
215 | mask1 = (center[:,0] >0) & (center[:,0] < width)
216 | mask2 = (center[:,1] >0) & (center[:,1] < height)
217 | mask = (mask1 & mask2).view(-1,1)
218 | boxes_in = boxes[mask.expand_as(boxes)].view(-1,4)
219 | if len(boxes_in) == 0:
220 | return bgr,boxes,labels
221 | box_shift = torch.FloatTensor([[int(shift_x),int(shift_y),int(shift_x),int(shift_y)]]).expand_as(boxes_in)
222 | boxes_in = boxes_in+box_shift
223 | labels_in = labels[mask.view(-1)]
224 | return after_shfit_image,boxes_in,labels_in
225 | return bgr,boxes,labels
226 |
227 | def randomScale(self,bgr,boxes):
228 | #固定住高度,以0.8-1.2伸缩宽度,做图像形变
229 | if random.random() < 0.5:
230 | scale = random.uniform(0.8,1.2)
231 | height,width,c = bgr.shape
232 | bgr = cv2.resize(bgr,(int(width*scale),height))
233 | scale_tensor = torch.FloatTensor([[scale,1,scale,1]]).expand_as(boxes)
234 | boxes = boxes * scale_tensor
235 | return bgr,boxes
236 | return bgr,boxes
237 |
238 | def randomCrop(self,bgr,boxes,labels):
239 | if random.random() < 0.5:
240 | center = (boxes[:,2:]+boxes[:,:2])/2
241 | height,width,c = bgr.shape
242 | h = random.uniform(0.6*height,height)
243 | w = random.uniform(0.6*width,width)
244 | x = random.uniform(0,width-w)
245 | y = random.uniform(0,height-h)
246 | x,y,h,w = int(x),int(y),int(h),int(w)
247 |
248 | center = center - torch.FloatTensor([[x,y]]).expand_as(center)
249 | mask1 = (center[:,0]>0) & (center[:,0]0) & (center[:,1] 0.3:
288 | im = im * alpha + random.randrange(-delta, delta)
289 | im = im.clip(min=0, max=255).astype(np.uint8)
290 | return im
291 |
292 |
293 | if __name__ == '__main__':
294 | from torch.utils.data import DataLoader
295 | import torchvision.transforms as transforms
296 | file_root = "../datasets"
297 | # train_dataset = yoloDataset(root=file_root, list_file=['voc2012.txt', 'voc2007.txt'],
298 | # train=True, transform=[transforms.ToTensor()])
299 | train_dataset = yoloDataset(root=file_root, list_file='images.txt',
300 | train=True, transform=[transforms.ToTensor()])
301 | train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0)
302 | train_iter = iter(train_loader)
303 | for i in range(1):
304 | img, target = next(train_iter)
305 | print(img.shape, target.shape)
306 | print(train_dataset.num_samples)
307 |
308 |
309 |
--------------------------------------------------------------------------------
/utils/piplist2equal.py:
--------------------------------------------------------------------------------
1 | # !/usr/bin/python
2 | # -*- coding:utf-8 -*-
3 | # author: Bryce
4 |
5 | import os
6 |
7 |
8 | reqs = []
9 | with open("requirements.txt", mode='r') as f_old:
10 | lines = f_old.readlines()
11 | lines = lines[2:]
12 | for line in lines:
13 | line = line.split()
14 | temp = line[0] + "==" + line[1]
15 | reqs.append(temp)
16 |
17 | with open("requirements.txt", mode='w') as f:
18 | for line in reqs:
19 | f.write(line + "\n")
20 |
--------------------------------------------------------------------------------
/utils/xml2txt.py:
--------------------------------------------------------------------------------
1 | import xml.etree.ElementTree as ET
2 | import os
3 |
4 | VOC_CLASSES = ( # always index 0
5 | 'aeroplane', 'bicycle', 'bird', 'boat',
6 | 'bottle', 'bus', 'car', 'cat', 'chair',
7 | 'cow', 'diningtable', 'dog', 'horse',
8 | 'motorbike', 'person', 'pottedplant',
9 | 'sheep', 'sofa', 'train', 'tvmonitor')
10 |
11 |
12 | def parse_rec(filename):
13 | """ Parse a PASCAL VOC xml file """
14 | tree = ET.parse(filename)
15 | objects = []
16 | for obj in tree.findall('object'):
17 | obj_struct = {}
18 | difficult = int(obj.find('difficult').text)
19 | if difficult == 1: # 剔除困难标签的图片
20 | # print(filename)
21 | continue
22 | obj_struct['name'] = obj.find('name').text
23 | bbox = obj.find('bndbox')
24 | obj_struct['bbox'] = [int(float(bbox.find('xmin').text)),
25 | int(float(bbox.find('ymin').text)),
26 | int(float(bbox.find('xmax').text)),
27 | int(float(bbox.find('ymax').text))]
28 | objects.append(obj_struct)
29 |
30 | return objects
31 |
32 |
33 | txt_file = open('voc2007test.txt', 'w')
34 | test_file = open('voc07testimg.txt', 'r')
35 | lines = test_file.readlines()
36 | lines = [x[:-1] for x in lines]
37 | print(lines)
38 |
39 | Annotations = 'path/to/VOC2007/Annotations/'
40 | xml_files = os.listdir(Annotations)
41 |
42 | count = 0
43 | for xml_file in xml_files:
44 | count += 1
45 | if xml_file.split('.')[0] not in lines:
46 | # print(xml_file.split('.')[0])
47 | continue
48 | image_path = xml_file.split('.')[0] + '.jpg'
49 | results = parse_rec(Annotations + xml_file)
50 | if len(results) == 0:
51 | print(xml_file)
52 | continue
53 | txt_file.write(image_path)
54 | # num_obj = len(results)
55 | # txt_file.write(str(num_obj)+' ')
56 | for result in results:
57 | class_name = result['name']
58 | bbox = result['bbox']
59 | class_name = VOC_CLASSES.index(class_name)
60 | txt_file.write(' '+str(bbox[0])+' '+str(bbox[1])+' '+str(bbox[2])+' '+str(bbox[3])+' '+str(class_name))
61 | txt_file.write('\n')
62 |
63 | txt_file.close()
--------------------------------------------------------------------------------