├── .dockerignore
├── .gitignore
├── LICENSE
├── README.md
├── docker
├── Dockerfile
├── elementwise_binary_broadcast_op-inl.h
└── mxnet
│ └── Dockerfile
├── scripts
├── rundocker.sh
├── test.py
├── test.sh
└── train.sh
└── sigr
├── __init__.py
├── app.py
├── base_module.py
├── constant.py
├── coral.py
├── data
├── __init__.py
├── capgmyo
│ ├── __init__.py
│ ├── dba.py
│ ├── dbb.py
│ └── dbc.py
├── csl.py
├── ninapro
│ ├── __init__.py
│ ├── caputo.py
│ ├── db1.py
│ ├── db1_g12.py
│ ├── db1_g5.py
│ ├── db1_g53.py
│ ├── db1_g8.py
│ └── db1_matlab_lowpass.py
├── preprocess.py
├── s21.py
└── s21_soft_label.scv
├── evaluation.py
├── fft.py
├── lstm.py
├── module.py
├── parse_log.py
├── sklearn_module.py
├── symbol.py
├── utils
├── __init__.py
└── proxy.py
└── vote.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | .cache/
2 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | *.swp
2 | *.pyc
3 | .ipynb_checkpoints/
4 | .cache/
5 | /scripts/exp_inter
6 | /tmp/
7 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 | {one line to give the program's name and a brief idea of what it does.}
635 | Copyright (C) {year} {name of author}
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | {project} Copyright (C) {year} {fullname}
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Surface EMG-based Inter-session Gesture Recognition Enhanced by Deep Domain Adaptation
2 |
3 | ## Requirements
4 |
5 | * A CUDA compatible GPU
6 | * Ubuntu 14.04 or any other Linux/Unix that can run Docker
7 | * [Docker](http://docker.io/)
8 | * [Nvidia Docker](https://github.com/NVIDIA/nvidia-docker)
9 |
10 | ## Usage
11 |
12 | Following commands will
13 | (1) pull docker image (see `docker/Dockerfile` for details);
14 | (2) train ConvNets on the training sets of CSL-HDEMG, CapgMyo and NinaPro DB1, respectively;
15 | and (3) test trained ConvNets on the test sets.
16 |
17 | ```
18 | mkdir .cache
19 | # put NinaPro DB1 in .cache/ninapro-db1
20 | # put CapgMyo DB-a in .cache/dba
21 | # put CapgMyo DB-b in .cache/dbb
22 | # put CapgMyo DB-c in .cache/dbc
23 | # put CSL-HDEMG in .cache/csl
24 | docker pull answeror/sigr:2016-09-21
25 | scripts/train.sh
26 | scripts/test.sh
27 | ```
28 |
29 | Training on NinaPro and CapgMyo will take 1 to 2 hours depending on your GPU.
30 | Training on CSL-HDEMG will take several days.
31 | You can accelerate traning and testing by distribute different folds on different GPUs with the `gpu` parameter.
32 |
33 | The NinaPro DB1 should be segmented according to the gesture labels and stored in Matlab format as follows.
34 | `.cache/ninapro-db1/data/sss/ggg/sss_ggg_ttt.mat` contains a field `data` (frames x channels) represents the trial `ttt` of gesture `ggg` of subject `sss`.
35 | Numbers are starting from zero. Gesture 0 is the rest posture.
36 | For example, `.cache/ninapro-db1/data/000/001/000_001_000.mat` is the 0th trial of 1st gesture of 0th subject,
37 | and `.cache/ninapro-db1/data/002/003/002_003_004.mat` is the 4th trial of 3th gesture of 2nd subject.
38 | You can download the prepared dataset from or prepare it by yourself.
39 |
40 | ## License
41 |
42 | Licensed under an GPL v3.0 license.
43 |
44 | ## Bibtex
45 |
46 | ```
47 | @article{Du_Sensors_2017,
48 | title={{Surface EMG-based inter-session gesture recognition enhanced by deep domain adaptation}},
49 | author={Du, Yu and Jin, Wenguang and Wei, Wentao and Hu, Yu and Geng, Weidong},
50 | journal={Sensors},
51 | volume={17},
52 | number={3},
53 | pages={458},
54 | year={2017},
55 | publisher={Multidisciplinary Digital Publishing Institute}
56 | }
57 | ```
58 |
59 | ## Misc
60 |
61 | Thanks DMLC team for their great [MxNet](https://github.com/dmlc/mxnet)!
62 |
--------------------------------------------------------------------------------
/docker/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM answeror/mxnet:f2684a6
2 | MAINTAINER answeror
3 |
4 | RUN apt-get install -y python-pip python-scipy
5 | RUN pip install click logbook joblib nose
6 |
7 | RUN cd /mxnet && \
8 | git reset --hard && \
9 | git checkout master && \
10 | git pull
11 |
12 | RUN cd /mxnet && \
13 | git checkout 7a485bb && \
14 | git submodule update && \
15 | git checkout 887491d src/operator/elementwise_binary_broadcast_op-inl.h && \
16 | sed -i -e 's/CHECK(ksize_x <= dshape\[3\] && ksize_y <= dshape\[2\])/CHECK(ksize_x <= dshape[3] + 2 * param_.pad[1] \&\& ksize_y <= dshape[2] + 2 * param_.pad[0])/' src/operator/convolution-inl.h && \
17 | cp make/config.mk . && \
18 | echo "USE_CUDA=1" >>config.mk && \
19 | echo "USE_CUDA_PATH=/usr/local/cuda" >>config.mk && \
20 | echo "USE_CUDNN=1" >>config.mk && \
21 | echo "USE_BLAS=openblas" >>config.mk && \
22 | make clean && \
23 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs
24 |
25 | ADD elementwise_binary_broadcast_op-inl.h /mxnet/src/operator/elementwise_binary_broadcast_op-inl.h
26 | RUN cd /mxnet && \
27 | make clean && \
28 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs
29 |
30 | RUN pip install jupyter pandas matplotlib seaborn scikit-learn
31 | RUN mkdir -p -m 700 /root/.jupyter/ && \
32 | echo "c.NotebookApp.ip = '*'" >> /root/.jupyter/jupyter_notebook_config.py
33 | EXPOSE 8888
34 | CMD ["sh", "-c", "jupyter notebook"]
35 |
36 | WORKDIR /code
37 |
--------------------------------------------------------------------------------
/docker/mxnet/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM nvidia/cuda:7.5-cudnn5-devel
2 | MAINTAINER answeror
3 |
4 | RUN echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty main restricted universe multiverse" > /etc/apt/sources.list && \
5 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-security main restricted universe multiverse" >> /etc/apt/sources.list && \
6 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-updates main restricted universe multiverse" >> /etc/apt/sources.list && \
7 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-proposed main restricted universe multiverse" >> /etc/apt/sources.list && \
8 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-backports main restricted universe multiverse" >> /etc/apt/sources.list && \
9 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty main restricted universe multiverse" >> /etc/apt/sources.list && \
10 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-security main restricted universe multiverse" >> /etc/apt/sources.list && \
11 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-updates main restricted universe multiverse" >> /etc/apt/sources.list && \
12 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-proposed main restricted universe multiverse" >> /etc/apt/sources.list && \
13 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-backports main restricted universe multiverse" >> /etc/apt/sources.list && \
14 | apt-get -qqy update
15 |
16 | # mxnet
17 | RUN apt-get update && apt-get install -y \
18 | build-essential \
19 | git \
20 | libopenblas-dev \
21 | libopencv-dev \
22 | python-numpy \
23 | wget \
24 | unzip
25 | RUN git clone --recursive https://github.com/dmlc/mxnet/ && cd mxnet && \
26 | git checkout f2684a6 && \
27 | sed -i -e 's/CHECK(ksize_x <= dshape\[3\] && ksize_y <= dshape\[2\])/CHECK(ksize_x <= dshape[3] + 2 * param_.pad[1] \&\& ksize_y <= dshape[2] + 2 * param_.pad[0])/' src/operator/convolution-inl.h && \
28 | cp make/config.mk . && \
29 | echo "USE_CUDA=1" >>config.mk && \
30 | echo "USE_CUDA_PATH=/usr/local/cuda" >>config.mk && \
31 | echo "USE_CUDNN=1" >>config.mk && \
32 | echo "USE_BLAS=openblas" >>config.mk && \
33 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs
34 | ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:$LD_LIBRARY_PATH
35 |
36 | ENV PYTHONPATH /mxnet/python
37 |
--------------------------------------------------------------------------------
/scripts/rundocker.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | nvidia-docker run --rm -ti -v $(pwd):/code answeror/sigr:2016-09-21 $@
4 |
--------------------------------------------------------------------------------
/scripts/test.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function, division
2 | import os
3 | import sys
4 | sys.path.insert(0, os.getcwd())
5 | import numpy as np
6 | import mxnet as mx
7 | from sigr.evaluation import CrossValEvaluation as CV, Exp
8 | from sigr.data import Preprocess, Dataset
9 | from sigr import Context
10 |
11 |
12 | inter_subject_eval = CV(crossval_type='inter-subject', batch_size=1000)
13 | inter_session_eval = CV(crossval_type='inter-session', batch_size=1000)
14 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000)
15 |
16 | print('Inter-session CSL-HDEMG')
17 | print('============')
18 |
19 | with Context(parallel=True, level='DEBUG'):
20 | acc = inter_session_eval.accuracies(
21 | [Exp(dataset=Dataset.from_name('csl'), vote=-1,
22 | dataset_args=dict(preprocess=Preprocess.parse('(csl-bandpass,csl-cut,median)')),
23 | Mod=dict(num_gesture=27,
24 | adabn=True,
25 | num_adabn_epoch=10,
26 | context=[mx.gpu(0)],
27 | symbol_kargs=dict(dropout=0, num_semg_row=24, num_semg_col=7, num_filter=64),
28 | params='.cache/sensors-csl-inter-session-%d/model-0028.params'))],
29 | folds=np.arange(25))
30 | print('Per-trial majority voting accuracy: %f' % acc.mean())
31 |
32 | print('')
33 | print('Inter-subject CapgMyo DB-b')
34 | print('============')
35 |
36 | with Context(parallel=True, level='DEBUG'):
37 | acc = inter_subject_eval.vote_accuracy_curves(
38 | [Exp(dataset=Dataset.from_name('dbb'),
39 | Mod=dict(num_gesture=8,
40 | adabn=True,
41 | num_adabn_epoch=10,
42 | context=[mx.gpu(0)],
43 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64),
44 | params='.cache/sensors-dbb-inter-subject-%d/model-0028.params'))],
45 | folds=np.arange(10),
46 | windows=[1, 150])
47 | acc = acc.mean(axis=(0, 1))
48 | print('Single frame accuracy: %f' % acc[0])
49 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[1])
50 |
51 | print('')
52 | print('Inter-session CapgMyo DB-b')
53 | print('============')
54 |
55 | with Context(parallel=True, level='DEBUG'):
56 | acc = inter_session_eval.vote_accuracy_curves(
57 | [Exp(dataset=Dataset.from_name('dbb'),
58 | Mod=dict(num_gesture=8,
59 | adabn=True,
60 | num_adabn_epoch=10,
61 | context=[mx.gpu(0)],
62 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64),
63 | params='.cache/sensors-dbb-inter-session-%d/model-0028.params'))],
64 | folds=np.arange(10),
65 | windows=[1, 150])
66 | acc = acc.mean(axis=(0, 1))
67 | print('Single frame accuracy: %f' % acc[0])
68 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[1])
69 |
70 | print('')
71 | print('Inter-subject CapgMyo DB-c')
72 | print('============')
73 |
74 | with Context(parallel=True, level='DEBUG'):
75 | acc = inter_subject_eval.vote_accuracy_curves(
76 | [Exp(dataset=Dataset.from_name('dbc'),
77 | Mod=dict(num_gesture=12,
78 | adabn=True,
79 | num_adabn_epoch=10,
80 | context=[mx.gpu(0)],
81 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64),
82 | params='.cache/sensors-dbc-inter-subject-%d/model-0028.params'))],
83 | folds=np.arange(10),
84 | windows=[1, 150])
85 | acc = acc.mean(axis=(0, 1))
86 | print('Single frame accuracy: %f' % acc[0])
87 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[1])
88 |
89 | print('')
90 | print('Inter-subject NinaPro DB1')
91 | print('===========')
92 | with Context(parallel=True, level='DEBUG'):
93 | acc = one_fold_intra_subject_eval.vote_accuracy_curves(
94 | [Exp(dataset=Dataset.from_name('ninapro-db1/caputo'),
95 | Mod=dict(num_gesture=52,
96 | context=[mx.gpu(0)],
97 | symbol_kargs=dict(dropout=0, num_semg_row=1, num_semg_col=10, num_filter=64),
98 | params='.cache/sensors-ninapro-one-fold-intra-subject-%d/model-0028.params'))],
99 | folds=np.arange(27),
100 | windows=[1, 40])
101 | acc = acc.mean(axis=(0, 1))
102 | print('Single frame accuracy: %f' % acc[0])
103 | print('40 frames (400 ms) majority voting accuracy: %f' % acc[1])
104 |
--------------------------------------------------------------------------------
/scripts/test.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | scripts/rundocker.sh python scripts/test.py
4 |
--------------------------------------------------------------------------------
/scripts/train.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 |
3 | # Inter-subjet recognition of 8 gestures in CapgMyo DB-b
4 | for i in $(seq 0 9 | shuf); do
5 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
6 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
7 | --root .cache/sensors-dbb-inter-subject-$i \
8 | --num-semg-row 16 --num-semg-col 8 \
9 | --batch-size 1000 --decay-all --dataset dbb \
10 | --num-filter 64 \
11 | --adabn --minibatch \
12 | crossval --crossval-type inter-subject --fold $i
13 | done
14 |
15 | # Inter-session recognition of 8 gestures in CapgMyo DB-b
16 | for i in 1; do
17 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
18 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
19 | --root .cache/sensors-dbb-universal-inter-session-$i \
20 | --num-semg-row 16 --num-semg-col 8 \
21 | --batch-size 1000 --decay-all --dataset dbb \
22 | --num-filter 64 \
23 | --adabn --minibatch \
24 | crossval --crossval-type universal-inter-session --fold $i
25 | done
26 | for i in $(seq 1 2 19 | shuf); do
27 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
28 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
29 | --root .cache/sensors-dbb-inter-session-$i \
30 | --num-semg-row 16 --num-semg-col 8 \
31 | --batch-size 1000 --decay-all --dataset dbb \
32 | --num-filter 64 \
33 | --params .cache/sensors-dbb-universal-inter-session-1/model-0028.params \
34 | --fix-params ".*conv.*" --fix-params ".*pixel.*" --fix-params "fc1_.*" --fix-params "fc2_.*" \
35 | --adabn \
36 | crossval --crossval-type inter-session --fold $i
37 | done
38 |
39 | # Inter-subjet recognition of 12 gestures in CapgMyo DB-c
40 | for i in $(seq 0 9 | shuf); do
41 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
42 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
43 | --root .cache/sensors-dbc-inter-subject-$i \
44 | --num-semg-row 16 --num-semg-col 8 \
45 | --batch-size 1000 --decay-all --dataset dbc \
46 | --num-filter 64 \
47 | --adabn --minibatch \
48 | crossval --crossval-type inter-subject --fold $i
49 | done
50 |
51 | # Inter-session recognition of 27 gestures in CSL-HDEMG
52 | for i in $(seq 0 5 | shuf); do
53 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
54 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
55 | --root .cache/sensors-csl-universal-inter-session-$i \
56 | --num-semg-row 24 --num-semg-col 7 \
57 | --batch-size 1000 --decay-all --adabn --minibatch --dataset csl \
58 | --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \
59 | --balance-gesture 1 \
60 | --num-filter 64 \
61 | crossval --crossval-type universal-inter-session --fold $i
62 | done
63 | for i in $(seq 0 24 | shuf); do
64 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
65 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
66 | --root .cache/sensors-csl-inter-session-$i \
67 | --num-semg-row 24 --num-semg-col 7 \
68 | --batch-size 1000 --decay-all --adabn --minibatch --dataset csl \
69 | --preprocess '(csl-bandpass,csl-cut,median)' \
70 | --balance-gesture 1 \
71 | --num-filter 64 \
72 | --params .cache/sensors-csl-universal-inter-session-$(($i % 5))/model-0028.params \
73 | crossval --crossval-type inter-session --fold $i
74 | done
75 |
76 | # Inter-subject recognition of 52 gestures in NinaPro DB1 with calibration data
77 | for i in 0; do
78 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
79 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
80 | --root .cache/sensors-ninapro-universal-one-fold-intra-subject-$i \
81 | --num-semg-row 1 --num-semg-col 10 \
82 | --batch-size 1000 --decay-all --adabn --minibatch --dataset ninapro-db1/caputo \
83 | --num-filter 64 \
84 | --preprocess 'downsample-16' \
85 | crossval --crossval-type universal-one-fold-intra-subject --fold $i
86 | done
87 | for i in $(seq 0 26 | shuf); do
88 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \
89 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \
90 | --root .cache/sensors-ninapro-one-fold-intra-subject-$i \
91 | --num-semg-row 1 --num-semg-col 10 \
92 | --batch-size 1000 --decay-all --dataset ninapro-db1/caputo \
93 | --num-filter 64 \
94 | --params .cache/sensors-ninapro-universal-one-fold-intra-subject-0/model-0028.params \
95 | --preprocess 'downsample-16' \
96 | crossval --crossval-type one-fold-intra-subject --fold $i
97 | done
98 |
--------------------------------------------------------------------------------
/sigr/__init__.py:
--------------------------------------------------------------------------------
1 | import mxnet as mx
2 | import numpy as np
3 | import random
4 |
5 | mx.random.seed(42)
6 | np.random.seed(43)
7 | random.seed(44)
8 |
9 | import os
10 |
11 | os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp'
12 |
13 | ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
14 | CACHE = os.path.join(ROOT, '.cache')
15 |
16 | from contextlib import contextmanager
17 |
18 |
19 | @contextmanager
20 | def Context(log=None, parallel=False, level=None):
21 | from .utils import logging_context
22 | with logging_context(log, level=level):
23 | if not parallel:
24 | yield
25 | else:
26 | import joblib as jb
27 | from multiprocessing import cpu_count
28 | with jb.Parallel(n_jobs=cpu_count()) as par:
29 | Context.parallel = par
30 | yield
31 |
32 |
33 | def _patch(func):
34 | func()
35 | return lambda: None
36 |
37 |
38 | @_patch
39 | def _patch_click():
40 | import click
41 | orig = click.option
42 |
43 | def option(*args, **kargs):
44 | if 'help' in kargs and 'default' in kargs:
45 | kargs['help'] += ' (default {})'.format(kargs['default'])
46 | return orig(*args, **kargs)
47 |
48 | click.option = option
49 |
50 |
51 | from .data import s21 as data_s21
52 |
53 |
54 | __all__ = ['ROOT', 'CACHE', 'Context', 'data_s21']
55 |
--------------------------------------------------------------------------------
/sigr/app.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import click
3 | import mxnet as mx
4 | from logbook import Logger
5 | from pprint import pformat
6 | import os
7 | from .utils import packargs, Bunch
8 | from .module import Module
9 | from .data import Preprocess, Dataset
10 | from . import data_s21, Context, constant
11 |
12 |
13 | logger = Logger('sigr')
14 |
15 |
16 | @click.group()
17 | def cli():
18 | pass
19 |
20 |
21 | @cli.group()
22 | @click.option('--downsample', type=int, default=0)
23 | @click.option('--num-semg-row', type=int, default=constant.NUM_SEMG_ROW, help='Rows of sEMG image')
24 | @click.option('--num-semg-col', type=int, default=constant.NUM_SEMG_COL, help='Cols of sEMG image')
25 | @click.option('--num-epoch', type=int, default=60, help='Maximum epoches')
26 | @click.option('--num-tzeng-batch', type=int, default=constant.NUM_TZENG_BATCH,
27 | help='Batch number of each Tzeng update, 2 means interleaved domain and label update')
28 | @click.option('--lr-step', type=int, multiple=True, default=[20, 40], help='Epoch numbers to decay learning rate')
29 | @click.option('--lr-factor', type=float, multiple=True)
30 | @click.option('--batch-size', type=int, default=1000,
31 | help='Batch size, should be 900 with --minibatch for s21 inter-subject experiment')
32 | @click.option('--lr', type=float, default=0.1, help='Base learning rate')
33 | @click.option('--wd', type=float, default=0.0001, help='Weight decay')
34 | @click.option('--subject-wd', type=float, help='Weight decay multiplier of the subject branch')
35 | @click.option('--gpu', type=int, multiple=True, default=[0])
36 | @click.option('--gamma', type=float, default=constant.GAMMA, help='Gamma in RevGrad')
37 | @click.option('--log', type=click.Path(), help='Path of the logging file')
38 | @click.option('--snapshot', type=click.Path(), help='Snapshot prefix')
39 | @click.option('--root', type=click.Path(), help='Root path of the experiment, auto create if not exists')
40 | @click.option('--revgrad', is_flag=True, help='Use RevGrad')
41 | @click.option('--num-revgrad-batch', type=int, default=2,
42 | help=('Batch number of each RevGrad update, 2 means interleaved domain and label update, '
43 | 'see "Adversarial Deep Averaging Networks for Cross-Lingual Sentiment Classification" for details'))
44 | @click.option('--tzeng', is_flag=True, help='Use Tzeng_ICCV_2015')
45 | @click.option('--confuse-conv', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on conv2')
46 | @click.option('--confuse-all', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on all layers')
47 | @click.option('--subject-loss-weight', type=float, default=1, help='Ganin et al. use 0.1 in their code')
48 | @click.option('--subject-confusion-loss-weight', type=float, default=1,
49 | help='Tzeng confusion loss weight, larger than 1 seems better')
50 | @click.option('--lambda-scale', type=float, default=constant.LAMBDA_SCALE,
51 | help='Global scale of lambda in RevGrad, 1 in their paper and 0.1 in their code')
52 | @click.option('--params', type=click.Path(exists=True), help='Inital weights')
53 | @click.option('--ignore-params', multiple=True, help='Ignore params in --params with regex')
54 | @click.option('--random-shift-fill', type=click.Choice(['zero', 'margin']),
55 | default=constant.RANDOM_SHIFT_FILL, help='Random shift filling value')
56 | @click.option('--random-shift-horizontal', type=int, default=0, help='Random shift input horizontally by x pixels')
57 | @click.option('--random-shift-vertical', type=int, default=0, help='Random shift input vertically by x pixels')
58 | @click.option('--random-scale', type=float, default=0,
59 | help='Random scale input data globally by 2^scale, and locally by 2^(scale/4)')
60 | @click.option('--random-bad-channel', type=float, multiple=True, default=[],
61 | help='Random (with a probability of 0.5 for each image) assign a pixel as specified value, usually [-1, 0, 1]')
62 | @click.option('--num-feature-block', type=int, default=constant.NUM_FEATURE_BLOCK, help='Number of FC layers in feature extraction part')
63 | @click.option('--num-gesture-block', type=int, default=constant.NUM_GESTURE_BLOCK, help='Number of FC layers in gesture branch')
64 | @click.option('--num-subject-block', type=int, default=constant.NUM_SUBJECT_BLOCK, help='Number of FC layers in subject branch')
65 | @click.option('--adabn', is_flag=True, help='AdaBN for model adaptation, must be used with --minibatch')
66 | @click.option('--num-adabn-epoch', type=int, default=constant.NUM_ADABN_EPOCH)
67 | @click.option('--num-pixel', type=int, default=constant.NUM_PIXEL, help='Pixelwise reduction layers')
68 | @click.option('--num-filter', type=int, default=constant.NUM_FILTER, help='Kernels of the conv layers')
69 | @click.option('--num-hidden', type=int, default=constant.NUM_HIDDEN, help='Kernels of the FC layers')
70 | @click.option('--num-bottleneck', type=int, default=constant.NUM_BOTTLENECK, help='Kernels of the bottleneck layer')
71 | @click.option('--dropout', type=float, default=constant.DROPOUT, help='Dropout ratio')
72 | @click.option('--window', type=int, default=1, help='Multi-frame as image channels')
73 | @click.option('--lstm-window', type=int)
74 | @click.option('--num-presnet', type=int, multiple=True, help='Deprecated')
75 | @click.option('--presnet-branch', type=int, multiple=True, help='Deprecated')
76 | @click.option('--drop-presnet', is_flag=True)
77 | @click.option('--bng', is_flag=True, help='Deprecated')
78 | @click.option('--minibatch', is_flag=True, help='Split data into minibatch by subject id')
79 | @click.option('--drop-branch', is_flag=True, help='Dropout after each FC in branches')
80 | @click.option('--pool', is_flag=True, help='Deprecated')
81 | @click.option('--fft', is_flag=True, help='Deprecaded. Perform FFT and use spectrum amplitude as image channels. Cannot be used on non-uniform (segment length) dataset like NinaPro')
82 | @click.option('--fft-append', is_flag=True, help='Append FFT feature to raw frames in channel axis')
83 | @click.option('--dual-stream', is_flag=True, help='Use raw frames and FFT feature as dual-stream')
84 | @click.option('--zscore/--no-zscore', default=True, help='Use z-score normalization on input')
85 | @click.option('--zscore-bng', is_flag=True, help='Use global BatchNorm as z-score normalization, for window > 1 or FFT')
86 | @click.option('--lstm', is_flag=True)
87 | @click.option('--num-lstm-hidden', type=int, default=constant.NUM_LSTM_HIDDEN, help='Kernels of the hidden layers in LSTM')
88 | @click.option('--num-lstm-layer', type=int, default=constant.NUM_LSTM_LAYER, help='Number of the hidden layers in LSTM')
89 | @click.option('--dense-window/--no-dense-window', default=True, help='Dense sampling of windows during training')
90 | @click.option('--lstm-last', type=int, default=0)
91 | @click.option('--lstm-dropout', type=float, default=constant.LSTM_DROPOUT, help='LSTM dropout ratio')
92 | @click.option('--lstm-shortcut', is_flag=True)
93 | @click.option('--lstm-bn/--no-lstm-bn', default=True, help='BatchNorm in LSTM')
94 | @click.option('--lstm-grad-scale/--no-lstm-grad-scale', default=True, help='Grad scale by the number of LSTM output')
95 | @click.option('--faug', type=float, default=0)
96 | @click.option('--faug-classwise', is_flag=True)
97 | @click.option('--num-eval-epoch', type=int, default=1)
98 | @click.option('--snapshot-period', type=int, default=1)
99 | @click.option('--gpu-x', type=int, default=0)
100 | @click.option('--drop-conv', is_flag=True)
101 | @click.option('--drop-pixel', type=int, multiple=True, default=(-1,))
102 | @click.option('--drop-presnet-branch', is_flag=True)
103 | @click.option('--drop-presnet-proj', is_flag=True)
104 | @click.option('--fix-params', multiple=True)
105 | @click.option('--presnet-proj-type', type=click.Choice(['A', 'B']), default='A')
106 | @click.option('--decay-all', is_flag=True)
107 | @click.option('--presnet-promote', is_flag=True)
108 | @click.option('--pixel-reduce-loss-weight', type=float, default=0)
109 | @click.option('--fast-pixel-reduce/--no-fast-pixel-reduce', default=True)
110 | @click.option('--pixel-reduce-bias', is_flag=True)
111 | @click.option('--pixel-reduce-kernel', type=int, multiple=True, default=(1, 1))
112 | @click.option('--pixel-reduce-stride', type=int, multiple=True, default=(1, 1))
113 | @click.option('--pixel-reduce-pad', type=int, multiple=True, default=(0, 0))
114 | @click.option('--pixel-reduce-norm', is_flag=True)
115 | @click.option('--pixel-reduce-reg-out', is_flag=True)
116 | @click.option('--num-pixel-reduce-filter', type=int, multiple=True, default=tuple(None for _ in range(constant.NUM_PIXEL)))
117 | @click.option('--num-conv', type=int, default=2)
118 | @click.option('--pixel-same-init', is_flag=True)
119 | @click.option('--presnet-dense', is_flag=True)
120 | @click.option('--conv-shortcut', is_flag=True)
121 | @click.option('--preprocess', callback=lambda ctx, param, value: Preprocess.parse(value))
122 | @click.option('--bandstop', is_flag=True)
123 | @click.option('--dataset', type=click.Choice(['s21', 'csl',
124 | 'dba', 'dbb', 'dbc',
125 | 'ninapro-db1-matlab-lowpass',
126 | 'ninapro-db1/caputo',
127 | 'ninapro-db1',
128 | 'ninapro-db1/g53',
129 | 'ninapro-db1/g5',
130 | 'ninapro-db1/g8',
131 | 'ninapro-db1/g12']), required=True)
132 | @click.option('--balance-gesture', type=float, default=0)
133 | @click.option('--module', type=click.Choice(['convnet',
134 | 'knn',
135 | 'svm',
136 | 'random-forests',
137 | 'lda']), default='convnet')
138 | @click.option('--amplitude-weighting', is_flag=True)
139 | @packargs
140 | def exp(args):
141 | pass
142 |
143 |
144 | @exp.command()
145 | @click.option('--fold', type=int, required=True, help='Fold number of the crossval experiment')
146 | @click.option('--crossval-type', type=click.Choice(['intra-session',
147 | 'universal-intra-session',
148 | 'inter-session',
149 | 'universal-inter-session',
150 | 'intra-subject',
151 | 'universal-intra-subject',
152 | 'inter-subject',
153 | 'one-fold-intra-subject',
154 | 'universal-one-fold-intra-subject']), required=True)
155 | @packargs
156 | def crossval(args):
157 | if args.root:
158 | if args.log:
159 | args.log = os.path.join(args.root, args.log)
160 | if args.snapshot:
161 | args.snapshot = os.path.join(args.root, args.snapshot)
162 |
163 | if args.gpu_x:
164 | args.gpu = sum([list(args.gpu) for i in range(args.gpu_x)], [])
165 |
166 | if os.path.exists(args.log):
167 | click.echo('Found log {}, exit'.format(args.log))
168 | return
169 |
170 | with Context(args.log, parallel=True):
171 | logger.info('Args:\n{}', pformat(args))
172 | for i in range(args.num_epoch):
173 | path = args.snapshot + '-%04d.params' % (i + 1)
174 | if os.path.exists(path):
175 | logger.info('Found snapshot {}, exit', path)
176 | return
177 |
178 | dataset = Dataset.from_name(args.dataset)
179 | get_crossval_data = getattr(dataset, 'get_%s_data' % args.crossval_type.replace('-', '_'))
180 | train, val = get_crossval_data(
181 | batch_size=args.batch_size,
182 | fold=args.fold,
183 | preprocess=args.preprocess,
184 | adabn=args.adabn,
185 | minibatch=args.minibatch,
186 | balance_gesture=args.balance_gesture,
187 | amplitude_weighting=args.amplitude_weighting,
188 | random_shift_fill=args.random_shift_fill,
189 | random_shift_horizontal=args.random_shift_horizontal,
190 | random_shift_vertical=args.random_shift_vertical
191 | )
192 | logger.info('Train samples: {}', train.num_sample)
193 | logger.info('Val samples: {}', val.num_sample)
194 | mod = Module.parse(
195 | args.module,
196 | revgrad=args.revgrad,
197 | num_revgrad_batch=args.num_revgrad_batch,
198 | tzeng=args.tzeng,
199 | num_tzeng_batch=args.num_tzeng_batch,
200 | num_gesture=train.num_gesture,
201 | num_subject=train.num_subject,
202 | subject_loss_weight=args.subject_loss_weight,
203 | lambda_scale=args.lambda_scale,
204 | adabn=args.adabn,
205 | num_adabn_epoch=args.num_adabn_epoch,
206 | random_scale=args.random_scale,
207 | dual_stream=args.dual_stream,
208 | lstm=args.lstm,
209 | num_lstm_hidden=args.num_lstm_hidden,
210 | num_lstm_layer=args.num_lstm_layer,
211 | for_training=True,
212 | faug=args.faug,
213 | faug_classwise=args.faug_classwise,
214 | num_eval_epoch=args.num_eval_epoch,
215 | snapshot_period=args.snapshot_period,
216 | pixel_same_init=args.pixel_same_init,
217 | symbol_kargs=dict(
218 | num_semg_row=args.num_semg_row,
219 | num_semg_col=args.num_semg_col,
220 | num_filter=args.num_filter,
221 | num_pixel=args.num_pixel,
222 | num_feature_block=args.num_feature_block,
223 | num_gesture_block=args.num_gesture_block,
224 | num_subject_block=args.num_subject_block,
225 | num_hidden=args.num_hidden,
226 | num_bottleneck=args.num_bottleneck,
227 | dropout=args.dropout,
228 | num_channel=train.num_channel // (args.lstm_window or 1),
229 | num_presnet=args.num_presnet,
230 | presnet_branch=args.presnet_branch,
231 | drop_presnet=args.drop_presnet,
232 | bng=args.bng,
233 | subject_confusion_loss_weight=args.subject_confusion_loss_weight,
234 | minibatch=args.minibatch,
235 | confuse_conv=args.confuse_conv,
236 | confuse_all=args.confuse_all,
237 | subject_wd=args.subject_wd,
238 | drop_branch=args.drop_branch,
239 | pool=args.pool,
240 | zscore=args.zscore,
241 | zscore_bng=args.zscore_bng,
242 | num_stream=2 if args.dual_stream else 1,
243 | lstm_last=args.lstm_last,
244 | lstm_dropout=args.lstm_dropout,
245 | lstm_shortcut=args.lstm_shortcut,
246 | lstm_bn=args.lstm_bn,
247 | lstm_window=args.lstm_window,
248 | lstm_grad_scale=args.lstm_grad_scale,
249 | drop_conv=args.drop_conv,
250 | drop_presnet_branch=args.drop_presnet_branch,
251 | drop_presnet_proj=args.drop_presnet_proj,
252 | presnet_proj_type=args.presnet_proj_type,
253 | presnet_promote=args.presnet_promote,
254 | pixel_reduce_loss_weight=args.pixel_reduce_loss_weight,
255 | pixel_reduce_bias=args.pixel_reduce_bias,
256 | pixel_reduce_kernel=args.pixel_reduce_kernel,
257 | pixel_reduce_stride=args.pixel_reduce_stride,
258 | pixel_reduce_pad=args.pixel_reduce_pad,
259 | pixel_reduce_norm=args.pixel_reduce_norm,
260 | pixel_reduce_reg_out=args.pixel_reduce_reg_out,
261 | num_pixel_reduce_filter=args.num_pixel_reduce_filter,
262 | fast_pixel_reduce=args.fast_pixel_reduce,
263 | drop_pixel=args.drop_pixel,
264 | num_conv=args.num_conv,
265 | presnet_dense=args.presnet_dense,
266 | conv_shortcut=args.conv_shortcut
267 | ),
268 | context=[mx.gpu(i) for i in args.gpu]
269 | )
270 | mod.fit(
271 | train_data=train,
272 | eval_data=val,
273 | num_epoch=args.num_epoch,
274 | num_train=train.num_sample,
275 | batch_size=args.batch_size,
276 | lr_step=args.lr_step,
277 | lr_factor=args.lr_factor,
278 | lr=args.lr,
279 | wd=args.wd,
280 | gamma=args.gamma,
281 | snapshot=args.snapshot,
282 | params=args.params,
283 | ignore_params=args.ignore_params,
284 | fix_params=args.fix_params,
285 | decay_all=args.decay_all
286 | )
287 |
288 |
289 | @exp.command()
290 | @packargs
291 | def general(args):
292 | if args.root:
293 | if args.log:
294 | args.log = os.path.join(args.root, args.log)
295 | if args.snapshot:
296 | args.snapshot = os.path.join(args.root, args.snapshot)
297 |
298 | if args.gpu_x:
299 | args.gpu = sum([list(args.gpu) for i in range(args.gpu_x)], [])
300 |
301 | with Context(args.log):
302 | logger.info('Args:\n{}', pformat(args))
303 | for i in range(args.num_epoch):
304 | path = args.snapshot + '-%04d.params' % (i + 1)
305 | if os.path.exists(path):
306 | logger.info('Found snapshot {}, exit', path)
307 | return
308 |
309 | from .data import csl
310 |
311 | train, val = csl.get_general_data(
312 | batch_size=args.batch_size,
313 | adabn=args.adabn,
314 | minibatch=args.minibatch,
315 | downsample=args.downsample
316 | )
317 | logger.info('Train samples: {}', train.num_sample)
318 | logger.info('Val samples: {}', val.num_sample)
319 | mod = Module(
320 | revgrad=args.revgrad,
321 | num_revgrad_batch=args.num_revgrad_batch,
322 | tzeng=args.tzeng,
323 | num_tzeng_batch=args.num_tzeng_batch,
324 | num_gesture=train.num_gesture,
325 | num_subject=train.num_subject,
326 | subject_loss_weight=args.subject_loss_weight,
327 | lambda_scale=args.lambda_scale,
328 | adabn=args.adabn,
329 | num_adabn_epoch=args.num_adabn_epoch,
330 | random_scale=args.random_scale,
331 | dual_stream=args.dual_stream,
332 | lstm=args.lstm,
333 | num_lstm_hidden=args.num_lstm_hidden,
334 | num_lstm_layer=args.num_lstm_layer,
335 | for_training=True,
336 | faug=args.faug,
337 | faug_classwise=args.faug_classwise,
338 | num_eval_epoch=args.num_eval_epoch,
339 | snapshot_period=args.snapshot_period,
340 | pixel_same_init=args.pixel_same_init,
341 | symbol_kargs=dict(
342 | num_semg_row=args.num_semg_row,
343 | num_semg_col=args.num_semg_col,
344 | num_filter=args.num_filter,
345 | num_pixel=args.num_pixel,
346 | num_feature_block=args.num_feature_block,
347 | num_gesture_block=args.num_gesture_block,
348 | num_subject_block=args.num_subject_block,
349 | num_hidden=args.num_hidden,
350 | num_bottleneck=args.num_bottleneck,
351 | dropout=args.dropout,
352 | num_channel=train.num_channel // (args.lstm_window or 1),
353 | num_presnet=args.num_presnet,
354 | presnet_branch=args.presnet_branch,
355 | drop_presnet=args.drop_presnet,
356 | bng=args.bng,
357 | subject_confusion_loss_weight=args.subject_confusion_loss_weight,
358 | minibatch=args.minibatch,
359 | confuse_conv=args.confuse_conv,
360 | confuse_all=args.confuse_all,
361 | subject_wd=args.subject_wd,
362 | drop_branch=args.drop_branch,
363 | pool=args.pool,
364 | zscore=args.zscore,
365 | zscore_bng=args.zscore_bng,
366 | num_stream=2 if args.dual_stream else 1,
367 | lstm_last=args.lstm_last,
368 | lstm_dropout=args.lstm_dropout,
369 | lstm_shortcut=args.lstm_shortcut,
370 | lstm_bn=args.lstm_bn,
371 | lstm_window=args.lstm_window,
372 | lstm_grad_scale=args.lstm_grad_scale,
373 | drop_conv=args.drop_conv,
374 | drop_presnet_branch=args.drop_presnet_branch,
375 | drop_presnet_proj=args.drop_presnet_proj,
376 | presnet_proj_type=args.presnet_proj_type,
377 | presnet_promote=args.presnet_promote,
378 | pixel_reduce_loss_weight=args.pixel_reduce_loss_weight,
379 | pixel_reduce_bias=args.pixel_reduce_bias,
380 | pixel_reduce_kernel=args.pixel_reduce_kernel,
381 | pixel_reduce_stride=args.pixel_reduce_stride,
382 | pixel_reduce_pad=args.pixel_reduce_pad,
383 | pixel_reduce_norm=args.pixel_reduce_norm,
384 | pixel_reduce_reg_out=args.pixel_reduce_reg_out,
385 | num_pixel_reduce_filter=args.num_pixel_reduce_filter,
386 | fast_pixel_reduce=args.fast_pixel_reduce,
387 | drop_pixel=args.drop_pixel,
388 | num_conv=args.num_conv,
389 | presnet_dense=args.presnet_dense,
390 | conv_shortcut=args.conv_shortcut
391 | ),
392 | context=[mx.gpu(i) for i in args.gpu]
393 | )
394 | mod.fit(
395 | train_data=train,
396 | eval_data=val,
397 | num_epoch=args.num_epoch,
398 | num_train=train.num_sample,
399 | batch_size=args.batch_size,
400 | lr_step=args.lr_step,
401 | lr=args.lr,
402 | wd=args.wd,
403 | gamma=args.gamma,
404 | snapshot=args.snapshot,
405 | params=args.params,
406 | ignore_params=args.ignore_params,
407 | fix_params=args.fix_params,
408 | decay_all=args.decay_all
409 | )
410 |
411 |
412 | @cli.command()
413 | @click.option('--num-semg-row', type=int, default=constant.NUM_SEMG_ROW, help='Rows of sEMG image')
414 | @click.option('--num-semg-col', type=int, default=constant.NUM_SEMG_COL, help='Cols of sEMG image')
415 | @click.option('--num-epoch', type=int, default=60, help='Maximum epoches')
416 | @click.option('--num-tzeng-batch', type=int, default=constant.NUM_TZENG_BATCH,
417 | help='Batch number of each Tzeng update, 2 means interleaved domain and label update')
418 | @click.option('--lr-step', type=int, multiple=True, default=[20, 40], help='Epoch numbers to decay learning rate')
419 | @click.option('--batch-size', type=int, default=1000,
420 | help='Batch size, should be 900 with --minibatch for s21 inter-subject experiment')
421 | @click.option('--lr', type=float, default=0.1, help='Base learning rate')
422 | @click.option('--wd', type=float, default=0.0001, help='Weight decay')
423 | @click.option('--subject-wd', type=float, help='Weight decay multiplier of the subject branch')
424 | @click.option('--gpu', type=int, multiple=True, default=[0])
425 | @click.option('--gamma', type=float, default=constant.GAMMA, help='Gamma in RevGrad')
426 | @click.option('--log', type=click.Path(), help='Path of the logging file')
427 | @click.option('--snapshot', type=click.Path(), help='Snapshot prefix')
428 | @click.option('--root', type=click.Path(), help='Root path of the experiment, auto create if not exists')
429 | @click.option('--fold', type=int, required=True, help='Fold number of the inter-subject experiment')
430 | @click.option('--maxforce', is_flag=True, help='Use maxforce data of the target subject as calibration data')
431 | @click.option('--calib', is_flag=True, help='Use first repetition of the target subject as calibration data')
432 | @click.option('--only-calib', is_flag=True, help='Only use first repetition of the target subject as calibration data')
433 | @click.option('--target-binary', is_flag=True, help='Make binary prediction of subject and upsampling target dataset')
434 | @click.option('--revgrad', is_flag=True, help='Use RevGrad')
435 | @click.option('--num-revgrad-batch', type=int, default=2,
436 | help=('Batch number of each RevGrad update, 2 means interleaved domain and label update, '
437 | 'see "Adversarial Deep Averaging Networks for Cross-Lingual Sentiment Classification" for details'))
438 | @click.option('--tzeng', is_flag=True, help='Use Tzeng_ICCV_2015')
439 | @click.option('--confuse-conv', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on conv2')
440 | @click.option('--confuse-all', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on all layers')
441 | @click.option('--subject-loss-weight', type=float, default=1, help='Ganin et al. use 0.1 in their code')
442 | @click.option('--subject-confusion-loss-weight', type=float, default=1,
443 | help='Tzeng confusion loss weight, larger than 1 seems better')
444 | @click.option('--target-gesture-loss-weight', type=float, help='For --calib to emphasis calibration data')
445 | @click.option('--lambda-scale', type=float, default=constant.LAMBDA_SCALE,
446 | help='Global scale of lambda in RevGrad, 1 in their paper and 0.1 in their code')
447 | @click.option('--params', type=click.Path(exists=True), help='Inital weights')
448 | @click.option('--ignore-params', multiple=True, help='Ignore params in --params with regex')
449 | @click.option('--random-scale', type=float, default=0,
450 | help='Random scale input data globally by 2^scale, and locally by 2^(scale/4)')
451 | @click.option('--random-bad-channel', type=float, multiple=True, default=[],
452 | help='Random (with a probability of 0.5 for each image) assign a pixel as specified value, usually [-1, 0, 1]')
453 | @click.option('--num-feature-block', type=int, default=constant.NUM_FEATURE_BLOCK, help='Number of FC layers in feature extraction part')
454 | @click.option('--num-gesture-block', type=int, default=constant.NUM_GESTURE_BLOCK, help='Number of FC layers in gesture branch')
455 | @click.option('--num-subject-block', type=int, default=constant.NUM_SUBJECT_BLOCK, help='Number of FC layers in subject branch')
456 | @click.option('--adabn', is_flag=True, help='AdaBN for model adaptation, must be used with --minibatch')
457 | @click.option('--num-adabn-epoch', type=int, default=constant.NUM_ADABN_EPOCH)
458 | @click.option('--num-pixel', type=int, default=constant.NUM_PIXEL, help='Pixelwise reduction layers')
459 | @click.option('--num-filter', type=int, default=constant.NUM_FILTER, help='Kernels of the conv layers')
460 | @click.option('--num-hidden', type=int, default=constant.NUM_HIDDEN, help='Kernels of the FC layers')
461 | @click.option('--num-bottleneck', type=int, default=constant.NUM_BOTTLENECK, help='Kernels of the bottleneck layer')
462 | @click.option('--dropout', type=float, default=constant.DROPOUT, help='Dropout ratio')
463 | @click.option('--window', type=int, default=1, help='Multi-frame as image channels')
464 | @click.option('--lstm-window', type=int)
465 | @click.option('--num-presnet', type=int, multiple=True, help='Deprecated')
466 | @click.option('--presnet-branch', type=int, multiple=True, help='Deprecated')
467 | @click.option('--drop-presnet', is_flag=True)
468 | @click.option('--bng', is_flag=True, help='Deprecated')
469 | @click.option('--soft-label', is_flag=True, help='Tzeng soft-label for finetuning with calibration data')
470 | @click.option('--minibatch', is_flag=True, help='Split data into minibatch by subject id')
471 | @click.option('--drop-branch', is_flag=True, help='Dropout after each FC in branches')
472 | @click.option('--pool', is_flag=True, help='Deprecated')
473 | @click.option('--fft', is_flag=True, help='Deprecaded. Perform FFT and use spectrum amplitude as image channels. Cannot be used on non-uniform (segment length) dataset like NinaPro')
474 | @click.option('--fft-append', is_flag=True, help='Append FFT feature to raw frames in channel axis')
475 | @click.option('--dual-stream', is_flag=True, help='Use raw frames and FFT feature as dual-stream')
476 | @click.option('--zscore/--no-zscore', default=True, help='Use z-score normalization on input')
477 | @click.option('--zscore-bng', is_flag=True, help='Use global BatchNorm as z-score normalization, for window > 1 or FFT')
478 | @click.option('--lstm', is_flag=True)
479 | @click.option('--num-lstm-hidden', type=int, default=constant.NUM_LSTM_HIDDEN, help='Kernels of the hidden layers in LSTM')
480 | @click.option('--num-lstm-layer', type=int, default=constant.NUM_LSTM_LAYER, help='Number of the hidden layers in LSTM')
481 | @click.option('--dense-window/--no-dense-window', default=True, help='Dense sampling of windows during training')
482 | @click.option('--lstm-last', type=int, default=0)
483 | @click.option('--lstm-dropout', type=float, default=constant.LSTM_DROPOUT, help='LSTM dropout ratio')
484 | @click.option('--lstm-shortcut', is_flag=True)
485 | @click.option('--lstm-bn/--no-lstm-bn', default=True, help='BatchNorm in LSTM')
486 | @click.option('--lstm-grad-scale/--no-lstm-grad-scale', default=True, help='Grad scale by the number of LSTM output')
487 | @click.option('--faug', type=float, default=0)
488 | @click.option('--faug-classwise', is_flag=True)
489 | @click.option('--num-eval-epoch', type=int, default=1)
490 | @click.option('--snapshot-period', type=int, default=1)
491 | @click.option('--gpu-x', type=int, default=0)
492 | @click.option('--drop-conv', is_flag=True)
493 | @click.option('--drop-pixel', type=int, multiple=True, default=(-1,))
494 | @click.option('--drop-presnet-branch', is_flag=True)
495 | @click.option('--drop-presnet-proj', is_flag=True)
496 | @click.option('--fix-params', multiple=True)
497 | @click.option('--presnet-proj-type', type=click.Choice(['A', 'B']), default='A')
498 | @click.option('--decay-all', is_flag=True)
499 | @click.option('--presnet-promote', is_flag=True)
500 | @click.option('--pixel-reduce-loss-weight', type=float, default=0)
501 | @click.option('--fast-pixel-reduce/--no-fast-pixel-reduce', default=True)
502 | @click.option('--pixel-reduce-bias', is_flag=True)
503 | @click.option('--pixel-reduce-kernel', type=int, multiple=True, default=(1, 1))
504 | @click.option('--pixel-reduce-stride', type=int, multiple=True, default=(1, 1))
505 | @click.option('--pixel-reduce-pad', type=int, multiple=True, default=(0, 0))
506 | @click.option('--pixel-reduce-norm', is_flag=True)
507 | @click.option('--pixel-reduce-reg-out', is_flag=True)
508 | @click.option('--num-pixel-reduce-filter', type=int, multiple=True, default=(16, 16))
509 | @click.option('--num-conv', type=int, default=2)
510 | @click.option('--pixel-same-init', is_flag=True)
511 | @click.option('--presnet-dense', is_flag=True)
512 | @click.option('--conv-shortcut', is_flag=True)
513 | @packargs
514 | def inter(args):
515 | '''Inter-subject experiment on S21 dataset'''
516 | if args.root:
517 | if args.log:
518 | args.log = os.path.join(args.root, args.log)
519 | if args.snapshot:
520 | args.snapshot = os.path.join(args.root, args.snapshot)
521 |
522 | if args.gpu_x:
523 | args.gpu = sum([list(args.gpu) for i in range(args.gpu_x)], [])
524 |
525 | with Context(args.log):
526 | logger.info('Args:\n{}', pformat(args))
527 | for i in range(args.num_epoch):
528 | path = args.snapshot + '-%04d.params' % (i + 1)
529 | if os.path.exists(path):
530 | logger.info('Found snapshot {}, exit', path)
531 | return
532 | train, val = data_s21.get_inter_subject_data(
533 | '.cache/mat.s21.bandstop-45-55.s1000m.scale-01',
534 | fold=args.fold,
535 | batch_size=args.batch_size,
536 | maxforce=args.maxforce,
537 | calib=args.calib or args.only_calib,
538 | only_calib=args.only_calib,
539 | target_binary=args.target_binary,
540 | with_subject=args.revgrad or args.tzeng,
541 | with_target_gesture=args.target_gesture_loss_weight is not None,
542 | random_scale=args.random_scale,
543 | random_bad_channel=args.random_bad_channel,
544 | shuffle=True,
545 | adabn=args.adabn,
546 | window=args.window,
547 | dense_window=args.dense_window,
548 | soft_label=args.soft_label,
549 | minibatch=args.minibatch,
550 | fft=args.fft,
551 | fft_append=args.fft_append,
552 | dual_stream=args.dual_stream,
553 | lstm=args.lstm,
554 | lstm_window=args.lstm_window
555 | )
556 | logger.info('Train samples: {}', train.num_sample)
557 | logger.info('Val samples: {}', val.num_sample)
558 | mod = Module(
559 | revgrad=args.revgrad,
560 | num_revgrad_batch=args.num_revgrad_batch,
561 | tzeng=args.tzeng,
562 | num_tzeng_batch=args.num_tzeng_batch,
563 | num_gesture=train.num_gesture,
564 | num_subject=train.num_subject,
565 | subject_loss_weight=args.subject_loss_weight,
566 | target_gesture_loss_weight=args.target_gesture_loss_weight,
567 | lambda_scale=args.lambda_scale,
568 | adabn=args.adabn,
569 | num_adabn_epoch=args.num_adabn_epoch,
570 | random_scale=args.random_scale,
571 | soft_label=args.soft_label,
572 | dual_stream=args.dual_stream,
573 | lstm=args.lstm,
574 | num_lstm_hidden=args.num_lstm_hidden,
575 | num_lstm_layer=args.num_lstm_layer,
576 | for_training=True,
577 | faug=args.faug,
578 | faug_classwise=args.faug_classwise,
579 | num_eval_epoch=args.num_eval_epoch,
580 | snapshot_period=args.snapshot_period,
581 | pixel_same_init=args.pixel_same_init,
582 | symbol_kargs=dict(
583 | num_semg_row=args.num_semg_row,
584 | num_semg_col=args.num_semg_col,
585 | num_filter=args.num_filter,
586 | num_pixel=args.num_pixel,
587 | num_feature_block=args.num_feature_block,
588 | num_gesture_block=args.num_gesture_block,
589 | num_subject_block=args.num_subject_block,
590 | num_hidden=args.num_hidden,
591 | num_bottleneck=args.num_bottleneck,
592 | dropout=args.dropout,
593 | num_channel=train.num_channel // (args.lstm_window or 1),
594 | num_presnet=args.num_presnet,
595 | presnet_branch=args.presnet_branch,
596 | drop_presnet=args.drop_presnet,
597 | bng=args.bng,
598 | subject_confusion_loss_weight=args.subject_confusion_loss_weight,
599 | minibatch=args.minibatch,
600 | confuse_conv=args.confuse_conv,
601 | confuse_all=args.confuse_all,
602 | subject_wd=args.subject_wd,
603 | drop_branch=args.drop_branch,
604 | pool=args.pool,
605 | zscore=args.zscore,
606 | zscore_bng=args.zscore_bng,
607 | num_stream=2 if args.dual_stream else 1,
608 | lstm_last=args.lstm_last,
609 | lstm_dropout=args.lstm_dropout,
610 | lstm_shortcut=args.lstm_shortcut,
611 | lstm_bn=args.lstm_bn,
612 | lstm_window=args.lstm_window,
613 | lstm_grad_scale=args.lstm_grad_scale,
614 | drop_conv=args.drop_conv,
615 | drop_presnet_branch=args.drop_presnet_branch,
616 | drop_presnet_proj=args.drop_presnet_proj,
617 | presnet_proj_type=args.presnet_proj_type,
618 | presnet_promote=args.presnet_promote,
619 | pixel_reduce_loss_weight=args.pixel_reduce_loss_weight,
620 | pixel_reduce_bias=args.pixel_reduce_bias,
621 | pixel_reduce_kernel=args.pixel_reduce_kernel,
622 | pixel_reduce_stride=args.pixel_reduce_stride,
623 | pixel_reduce_pad=args.pixel_reduce_pad,
624 | pixel_reduce_norm=args.pixel_reduce_norm,
625 | pixel_reduce_reg_out=args.pixel_reduce_reg_out,
626 | num_pixel_reduce_filter=args.num_pixel_reduce_filter,
627 | fast_pixel_reduce=args.fast_pixel_reduce,
628 | drop_pixel=args.drop_pixel,
629 | num_conv=args.num_conv,
630 | presnet_dense=args.presnet_dense,
631 | conv_shortcut=args.conv_shortcut
632 | ),
633 | context=[mx.gpu(i) for i in args.gpu]
634 | )
635 | mod.fit(
636 | train_data=train,
637 | eval_data=val,
638 | num_epoch=args.num_epoch,
639 | num_train=train.num_sample,
640 | batch_size=args.batch_size,
641 | lr_step=args.lr_step,
642 | lr=args.lr,
643 | wd=args.wd,
644 | gamma=args.gamma,
645 | snapshot=args.snapshot,
646 | params=args.params,
647 | ignore_params=args.ignore_params,
648 | fix_params=args.fix_params,
649 | decay_all=args.decay_all
650 | )
651 |
652 |
653 | @cli.command()
654 | @click.option('--num-epoch', type=int, default=150)
655 | @click.option('--lr-step', type=int, default=50)
656 | @click.option('--batch-size', type=int, default=2000)
657 | @click.option('--lr', type=float, default=0.1)
658 | @click.option('--gpu', type=int, multiple=True, default=[0])
659 | @click.option('--log', type=click.Path())
660 | @click.option('--snapshot', type=click.Path())
661 | @click.option('--root', type=click.Path())
662 | @click.option('--adapt', is_flag=True)
663 | @click.option('--gamma', type=float, default=10)
664 | @click.option('--subject-loss-weight', type=float, default=0.1)
665 | def _general(
666 | num_epoch,
667 | lr_step,
668 | batch_size,
669 | lr,
670 | gpu,
671 | log,
672 | snapshot,
673 | root,
674 | adapt,
675 | gamma,
676 | subject_loss_weight
677 | ):
678 | if root:
679 | if log:
680 | log = os.path.join(root, log)
681 | if snapshot:
682 | snapshot = os.path.join(root, snapshot)
683 |
684 | with Context(log):
685 | logger.info('Args:\n{}', pformat(locals()))
686 | mod = Module(
687 | adapt=adapt,
688 | num_gesture=8,
689 | num_subject=10,
690 | subject_loss_weight=subject_loss_weight,
691 | context=[mx.gpu(i) for i in gpu]
692 | )
693 | train, val, num_train, _ = data_s21.get_general_data(
694 | '.cache/mat.s21.bandstop-45-55.s1000m.scale-01',
695 | batch_size=batch_size,
696 | adapt=adapt
697 | )
698 | logger.info('Train samples: {}', num_train)
699 | mod.fit(
700 | train_data=train,
701 | eval_data=val,
702 | num_epoch=num_epoch,
703 | num_train=num_train,
704 | batch_size=batch_size,
705 | lr_step=lr_step,
706 | lr=lr,
707 | gamma=gamma,
708 | snapshot=snapshot
709 | )
710 |
711 |
712 | @cli.command()
713 | def stats():
714 | click.echo(data_s21.get_stats())
715 |
716 |
717 | @cli.command()
718 | @click.option('--gpu', type=int, multiple=True, default=[0])
719 | @click.option('--fold', type=int, required=True)
720 | @click.option('--batch-size', type=int, default=2000)
721 | def coral(gpu, fold, batch_size):
722 | with Context():
723 | val = data_s21.get_inter_subject_val(fold=fold, batch_size=batch_size)
724 |
725 | mod = Module(
726 | num_gesture=8,
727 | coral=True,
728 | adabn=True,
729 | adabn_num_epoch=10,
730 | symbol_kargs=dict(
731 | num_filter=16,
732 | num_pixel=2,
733 | num_feature_block=2,
734 | num_gesture_block=0,
735 | num_hidden=512,
736 | num_bottleneck=128,
737 | dropout=0.5,
738 | num_channel=1
739 | ),
740 | context=[mx.gpu(i) for i in gpu]
741 | )
742 | mod.init_coral(
743 | '.cache/sigr-inter-adabn-%d-v403/model-0060.params' % fold,
744 | [data_s21.get_coral([i], batch_size) for i in range(10) if i != fold],
745 | data_s21.get_coral([fold], batch_size)
746 | )
747 | # mod.bind(data_shapes=val.provide_data, for_training=False)
748 | # mod.load_params('.cache/sigr-inter-%d-final/model-0060.params' % fold)
749 |
750 | metric = mx.metric.create('acc')
751 | mod.score(val, metric)
752 | logger.info('Fold {} accuracy: {}', fold, metric.get()[1])
753 |
754 |
755 | if __name__ == '__main__':
756 | cli(obj=Bunch())
757 |
--------------------------------------------------------------------------------
/sigr/base_module.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 |
3 |
4 | class Meta(type):
5 |
6 | impls = []
7 |
8 | def __init__(cls, name, bases, fields):
9 | type.__init__(cls, name, bases, fields)
10 | Meta.impls.append(cls)
11 |
12 |
13 | class BaseModule(object):
14 |
15 | __metaclass__ = Meta
16 |
17 | @classmethod
18 | def parse(cls, text, **kargs):
19 | if cls is BaseModule:
20 | for impl in Meta.impls:
21 | if impl is not BaseModule:
22 | inst = impl.parse(text, **kargs)
23 | if inst is not None:
24 | return inst
25 |
26 |
27 | __all__ = ['BaseModule']
28 |
--------------------------------------------------------------------------------
/sigr/constant.py:
--------------------------------------------------------------------------------
1 | NUM_LSTM_HIDDEN = 128
2 | NUM_LSTM_LAYER = 1
3 | LSTM_DROPOUT = 0.
4 | NUM_SEMG_ROW = 16
5 | NUM_SEMG_COL = 8
6 | NUM_SEMG_POINT = NUM_SEMG_ROW * NUM_SEMG_COL
7 | NUM_FILTER = 16
8 | NUM_HIDDEN = 512
9 | NUM_BOTTLENECK = 128
10 | DROPOUT = 0.5
11 | GAMMA = 10
12 | NUM_FEATURE_BLOCK = 2
13 | NUM_GESTURE_BLOCK = 0
14 | NUM_SUBJECT_BLOCK = 0
15 | NUM_PIXEL = 2
16 | LAMBDA_SCALE = 1
17 | NUM_TZENG_BATCH = 2
18 | NUM_ADABN_EPOCH = 1
19 | RANDOM_SHIFT_FILL = 'zero'
20 |
--------------------------------------------------------------------------------
/sigr/coral.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.linalg as splg
3 |
4 |
5 | def get_coral_params(ds, dt, lam=1e-3):
6 | ms = ds.mean(axis=0)
7 | ds = ds - ms
8 | mt = dt.mean(axis=0)
9 | dt = dt - mt
10 | cs = np.cov(ds.T) + lam * np.eye(ds.shape[1])
11 | ct = np.cov(dt.T) + lam * np.eye(dt.shape[1])
12 | sqrt = splg.sqrtm
13 | w = sqrt(ct).dot(np.linalg.inv(sqrt(cs)))
14 | b = mt - w.dot(ms.reshape(-1, 1)).ravel()
15 | return w, b
16 |
--------------------------------------------------------------------------------
/sigr/data/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import mxnet as mx
3 | import os
4 | import scipy.io as spio
5 | import numpy as np
6 | from collections import namedtuple, OrderedDict
7 | from logbook import Logger
8 | from nose.tools import assert_equal
9 | from functools import partial
10 | from itertools import product, izip
11 | from .. import utils, constant
12 |
13 |
14 | logger = Logger('data')
15 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False)
16 | Trial = namedtuple('Trial', ['data', 'gesture', 'subject'], verbose=False)
17 |
18 |
19 | def _register(impl):
20 | _register.impls.append(impl)
21 |
22 |
23 | _register.impls = []
24 |
25 |
26 | class Dataset(object):
27 |
28 | class __metaclass__(type):
29 |
30 | def __init__(cls, name, bases, fields):
31 | type.__init__(cls, name, bases, fields)
32 | _register(cls)
33 |
34 | @property
35 | def num_trial(self):
36 | return len(self.trials)
37 |
38 | @property
39 | def num_gesture(self):
40 | return len(self.gestures)
41 |
42 | @property
43 | def num_subject(self):
44 | return len(self.subjects)
45 |
46 | @classmethod
47 | def from_name(cls, name):
48 | if name == 's21':
49 | from . import s21
50 | return s21
51 | if name == 'csl':
52 | from . import csl
53 | return csl
54 | inst = cls.parse(name)
55 | assert inst is not None, 'Unknown dataset {}'.format(name)
56 | return inst
57 |
58 | @classmethod
59 | def parse(cls, text):
60 | if cls is Dataset:
61 | for impl in _register.impls:
62 | if impl is not Dataset:
63 | inst = impl.parse(text)
64 | if inst is not None:
65 | return inst
66 |
67 | def get_combos(self, *args):
68 | for arg in args:
69 | if isinstance(arg, tuple):
70 | arg = [arg]
71 | for a in arg:
72 | yield Combo(*a)
73 |
74 |
75 | class SingleSessionMixin(object):
76 |
77 | def get_one_fold_intra_subject_trials(self):
78 | return self.trials[::2], self.trials[1::2]
79 |
80 | def get_inter_subject_data(self, fold, batch_size, preprocess,
81 | adabn, minibatch, **kargs):
82 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
83 | load = partial(get_data,
84 | root=self.root,
85 | last_batch_handle='pad',
86 | get_trial=get_trial,
87 | batch_size=batch_size,
88 | num_semg_row=self.num_semg_row,
89 | num_semg_col=self.num_semg_col)
90 | subject = self.subjects[fold]
91 | train = load(
92 | combos=self.get_combos(product([i for i in self.subjects if i != subject],
93 | self.gestures, self.trials)),
94 | adabn=adabn,
95 | # mini_batch_size=batch_size // (self.num_subject - 1 if minibatch else 1),
96 | mini_batch_size=10 if minibatch else 1,
97 | shuffle=True)
98 | val = load(
99 | combos=self.get_combos(product([subject], self.gestures, self.trials)),
100 | shuffle=False)
101 | return train, val
102 |
103 | def get_inter_subject_val(self, fold, batch_size, preprocess=None, **kargs):
104 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
105 | load = partial(get_data,
106 | root=self.root,
107 | last_batch_handle='pad',
108 | get_trial=get_trial,
109 | batch_size=batch_size,
110 | num_semg_row=self.num_semg_row,
111 | num_semg_col=self.num_semg_col)
112 | subject = self.subjects[fold]
113 | val = load(
114 | combos=self.get_combos(product([subject], self.gestures, self.trials)),
115 | shuffle=False)
116 | return val
117 |
118 | def get_intra_subject_data(self, fold, batch_size, preprocess,
119 | adabn, minibatch, **kargs):
120 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
121 | load = partial(get_data,
122 | root=self.root,
123 | last_batch_handle='pad',
124 | get_trial=get_trial,
125 | batch_size=batch_size,
126 | num_semg_row=self.num_semg_row,
127 | num_semg_col=self.num_semg_col)
128 | subject = self.subjects[fold // self.num_trial]
129 | trial = self.trials[fold % self.num_trial]
130 | train = load(
131 | combos=self.get_combos(product([subject], self.gestures,
132 | [i for i in self.trials if i != trial])),
133 | adabn=adabn,
134 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1),
135 | mini_batch_size=10 if minibatch else 1,
136 | shuffle=True)
137 | val = load(
138 | combos=self.get_combos(product([subject], self.gestures, [trial])),
139 | shuffle=False)
140 | return train, val
141 |
142 | def get_intra_subject_val(self, fold, batch_size, preprocess=None, **kargs):
143 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
144 | load = partial(get_data,
145 | root=self.root,
146 | last_batch_handle='pad',
147 | get_trial=get_trial,
148 | batch_size=batch_size,
149 | num_semg_row=self.num_semg_row,
150 | num_semg_col=self.num_semg_col)
151 | subject = self.subjects[fold // self.num_trial]
152 | trial = self.trials[fold % self.num_trial]
153 | val = load(
154 | combos=self.get_combos(product([subject], self.gestures, [trial])),
155 | shuffle=False)
156 | return val
157 |
158 | def get_universal_intra_subject_data(self, fold, batch_size, preprocess,
159 | adabn, minibatch, **kargs):
160 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
161 | load = partial(get_data,
162 | root=self.root,
163 | last_batch_handle='pad',
164 | get_trial=get_trial,
165 | batch_size=batch_size,
166 | num_semg_row=self.num_semg_row,
167 | num_semg_col=self.num_semg_col)
168 | trial = self.trials[fold]
169 | train = load(
170 | combos=self.get_combos(product(self.subjects, self.gestures,
171 | [i for i in self.trials if i != trial])),
172 | adabn=adabn,
173 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1),
174 | mini_batch_size=10 if minibatch else 1,
175 | shuffle=True)
176 | val = load(
177 | combos=self.get_combos(product(self.subjects, self.gestures, [trial])),
178 | shuffle=False)
179 | return train, val
180 |
181 | def get_one_fold_intra_subject_val(self, fold, batch_size, preprocess=None, **kargs):
182 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
183 | load = partial(get_data,
184 | root=self.root,
185 | last_batch_handle='pad',
186 | get_trial=get_trial,
187 | batch_size=batch_size,
188 | num_semg_row=self.num_semg_row,
189 | num_semg_col=self.num_semg_col)
190 | subject = self.subjects[fold]
191 | _, val_trials = self.get_one_fold_intra_subject_trials()
192 | val = load(
193 | combos=self.get_combos(product([subject], self.gestures,
194 | [i for i in val_trials])),
195 | shuffle=False)
196 | return val
197 |
198 | def get_one_fold_intra_subject_data(self, fold, batch_size, preprocess,
199 | adabn, minibatch, **kargs):
200 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
201 | load = partial(get_data,
202 | root=self.root,
203 | last_batch_handle='pad',
204 | get_trial=get_trial,
205 | batch_size=batch_size,
206 | num_semg_row=self.num_semg_row,
207 | num_semg_col=self.num_semg_col)
208 | subject = self.subjects[fold]
209 | train_trials, val_trials = self.get_one_fold_intra_subject_trials()
210 | train = load(
211 | combos=self.get_combos(product([subject], self.gestures,
212 | [i for i in train_trials])),
213 | adabn=adabn,
214 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1),
215 | mini_batch_size=10 if minibatch else 1,
216 | shuffle=True)
217 | val = load(
218 | combos=self.get_combos(product([subject], self.gestures,
219 | [i for i in val_trials])),
220 | shuffle=False)
221 | return train, val
222 |
223 | def get_universal_one_fold_intra_subject_data(self, fold, batch_size, preprocess,
224 | adabn, minibatch, **kargs):
225 | assert_equal(fold, 0)
226 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
227 | load = partial(get_data,
228 | root=self.root,
229 | last_batch_handle='pad',
230 | get_trial=get_trial,
231 | batch_size=batch_size,
232 | num_semg_row=self.num_semg_row,
233 | num_semg_col=self.num_semg_col)
234 | train_trials, val_trials = self.get_one_fold_intra_subject_trials()
235 | train = load(
236 | combos=self.get_combos(product(self.subjects, self.gestures,
237 | [i for i in train_trials])),
238 | adabn=adabn,
239 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1),
240 | mini_batch_size=10 if minibatch else 1,
241 | shuffle=True)
242 | val = load(
243 | combos=self.get_combos(product(self.subjects, self.gestures,
244 | [i for i in val_trials])),
245 | shuffle=False)
246 | return train, val
247 |
248 |
249 | def get_index(a):
250 | '''Convert label to 0 based index'''
251 | b = list(set(a))
252 | return np.array([x if x < 0 else b.index(x) for x in a.ravel()]).reshape(a.shape)
253 |
254 |
255 | def get_path(root, combo):
256 | return os.path.join(
257 | root,
258 | '{0.subject:03d}',
259 | '{0.gesture:03d}',
260 | '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}.mat'
261 | ).format(combo)
262 |
263 |
264 | def label_to_gesture(label):
265 | '''Convert maxforce to -1'''
266 | return label if label < 100 else -1
267 |
268 |
269 | def _get_trial(root, combo):
270 | path = get_path(root, combo)
271 | mat = spio.loadmat(path)
272 | data = mat['data'].astype(np.float32)
273 | gesture = np.repeat(label_to_gesture(np.asscalar(mat['label'].astype(np.int))), len(data))
274 | subject = np.repeat(np.asscalar(mat['subject'].astype(np.int)), len(data))
275 | return Trial(data=data, gesture=gesture, subject=subject)
276 |
277 |
278 | def get_data(
279 | root,
280 | combos,
281 | num_semg_row,
282 | num_semg_col,
283 | mean=None,
284 | scale=None,
285 | with_subject=False,
286 | target_combos=None,
287 | target_binary=False,
288 | with_target_gesture=False,
289 | min_size=None,
290 | random_scale=False,
291 | random_bad_channel=[],
292 | shuffle=True,
293 | adabn=False,
294 | window=1,
295 | soft_label=False,
296 | fft=False,
297 | fft_append=False,
298 | dual_stream=False,
299 | num_ignore_per_segment=0,
300 | dense_window=True,
301 | faug=False,
302 | get_trial=None,
303 | balance_gesture=0,
304 | **kargs
305 | ):
306 | '''Get mxnet data iter'''
307 | if os.path.isdir(os.path.join(root, 'data')):
308 | root = os.path.join(root, 'data')
309 |
310 | combos = list(combos)
311 | if target_combos is not None:
312 | target_combos = list(target_combos)
313 |
314 | if get_trial is None:
315 | get_trial = _get_trial
316 |
317 | def try_scale(data):
318 | if mean is not None:
319 | data = data - mean
320 | if scale is not None:
321 | data = data * scale
322 | return data
323 |
324 | data = []
325 | gesture = []
326 | subject = []
327 | segment = []
328 |
329 | for combo in combos:
330 | trial = get_trial(root=root, combo=combo)
331 | data.append(try_scale(trial.data))
332 | gesture.append(trial.gesture)
333 | subject.append(np.repeat(0, len(data[-1])) if target_binary else trial.subject)
334 | segment.append(np.repeat(len(segment), len(data[-1])))
335 |
336 | if target_combos:
337 | for combo in target_combos:
338 | trial = get_trial(root=root, combo=combo)
339 | data.append(try_scale(data))
340 | gesture.append(trial.gesture)
341 | subject.append(np.repeat(1, len(data[-1])) if target_binary else trial.subject)
342 | segment.append(np.repeat(len(segment), len(data[-1])))
343 |
344 | # if window > 1:
345 | # data = [get_segments(seg, window) for seg in data]
346 | # gesture = [seg[window - 1:] for seg in gesture]
347 | # subject = [seg[window - 1:] for seg in subject]
348 | # for t in zip(data, gesture, subject):
349 | # for lhs, rhs in zip(t[:-1], t[1:]):
350 | # assert len(lhs) == len(rhs)
351 |
352 | logger.debug('MAT loaded')
353 |
354 | if not data:
355 | logger.warn('Empty data')
356 | return
357 |
358 | index = []
359 | n = 0
360 | for seg in data:
361 | if dense_window:
362 | index.append(np.arange(n, n + len(seg) - window + 1 - num_ignore_per_segment))
363 | else:
364 | index.append(np.arange(n, n + len(seg) - window + 1 - num_ignore_per_segment, window))
365 | # Pad with the last value
366 | # index.append(np.repeat(n + len(seg) - window, window - 1))
367 | n += len(seg)
368 | index = np.hstack(index)
369 | logger.debug('Index made')
370 |
371 | logger.debug('Segments: {}', len(data))
372 | logger.debug('First segment shape: {}', data[0].shape)
373 | data = np.vstack(data).reshape(-1, 1, num_semg_row, num_semg_col)
374 | logger.debug('Data stacked')
375 | if min_size is not None:
376 | h = (min_size - num_semg_row) // 2
377 | w = (min_size - num_semg_col) // 2
378 | data = np.pad(
379 | data,
380 | ((0, 0), (0, 0), (h, h), (w, w)),
381 | 'constant',
382 | constant_values=0
383 | )
384 |
385 | # data = np.tile(data, (1, 3, 1, 1))
386 | gesture = get_index(np.hstack(gesture))
387 | subject_orig = np.hstack(subject)
388 | subject = get_index(subject_orig)
389 | segment = np.hstack(segment)
390 |
391 | label = []
392 |
393 | if soft_label is not False:
394 | label.append(('gesture_softmax_label', gesture))
395 | label.append(('soft_label', soft_label[gesture]))
396 | else:
397 | label.append(('gesture_softmax_label', gesture))
398 |
399 | if with_subject:
400 | label.append(('subject_softmax_label', subject))
401 | # for i in range(gesture.max() + 1):
402 | # subset = subject.copy()
403 | # subset[gesture != i] = -1
404 | # label.append(('gesture%d_subject_softmax_label' % i, subset))
405 |
406 | if with_target_gesture:
407 | if target_combos is not None:
408 | mask = np.in1d(subject_orig, list(set({combo.subject for combo in target_combos})))
409 | target_gesture = gesture.copy()
410 | target_gesture[~mask, ...] = -1
411 | label.append(('target_gesture_softmax_label', target_gesture))
412 | else:
413 | label.append(('target_gesture_softmax_label', gesture))
414 |
415 | logger.debug('Make data iter')
416 |
417 | # important, use OrderedDict to ensure label order
418 | data = Data(
419 | data=OrderedDict([('data', data)]),
420 | label=OrderedDict(label),
421 | shuffle=shuffle,
422 | adabn=adabn,
423 | gesture=gesture.copy(),
424 | subject=subject.copy(),
425 | segment=segment.copy(),
426 | window=window,
427 | index=index,
428 | random_scale=random_scale,
429 | random_bad_channel=random_bad_channel,
430 | # num_sample=len(index),
431 | num_gesture=gesture.max() + 1,
432 | num_subject=subject.max() + 1,
433 | fft=fft,
434 | fft_append=fft_append,
435 | dual_stream=dual_stream,
436 | dense_window=dense_window,
437 | faug=faug,
438 | balance_gesture=balance_gesture,
439 | **kargs
440 | )
441 | if not fft:
442 | data = Preload(data)
443 | return data
444 |
445 |
446 | class Preload(mx.io.PrefetchingIter):
447 |
448 | def __getattr__(self, name):
449 | if name != 'iters' and hasattr(self, 'iters') and hasattr(self.iters[0], name):
450 | return getattr(self.iters[0], name)
451 | raise AttributeError(name)
452 |
453 | def __setattr__(self, name, value):
454 | if name in ('shuffle', 'downsample', 'last_batch_handle'):
455 | return setattr(self.iters[0], name, value)
456 | return super(Preload, self).__setattr__(name, value)
457 |
458 | def iter_next(self):
459 | for e in self.data_ready:
460 | e.wait()
461 | if self.next_batch[0] is None:
462 | # for i in self.next_batch:
463 | # assert i is None, "Number of entry mismatches between iterators"
464 | return False
465 | else:
466 | # for batch in self.next_batch:
467 | # assert batch.pad == self.next_batch[0].pad, "Number of entry mismatches between iterators"
468 | self.current_batch = mx.io.DataBatch(sum([batch.data for batch in self.next_batch], []),
469 | sum([batch.label for batch in self.next_batch], []),
470 | self.next_batch[0].pad,
471 | self.next_batch[0].index)
472 | for e in self.data_ready:
473 | e.clear()
474 | for e in self.data_taken:
475 | e.set()
476 | return True
477 |
478 |
479 | class FaugData(mx.io.DataIter):
480 |
481 | def __init__(self, faug, batch_size, num_feature):
482 | super(FaugData, self).__init__()
483 | self.faug = faug
484 | self.batch_size = batch_size
485 | self.num_feature = num_feature
486 |
487 | @property
488 | def provide_data(self):
489 | return [('faug', (self.batch_size, self.num_feature))]
490 |
491 | @property
492 | def provide_label(self):
493 | return []
494 |
495 | def iter_next(self):
496 | return True
497 |
498 | def getdata(self):
499 | if self.faug:
500 | return [mx.nd.array(self.faug * np.random.randn(self.batch_size, self.num_feature))]
501 | else:
502 | return [mx.nd.array(np.zeros((self.batch_size, self.num_feature)))]
503 |
504 | def getlabel(self):
505 | return []
506 |
507 |
508 | class Data(mx.io.NDArrayIter):
509 |
510 | def __init__(self, *args, **kargs):
511 | self.random_shift_vertical = kargs.pop('random_shift_vertical', 0)
512 | self.random_shift_horizontal = kargs.pop('random_shift_horizontal', 0)
513 | self.random_shift_fill = kargs.pop('random_shift_fill', constant.RANDOM_SHIFT_FILL)
514 | self.framerate = kargs.pop('framerate', 1000)
515 | self.amplitude_weighting = kargs.pop('amplitude_weighting', False)
516 | self.amplitude_weighting_sort = kargs.pop('amplitude_weighting_sort', False)
517 | self.downsample = kargs.pop('downsample', None)
518 | self.dense_window = kargs.pop('dense_window')
519 | self.random_scale = kargs.pop('random_scale')
520 | self.random_bad_channel = kargs.pop('random_bad_channel')
521 | self.shuffle = kargs.pop('shuffle', False)
522 | self.adabn = kargs.pop('adabn', False)
523 | self._gesture = kargs.pop('gesture')
524 | self._subject = kargs.pop('subject')
525 | self._segment = kargs.pop('segment')
526 | self.window = kargs.pop('window')
527 | self._index_orig = kargs.pop('index')
528 | self._index = np.copy(self._index_orig)
529 | # self.num_sample = kargs.pop('num_sample')
530 | self.num_gesture = kargs.pop('num_gesture')
531 | self.num_subject = kargs.pop('num_subject')
532 | self.mini_batch_size = kargs.pop('mini_batch_size', kargs.get('batch_size'))
533 | self.random_state = kargs.pop('random_state', np.random)
534 | self.fft = kargs.pop('fft', False)
535 | self.fft_append = kargs.pop('fft_append', False)
536 | self.dual_stream = kargs.pop('dual_stream', False)
537 | self.faug = kargs.pop('faug', False)
538 | self.balance_gesture = kargs.pop('balance_gesture', 0)
539 | if not self.dual_stream:
540 | self.num_channel = self.window if not self.fft else self.window // 2 + (self.window if self.fft_append else 0)
541 | else:
542 | assert self.fft and not self.fft_append
543 | self.num_channel = [self.window, self.window // 2]
544 |
545 | super(Data, self).__init__(*args, **kargs)
546 |
547 | self.data = [(k, self._asnumpy(v)) for k, v in self.data]
548 | self.label = [(k, self._asnumpy(v)) for k, v in self.label]
549 | self.num_data = len(self._index)
550 | self.data_orig = self.data
551 | self.reset()
552 | # self.num_data = len(self._index)
553 |
554 | def _asnumpy(self, a):
555 | return a if not isinstance(a, mx.nd.NDArray) else a.asnumpy()
556 |
557 | @property
558 | def num_sample(self):
559 | return self.num_data
560 |
561 | @property
562 | def gesture(self):
563 | return self._gesture[self._index]
564 |
565 | @property
566 | def subject(self):
567 | return self._subject[self._index]
568 |
569 | @property
570 | def segment(self):
571 | return self._segment[self._index]
572 |
573 | @property
574 | def provide_data(self):
575 | if not self.dual_stream:
576 | res = [(k, tuple([self.batch_size, self.num_channel] + list(v.shape[2:]))) for k, v in self.data]
577 | else:
578 | assert_equal(len(self.data), 1)
579 | res = [('stream%d_' % i + self.data[0][0], tuple([self.batch_size, ch] + list(self.data[0][1].shape[2:])))
580 | for i, ch in enumerate(self.num_channel)]
581 | if self.faug:
582 | res += [('faug', (self.batch_size, 16))]
583 | return res
584 |
585 | def _expand_index(self, index):
586 | return np.hstack([np.arange(i, i + self.window) for i in index])
587 |
588 | def _reshape_data(self, data):
589 | return data.reshape(-1, self.window, *data.shape[2:])
590 |
591 | def _get_fft(self, data):
592 | from .. import Context
593 | import joblib as jb
594 | res = []
595 | for amp in Context.parallel(jb.delayed(_get_fft_aux)(sample, self.fft_append) for sample in data):
596 | res.append(amp[np.newaxis, ...])
597 | return np.concatenate(res, axis=0)
598 |
599 | def _get_segments(self, a, index):
600 | b = mx.nd.empty((len(index), self.window) + a.shape[2:], dtype=a.dtype)
601 | for i, j in enumerate(index):
602 | b[i] = a[j:j + self.window].reshape(self.window, *a.shape[2:])
603 | return b
604 |
605 | def _getdata(self, data_source):
606 | """Load data from underlying arrays, internal use only"""
607 | assert(self.cursor < self.num_data), "DataIter needs reset."
608 |
609 | if data_source is self.data and self.window > 1:
610 | if self.cursor + self.batch_size <= self.num_data:
611 | # res = [self._reshape_data(x[1][self._expand_index(self._index[self.cursor:self.cursor+self.batch_size])]) for x in data_source]
612 | res = [self._get_segments(x[1], self._index[self.cursor:self.cursor+self.batch_size]) for x in data_source]
613 | else:
614 | pad = self.batch_size - self.num_data + self.cursor
615 | res = [(np.concatenate((self._reshape_data(x[1][self._expand_index(self._index[self.cursor:])]),
616 | self._reshape_data(x[1][self._expand_index(self._index[:pad])])), axis=0)) for x in data_source]
617 | else:
618 | if self.cursor + self.batch_size <= self.num_data:
619 | res = [(x[1][self._index[self.cursor:self.cursor+self.batch_size]]) for x in data_source]
620 | else:
621 | pad = self.batch_size - self.num_data + self.cursor
622 | res = [(np.concatenate((x[1][self._index[self.cursor:]], x[1][self._index[:pad]]), axis=0)) for x in data_source]
623 |
624 | # if data_source is self.data:
625 | # for a in res:
626 | # assert np.all(np.isfinite(a)) and not np.all(a == 0)
627 |
628 | if data_source is self.data and self.fft:
629 | if not self.dual_stream:
630 | res = [self._get_fft(a.asnumpy() if isinstance(a, mx.nd.NDArray) else a) for a in res]
631 | else:
632 | res = res + [self._get_fft(a.asnumpy() if isinstance(a, mx.nd.NDArray) else a) for a in res]
633 | assert_equal(len(res), 2)
634 |
635 | if data_source is self.data and self.faug:
636 | res += [self.faug * self.random_state.randn(self.batch_size, 16)]
637 |
638 | res = [a if isinstance(a, mx.nd.NDArray) else mx.nd.array(a) for a in res]
639 | return res
640 |
641 | def _rand(self, smin, smax, shape):
642 | return (smax - smin) * self.random_state.rand(*shape) + smin
643 |
644 | def _do_shuffle(self):
645 | if not self.adabn or len(set(self._subject)) == 1:
646 | self.random_state.shuffle(self._index)
647 | else:
648 | batch_size = self.mini_batch_size
649 | # batch_size = self.batch_size
650 | # logger.info('AdaBN shuffle with a mini batch size of {}', batch_size)
651 | self.random_state.shuffle(self._index)
652 | subject_shuffled = self._subject[self._index]
653 | index_batch = []
654 | for i in sorted(set(self._subject)):
655 | index = self._index[subject_shuffled == i]
656 | index = index[:len(index) // batch_size * batch_size]
657 | index_batch.append(index.reshape(-1, batch_size))
658 | index_batch = np.vstack(index_batch)
659 | index = np.arange(len(index_batch))
660 | self.random_state.shuffle(index)
661 | self._index = index_batch[index, :].ravel()
662 | # assert len(self._index) == len(set(self._index))
663 |
664 | for i in range(0, len(self._subject), batch_size):
665 | # Make sure that the samples in one batch are from the same subject
666 | assert np.all(self._subject[self._index[i:i + batch_size - 1]] ==
667 | self._subject[self._index[i + 1:i + batch_size]])
668 |
669 | if batch_size != self.batch_size:
670 | assert self.batch_size % batch_size == 0
671 | # assert (self.batch_size // batch_size) % self.num_subject == 0
672 | self._index = self._index[:len(self._index) // self.batch_size * self.batch_size].reshape(
673 | -1, self.batch_size // batch_size, batch_size).transpose(0, 2, 1).ravel()
674 |
675 | def reset(self):
676 | self._reset()
677 | super(Data, self).reset()
678 |
679 | def _reset(self):
680 | # self._index.sort()
681 | self._index = np.copy(self._index_orig)
682 |
683 | if self.amplitude_weighting:
684 | assert np.all(self._index[:-1] < self._index[1:])
685 | if not hasattr(self, 'amplitude_weight'):
686 | self.amplitude_weight = get_amplitude_weight(
687 | self.data[0][1], self._segment, self.framerate)
688 | if self.shuffle:
689 | random_state = self.random_state
690 | else:
691 | random_state = np.random.RandomState(677)
692 | self._index = random_state.choice(
693 | self._index, len(self._index), p=self.amplitude_weight)
694 | if self.amplitude_weighting_sort:
695 | logger.debug('Amplitude weighting sort')
696 | self._index.sort()
697 |
698 | if self.downsample:
699 | samples = np.arange(len(self._index))
700 | np.random.RandomState(667).shuffle(samples)
701 | assert self.downsample > 0 and self.downsample <= 1
702 | samples = samples[:int(np.round(len(samples) * self.downsample))]
703 | assert len(samples) > 0
704 | self._index = self._index[samples]
705 |
706 | if self.balance_gesture:
707 | num_sample_per_gesture = int(np.round(self.balance_gesture *
708 | len(self._index) / self.num_gesture))
709 | choice = []
710 | for gesture in set(self.gesture):
711 | mask = self._gesture[self._index] == gesture
712 | choice.append(self.random_state.choice(np.where(mask)[0],
713 | num_sample_per_gesture))
714 | choice = np.hstack(choice)
715 | self._index = self._index[choice]
716 |
717 | if self.shuffle:
718 | self._do_shuffle()
719 |
720 | if self.random_shift_horizontal or self.random_shift_vertical or self.random_scale or self.random_bad_channel:
721 | data = [(k, a.copy()) for k, a in self.data_orig]
722 | if self.random_shift_horizontal or self.random_shift_vertical:
723 | logger.info('shift {} {} {}',
724 | self.random_shift_fill,
725 | self.random_shift_horizontal,
726 | self.random_shift_vertical)
727 | hss = self.random_state.choice(1 + 2 * self.random_shift_horizontal,
728 | len(data[0][1])) - self.random_shift_horizontal
729 | vss = self.random_state.choice(1 + 2 * self.random_shift_vertical,
730 | len(data[0][1])) - self.random_shift_vertical
731 | # data = [(k, np.array([np.roll(row, s, axis=1) for row, s in izip(a, shift)]))
732 | # for k, a in data]
733 | data = [(k, np.array([_shift(row, hs, vs, self.random_shift_fill)
734 | for row, hs, vs in izip(a, hss, vss)]))
735 | for k, a in data]
736 | if self.random_scale:
737 | s = self.random_scale
738 | ss = s / 4
739 | data = [
740 | (k, a * 2 ** (self._rand(-s, s, (a.shape[0], 1, 1, 1)) + self._rand(-ss, ss, a.shape)))
741 | for k, a in data
742 | ]
743 | if self.random_bad_channel:
744 | mask = self.random_state.choice(2, len(data[0][1])) > 0
745 | if mask.sum():
746 | ch = self.random_state.choice(np.prod(data[0][1].shape[2:]), mask.sum())
747 | row = ch // data[0][1].shape[3]
748 | col = ch % data[0][1].shape[3]
749 | val = self.random_state.choice(self.random_bad_channel, mask.sum())
750 | val = np.tile(val.reshape(-1, 1), (1, data[0][1].shape[1]))
751 | for k, a in data:
752 | a[mask, :, row, col] = val
753 | self.data = data
754 |
755 | self.num_data = len(self._index)
756 |
757 |
758 | def _shift(a, hs, vs, fill):
759 | if fill == 'zero':
760 | b = np.zeros(a.shape, dtype=a.dtype)
761 | elif fill == 'margin':
762 | b = np.empty(a.shape, dtype=a.dtype)
763 | else:
764 | assert False, 'Known fill type: {}'.format(fill)
765 |
766 | s = a.shape
767 | if hs < 0:
768 | shb, she = -hs, s[2]
769 | thb, the = 0, s[2] + hs
770 | else:
771 | shb, she = 0, s[2] - hs
772 | thb, the = hs, s[2]
773 | if vs < 0:
774 | svb, sve = -vs, s[1]
775 | tvb, tve = 0, s[1] + vs
776 | else:
777 | svb, sve = 0, s[1] - vs
778 | tvb, tve = vs, s[1]
779 | b[:, tvb:tve, thb:the] = a[:, svb:sve, shb:she]
780 |
781 | if fill == 'margin':
782 | # Corners
783 | b[:, :tvb, :thb] = b[:, tvb, thb]
784 | b[:, tve:, :thb] = b[:, tve - 1, thb]
785 | b[:, tve:, the:] = b[:, tve - 1, the - 1]
786 | b[:, :tvb, the:] = b[:, tvb, the - 1]
787 | # Borders
788 | b[:, :tvb, thb:the] = b[:, tvb:tvb + 1, thb:the]
789 | b[:, tvb:tve, :thb] = b[:, tvb:tve, thb:thb + 1]
790 | b[:, tve:, thb:the] = b[:, tve - 1:tve, thb:the]
791 | b[:, tvb:tve, the:] = b[:, tvb:tve, the - 1:the]
792 |
793 | return b
794 |
795 |
796 | def _get_fft_aux(data, append):
797 | from ..fft import fft
798 | _, amp = fft(data.reshape(data.shape[0], -1).transpose(), 1000)
799 | amp = amp.transpose().reshape(-1, *data.shape[1:])
800 | return amp if not append else np.concatenate([data, amp], axis=0)
801 |
802 |
803 | def get_amplitude_weight(data, segment, framerate):
804 | from .. import Context
805 | import joblib as jb
806 | indices = [np.where(segment == i)[0] for i in set(segment)]
807 | w = np.empty(len(segment), dtype=np.float)
808 | for i, ret in zip(
809 | indices,
810 | Context.parallel(jb.delayed(get_amplitude_weight_aux)(data[i], framerate)
811 | for i in indices)
812 | ):
813 | w[i] = ret
814 | return w / max(w.sum(), 1e-8)
815 |
816 |
817 | def get_amplitude_weight_aux(data, framerate):
818 | return _get_amplitude_weight_aux(data, framerate)
819 |
820 |
821 | @utils.cached
822 | def _get_amplitude_weight_aux(data, framerate):
823 | # High-Density Electromyography and Motor Skill Learning for Robust Long-Term Control of a 7-DoF Robot Arm
824 | lowpass = utils.butter_lowpass_filter
825 | shape = data.shape
826 | data = np.abs(data.reshape(shape[0], -1))
827 | data = np.transpose([lowpass(ch, 3, framerate, 4, zero_phase=True) for ch in data.T])
828 | data = data.mean(axis=1)
829 | data -= data.min()
830 | data /= max(data.max(), 1e-8)
831 | return data
832 |
833 |
834 | from .preprocess import Preprocess
835 | from . import capgmyo, ninapro
836 | assert capgmyo, ninapro
837 |
838 |
839 | __all__ = ['Dataset', 'Preprocess', 'get_data']
840 |
--------------------------------------------------------------------------------
/sigr/data/capgmyo/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | from itertools import product
4 | import numpy as np
5 | import scipy.io as sio
6 | from logbook import Logger
7 | from ... import utils, CACHE
8 | from .. import Dataset as Base, Combo, Trial, SingleSessionMixin
9 |
10 |
11 | TRIALS = list(range(1, 11))
12 | NUM_TRIAL = len(TRIALS)
13 | NUM_SEMG_ROW = 16
14 | NUM_SEMG_COL = 8
15 | FRAMERATE = 1000
16 | PREPROCESS_KARGS = dict(
17 | framerate=FRAMERATE,
18 | num_semg_row=NUM_SEMG_ROW,
19 | num_semg_col=NUM_SEMG_COL
20 | )
21 |
22 | logger = Logger(__name__)
23 |
24 |
25 | class GetTrial(object):
26 |
27 | def __init__(self, gestures, trials, preprocess=None):
28 | self.preprocess = preprocess
29 | self.memo = {}
30 | self.gesture_and_trials = list(product(gestures, trials))
31 |
32 | def get_path(self, root, combo):
33 | return os.path.join(
34 | root,
35 | '{c.subject:03d}-{c.gesture:03d}-{c.trial:03d}.mat'.format(c=combo))
36 |
37 | def __call__(self, root, combo):
38 | path = self.get_path(root, combo)
39 | if path not in self.memo:
40 | logger.debug('Load subject {}', combo.subject)
41 | paths = [self.get_path(root, Combo(combo.subject, gesture, trial))
42 | for gesture, trial in self.gesture_and_trials]
43 | self.memo.update({path: data for path, data in
44 | zip(paths, _get_data(paths, self.preprocess))})
45 | data = self.memo[path]
46 | data = data.copy()
47 | gesture = np.repeat(combo.gesture, len(data))
48 | subject = np.repeat(combo.subject, len(data))
49 | return Trial(data=data, gesture=gesture, subject=subject)
50 |
51 |
52 | @utils.cached
53 | def _get_data(paths, preprocess):
54 | # return list(Context.parallel(
55 | # jb.delayed(_get_data_aux)(path, preprocess) for path in paths))
56 | return [_get_data_aux(path, preprocess) for path in paths]
57 |
58 |
59 | def _get_data_aux(path, preprocess):
60 | data = sio.loadmat(path)['data'].astype(np.float32)
61 | if preprocess:
62 | data = preprocess(data, **PREPROCESS_KARGS)
63 | return data
64 |
65 |
66 | class Dataset(SingleSessionMixin, Base):
67 |
68 | framerate = FRAMERATE
69 | num_semg_row = NUM_SEMG_ROW
70 | num_semg_col = NUM_SEMG_COL
71 | trials = TRIALS
72 |
73 | def __init__(self, root):
74 | self.root = root
75 |
76 | def get_trial_func(self, *args, **kargs):
77 | return GetTrial(*args, **kargs)
78 |
79 | @classmethod
80 | def parse(cls, text):
81 | if cls is not Dataset and text == cls.name:
82 | return cls(root=os.path.join(CACHE, cls.name.split('/')[0]))
83 |
84 |
85 | from . import dba, dbb, dbc
86 | assert dba and dbb and dbc
87 |
--------------------------------------------------------------------------------
/sigr/data/capgmyo/dba.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'dba'
8 | subjects = list(range(1, 19))
9 | gestures = list(range(1, 9))
10 |
--------------------------------------------------------------------------------
/sigr/data/capgmyo/dbb.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from functools import partial
3 | from itertools import product
4 | from logbook import Logger
5 | from . import Dataset as Base
6 | from .. import get_data
7 | from ... import constant
8 |
9 |
10 | logger = Logger(__name__)
11 |
12 |
13 | class Dataset(Base):
14 |
15 | name = 'dbb'
16 | subjects = list(range(2, 21, 2))
17 | gestures = list(range(1, 9))
18 | num_session = 2
19 | sessions = [1, 2]
20 |
21 | def get_universal_inter_session_data(self, fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs):
22 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
23 | load = partial(get_data,
24 | framerate=self.framerate,
25 | root=self.root,
26 | last_batch_handle='pad',
27 | get_trial=get_trial,
28 | batch_size=batch_size,
29 | num_semg_row=self.num_semg_row,
30 | num_semg_col=self.num_semg_col)
31 | session = fold + 1
32 | subjects = list(range(1, 11))
33 | num_subject = 10
34 | train = load(combos=self.get_combos(product([self.encode_subject_and_session(s, i) for s, i in
35 | product(subjects, [i for i in self.sessions if i != session])],
36 | self.gestures, self.trials)),
37 | adabn=adabn,
38 | mini_batch_size=batch_size // (num_subject * (self.num_session - 1) if minibatch else 1),
39 | balance_gesture=balance_gesture,
40 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
41 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
42 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
43 | shuffle=True)
44 | logger.debug('Training set loaded')
45 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(s, session) for s in subjects],
46 | self.gestures, self.trials)),
47 | adabn=adabn,
48 | mini_batch_size=batch_size // (num_subject if minibatch else 1),
49 | shuffle=False)
50 | logger.debug('Test set loaded')
51 | return train, val
52 |
53 | def get_inter_session_data(self, fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs):
54 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
55 | load = partial(get_data,
56 | framerate=self.framerate,
57 | root=self.root,
58 | last_batch_handle='pad',
59 | get_trial=get_trial,
60 | batch_size=batch_size,
61 | num_semg_row=self.num_semg_row,
62 | num_semg_col=self.num_semg_col)
63 | subject = fold // self.num_session + 1
64 | session = fold % self.num_session + 1
65 | train = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, i) for i in self.sessions if i != session],
66 | self.gestures, self.trials)),
67 | adabn=adabn,
68 | mini_batch_size=batch_size // (self.num_session - 1 if minibatch else 1),
69 | balance_gesture=balance_gesture,
70 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
71 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
72 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
73 | shuffle=True)
74 | logger.debug('Training set loaded')
75 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, session)],
76 | self.gestures, self.trials)),
77 | shuffle=False)
78 | logger.debug('Test set loaded')
79 | return train, val
80 |
81 | def get_inter_session_val(self, fold, batch_size, preprocess=None, **kargs):
82 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess)
83 | load = partial(get_data,
84 | framerate=self.framerate,
85 | root=self.root,
86 | last_batch_handle='pad',
87 | get_trial=get_trial,
88 | batch_size=batch_size,
89 | num_semg_row=self.num_semg_row,
90 | num_semg_col=self.num_semg_col)
91 | subject = fold // self.num_session + 1
92 | session = fold % self.num_session + 1
93 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, session)],
94 | self.gestures, self.trials)),
95 | shuffle=False)
96 | logger.debug('Test set loaded')
97 | return val
98 |
99 | def encode_subject_and_session(self, subject, session):
100 | return (subject - 1) * self.num_session + session
101 |
--------------------------------------------------------------------------------
/sigr/data/capgmyo/dbc.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'dbc'
8 | subjects = list(range(1, 11))
9 | gestures = list(range(1, 13))
10 |
--------------------------------------------------------------------------------
/sigr/data/csl.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import get_data, Combo, Trial
3 | from .. import ROOT, Context, constant
4 | import os
5 | from itertools import product
6 | from functools import partial
7 | import scipy.io as sio
8 | import numpy as np
9 | from logbook import Logger
10 | import joblib as jb
11 | from ..utils import cached
12 | from nose.tools import assert_is_not_none
13 |
14 |
15 | ROOT = os.path.join(ROOT, '.cache/csl')
16 | NUM_TRIAL = 10
17 | SUBJECTS = list(range(1, 6))
18 | SESSIONS = list(range(1, 6))
19 | NUM_SESSION = len(SESSIONS)
20 | NUM_SUBJECT = len(SUBJECTS)
21 | NUM_SUBJECT_AND_SESSION = len(SUBJECTS) * NUM_SESSION
22 | SUBJECT_AND_SESSIONS = list(range(1, NUM_SUBJECT_AND_SESSION + 1))
23 | GESTURES = list(range(27))
24 | REST_TRIALS = [x - 1 for x in [2, 4, 7, 8, 11, 13, 19, 25, 26, 30]]
25 | NUM_SEMG_ROW = 24
26 | NUM_SEMG_COL = 7
27 | FRAMERATE = 2048
28 | framerate = FRAMERATE
29 | TRIALS = list(range(NUM_TRIAL))
30 | PREPROCESS_KARGS = dict(
31 | framerate=FRAMERATE,
32 | num_semg_row=NUM_SEMG_ROW,
33 | num_semg_col=NUM_SEMG_COL
34 | )
35 |
36 | logger = Logger('csl')
37 |
38 |
39 | def get_general_data(batch_size, adabn, minibatch, downsample, **kargs):
40 | get_trial = GetTrial(downsample=downsample)
41 | load = partial(get_data,
42 | framerate=FRAMERATE,
43 | root=ROOT,
44 | last_batch_handle='pad',
45 | get_trial=get_trial,
46 | batch_size=batch_size,
47 | num_semg_row=NUM_SEMG_ROW,
48 | num_semg_col=NUM_SEMG_COL)
49 | train = load(combos=get_combos(product(SUBJECT_AND_SESSIONS, GESTURES[1:], range(0, NUM_TRIAL, 2)),
50 | product(SUBJECT_AND_SESSIONS, GESTURES[:1], REST_TRIALS[0::2])),
51 | adabn=adabn,
52 | shuffle=True,
53 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
54 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
55 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
56 | mini_batch_size=batch_size // (NUM_SUBJECT_AND_SESSION if minibatch else 1))
57 | logger.debug('Training set loaded')
58 | val = load(combos=get_combos(product(SUBJECT_AND_SESSIONS, GESTURES[1:], range(1, NUM_TRIAL, 2)),
59 | product(SUBJECT_AND_SESSIONS, GESTURES[:1], REST_TRIALS[1::2])),
60 | shuffle=False)
61 | logger.debug('Test set loaded')
62 | return train, val
63 |
64 |
65 | def get_intra_session_val(fold, batch_size, preprocess, **kargs):
66 | get_trial = GetTrial(preprocess=preprocess)
67 | load = partial(get_data,
68 | amplitude_weighting=kargs.get('amplitude_weighting', False),
69 | amplitude_weighting_sort=kargs.get('amplitude_weighting_sort', False),
70 | framerate=FRAMERATE,
71 | root=ROOT,
72 | last_batch_handle='pad',
73 | get_trial=get_trial,
74 | batch_size=batch_size,
75 | num_semg_row=NUM_SEMG_ROW,
76 | num_semg_col=NUM_SEMG_COL,
77 | random_state=np.random.RandomState(42))
78 | subject = fold // (NUM_SESSION * NUM_TRIAL) + 1
79 | session = fold // NUM_TRIAL % NUM_SESSION + 1
80 | fold = fold % NUM_TRIAL
81 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)],
82 | GESTURES[1:], [fold]),
83 | product([encode_subject_and_session(subject, session)],
84 | GESTURES[:1], REST_TRIALS[fold:fold + 1])),
85 | shuffle=False)
86 | return val
87 |
88 |
89 | def get_universal_intra_session_data(fold, batch_size, preprocess, balance_gesture, **kargs):
90 | get_trial = GetTrial(preprocess=preprocess)
91 | load = partial(get_data,
92 | amplitude_weighting=kargs.get('amplitude_weighting', False),
93 | amplitude_weighting_sort=kargs.get('amplitude_weighting_sort', False),
94 | framerate=FRAMERATE,
95 | root=ROOT,
96 | last_batch_handle='pad',
97 | get_trial=get_trial,
98 | batch_size=batch_size,
99 | num_semg_row=NUM_SEMG_ROW,
100 | num_semg_col=NUM_SEMG_COL)
101 | trial = fold
102 | train = load(combos=get_combos(product(SUBJECT_AND_SESSIONS,
103 | GESTURES[1:], [i for i in range(NUM_TRIAL) if i != trial]),
104 | product(SUBJECT_AND_SESSIONS,
105 | GESTURES[:1], [REST_TRIALS[i] for i in range(NUM_TRIAL) if i != trial])),
106 | balance_gesture=balance_gesture,
107 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
108 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
109 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
110 | shuffle=True)
111 | assert_is_not_none(train)
112 | logger.debug('Training set loaded')
113 | val = load(combos=get_combos(product(SUBJECT_AND_SESSIONS,
114 | GESTURES[1:], [trial]),
115 | product(SUBJECT_AND_SESSIONS,
116 | GESTURES[:1], REST_TRIALS[trial:trial + 1])),
117 | shuffle=False)
118 | logger.debug('Test set loaded')
119 | assert_is_not_none(val)
120 | return train, val
121 |
122 |
123 | def get_intra_session_data(fold, batch_size, preprocess, balance_gesture, **kargs):
124 | get_trial = GetTrial(preprocess=preprocess)
125 | load = partial(get_data,
126 | amplitude_weighting=kargs.get('amplitude_weighting', False),
127 | amplitude_weighting_sort=kargs.get('amplitude_weighting_sort', False),
128 | framerate=FRAMERATE,
129 | root=ROOT,
130 | last_batch_handle='pad',
131 | get_trial=get_trial,
132 | batch_size=batch_size,
133 | num_semg_row=NUM_SEMG_ROW,
134 | num_semg_col=NUM_SEMG_COL)
135 | subject = fold // (NUM_SESSION * NUM_TRIAL) + 1
136 | session = fold // NUM_TRIAL % NUM_SESSION + 1
137 | fold = fold % NUM_TRIAL
138 | train = load(combos=get_combos(product([encode_subject_and_session(subject, session)],
139 | GESTURES[1:], [f for f in range(NUM_TRIAL) if f != fold]),
140 | product([encode_subject_and_session(subject, session)],
141 | GESTURES[:1], [REST_TRIALS[f] for f in range(NUM_TRIAL) if f != fold])),
142 | balance_gesture=balance_gesture,
143 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
144 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
145 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
146 | shuffle=True)
147 | assert_is_not_none(train)
148 | logger.debug('Training set loaded')
149 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)],
150 | GESTURES[1:], [fold]),
151 | product([encode_subject_and_session(subject, session)],
152 | GESTURES[:1], REST_TRIALS[fold:fold + 1])),
153 | shuffle=False)
154 | logger.debug('Test set loaded')
155 | assert_is_not_none(val)
156 | return train, val
157 |
158 |
159 | def get_inter_session_data(fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs):
160 | # TODO: calib
161 | get_trial = GetTrial(preprocess=preprocess)
162 | load = partial(get_data,
163 | framerate=FRAMERATE,
164 | root=ROOT,
165 | last_batch_handle='pad',
166 | get_trial=get_trial,
167 | batch_size=batch_size,
168 | num_semg_row=NUM_SEMG_ROW,
169 | num_semg_col=NUM_SEMG_COL)
170 | subject = fold // NUM_SESSION + 1
171 | session = fold % NUM_SESSION + 1
172 | train = load(combos=get_combos(product([encode_subject_and_session(subject, i) for i in SESSIONS if i != session],
173 | GESTURES[1:], TRIALS),
174 | product([encode_subject_and_session(subject, i) for i in SESSIONS if i != session],
175 | GESTURES[:1], REST_TRIALS)),
176 | adabn=adabn,
177 | mini_batch_size=batch_size // (NUM_SESSION - 1 if minibatch else 1),
178 | balance_gesture=balance_gesture,
179 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
180 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
181 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
182 | shuffle=True)
183 | logger.debug('Training set loaded')
184 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)],
185 | GESTURES[1:], TRIALS),
186 | product([encode_subject_and_session(subject, session)],
187 | GESTURES[:1], REST_TRIALS)),
188 | shuffle=False)
189 | logger.debug('Test set loaded')
190 | return train, val
191 |
192 |
193 | def get_inter_session_val(fold, batch_size, preprocess, **kargs):
194 | # TODO: calib
195 | get_trial = GetTrial(preprocess=preprocess)
196 | load = partial(get_data,
197 | framerate=FRAMERATE,
198 | root=ROOT,
199 | last_batch_handle='pad',
200 | get_trial=get_trial,
201 | batch_size=batch_size,
202 | num_semg_row=NUM_SEMG_ROW,
203 | num_semg_col=NUM_SEMG_COL,
204 | random_state=np.random.RandomState(42))
205 | subject = fold // NUM_SESSION + 1
206 | session = fold % NUM_SESSION + 1
207 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)],
208 | GESTURES[1:], TRIALS),
209 | product([encode_subject_and_session(subject, session)],
210 | GESTURES[:1], REST_TRIALS)),
211 | shuffle=False)
212 | return val
213 |
214 |
215 | def get_universal_inter_session_data(fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs):
216 | # TODO: calib
217 | get_trial = GetTrial(preprocess=preprocess)
218 | load = partial(get_data,
219 | framerate=FRAMERATE,
220 | root=ROOT,
221 | last_batch_handle='pad',
222 | get_trial=get_trial,
223 | batch_size=batch_size,
224 | num_semg_row=NUM_SEMG_ROW,
225 | num_semg_col=NUM_SEMG_COL)
226 | session = fold + 1
227 | train = load(combos=get_combos(product([encode_subject_and_session(s, i) for s, i in
228 | product(SUBJECTS, [i for i in SESSIONS if i != session])],
229 | GESTURES[1:], TRIALS),
230 | product([encode_subject_and_session(s, i) for s, i in
231 | product(SUBJECTS, [i for i in SESSIONS if i != session])],
232 | GESTURES[:1], REST_TRIALS)),
233 | adabn=adabn,
234 | mini_batch_size=batch_size // (NUM_SUBJECT * (NUM_SESSION - 1) if minibatch else 1),
235 | balance_gesture=balance_gesture,
236 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
237 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
238 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
239 | shuffle=True)
240 | logger.debug('Training set loaded')
241 | val = load(combos=get_combos(product([encode_subject_and_session(s, session) for s in SUBJECTS],
242 | GESTURES[1:], TRIALS),
243 | product([encode_subject_and_session(s, session) for s in SUBJECTS],
244 | GESTURES[:1], REST_TRIALS)),
245 | adabn=adabn,
246 | mini_batch_size=batch_size // (NUM_SUBJECT if minibatch else 1),
247 | shuffle=False)
248 | logger.debug('Test set loaded')
249 | return train, val
250 |
251 |
252 | def get_intra_subject_data(fold, batch_size, cut, bandstop, adabn, minibatch, **kargs):
253 | get_trial = GetTrial(cut=cut, bandstop=bandstop)
254 | load = partial(get_data,
255 | framerate=FRAMERATE,
256 | root=ROOT,
257 | last_batch_handle='pad',
258 | get_trial=get_trial,
259 | batch_size=batch_size,
260 | num_semg_row=NUM_SEMG_ROW,
261 | num_semg_col=NUM_SEMG_COL)
262 | subject = fold // NUM_TRIAL + 1
263 | fold = fold % NUM_TRIAL
264 | train = load(combos=get_combos(product([encode_subject_and_session(subject, session) for session in SESSIONS],
265 | GESTURES[1:], [f for f in range(NUM_TRIAL) if f != fold]),
266 | product([encode_subject_and_session(subject, session) for session in SESSIONS],
267 | GESTURES[:1], [REST_TRIALS[f] for f in range(NUM_TRIAL) if f != fold])),
268 | adabn=adabn,
269 | mini_batch_size=batch_size // (NUM_SESSION if minibatch else 1),
270 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL),
271 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0),
272 | random_shift_vertical=kargs.get('random_shift_vertical', 0),
273 | shuffle=True)
274 | logger.debug('Training set loaded')
275 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session) for session in SESSIONS],
276 | GESTURES[1:], [fold]),
277 | product([encode_subject_and_session(subject, session) for session in SESSIONS],
278 | GESTURES[:1], REST_TRIALS[fold:fold + 1])),
279 | shuffle=False)
280 | logger.debug('Test set loaded')
281 | return train, val
282 |
283 |
284 | class GetTrial(object):
285 |
286 | def __init__(self, preprocess=None):
287 | self.preprocess = preprocess
288 | self.memo = {}
289 |
290 | def __call__(self, root, combo):
291 | subject, session = decode_subject_and_session(combo.subject)
292 | path = os.path.join(root,
293 | 'subject%d' % subject,
294 | 'session%d' % session,
295 | 'gest%d.mat' % combo.gesture)
296 | if path not in self.memo:
297 | data = _get_data(path, self.preprocess)
298 | self.memo[path] = data
299 | logger.debug('{}', path)
300 | else:
301 | data = self.memo[path]
302 | assert combo.trial < len(data), str(combo)
303 | data = data[combo.trial].copy()
304 | gesture = np.repeat(combo.gesture, len(data))
305 | subject = np.repeat(combo.subject, len(data))
306 | return Trial(data=data, gesture=gesture, subject=subject)
307 |
308 |
309 | @cached
310 | def _get_data(path, preprocess):
311 | data = sio.loadmat(path)['gestures']
312 | data = [np.transpose(np.delete(segment.astype(np.float32), np.s_[7:192:8], 0))
313 | for segment in data.flat]
314 | if preprocess:
315 | data = list(Context.parallel(jb.delayed(preprocess)(segment, **PREPROCESS_KARGS)
316 | for segment in data))
317 | return data
318 |
319 |
320 | # @cached
321 | # def _get_data(path, bandstop, cut, downsample):
322 | # data = sio.loadmat(path)['gestures']
323 | # data = [np.transpose(np.delete(segment.astype(np.float32), np.s_[7:192:8], 0))
324 | # for segment in data.flat]
325 | # if bandstop:
326 | # data = list(Context.parallel(jb.delayed(get_bandstop)(segment) for segment in data))
327 | # if cut is not None:
328 | # data = list(Context.parallel(jb.delayed(cut)(segment, framerate=FRAMERATE) for segment in data))
329 | # if downsample > 1:
330 | # data = [segment[::downsample].copy() for segment in data]
331 | # return data
332 |
333 |
334 | def decode_subject_and_session(ss):
335 | return (ss - 1) // NUM_SESSION + 1, (ss - 1) % NUM_SESSION + 1
336 |
337 |
338 | def encode_subject_and_session(subject, session):
339 | return (subject - 1) * NUM_SESSION + session
340 |
341 |
342 | def get_bandstop(data):
343 | from ..utils import butter_bandstop_filter
344 | return np.array([butter_bandstop_filter(ch, 45, 55, 2048, 2) for ch in data])
345 |
346 |
347 | def get_combos(*args):
348 | for arg in args:
349 | if isinstance(arg, tuple):
350 | arg = [arg]
351 | for a in arg:
352 | combo = Combo(*a)
353 | if ignore_missing(combo):
354 | continue
355 | yield combo
356 |
357 |
358 | def ignore_missing(combo):
359 | return combo.subject == 19 and combo.gesture in (8, 9) and combo.trial == 9
360 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/__init__.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | from itertools import product
4 | import numpy as np
5 | import scipy.io as sio
6 | from logbook import Logger
7 | from ... import utils, CACHE
8 | from .. import Dataset as Base, Combo, Trial, SingleSessionMixin
9 |
10 |
11 | NUM_SEMG_ROW = 1
12 | NUM_SEMG_COL = 10
13 | FRAMERATE = 100
14 | PREPROCESS_KARGS = dict(
15 | framerate=FRAMERATE,
16 | num_semg_row=NUM_SEMG_ROW,
17 | num_semg_col=NUM_SEMG_COL
18 | )
19 |
20 | logger = Logger(__name__)
21 |
22 |
23 | class Dataset(SingleSessionMixin, Base):
24 |
25 | framerate = FRAMERATE
26 | num_semg_row = NUM_SEMG_ROW
27 | num_semg_col = NUM_SEMG_COL
28 | subjects = list(range(27))
29 | gestures = list(range(53))
30 | trials = list(range(10))
31 |
32 | def __init__(self, root):
33 | self.root = root
34 |
35 | def get_one_fold_intra_subject_trials(self):
36 | return [0, 2, 3, 5, 7, 8, 9], [1, 4, 6]
37 |
38 | def get_trial_func(self, *args, **kargs):
39 | return GetTrial(*args, **kargs)
40 |
41 | @classmethod
42 | def parse(cls, text):
43 | if cls is not Dataset and text == cls.name:
44 | return cls(root=os.path.join(CACHE, cls.name.split('/')[0], 'data'))
45 |
46 |
47 | class GetTrial(object):
48 |
49 | def __init__(self, gestures, trials, preprocess=None):
50 | self.preprocess = preprocess
51 | self.memo = {}
52 | self.gesture_and_trials = list(product(gestures, trials))
53 |
54 | def get_path(self, root, combo):
55 | return os.path.join(
56 | root,
57 | '{c.subject:03d}',
58 | '{c.gesture:03d}',
59 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo)
60 |
61 | def __call__(self, root, combo):
62 | path = self.get_path(root, combo)
63 | if path not in self.memo:
64 | logger.debug('Load subject {}', combo.subject)
65 | paths = [self.get_path(root, Combo(combo.subject, gesture, trial))
66 | for gesture, trial in self.gesture_and_trials]
67 | self.memo.update({path: data for path, data in
68 | zip(paths, _get_data(paths, self.preprocess))})
69 | data = self.memo[path]
70 | data = data.copy()
71 | gesture = np.repeat(combo.gesture, len(data))
72 | subject = np.repeat(combo.subject, len(data))
73 | return Trial(data=data, gesture=gesture, subject=subject)
74 |
75 |
76 | @utils.cached
77 | def _get_data(paths, preprocess):
78 | # return list(Context.parallel(
79 | # jb.delayed(_get_data_aux)(path, preprocess) for path in paths))
80 | return [_get_data_aux(path, preprocess) for path in paths]
81 |
82 |
83 | def _get_data_aux(path, preprocess):
84 | data = sio.loadmat(path)['data'].astype(np.float32)
85 | if preprocess:
86 | data = preprocess(data, **PREPROCESS_KARGS)
87 | return data
88 |
89 |
90 | from . import db1, db1_g53, db1_g5, db1_g8, db1_g12, caputo, db1_matlab_lowpass
91 | assert db1 and db1_g53 and db1_g5 and db1_g8 and db1_g12 and caputo and db1_matlab_lowpass
92 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/caputo.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'ninapro-db1/caputo'
8 | gestures = list(range(1, 53))
9 |
10 | def get_one_fold_intra_subject_trials(self):
11 | return [i - 1 for i in [1, 3, 4, 5, 9]], [i - 1 for i in [2, 6, 7, 8, 10]]
12 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/db1.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'ninapro-db1'
8 | gestures = list(range(1, 53))
9 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/db1_g12.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'ninapro-db1/g12'
8 | gestures = list(range(1, 13))
9 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/db1_g5.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'ninapro-db1/g5'
8 | gestures = list(range(25, 30))
9 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/db1_g53.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'ninapro-db1/g53'
8 | gestures = list(range(0, 53))
9 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/db1_g8.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Dataset(Base):
6 |
7 | name = 'ninapro-db1/g8'
8 | gestures = list(range(13, 21))
9 |
--------------------------------------------------------------------------------
/sigr/data/ninapro/db1_matlab_lowpass.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from . import Dataset as Base
3 |
4 |
5 | class Database(Base):
6 |
7 | name = 'ninapro-db1-matlab-lowpass'
8 | gestures = list(range(1, 53))
9 |
--------------------------------------------------------------------------------
/sigr/data/preprocess.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import re
3 | import numpy as np
4 | from nose.tools import assert_less_equal
5 | from ..utils import cached, butter_lowpass_filter as lowpass
6 | from scipy.ndimage.filters import median_filter
7 |
8 |
9 | class Preprocess(object):
10 |
11 | class __metaclass__(type):
12 |
13 | def __init__(cls, name, bases, fields):
14 | type.__init__(cls, name, bases, fields)
15 | if name != 'Preprocess':
16 | Preprocess.register(cls)
17 |
18 | impls = []
19 |
20 | def __call__(self, data, **kargs):
21 | return data
22 |
23 | @classmethod
24 | def parse(cls, text):
25 | if not text:
26 | return None
27 | if cls is Preprocess:
28 | for impl in cls.impls:
29 | inst = impl.parse(text)
30 | if inst is not None:
31 | return inst
32 |
33 | @classmethod
34 | def register(cls, impl):
35 | cls.impls.append(impl)
36 |
37 |
38 | class Sequence(Preprocess):
39 |
40 | @classmethod
41 | def parse(cls, text):
42 | matched = re.search('\((.+)\)', text)
43 | if matched:
44 | return cls([Preprocess.parse(stage) for stage
45 | in matched.group(1).split(',')])
46 |
47 | def __init__(self, stages):
48 | self.stages = stages
49 |
50 | def __call__(self, data, **kargs):
51 | for stage in self.stages:
52 | data = stage(data, **kargs)
53 | return data
54 |
55 | def __repr__(self):
56 | return 'Sequence(%s)' % ','.join(str(stage) for stage in self.stages)
57 |
58 |
59 | class Bandstop(Preprocess):
60 |
61 | @classmethod
62 | def parse(cls, text):
63 | if re.search('bandstop', text):
64 | return cls()
65 |
66 | def __call__(self, data, framerate, **kargs):
67 | from ..utils import butter_bandstop_filter as bandstop
68 | return np.transpose([bandstop(ch, 45, 55, framerate, 2) for ch in data.T])
69 |
70 | def __repr__(self):
71 | return 'Bandstop()'
72 |
73 |
74 | class CSLBandpass(Preprocess):
75 |
76 | @classmethod
77 | def parse(cls, text):
78 | if re.search('csl-bandpass', text):
79 | return cls()
80 |
81 | def __call__(self, data, framerate, **kargs):
82 | from ..utils import butter_bandpass_filter as bandpass
83 | return np.transpose([bandpass(ch, 20, 400, framerate, 4) for ch in data.T])
84 |
85 | def __repr__(self):
86 | return 'CSLBandpass()'
87 |
88 |
89 | class NinaProLowpass(Preprocess):
90 |
91 | @classmethod
92 | def parse(cls, text):
93 | if re.search('ninapro-lowpass', text):
94 | return cls()
95 |
96 | def __call__(self, data, framerate, **kargs):
97 | return np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T])
98 |
99 | def __repr__(self):
100 | return 'NinaProLowpass()'
101 |
102 |
103 | class Downsample(Preprocess):
104 |
105 | @classmethod
106 | def parse(cls, text):
107 | matched = re.search('downsample-(\d+)', text)
108 | if matched:
109 | return cls(int(matched.group(1)))
110 |
111 | def __init__(self, step):
112 | self.step = step
113 |
114 | def __call__(self, data, **kargs):
115 | return data[::self.step].copy()
116 |
117 | def __repr__(self):
118 | return 'Downsample(step=%d)' % self.step
119 |
120 |
121 | class Median3x3(Preprocess):
122 |
123 | @classmethod
124 | def parse(cls, text):
125 | if re.search('median', text):
126 | return cls()
127 |
128 | def __call__(self, data, num_semg_row, num_semg_col, **kargs):
129 | return np.array([median_filter(image, 3).ravel() for image
130 | in data.reshape(-1, num_semg_row, num_semg_col)])
131 |
132 | def __repr__(self):
133 | return 'Median3x3()'
134 |
135 |
136 | class Abs(Preprocess):
137 |
138 | @classmethod
139 | def parse(cls, text):
140 | if re.search('abs', text):
141 | return cls()
142 |
143 | def __call__(self, data, **kargs):
144 | return np.abs(data)
145 |
146 | def __repr__(self):
147 | return 'Abs()'
148 |
149 |
150 | class RMS(Preprocess):
151 |
152 | @classmethod
153 | def parse(cls, text):
154 | matched = re.search('rms-(\d+)', text)
155 | if matched:
156 | return cls(int(matched.group(1)))
157 |
158 | def __init__(self, window):
159 | self.window = window
160 |
161 | def __call__(self, data, **kargs):
162 | window = min(self.window, len(data))
163 | return np.transpose([moving_rms(ch, window) for ch in data.T])
164 |
165 | def __repr__(self):
166 | return 'RMS(window=%d)' % self.window
167 |
168 |
169 | class Cut(Preprocess):
170 | pass
171 |
172 |
173 | class MiddleCut(Cut):
174 |
175 | @classmethod
176 | def parse(cls, text):
177 | matched = re.search('mid-(\d+)', text)
178 | if matched:
179 | return cls(int(matched.group(1)))
180 |
181 | def __init__(self, window):
182 | self.window = window
183 |
184 | def __call__(self, data, **kargs):
185 | if len(data) < self.window:
186 | return data
187 | begin = (len(data) - self.window) // 2
188 | return data[begin:begin + self.window].copy()
189 |
190 | def __repr__(self):
191 | return 'MiddleCut(window=%d)' % self.window
192 |
193 |
194 | class PeakCut(Cut):
195 |
196 | @classmethod
197 | def parse(cls, text):
198 | matched = re.search('^peak-(\d+)$', text)
199 | if matched:
200 | return cls(int(matched.group(1)))
201 |
202 | def __init__(self, window):
203 | self.window = window
204 |
205 | def __call__(self, data, framerate, num_semg_row, num_semg_col, **kargs):
206 | if len(data) < self.window:
207 | return data
208 |
209 | begin = np.argmax(_get_amp(data, framerate, num_semg_row, num_semg_col)
210 | [self.window // 2:-(self.window - self.window // 2 - 1)])
211 | assert_less_equal(begin + self.window, len(data))
212 | return data[begin:begin + self.window]
213 |
214 | def __repr__(self):
215 | return 'PeakCut(window=%d)' % self.window
216 |
217 |
218 | class NinaProPeakCut(Cut):
219 |
220 | @classmethod
221 | def parse(cls, text):
222 | matched = re.search('^ninapro-peak-(\d+)$', text)
223 | if matched:
224 | return cls(int(matched.group(1)))
225 |
226 | def __init__(self, window):
227 | self.window = window
228 |
229 | def __call__(self, data, framerate, **kargs):
230 | if len(data) < self.window:
231 | return data
232 |
233 | begin = np.argmax(_get_ninapro_amp(data, framerate)
234 | [self.window // 2:-(self.window - self.window // 2 - 1)])
235 | assert_less_equal(begin + self.window, len(data))
236 | return data[begin:begin + self.window]
237 |
238 | def __repr__(self):
239 | return 'NinaProPeakCut(window=%d)' % self.window
240 |
241 |
242 | class CSLCut(Cut):
243 |
244 | @classmethod
245 | def parse(cls, text):
246 | if re.search('csl-cut', text):
247 | return cls()
248 |
249 | def __call__(self, data, framerate, **kargs):
250 | begin, end = _csl_cut(data, framerate)
251 | return data[begin:end]
252 |
253 | def __repr__(self):
254 | return 'CSLCut()'
255 |
256 |
257 | def _csl_cut(data, framerate):
258 | window = int(np.round(150 * framerate / 2048))
259 | data = data[:len(data) // window * window].reshape(-1, 150, data.shape[1])
260 | rms = np.sqrt(np.mean(np.square(data), axis=1))
261 | rms = [median_filter(image, 3).ravel() for image in rms.reshape(-1, 24, 7)]
262 | rms = np.mean(rms, axis=1)
263 | threshold = np.mean(rms)
264 | mask = rms > threshold
265 | for i in range(1, len(mask) - 1):
266 | if not mask[i] and mask[i - 1] and mask[i + 1]:
267 | mask[i] = True
268 | from .. import utils
269 | begin, end = max(utils.continuous_segments(mask),
270 | key=lambda s: (mask[s[0]], s[1] - s[0]))
271 | return begin * window, end * window
272 |
273 |
274 | @cached
275 | def _get_amp(data, framerate, num_semg_row, num_semg_col):
276 | data = np.abs(data)
277 | data = np.transpose([lowpass(ch, 2, framerate, 4, zero_phase=True) for ch in data.T])
278 | return [median_filter(image, 3).mean() for image in data.reshape(-1, num_semg_row, num_semg_col)]
279 |
280 |
281 | def _get_ninapro_amp(data, framerate):
282 | data = np.abs(data)
283 | data = np.transpose([lowpass(ch, 2, framerate, 4, zero_phase=True) for ch in data.T])
284 | return data.mean(axis=1)
285 |
286 |
287 | def moving_rms(a, window):
288 | a2 = np.square(a)
289 | window = np.ones(window) / window
290 | return np.sqrt(np.convolve(a2, window, 'valid'))
291 |
--------------------------------------------------------------------------------
/sigr/data/s21.py:
--------------------------------------------------------------------------------
1 | from itertools import product, starmap
2 | from . import get_data, Combo
3 | from .. import ROOT
4 | import os
5 | import numpy as np
6 |
7 |
8 | ROOT = os.path.join(ROOT, '.cache/mat.s21.bandstop-45-55.s1000m.scale-01')
9 |
10 |
11 | def get_coral(folds, batch_size):
12 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20]
13 | return get_data(
14 | root=ROOT,
15 | # combos=get_combos(product([subjects[fold] for fold in folds], [100, 101], [0])),
16 | combos=get_combos(product([subjects[fold] for fold in folds], range(1, 9), [0])),
17 | mean=0.5,
18 | scale=2,
19 | batch_size=2000,
20 | last_batch_handle='pad',
21 | shuffle=False,
22 | adabn=True
23 | )
24 |
25 |
26 | def get_combos(prods):
27 | return list(starmap(Combo, prods))
28 |
29 |
30 | def get_stats():
31 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20]
32 | load = lambda subject: get_data(
33 | root=ROOT,
34 | combos=get_combos(product([subject], range(1, 9), range(10))),
35 | mean=0.5,
36 | scale=2,
37 | batch_size=1000,
38 | last_batch_handle='roll_over'
39 | )
40 | stats = []
41 | for subject in subjects:
42 | batch = next(load(subject)[0])
43 | data = batch.data[0].asnumpy()
44 | stats.append({
45 | 'std': data.std()
46 | })
47 | import pandas as pd
48 | return pd.DataFrame(stats, index=range(10))
49 |
50 |
51 | def get_general_data(root, batch_size, with_subject):
52 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20]
53 | load = lambda **kargs: get_data(
54 | root=root,
55 | mean=0.5,
56 | scale=2,
57 | with_subject=with_subject,
58 | batch_size=batch_size,
59 | last_batch_handle='roll_over',
60 | **kargs
61 | )
62 | val, num_val = load(combos=get_combos(product(subjects, range(1, 9), range(1, 10, 2))))
63 | train, num_train = load(combos=get_combos(product(subjects, range(1, 9), range(0, 10, 2))))
64 | return train, val, num_train, num_val
65 |
66 |
67 | def get_inter_subject_data(
68 | root,
69 | fold,
70 | batch_size,
71 | maxforce,
72 | target_binary,
73 | calib,
74 | with_subject,
75 | with_target_gesture,
76 | random_scale,
77 | random_bad_channel,
78 | shuffle,
79 | adabn,
80 | window,
81 | only_calib,
82 | soft_label,
83 | minibatch,
84 | fft,
85 | fft_append,
86 | dual_stream,
87 | lstm,
88 | dense_window,
89 | lstm_window
90 | ):
91 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20]
92 |
93 | num_subject = 10 if maxforce or calib else 9
94 | if minibatch:
95 | assert batch_size % num_subject == 0, '%d %% %d' % (batch_size, num_subject)
96 | mini_batch_size = batch_size // num_subject
97 | else:
98 | mini_batch_size = batch_size
99 |
100 | load = lambda **kargs: get_data(
101 | root=root,
102 | mean=0.5,
103 | scale=2,
104 | with_subject=with_subject,
105 | target_binary=target_binary,
106 | batch_size=batch_size,
107 | with_target_gesture=with_target_gesture,
108 | fft=fft,
109 | fft_append=fft_append,
110 | dual_stream=dual_stream,
111 | **kargs
112 | )
113 | val_subject = subjects[fold]
114 | del subjects[fold]
115 | val = load(
116 | combos=get_combos(product([val_subject], range(1, 9), range(1, 10) if calib else range(10))),
117 | last_batch_handle='pad',
118 | shuffle=False,
119 | window=(window // (lstm_window or window)) if lstm else window,
120 | num_ignore_per_segment=window - 1 if lstm else 0,
121 | dense_window=dense_window
122 | )
123 |
124 | if maxforce and calib:
125 | target_combos = get_combos(product([val_subject], list(range(1, 9)) * 10 + [100, 101], [0] * (9 if target_binary else 1)))
126 | elif maxforce:
127 | target_combos = get_combos(product([val_subject], [100, 101], [0] * 41 * (9 if target_binary else 1)))
128 | elif only_calib:
129 | target_combos = get_combos(product([val_subject], list(range(1, 9)), [0]))
130 | elif calib:
131 | target_combos = get_combos(product([val_subject], list(range(1, 9)) * 10, [0] * (9 if target_binary else 1)))
132 | else:
133 | target_combos = None
134 |
135 | if only_calib:
136 | combos = []
137 | else:
138 | combos = get_combos(product(subjects, range(1, 9), range(10)))
139 | if maxforce:
140 | combos += get_combos(product(subjects, [100, 101], [0]))
141 |
142 | if soft_label:
143 | import pandas as pd
144 | soft_label = pd.DataFrame.from_csv(os.path.join(os.path.dirname(__file__), 's21_soft_label.scv'))
145 |
146 | train = load(
147 | combos=combos,
148 | target_combos=target_combos,
149 | random_scale=random_scale,
150 | random_bad_channel=random_bad_channel,
151 | last_batch_handle='pad',
152 | shuffle=shuffle,
153 | mini_batch_size=mini_batch_size,
154 | soft_label=False if soft_label is False else soft_label[soft_label['fold'] == fold][[str(i) for i in range(8)]].as_matrix(),
155 | adabn=adabn,
156 | window=window,
157 | dense_window=dense_window
158 | )
159 | return train, val
160 |
161 |
162 | def get_inter_subject_val(fold, batch_size, calib, **kargs):
163 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20]
164 | return get_data(
165 | combos=get_combos(product([subjects[fold]], range(1, 9), range(1, 10) if calib else range(10))),
166 | root=ROOT,
167 | mean=0.5,
168 | scale=2,
169 | batch_size=batch_size,
170 | last_batch_handle='pad',
171 | shuffle=False,
172 | random_state=np.random.RandomState(42),
173 | **kargs
174 | )
175 |
176 |
177 | def get_inter_subject_train(fold, batch_size):
178 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20]
179 | return get_data(
180 | combos=get_combos(product([subjects[i] for i in range(10) if i != fold], range(1, 9), range(10))),
181 | root=ROOT,
182 | mean=0.5,
183 | scale=2,
184 | batch_size=batch_size,
185 | last_batch_handle='pad',
186 | shuffle=False
187 | )
188 |
--------------------------------------------------------------------------------
/sigr/data/s21_soft_label.scv:
--------------------------------------------------------------------------------
1 | ,0,1,2,3,4,5,6,7,fold
2 | 0,0.585914433002,0.0113508105278,0.0469612777233,0.0103245843202,0.0297896899283,0.213723227382,0.0498025380075,0.0520934313536,0
3 | 1,0.0177394840866,0.260400772095,0.0862425193191,0.217843294144,0.0818059891462,0.0307203009725,0.155910596251,0.149291470647,0
4 | 2,0.0565337799489,0.0951329991221,0.290711790323,0.0840612798929,0.18585036695,0.0402426086366,0.122267112136,0.125134795904,0
5 | 3,0.0208049118519,0.22623719275,0.0898522436619,0.244077846408,0.120171234012,0.0297798700631,0.127505093813,0.141515702009,0
6 | 4,0.0588904172182,0.0689897313714,0.238068774343,0.0831308886409,0.332722753286,0.038050621748,0.0563105903566,0.123781606555,0
7 | 5,0.375521868467,0.0143311182037,0.0225172676146,0.00740715721622,0.0129349566996,0.436831176281,0.100549437106,0.0298644062132,0
8 | 6,0.0699726864696,0.116337053478,0.117743805051,0.0787186548114,0.0549033097923,0.118118651211,0.306120842695,0.138010010123,0
9 | 7,0.0656202509999,0.132219433784,0.132582515478,0.126813665032,0.14836679399,0.0508881472051,0.15999814868,0.183442443609,0
10 | 8,0.573849737644,0.0113045303151,0.0458538196981,0.0101971579716,0.0256233215332,0.230556309223,0.0508873090148,0.0516888573766,1
11 | 9,0.0176304485649,0.238844901323,0.0921256542206,0.208853840828,0.100514553487,0.0301328171045,0.161045968533,0.150802791119,1
12 | 10,0.0456261076033,0.0941201895475,0.301260918379,0.0844209119678,0.202144488692,0.035298217088,0.11817112565,0.118890993297,1
13 | 11,0.0206406675279,0.207882523537,0.0887452363968,0.239738628268,0.123213484883,0.0303926169872,0.147305071354,0.142024502158,1
14 | 12,0.0543091744184,0.0716122165322,0.208353236318,0.0926068499684,0.354343175888,0.0366435796022,0.0551674477756,0.126913920045,1
15 | 13,0.390921235085,0.0146044613793,0.0273387394845,0.00665758550167,0.0111129777506,0.439626246691,0.086997166276,0.022702537477,1
16 | 14,0.0779526233673,0.114770486951,0.122850477695,0.0789507627487,0.0550688132644,0.135418519378,0.289047718048,0.125866964459,1
17 | 15,0.0544591732323,0.142568826675,0.129290759563,0.14990568161,0.124165035784,0.0469007156789,0.163525983691,0.189116105437,1
18 | 16,0.594108045101,0.00999377202243,0.0427821725607,0.00983097590506,0.0300427172333,0.21718133986,0.0437685139477,0.0522559508681,2
19 | 17,0.0147701213136,0.252291023731,0.086970448494,0.229896858335,0.0909300968051,0.0276916641742,0.151407673955,0.145995393395,2
20 | 18,0.0548012703657,0.0855624973774,0.321182370186,0.0757946372032,0.200975820422,0.0397630445659,0.109111316502,0.112741105258,2
21 | 19,0.0160632822663,0.23853699863,0.0873838663101,0.254317045212,0.0998763814569,0.0285355579108,0.140727058053,0.134509548545,2
22 | 20,0.0488464124501,0.071256428957,0.250549972057,0.0880576819181,0.333701938391,0.0352295488119,0.0502944551408,0.122012011707,2
23 | 21,0.368753939867,0.0141330743209,0.0257930252701,0.00734513904899,0.0160015933216,0.426667273045,0.109375782311,0.0318860970438,2
24 | 22,0.0769981369376,0.0936448574066,0.119277991354,0.0723595842719,0.0508595369756,0.158597052097,0.304699063301,0.123484656215,2
25 | 23,0.0539519712329,0.140293493867,0.141128987074,0.132432863116,0.142665907741,0.0458317175508,0.16619721055,0.177431449294,2
26 | 24,0.5591365695,0.0114038847387,0.0458650290966,0.0102882077917,0.030391799286,0.239412352443,0.0506078414619,0.0528544560075,3
27 | 25,0.0177513454109,0.238689228892,0.100602254272,0.223592862487,0.101336151361,0.0273324083537,0.143142953515,0.147503301501,3
28 | 26,0.0509831905365,0.0920915007591,0.306010752916,0.0830404087901,0.199058055878,0.0363924279809,0.116912446916,0.115444153547,3
29 | 27,0.0209863614291,0.230153664947,0.0899317339063,0.233842685819,0.108670607209,0.0305126570165,0.147977411747,0.13786932826,3
30 | 28,0.0588953457773,0.0688157305121,0.240384683013,0.0937875658274,0.307589739561,0.038122843951,0.0595138818026,0.132834300399,3
31 | 29,0.34596735239,0.015977114439,0.0254017412663,0.00788462907076,0.0162124875933,0.443999558687,0.11173632741,0.0327737107873,3
32 | 30,0.0714117065072,0.113147959113,0.118221767247,0.0778153985739,0.0507748536766,0.149462670088,0.300601005554,0.11848885566,3
33 | 31,0.0650929734111,0.132121101022,0.133752852678,0.142429187894,0.129983246326,0.0497390404344,0.163882493973,0.182930201292,3
34 | 32,0.594033658504,0.00978412944824,0.0448851883411,0.00877121277153,0.0301306284964,0.217805102468,0.0442418269813,0.050309818238,4
35 | 33,0.0125299263746,0.260252594948,0.0838051810861,0.231062918901,0.0862522274256,0.0271735414863,0.147024214268,0.151855185628,4
36 | 34,0.0581892468035,0.0873780623078,0.314817845821,0.0790030509233,0.185462743044,0.041338711977,0.113018415868,0.120728157461,4
37 | 35,0.0210476107895,0.202843770385,0.089477263391,0.260362178087,0.116318069398,0.0304090902209,0.129763320088,0.149725064635,4
38 | 36,0.0577154792845,0.0685700327158,0.220444232225,0.0952667221427,0.335331767797,0.0371781699359,0.0557963885367,0.129641205072,4
39 | 37,0.350155651569,0.0145853841677,0.0262643638998,0.00586827797815,0.0150276897475,0.449112892151,0.10923538357,0.0297034103423,4
40 | 38,0.0629346594214,0.110558472574,0.113621123135,0.0723159685731,0.048894032836,0.153379887342,0.309411644936,0.128806352615,4
41 | 39,0.0646151080728,0.129451319575,0.131009638309,0.148646071553,0.128791987896,0.0500863455236,0.167172878981,0.180158615112,4
42 | 40,0.580878973007,0.0113963577896,0.0468598306179,0.0103343445808,0.0308610480279,0.219456076622,0.0469783619046,0.0531973131001,5
43 | 41,0.0178357362747,0.249909639359,0.0978484898806,0.218870550394,0.0988143011928,0.0303349476308,0.154420286417,0.131914392114,5
44 | 42,0.0603491105139,0.0798718780279,0.32292303443,0.0799543261528,0.213207960129,0.0414940938354,0.10795687139,0.0941794067621,5
45 | 43,0.0202580057085,0.226799473166,0.0851363390684,0.259252399206,0.0973977297544,0.0303331054747,0.152081489563,0.128690332174,5
46 | 44,0.0585519187152,0.0609935373068,0.246153384447,0.0856251418591,0.341543257236,0.0378966443241,0.0550683364272,0.114118672907,5
47 | 45,0.381124228239,0.0156234931201,0.0270085260272,0.00779242208228,0.0163024608046,0.412237107754,0.107251346111,0.0326209925115,5
48 | 46,0.0739379227161,0.10848467797,0.0726554319263,0.0708635002375,0.043027702719,0.160984665155,0.334530264139,0.135450929403,5
49 | 47,0.0556856766343,0.146907523274,0.128168582916,0.148990258574,0.123327106237,0.0457626357675,0.17281241715,0.178279042244,5
50 | 48,0.564357459545,0.0114045888186,0.0459825992584,0.0102657750249,0.0303347632289,0.232964366674,0.051675580442,0.0529765896499,6
51 | 49,0.0178332515061,0.243779942393,0.0952674150467,0.221842601895,0.0975822508335,0.0296472813934,0.143671065569,0.150326281786,6
52 | 50,0.0530138872564,0.0937199220061,0.303352326155,0.079468511045,0.18704906106,0.0393490642309,0.123327203095,0.12065808475,6
53 | 51,0.0210853293538,0.225017100573,0.0950988605618,0.219580188394,0.125873163342,0.0306841302663,0.138054862618,0.144554525614,6
54 | 52,0.0484657548368,0.0729737207294,0.22917728126,0.0949310436845,0.33986890316,0.0320509634912,0.0581087581813,0.124372884631,6
55 | 53,0.357948482037,0.0155431739986,0.0275333896279,0.00730467308313,0.0162500869483,0.441128909588,0.103582292795,0.030662054196,6
56 | 54,0.0759399980307,0.0977061539888,0.108938999474,0.0703660771251,0.0512497872114,0.153688088059,0.313845336437,0.128190472722,6
57 | 55,0.0589478500187,0.145432949066,0.14507548511,0.136068463326,0.147310584784,0.0499232001603,0.152261927724,0.164912343025,6
58 | 56,0.60515910387,0.00981649104506,0.0380278304219,0.00969350151718,0.0275855381042,0.221200808883,0.0429150685668,0.0455675348639,7
59 | 57,0.0123052867129,0.264212399721,0.0848530232906,0.221839383245,0.0756716877222,0.0268339943141,0.153965786099,0.160277932882,7
60 | 58,0.0582231655717,0.0821022167802,0.329458266497,0.0659338235855,0.178078427911,0.0419608466327,0.120090350509,0.12409196049,7
61 | 59,0.0133151542395,0.223469093442,0.0751332044601,0.256989300251,0.117942109704,0.0219468865544,0.138660281897,0.15249787271,7
62 | 60,0.0427861995995,0.0686382204294,0.232943952084,0.0817427933216,0.359221041203,0.0316239818931,0.0524121262133,0.130583316088,7
63 | 61,0.39994981885,0.00861005764455,0.0213297940791,0.00698478939012,0.0145348263904,0.437821269035,0.0879363417625,0.0227968432009,7
64 | 62,0.0704860463738,0.100654803216,0.115903668106,0.0651987493038,0.0495304837823,0.147507175803,0.323221951723,0.127423748374,7
65 | 63,0.0522875934839,0.1378274858,0.133011072874,0.138035595417,0.145129650831,0.0384888760746,0.167901203036,0.187253862619,7
66 | 64,0.622319102287,0.00547091430053,0.0352219045162,0.0057081039995,0.0194392669946,0.246698439121,0.0414409972727,0.023681294173,8
67 | 65,0.0157551020384,0.265378654003,0.0883843973279,0.225549280643,0.0939034298062,0.0188524145633,0.136794626713,0.15533824265,8
68 | 66,0.0573872178793,0.0854949876666,0.3361107409,0.0717113688588,0.202803596854,0.0280384868383,0.0982796773314,0.120111130178,8
69 | 67,0.0179295912385,0.229140669107,0.0825557112694,0.260068267584,0.121700055897,0.0170418191701,0.133098810911,0.138416275382,8
70 | 68,0.0514223910868,0.0662609562278,0.24324285984,0.0915150269866,0.352055311203,0.0198261514306,0.0505900233984,0.125039443374,8
71 | 69,0.317708104849,0.0155081218109,0.0258607566357,0.00671677058563,0.0152378566563,0.472864151001,0.114921763539,0.031142629683,8
72 | 70,0.0631428137422,0.108594641089,0.101839073002,0.0715390816331,0.0431883595884,0.146862730384,0.332589954138,0.132170587778,8
73 | 71,0.0608781836927,0.137492239475,0.121944501996,0.143452703953,0.148526206613,0.0384837388992,0.165029957891,0.184127256274,8
74 | 72,0.591446340084,0.0108800856397,0.0416322499514,0.00790138915181,0.0240155700594,0.232482612133,0.0463338270783,0.0452779754996,9
75 | 73,0.0167807787657,0.242463931441,0.0944532901049,0.218917831779,0.0947048291564,0.0288841463625,0.153221786022,0.15052652359,9
76 | 74,0.0518501289189,0.0852868705988,0.313529372215,0.0755426958203,0.20850302279,0.040258616209,0.106483064592,0.118484579027,9
77 | 75,0.0191303230822,0.214621096849,0.0793804228306,0.251748353243,0.115605682135,0.029749520123,0.143106788397,0.146608382463,9
78 | 76,0.0519566200674,0.0685137063265,0.235343664885,0.0947128608823,0.343081116676,0.0370681136847,0.0540818944573,0.115194141865,9
79 | 77,0.359065055847,0.0153089584783,0.0246519688517,0.00715514039621,0.0132492622361,0.442372530699,0.1072749421,0.0308763105422,9
80 | 78,0.0754193663597,0.0997704938054,0.1212580055,0.0742529407144,0.0518601499498,0.146741092205,0.307116210461,0.12350846082,9
81 | 79,0.0604774132371,0.143125548959,0.125325471163,0.141259744763,0.139200344682,0.0475078225136,0.166975483298,0.176064044237,9
82 |
--------------------------------------------------------------------------------
/sigr/evaluation.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import os
3 | import numpy as np
4 | from functools import partial
5 | from .parse_log import parse_log
6 | from . import utils
7 | from . import module
8 | from logbook import Logger
9 | from copy import deepcopy
10 | import mxnet as mx
11 |
12 |
13 | Exp = utils.Bunch
14 |
15 | logger = Logger(__name__)
16 |
17 |
18 | @utils.cached(ignore=['context'])
19 | def _crossval_predict_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None):
20 | Mod = deepcopy(Mod)
21 | Mod.update(context=context)
22 | mod = module.RuntimeModule(**Mod)
23 | Val = partial(
24 | get_crossval_val,
25 | fold=fold,
26 | batch_size=self.batch_size,
27 | window=mod.num_channel,
28 | **(dataset_args or {})
29 | )
30 | return mod.predict(utils.LazyProxy(Val))
31 |
32 |
33 | @utils.cached(ignore=['context'])
34 | def _crossval_predict_proba_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None):
35 | Mod = deepcopy(Mod)
36 | Mod.update(context=context)
37 | mod = module.RuntimeModule(**Mod)
38 | Val = partial(
39 | get_crossval_val,
40 | fold=fold,
41 | batch_size=self.batch_size,
42 | window=mod.num_channel,
43 | **(dataset_args or {})
44 | )
45 | return mod.predict_proba(utils.LazyProxy(Val))
46 |
47 |
48 | def _crossval_predict(self, **kargs):
49 | proba = kargs.pop('proba', False)
50 | fold = int(kargs.pop('fold'))
51 | Mod = kargs.pop('Mod')
52 | Mod = deepcopy(Mod)
53 | Mod.update(params=self.format_params(Mod['params'], fold))
54 | context = Mod.pop('context', [mx.gpu(0)])
55 | # import pickle
56 | # d = kargs.copy()
57 | # d.update(Mod=Mod, fold=fold)
58 | # print(pickle.dumps(d))
59 |
60 | # Ensure load from disk.
61 | # Otherwise following cached methods like vote will have two caches,
62 | # one for the first computation,
63 | # and the other for the cached one.
64 | func = _crossval_predict_aux if not proba else _crossval_predict_proba_aux
65 | return func.call_and_shelve(self, Mod=Mod, fold=fold, context=context, **kargs).get()
66 |
67 |
68 | class Evaluation(object):
69 |
70 | def __init__(self, batch_size=None):
71 | self.batch_size = batch_size
72 |
73 |
74 | class CrossValEvaluation(Evaluation):
75 |
76 | def __init__(self, **kargs):
77 | self.crossval_type = kargs.pop('crossval_type')
78 | super(CrossValEvaluation, self).__init__(**kargs)
79 |
80 | def get_crossval_val_func(self, dataset):
81 | return getattr(dataset, 'get_%s_val' % self.crossval_type.replace('-', '_'))
82 |
83 | def format_params(self, params, fold):
84 | try:
85 | return params % fold
86 | except:
87 | return params
88 |
89 | def transform(self, Mod, dataset, fold, dataset_args=None):
90 | get_crossval_val = self.get_crossval_val_func(dataset)
91 | pred, true, _ = _crossval_predict(
92 | self,
93 | proba=True,
94 | Mod=Mod,
95 | get_crossval_val=get_crossval_val,
96 | fold=fold,
97 | dataset_args=dataset_args)
98 | return pred, true
99 |
100 | def accuracy_mod(self, Mod, dataset, fold,
101 | vote=False,
102 | dataset_args=None,
103 | balance=False):
104 | get_crossval_val = self.get_crossval_val_func(dataset)
105 | pred, true, segment = _crossval_predict(
106 | self,
107 | Mod=Mod,
108 | get_crossval_val=get_crossval_val,
109 | fold=fold,
110 | dataset_args=dataset_args)
111 | if vote:
112 | from .vote import vote as do
113 | return do(true, pred, segment, vote, balance)
114 | return (true == pred).sum() / true.size
115 |
116 | def accuracy_exp(self, exp, fold):
117 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'):
118 | return self.accuracy_mod(Mod=exp.Mod,
119 | dataset=exp.dataset,
120 | fold=fold,
121 | vote=exp.get('vote', False),
122 | dataset_args=exp.get('dataset_args'))
123 | else:
124 | try:
125 | return parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1]
126 | except:
127 | return np.nan
128 |
129 | def accuracy(self, **kargs):
130 | if 'exp' in kargs:
131 | return self.accuracy_exp(**kargs)
132 | elif 'Mod' in kargs:
133 | return self.accuracy_mod(**kargs)
134 | else:
135 | assert False
136 |
137 | def accuracies(self, exps, folds):
138 | acc = []
139 | for exp in exps:
140 | for fold in folds:
141 | acc.append(self.accuracy(exp=exp, fold=fold))
142 | return np.array(acc).reshape(len(exps), len(folds))
143 |
144 | def compare(self, exps, fold):
145 | acc = []
146 | for exp in exps:
147 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'):
148 | acc.append(self.accuracy(Mod=exp.Mod,
149 | dataset=exp.dataset,
150 | fold=fold,
151 | vote=exp.get('vote', False),
152 | dataset_args=exp.get('dataset_args')))
153 | else:
154 | try:
155 | acc.append(parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1])
156 | except:
157 | acc.append(np.nan)
158 | return acc
159 |
160 | def vote_accuracy_curves(self, exps, folds, windows, balance=False):
161 | acc = []
162 | for exp in exps:
163 | for fold in folds:
164 | acc.append(self.vote_accuracy_curve(
165 | Mod=exp.Mod,
166 | dataset=exp.dataset,
167 | fold=int(fold),
168 | windows=windows,
169 | dataset_args=exp.get('dataset_args'),
170 | balance=balance))
171 | return np.array(acc).reshape(len(exps), len(folds), len(windows))
172 |
173 | def vote_accuracy_curve(self, Mod, dataset, fold, windows,
174 | dataset_args=None,
175 | balance=False):
176 | get_crossval_val = self.get_crossval_val_func(dataset)
177 | pred, true, segment = _crossval_predict(
178 | self,
179 | Mod=Mod,
180 | get_crossval_val=get_crossval_val,
181 | fold=fold,
182 | dataset_args=dataset_args)
183 | from .vote import get_vote_accuracy_curve as do
184 | return do(true, pred, segment, windows, balance)[1]
185 |
186 |
187 | def get_crossval_accuracies(crossval_type, exps, folds, batch_size=1000):
188 | acc = []
189 | evaluation = CrossValEvaluation(
190 | crossval_type=crossval_type,
191 | batch_size=batch_size
192 | )
193 | for fold in folds:
194 | acc.append(evaluation.compare(exps, fold))
195 | return acc
196 |
--------------------------------------------------------------------------------
/sigr/fft.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 |
4 |
5 | def fft(data, fs):
6 | n = data.shape[-1]
7 | window = np.hanning(n)
8 | windowed = data * window
9 | spectrum = np.fft.fft(windowed)
10 | freq = np.fft.fftfreq(n, 1 / fs)
11 | half_n = np.ceil(n / 2)
12 | spectrum_half = (2 / n) * spectrum[..., :half_n]
13 | freq_half = freq[:half_n]
14 | return freq_half, np.abs(spectrum_half)
15 |
--------------------------------------------------------------------------------
/sigr/lstm.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 | import mxnet as mx
3 |
4 |
5 | LSTMState = namedtuple("LSTMState", ["c", "h"])
6 | # LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", "h2h_weight", "h2h_bias"])
7 | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_gamma", "h2h_weight", "h2h_gamma",
8 | "beta", "c_gamma", "c_beta"])
9 | LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol",
10 | "init_states", "last_states",
11 | "seq_data", "seq_labels", "seq_outputs",
12 | "param_blocks"])
13 |
14 |
15 | class LSTM(object):
16 |
17 | def lstm_orig(self, prefix, num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
18 | """LSTM Cell symbol"""
19 | if dropout > 0.:
20 | indata = mx.sym.Dropout(data=indata, p=dropout)
21 | i2h = mx.sym.FullyConnected(data=indata,
22 | weight=param.i2h_weight,
23 | bias=param.i2h_bias,
24 | num_hidden=num_hidden * 4,
25 | name=prefix + "t%d_l%d_i2h" % (seqidx, layeridx))
26 | h2h = mx.sym.FullyConnected(data=prev_state.h,
27 | weight=param.h2h_weight,
28 | bias=param.h2h_bias,
29 | num_hidden=num_hidden * 4,
30 | name=prefix + "t%d_l%d_h2h" % (seqidx, layeridx))
31 | gates = i2h + h2h
32 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
33 | name=prefix + "t%d_l%d_slice" % (seqidx, layeridx))
34 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
35 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
36 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
37 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
38 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
39 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh")
40 | return LSTMState(c=next_c, h=next_h)
41 |
42 | def lstm_not_share_beta_gamma(self, prefix, num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
43 | """LSTM Cell symbol"""
44 | if dropout > 0.:
45 | indata = mx.sym.Dropout(data=indata, p=dropout)
46 | i2h = mx.sym.FullyConnected(data=indata,
47 | weight=param.i2h_weight,
48 | bias=param.i2h_bias,
49 | num_hidden=num_hidden * 4,
50 | name=prefix + "t%d_l%d_i2h" % (seqidx, layeridx))
51 | i2h = mx.sym.BatchNorm(
52 | name=prefix + "t%d_l%d_i2h_bn" % (seqidx, layeridx),
53 | data=i2h,
54 | fix_gamma=False,
55 | momentum=0.9,
56 | attr={'wd_mult': '0'}
57 | )
58 | h2h = mx.sym.FullyConnected(data=prev_state.h,
59 | weight=param.h2h_weight,
60 | bias=param.h2h_bias,
61 | num_hidden=num_hidden * 4,
62 | name=prefix + "t%d_l%d_h2h" % (seqidx, layeridx))
63 | h2h = mx.sym.BatchNorm(
64 | name=prefix + "t%d_l%d_h2h_bn" % (seqidx, layeridx),
65 | data=h2h,
66 | fix_gamma=False,
67 | momentum=0.9,
68 | attr={'wd_mult': '0'}
69 | )
70 | gates = i2h + h2h
71 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
72 | name=prefix + "t%d_l%d_slice" % (seqidx, layeridx))
73 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
74 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
75 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
76 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
77 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
78 | next_h = out_gate * mx.sym.Activation(
79 | mx.symbol.BatchNorm(
80 | name=prefix + 't%d_l%d_c_bn' % (seqidx, layeridx),
81 | data=next_c,
82 | fix_gamma=False,
83 | momentum=0.9,
84 | attr={'wd_mult': '0'}
85 | ),
86 | act_type="tanh"
87 | )
88 | return LSTMState(c=next_c, h=next_h)
89 |
90 | def lstm(self, prefix, num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.):
91 | """LSTM Cell symbol"""
92 | if dropout > 0.:
93 | indata = mx.sym.Dropout(data=indata, p=dropout)
94 | i2h = mx.sym.FullyConnected(data=indata,
95 | weight=param.i2h_weight,
96 | num_hidden=num_hidden * 4,
97 | no_bias=True,
98 | name=prefix + "t%d_l%d_i2h" % (seqidx, layeridx))
99 | i2h = self.BatchNorm(
100 | name=prefix + "t%d_l%d_i2h_bn" % (seqidx, layeridx),
101 | data=i2h,
102 | gamma=param.i2h_gamma,
103 | num_channel=num_hidden * 4
104 | )
105 | h2h = mx.sym.FullyConnected(data=prev_state.h,
106 | weight=param.h2h_weight,
107 | num_hidden=num_hidden * 4,
108 | no_bias=True,
109 | name=prefix + "t%d_l%d_h2h" % (seqidx, layeridx))
110 | h2h = self.BatchNorm(
111 | name=prefix + "t%d_l%d_h2h_bn" % (seqidx, layeridx),
112 | data=h2h,
113 | gamma=param.h2h_gamma,
114 | beta=param.beta,
115 | num_channel=num_hidden * 4
116 | )
117 | gates = i2h + h2h
118 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4,
119 | name=prefix + "t%d_l%d_slice" % (seqidx, layeridx))
120 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid")
121 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh")
122 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid")
123 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid")
124 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform)
125 | next_h = out_gate * mx.sym.Activation(
126 | self.BatchNorm(
127 | name=prefix + 't%d_l%d_c_bn' % (seqidx, layeridx),
128 | data=next_c,
129 | gamma=param.c_gamma,
130 | beta=param.c_beta,
131 | num_channel=num_hidden
132 | ),
133 | act_type="tanh"
134 | )
135 | return LSTMState(c=next_c, h=next_h)
136 |
137 | def BatchNorm(self, name, data, gamma, beta=None, **kargs):
138 | net = data
139 |
140 | if not self.bn:
141 | return net
142 |
143 | if self.minibatch:
144 | num_channel = kargs.pop('num_channel')
145 | net = mx.symbol.Reshape(net, shape=(-1, self.num_subject * num_channel))
146 | net = mx.symbol.BatchNorm(
147 | name=name + '_norm',
148 | data=net,
149 | fix_gamma=True,
150 | momentum=0.9,
151 | attr={'wd_mult': '0', 'lr_mult': '0'}
152 | )
153 | net = mx.symbol.Reshape(data=net, shape=(-1, num_channel))
154 | else:
155 | net = mx.symbol.BatchNorm(
156 | name=name + '_norm',
157 | data=net,
158 | fix_gamma=True,
159 | momentum=0.9,
160 | attr={'wd_mult': '0', 'lr_mult': '0'}
161 | )
162 | net = mx.symbol.broadcast_mul(net, gamma)
163 | if beta is not None:
164 | net = mx.symbol.broadcast_plus(net, beta)
165 | return net
166 |
167 | def __init__(
168 | self,
169 | prefix,
170 | data,
171 | num_lstm_layer,
172 | seq_len,
173 | num_hidden,
174 | dropout=0.,
175 | minibatch=False,
176 | num_subject=0,
177 | bn=True,
178 | ):
179 | self.bn = bn
180 | self.minibatch = minibatch
181 | self.num_subject = num_subject
182 | if self.minibatch:
183 | assert self.num_subject > 0
184 |
185 | prefix += 'lstm_'
186 |
187 | param_cells = []
188 | last_states = []
189 | for i in range(num_lstm_layer):
190 | param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_weight" % i),
191 | # i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_bias" % i),
192 | h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_weight" % i),
193 | # h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_bias" % i)))
194 | i2h_gamma=mx.symbol.Variable(prefix + 'l%d_i2h_gamma' % i, shape=(1, num_hidden * 4), attr={'wd_mult': '0'}),
195 | h2h_gamma=mx.symbol.Variable(prefix + 'l%d_h2h_gamma' % i, shape=(1, num_hidden * 4), attr={'wd_mult': '0'}),
196 | beta=mx.symbol.Variable(prefix + 'l%d_beta' % i, shape=(1, num_hidden * 4), attr={'wd_mult': '0'}),
197 | c_gamma=mx.symbol.Variable(prefix + 'l%d_c_gamma' % i, shape=(1, num_hidden), attr={'wd_mult': '0'}),
198 | c_beta=mx.symbol.Variable(prefix + 'l%d_c_beta' % i, shape=(1, num_hidden), attr={'wd_mult': '0'})))
199 | state = LSTMState(c=mx.sym.Variable(prefix + "l%d_init_c" % i, attr={'lr_mult': '0'}),
200 | h=mx.sym.Variable(prefix + "l%d_init_h" % i, attr={'lr_mult': '0'}))
201 | last_states.append(state)
202 | assert(len(last_states) == num_lstm_layer)
203 |
204 | wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1)
205 |
206 | hidden_all = []
207 | for seqidx in range(seq_len):
208 | hidden = wordvec[seqidx]
209 |
210 | # stack LSTM
211 | for i in range(num_lstm_layer):
212 | if i == 0:
213 | dp_ratio = 0.
214 | else:
215 | dp_ratio = dropout
216 | next_state = self.lstm(prefix, num_hidden, indata=hidden,
217 | prev_state=last_states[i],
218 | param=param_cells[i],
219 | seqidx=seqidx, layeridx=i, dropout=dp_ratio)
220 | hidden = next_state.h
221 | last_states[i] = next_state
222 |
223 | # decoder
224 | if dropout > 0.:
225 | hidden = mx.sym.Dropout(data=hidden, p=dropout)
226 | hidden_all.append(hidden)
227 |
228 | self.net = hidden_all
229 | # return mx.sym.Concat(*hidden_all, dim=1)
230 | # return mx.sym.Pooling(mx.sym.Concat(*[mx.sym.Reshape(h, shape=(0, 0, 1, 1)) for h in hidden_all], dim=2), kernel=(1, 1), global_pool=True, pool_type='max')
231 | # return mx.sym.Pooling(mx.sym.Concat(*[mx.sym.Reshape(h, shape=(0, 0, 1, 1)) for h in hidden_all], dim=2), kernel=(1, 1), global_pool=True, pool_type='avg')
232 |
233 |
234 | def lstm_unroll(**kargs):
235 | return LSTM(**kargs).net
236 |
--------------------------------------------------------------------------------
/sigr/parse_log.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import re
3 | import numpy as np
4 |
5 |
6 | def div(up, down):
7 | try:
8 | return up / down
9 | except:
10 | return np.nan
11 |
12 |
13 | def parse_log(path):
14 | with open(path, 'r') as f:
15 | lines = f.readlines()
16 |
17 | res = [re.compile('.*Epoch\[(\d+)\] Train-accuracy(?:\[g\])?=([.\d]+)'),
18 | re.compile('.*Epoch\[(\d+)\] Validation-accuracy(?:\[g\])?=([.\d]+)'),
19 | re.compile('.*Epoch\[(\d+)\] Time.*=([.\d]+)')]
20 |
21 | data = {}
22 | for l in lines:
23 | i = 0
24 | for r in res:
25 | m = r.match(l)
26 | if m is not None:
27 | break
28 | i += 1
29 | if m is None:
30 | continue
31 |
32 | assert len(m.groups()) == 2
33 | epoch = int(m.groups()[0])
34 | val = float(m.groups()[1])
35 |
36 | if epoch not in data:
37 | data[epoch] = [0] * len(res) * 2
38 |
39 | data[epoch][i*2] += val
40 | data[epoch][i*2+1] += 1
41 |
42 | df = []
43 | for k, v in data.items():
44 | try:
45 | df.append({
46 | 'epoch': k + 1,
47 | 'train': div(v[0], v[1]),
48 | 'val': div(v[2], v[3]),
49 | 'time': div(v[4], v[5])
50 | })
51 | except:
52 | pass
53 | try:
54 | import pandas as pd
55 | return pd.DataFrame(df)
56 | except:
57 | return df
58 |
--------------------------------------------------------------------------------
/sigr/sklearn_module.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | from nose.tools import assert_equal
3 | import mxnet as mx
4 | import numpy as np
5 | from logbook import Logger
6 | import joblib as jb
7 | from .base_module import BaseModule
8 |
9 |
10 | logger = Logger('sigr')
11 |
12 |
13 | class SklearnModule(BaseModule):
14 |
15 | def _get_data_label(self, data_iter):
16 | data = []
17 | label = []
18 | for batch in data_iter:
19 | data.append(batch.data[0].asnumpy().reshape(
20 | batch.data[0].shape[0], -1))
21 | label.append(batch.label[0].asnumpy())
22 | if batch.pad:
23 | data[-1] = data[-1][:-batch.pad]
24 | label[-1] = label[-1][:-batch.pad]
25 | data = np.vstack(data)
26 | label = np.hstack(label)
27 | assert_equal(len(data), len(label))
28 | return data, label
29 |
30 | def fit(self, train_data, eval_data, eval_metric='acc', **kargs):
31 | snapshot = kargs.pop('snapshot')
32 | self.clf.fit(*self._get_data_label(train_data))
33 | jb.dump(self.clf, snapshot + '-0001.params')
34 |
35 | if not isinstance(eval_metric, mx.metric.EvalMetric):
36 | eval_metric = mx.metric.create(eval_metric)
37 | data, label = self._get_data_label(eval_data)
38 | pred = self.clf.predict(data).astype(np.int64)
39 | prob = np.zeros((len(pred), pred.max() + 1))
40 | prob[np.arange(len(prob)), pred] = 1
41 | eval_metric.update([mx.nd.array(label)], [mx.nd.array(prob)])
42 | for name, val in eval_metric.get_name_value():
43 | logger.info('Epoch[0] Validation-{}={}', name, val)
44 |
45 |
46 | class KNNModule(SklearnModule):
47 |
48 | def __init__(self):
49 | from sklearn.neighbors import KNeighborsClassifier as KNN
50 | self.clf = KNN()
51 |
52 | @classmethod
53 | def parse(cls, text, **kargs):
54 | if text == 'knn':
55 | return cls()
56 |
57 |
58 | class SVMModule(SklearnModule):
59 |
60 | def __init__(self):
61 | from sklearn.svm import LinearSVC
62 | self.clf = LinearSVC()
63 |
64 | @classmethod
65 | def parse(cls, text, **kargs):
66 | if text == 'svm':
67 | return cls()
68 |
69 |
70 | class RandomForestsModule(SklearnModule):
71 |
72 | def __init__(self):
73 | from sklearn.ensemble import RandomForestClassifier as RandomForests
74 | self.clf = RandomForests()
75 |
76 | @classmethod
77 | def parse(cls, text, **kargs):
78 | if text == 'random-forests':
79 | return cls()
80 |
81 |
82 | class LDAModule(SklearnModule):
83 |
84 | def __init__(self):
85 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA
86 | self.clf = LDA()
87 |
88 | @classmethod
89 | def parse(cls, text, **kargs):
90 | if text == 'lda':
91 | return cls()
92 |
--------------------------------------------------------------------------------
/sigr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from contextlib import contextmanager
2 | import os
3 | import numpy as np
4 | from .proxy import LazyProxy
5 | assert LazyProxy
6 |
7 |
8 | @contextmanager
9 | def logging_context(path=None, level=None):
10 | from logbook import StderrHandler, FileHandler
11 | from logbook.compat import redirected_logging
12 | with StderrHandler(level=level or 'INFO').applicationbound():
13 | if path:
14 | if not os.path.isdir(os.path.dirname(path)):
15 | os.makedirs(os.path.dirname(path))
16 | with FileHandler(path, bubble=True).applicationbound():
17 | with redirected_logging():
18 | yield
19 | else:
20 | with redirected_logging():
21 | yield
22 |
23 |
24 | def return_list(func):
25 | import inspect
26 | from functools import wraps
27 | assert inspect.isgeneratorfunction(func)
28 |
29 | @wraps(func)
30 | def wrapped(*args, **kargs):
31 | return list(func(*args, **kargs))
32 |
33 | return wrapped
34 |
35 |
36 | @return_list
37 | def continuous_segments(label):
38 | label = np.asarray(label)
39 |
40 | if not len(label):
41 | return
42 |
43 | breaks = list(np.where(label[:-1] != label[1:])[0] + 1)
44 | for begin, end in zip([0] + breaks, breaks + [len(label)]):
45 | assert begin < end
46 | yield begin, end
47 |
48 |
49 | def cached(*args, **kargs):
50 | import joblib as jb
51 | from .. import CACHE
52 | memo = getattr(cached, 'memo', None)
53 | if memo is None:
54 | cached.memo = memo = jb.Memory(CACHE, verbose=0)
55 | return memo.cache(*args, **kargs)
56 |
57 |
58 | def get_segments(data, window):
59 | return windowed_view(
60 | data.flat,
61 | window * data.shape[1],
62 | (window - 1) * data.shape[1]
63 | )
64 |
65 |
66 | def windowed_view(arr, window, overlap):
67 | from numpy.lib.stride_tricks import as_strided
68 | arr = np.asarray(arr)
69 | window_step = window - overlap
70 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step,
71 | window)
72 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) +
73 | arr.strides[-1:])
74 | return as_strided(arr, shape=new_shape, strides=new_strides)
75 |
76 |
77 | class Bunch(dict):
78 |
79 | def __getattr__(self, key):
80 | if key in self:
81 | return self[key]
82 | raise AttributeError(key)
83 |
84 | def __setattr__(self, key, value):
85 | self[key] = value
86 |
87 |
88 | def _packargs(func):
89 | from functools import wraps
90 | import inspect
91 |
92 | @wraps(func)
93 | def wrapped(ctx_or_args, **kargs):
94 | if isinstance(ctx_or_args, Bunch):
95 | args = ctx_or_args
96 | else:
97 | args = ctx_or_args.obj
98 | ignore = inspect.getargspec(func).args
99 | args.update({key: kargs.pop(key) for key in list(kargs)
100 | if key not in ignore and key not in args})
101 | return func(ctx_or_args, **kargs)
102 | return wrapped
103 |
104 |
105 | def packargs(func):
106 | import click
107 | return click.pass_obj(_packargs(func))
108 |
109 |
110 | def butter_bandpass_filter(data, lowcut, highcut, fs, order):
111 | from scipy.signal import butter, lfilter
112 |
113 | nyq = 0.5 * fs
114 | low = lowcut / nyq
115 | high = highcut / nyq
116 |
117 | b, a = butter(order, [low, high], btype='bandpass')
118 | y = lfilter(b, a, data)
119 | return y
120 |
121 |
122 | def butter_bandstop_filter(data, lowcut, highcut, fs, order):
123 | from scipy.signal import butter, lfilter
124 |
125 | nyq = 0.5 * fs
126 | low = lowcut / nyq
127 | high = highcut / nyq
128 |
129 | b, a = butter(order, [low, high], btype='bandstop')
130 | y = lfilter(b, a, data)
131 | return y
132 |
133 |
134 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False):
135 | from scipy.signal import butter, lfilter, filtfilt
136 |
137 | nyq = 0.5 * fs
138 | cut = cut / nyq
139 |
140 | b, a = butter(order, cut, btype='low')
141 | y = (filtfilt if zero_phase else lfilter)(b, a, data)
142 | return y
143 |
--------------------------------------------------------------------------------
/sigr/utils/proxy.py:
--------------------------------------------------------------------------------
1 | class LazyProxy(object):
2 |
3 | def __init__(self, make):
4 | self._make = make
5 |
6 | def __getattr__(self, name):
7 | if name == '_inst':
8 | self._inst = self._make()
9 | return self._inst
10 | return getattr(self._inst, name)
11 |
12 | def __setattr__(self, name, value):
13 | if name in ('_make', '_inst'):
14 | return super(LazyProxy, self).__setattr__(name, value)
15 | return setattr(self._inst, name, value)
16 |
17 | def __getstate__(self):
18 | return self._make
19 |
20 | def __setstate__(self, make):
21 | self._make = make
22 |
23 | def __hash__(self):
24 | return hash(self._make)
25 |
26 | def __iter__(self):
27 | return self._inst.__iter__()
28 |
--------------------------------------------------------------------------------
/sigr/vote.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import numpy as np
3 | import joblib as jb
4 | from nose.tools import assert_greater
5 | from .utils import return_list, cached
6 | from . import Context
7 |
8 |
9 | def get_vote_accuracy_curve(labels, predictions, segments, windows, balance=False):
10 | if len(set(segments)) < len(windows):
11 | func = get_vote_accuracy_curve_aux
12 | else:
13 | func = get_vote_accuracy_curve_aux_few_windows
14 | return func(np.asarray(labels),
15 | np.asarray(predictions),
16 | np.asarray(segments),
17 | np.asarray(windows),
18 | balance)
19 |
20 |
21 | @cached
22 | def get_vote_accuracy_curve_aux(labels, predictions, segments, windows, balance):
23 | segment_labels = partial_vote(labels, segments)
24 | return (
25 | np.asarray(windows),
26 | np.array(list(Context.parallel(
27 | jb.delayed(get_vote_accuracy_curve_step)(
28 | segment_labels,
29 | predictions,
30 | segments,
31 | window,
32 | balance
33 | ) for window in windows
34 | )))
35 | )
36 |
37 |
38 | @cached
39 | def get_vote_accuracy_curve_aux_few_windows(labels, predictions, segments, windows, balance):
40 | segment_labels = partial_vote(labels, segments)
41 | return (
42 | np.asarray(windows),
43 | np.array([
44 | get_vote_accuracy_curve_step(
45 | segment_labels,
46 | predictions,
47 | segments,
48 | window,
49 | balance,
50 | parallel=True
51 | ) for window in windows
52 | ])
53 | )
54 |
55 |
56 | def get_vote_accuracy(labels, predictions, segments, window, balance):
57 | _, y = get_vote_accuracy_curve(labels, predictions, segments, [window], balance)
58 | return y[0]
59 |
60 |
61 | vote = get_vote_accuracy
62 |
63 |
64 | def get_segment_vote_accuracy(segment_label, segment_predictions, window):
65 | def gen():
66 | count = {
67 | label: np.hstack([[0], np.cumsum(segment_predictions == label)])
68 | for label in set(segment_predictions)
69 | }
70 | tmp = window
71 | if tmp == -1:
72 | tmp = len(segment_predictions)
73 | tmp = min(tmp, len(segment_predictions))
74 | for begin in range(len(segment_predictions) - tmp + 1):
75 | yield segment_label == max(
76 | count,
77 | key=lambda label: count[label][begin + tmp] - count[label][begin]
78 | ), segment_label
79 | return list(gen())
80 |
81 |
82 | def get_vote_accuracy_curve_step(segment_labels, predictions, segments, window,
83 | balance,
84 | parallel=False):
85 | def gen():
86 | # assert_greater(window, 0)
87 | assert window > 0 or window == -1
88 | if not parallel:
89 | for segment_label, segment_predictions in zip(segment_labels, split(predictions, segments)):
90 | for ret in get_segment_vote_accuracy(segment_label, segment_predictions, window):
91 | yield ret
92 | else:
93 | for rets in Context.parallel(
94 | jb.delayed(get_segment_vote_accuracy)(segment_label, segment_predictions, window)
95 | for segment_label, segment_predictions in zip(segment_labels, split(predictions, segments))
96 | ):
97 | for ret in rets:
98 | yield ret
99 |
100 | good, labels = zip(*list(gen()))
101 | good = np.asarray(good)
102 |
103 | if not balance:
104 | return np.sum(good) / len(good)
105 | else:
106 | acc = []
107 | for label in set(labels):
108 | mask = [labels == label]
109 | acc.append(np.sum(good[mask]) / np.sum(mask))
110 | return np.mean(acc)
111 |
112 |
113 | @return_list
114 | def partial_vote(labels, segments, length=None):
115 | for part in split(labels, segments):
116 | part = list(part)
117 |
118 | if length is not None:
119 | part = part[:length]
120 |
121 | assert_greater(len(part), 0)
122 | yield max([(part.count(label), label) for label in set(part)])[1]
123 |
124 |
125 | def split(labels, segments):
126 | return [labels[segments == segment] for segment in sorted(set(segments))]
127 |
--------------------------------------------------------------------------------