├── .dockerignore ├── .gitignore ├── LICENSE ├── README.md ├── docker ├── Dockerfile ├── elementwise_binary_broadcast_op-inl.h └── mxnet │ └── Dockerfile ├── scripts ├── rundocker.sh ├── test.py ├── test.sh └── train.sh └── sigr ├── __init__.py ├── app.py ├── base_module.py ├── constant.py ├── coral.py ├── data ├── __init__.py ├── capgmyo │ ├── __init__.py │ ├── dba.py │ ├── dbb.py │ └── dbc.py ├── csl.py ├── ninapro │ ├── __init__.py │ ├── caputo.py │ ├── db1.py │ ├── db1_g12.py │ ├── db1_g5.py │ ├── db1_g53.py │ ├── db1_g8.py │ └── db1_matlab_lowpass.py ├── preprocess.py ├── s21.py └── s21_soft_label.scv ├── evaluation.py ├── fft.py ├── lstm.py ├── module.py ├── parse_log.py ├── sklearn_module.py ├── symbol.py ├── utils ├── __init__.py └── proxy.py └── vote.py /.dockerignore: -------------------------------------------------------------------------------- 1 | .cache/ 2 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | *.swp 2 | *.pyc 3 | .ipynb_checkpoints/ 4 | .cache/ 5 | /scripts/exp_inter 6 | /tmp/ 7 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | {one line to give the program's name and a brief idea of what it does.} 635 | Copyright (C) {year} {name of author} 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | {project} Copyright (C) {year} {fullname} 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Surface EMG-based Inter-session Gesture Recognition Enhanced by Deep Domain Adaptation 2 | 3 | ## Requirements 4 | 5 | * A CUDA compatible GPU 6 | * Ubuntu 14.04 or any other Linux/Unix that can run Docker 7 | * [Docker](http://docker.io/) 8 | * [Nvidia Docker](https://github.com/NVIDIA/nvidia-docker) 9 | 10 | ## Usage 11 | 12 | Following commands will 13 | (1) pull docker image (see `docker/Dockerfile` for details); 14 | (2) train ConvNets on the training sets of CSL-HDEMG, CapgMyo and NinaPro DB1, respectively; 15 | and (3) test trained ConvNets on the test sets. 16 | 17 | ``` 18 | mkdir .cache 19 | # put NinaPro DB1 in .cache/ninapro-db1 20 | # put CapgMyo DB-a in .cache/dba 21 | # put CapgMyo DB-b in .cache/dbb 22 | # put CapgMyo DB-c in .cache/dbc 23 | # put CSL-HDEMG in .cache/csl 24 | docker pull answeror/sigr:2016-09-21 25 | scripts/train.sh 26 | scripts/test.sh 27 | ``` 28 | 29 | Training on NinaPro and CapgMyo will take 1 to 2 hours depending on your GPU. 30 | Training on CSL-HDEMG will take several days. 31 | You can accelerate traning and testing by distribute different folds on different GPUs with the `gpu` parameter. 32 | 33 | The NinaPro DB1 should be segmented according to the gesture labels and stored in Matlab format as follows. 34 | `.cache/ninapro-db1/data/sss/ggg/sss_ggg_ttt.mat` contains a field `data` (frames x channels) represents the trial `ttt` of gesture `ggg` of subject `sss`. 35 | Numbers are starting from zero. Gesture 0 is the rest posture. 36 | For example, `.cache/ninapro-db1/data/000/001/000_001_000.mat` is the 0th trial of 1st gesture of 0th subject, 37 | and `.cache/ninapro-db1/data/002/003/002_003_004.mat` is the 4th trial of 3th gesture of 2nd subject. 38 | You can download the prepared dataset from or prepare it by yourself. 39 | 40 | ## License 41 | 42 | Licensed under an GPL v3.0 license. 43 | 44 | ## Bibtex 45 | 46 | ``` 47 | @article{Du_Sensors_2017, 48 | title={{Surface EMG-based inter-session gesture recognition enhanced by deep domain adaptation}}, 49 | author={Du, Yu and Jin, Wenguang and Wei, Wentao and Hu, Yu and Geng, Weidong}, 50 | journal={Sensors}, 51 | volume={17}, 52 | number={3}, 53 | pages={458}, 54 | year={2017}, 55 | publisher={Multidisciplinary Digital Publishing Institute} 56 | } 57 | ``` 58 | 59 | ## Misc 60 | 61 | Thanks DMLC team for their great [MxNet](https://github.com/dmlc/mxnet)! 62 | -------------------------------------------------------------------------------- /docker/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM answeror/mxnet:f2684a6 2 | MAINTAINER answeror 3 | 4 | RUN apt-get install -y python-pip python-scipy 5 | RUN pip install click logbook joblib nose 6 | 7 | RUN cd /mxnet && \ 8 | git reset --hard && \ 9 | git checkout master && \ 10 | git pull 11 | 12 | RUN cd /mxnet && \ 13 | git checkout 7a485bb && \ 14 | git submodule update && \ 15 | git checkout 887491d src/operator/elementwise_binary_broadcast_op-inl.h && \ 16 | sed -i -e 's/CHECK(ksize_x <= dshape\[3\] && ksize_y <= dshape\[2\])/CHECK(ksize_x <= dshape[3] + 2 * param_.pad[1] \&\& ksize_y <= dshape[2] + 2 * param_.pad[0])/' src/operator/convolution-inl.h && \ 17 | cp make/config.mk . && \ 18 | echo "USE_CUDA=1" >>config.mk && \ 19 | echo "USE_CUDA_PATH=/usr/local/cuda" >>config.mk && \ 20 | echo "USE_CUDNN=1" >>config.mk && \ 21 | echo "USE_BLAS=openblas" >>config.mk && \ 22 | make clean && \ 23 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs 24 | 25 | ADD elementwise_binary_broadcast_op-inl.h /mxnet/src/operator/elementwise_binary_broadcast_op-inl.h 26 | RUN cd /mxnet && \ 27 | make clean && \ 28 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs 29 | 30 | RUN pip install jupyter pandas matplotlib seaborn scikit-learn 31 | RUN mkdir -p -m 700 /root/.jupyter/ && \ 32 | echo "c.NotebookApp.ip = '*'" >> /root/.jupyter/jupyter_notebook_config.py 33 | EXPOSE 8888 34 | CMD ["sh", "-c", "jupyter notebook"] 35 | 36 | WORKDIR /code 37 | -------------------------------------------------------------------------------- /docker/mxnet/Dockerfile: -------------------------------------------------------------------------------- 1 | FROM nvidia/cuda:7.5-cudnn5-devel 2 | MAINTAINER answeror 3 | 4 | RUN echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty main restricted universe multiverse" > /etc/apt/sources.list && \ 5 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-security main restricted universe multiverse" >> /etc/apt/sources.list && \ 6 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-updates main restricted universe multiverse" >> /etc/apt/sources.list && \ 7 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-proposed main restricted universe multiverse" >> /etc/apt/sources.list && \ 8 | echo "deb http://mirrors.zju.edu.cn/ubuntu/ trusty-backports main restricted universe multiverse" >> /etc/apt/sources.list && \ 9 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty main restricted universe multiverse" >> /etc/apt/sources.list && \ 10 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-security main restricted universe multiverse" >> /etc/apt/sources.list && \ 11 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-updates main restricted universe multiverse" >> /etc/apt/sources.list && \ 12 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-proposed main restricted universe multiverse" >> /etc/apt/sources.list && \ 13 | echo "deb-src http://mirrors.zju.edu.cn/ubuntu/ trusty-backports main restricted universe multiverse" >> /etc/apt/sources.list && \ 14 | apt-get -qqy update 15 | 16 | # mxnet 17 | RUN apt-get update && apt-get install -y \ 18 | build-essential \ 19 | git \ 20 | libopenblas-dev \ 21 | libopencv-dev \ 22 | python-numpy \ 23 | wget \ 24 | unzip 25 | RUN git clone --recursive https://github.com/dmlc/mxnet/ && cd mxnet && \ 26 | git checkout f2684a6 && \ 27 | sed -i -e 's/CHECK(ksize_x <= dshape\[3\] && ksize_y <= dshape\[2\])/CHECK(ksize_x <= dshape[3] + 2 * param_.pad[1] \&\& ksize_y <= dshape[2] + 2 * param_.pad[0])/' src/operator/convolution-inl.h && \ 28 | cp make/config.mk . && \ 29 | echo "USE_CUDA=1" >>config.mk && \ 30 | echo "USE_CUDA_PATH=/usr/local/cuda" >>config.mk && \ 31 | echo "USE_CUDNN=1" >>config.mk && \ 32 | echo "USE_BLAS=openblas" >>config.mk && \ 33 | make -j8 ADD_LDFLAGS=-L/usr/local/cuda/lib64/stubs 34 | ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:$LD_LIBRARY_PATH 35 | 36 | ENV PYTHONPATH /mxnet/python 37 | -------------------------------------------------------------------------------- /scripts/rundocker.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | nvidia-docker run --rm -ti -v $(pwd):/code answeror/sigr:2016-09-21 $@ 4 | -------------------------------------------------------------------------------- /scripts/test.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function, division 2 | import os 3 | import sys 4 | sys.path.insert(0, os.getcwd()) 5 | import numpy as np 6 | import mxnet as mx 7 | from sigr.evaluation import CrossValEvaluation as CV, Exp 8 | from sigr.data import Preprocess, Dataset 9 | from sigr import Context 10 | 11 | 12 | inter_subject_eval = CV(crossval_type='inter-subject', batch_size=1000) 13 | inter_session_eval = CV(crossval_type='inter-session', batch_size=1000) 14 | one_fold_intra_subject_eval = CV(crossval_type='one-fold-intra-subject', batch_size=1000) 15 | 16 | print('Inter-session CSL-HDEMG') 17 | print('============') 18 | 19 | with Context(parallel=True, level='DEBUG'): 20 | acc = inter_session_eval.accuracies( 21 | [Exp(dataset=Dataset.from_name('csl'), vote=-1, 22 | dataset_args=dict(preprocess=Preprocess.parse('(csl-bandpass,csl-cut,median)')), 23 | Mod=dict(num_gesture=27, 24 | adabn=True, 25 | num_adabn_epoch=10, 26 | context=[mx.gpu(0)], 27 | symbol_kargs=dict(dropout=0, num_semg_row=24, num_semg_col=7, num_filter=64), 28 | params='.cache/sensors-csl-inter-session-%d/model-0028.params'))], 29 | folds=np.arange(25)) 30 | print('Per-trial majority voting accuracy: %f' % acc.mean()) 31 | 32 | print('') 33 | print('Inter-subject CapgMyo DB-b') 34 | print('============') 35 | 36 | with Context(parallel=True, level='DEBUG'): 37 | acc = inter_subject_eval.vote_accuracy_curves( 38 | [Exp(dataset=Dataset.from_name('dbb'), 39 | Mod=dict(num_gesture=8, 40 | adabn=True, 41 | num_adabn_epoch=10, 42 | context=[mx.gpu(0)], 43 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64), 44 | params='.cache/sensors-dbb-inter-subject-%d/model-0028.params'))], 45 | folds=np.arange(10), 46 | windows=[1, 150]) 47 | acc = acc.mean(axis=(0, 1)) 48 | print('Single frame accuracy: %f' % acc[0]) 49 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[1]) 50 | 51 | print('') 52 | print('Inter-session CapgMyo DB-b') 53 | print('============') 54 | 55 | with Context(parallel=True, level='DEBUG'): 56 | acc = inter_session_eval.vote_accuracy_curves( 57 | [Exp(dataset=Dataset.from_name('dbb'), 58 | Mod=dict(num_gesture=8, 59 | adabn=True, 60 | num_adabn_epoch=10, 61 | context=[mx.gpu(0)], 62 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64), 63 | params='.cache/sensors-dbb-inter-session-%d/model-0028.params'))], 64 | folds=np.arange(10), 65 | windows=[1, 150]) 66 | acc = acc.mean(axis=(0, 1)) 67 | print('Single frame accuracy: %f' % acc[0]) 68 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[1]) 69 | 70 | print('') 71 | print('Inter-subject CapgMyo DB-c') 72 | print('============') 73 | 74 | with Context(parallel=True, level='DEBUG'): 75 | acc = inter_subject_eval.vote_accuracy_curves( 76 | [Exp(dataset=Dataset.from_name('dbc'), 77 | Mod=dict(num_gesture=12, 78 | adabn=True, 79 | num_adabn_epoch=10, 80 | context=[mx.gpu(0)], 81 | symbol_kargs=dict(dropout=0, num_semg_row=16, num_semg_col=8, num_filter=64), 82 | params='.cache/sensors-dbc-inter-subject-%d/model-0028.params'))], 83 | folds=np.arange(10), 84 | windows=[1, 150]) 85 | acc = acc.mean(axis=(0, 1)) 86 | print('Single frame accuracy: %f' % acc[0]) 87 | print('150 frames (150 ms) majority voting accuracy: %f' % acc[1]) 88 | 89 | print('') 90 | print('Inter-subject NinaPro DB1') 91 | print('===========') 92 | with Context(parallel=True, level='DEBUG'): 93 | acc = one_fold_intra_subject_eval.vote_accuracy_curves( 94 | [Exp(dataset=Dataset.from_name('ninapro-db1/caputo'), 95 | Mod=dict(num_gesture=52, 96 | context=[mx.gpu(0)], 97 | symbol_kargs=dict(dropout=0, num_semg_row=1, num_semg_col=10, num_filter=64), 98 | params='.cache/sensors-ninapro-one-fold-intra-subject-%d/model-0028.params'))], 99 | folds=np.arange(27), 100 | windows=[1, 40]) 101 | acc = acc.mean(axis=(0, 1)) 102 | print('Single frame accuracy: %f' % acc[0]) 103 | print('40 frames (400 ms) majority voting accuracy: %f' % acc[1]) 104 | -------------------------------------------------------------------------------- /scripts/test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | scripts/rundocker.sh python scripts/test.py 4 | -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | # Inter-subjet recognition of 8 gestures in CapgMyo DB-b 4 | for i in $(seq 0 9 | shuf); do 5 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 6 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 7 | --root .cache/sensors-dbb-inter-subject-$i \ 8 | --num-semg-row 16 --num-semg-col 8 \ 9 | --batch-size 1000 --decay-all --dataset dbb \ 10 | --num-filter 64 \ 11 | --adabn --minibatch \ 12 | crossval --crossval-type inter-subject --fold $i 13 | done 14 | 15 | # Inter-session recognition of 8 gestures in CapgMyo DB-b 16 | for i in 1; do 17 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 18 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 19 | --root .cache/sensors-dbb-universal-inter-session-$i \ 20 | --num-semg-row 16 --num-semg-col 8 \ 21 | --batch-size 1000 --decay-all --dataset dbb \ 22 | --num-filter 64 \ 23 | --adabn --minibatch \ 24 | crossval --crossval-type universal-inter-session --fold $i 25 | done 26 | for i in $(seq 1 2 19 | shuf); do 27 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 28 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 29 | --root .cache/sensors-dbb-inter-session-$i \ 30 | --num-semg-row 16 --num-semg-col 8 \ 31 | --batch-size 1000 --decay-all --dataset dbb \ 32 | --num-filter 64 \ 33 | --params .cache/sensors-dbb-universal-inter-session-1/model-0028.params \ 34 | --fix-params ".*conv.*" --fix-params ".*pixel.*" --fix-params "fc1_.*" --fix-params "fc2_.*" \ 35 | --adabn \ 36 | crossval --crossval-type inter-session --fold $i 37 | done 38 | 39 | # Inter-subjet recognition of 12 gestures in CapgMyo DB-c 40 | for i in $(seq 0 9 | shuf); do 41 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 42 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 43 | --root .cache/sensors-dbc-inter-subject-$i \ 44 | --num-semg-row 16 --num-semg-col 8 \ 45 | --batch-size 1000 --decay-all --dataset dbc \ 46 | --num-filter 64 \ 47 | --adabn --minibatch \ 48 | crossval --crossval-type inter-subject --fold $i 49 | done 50 | 51 | # Inter-session recognition of 27 gestures in CSL-HDEMG 52 | for i in $(seq 0 5 | shuf); do 53 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 54 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 55 | --root .cache/sensors-csl-universal-inter-session-$i \ 56 | --num-semg-row 24 --num-semg-col 7 \ 57 | --batch-size 1000 --decay-all --adabn --minibatch --dataset csl \ 58 | --preprocess '(csl-bandpass,csl-cut,downsample-5,median)' \ 59 | --balance-gesture 1 \ 60 | --num-filter 64 \ 61 | crossval --crossval-type universal-inter-session --fold $i 62 | done 63 | for i in $(seq 0 24 | shuf); do 64 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 65 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 66 | --root .cache/sensors-csl-inter-session-$i \ 67 | --num-semg-row 24 --num-semg-col 7 \ 68 | --batch-size 1000 --decay-all --adabn --minibatch --dataset csl \ 69 | --preprocess '(csl-bandpass,csl-cut,median)' \ 70 | --balance-gesture 1 \ 71 | --num-filter 64 \ 72 | --params .cache/sensors-csl-universal-inter-session-$(($i % 5))/model-0028.params \ 73 | crossval --crossval-type inter-session --fold $i 74 | done 75 | 76 | # Inter-subject recognition of 52 gestures in NinaPro DB1 with calibration data 77 | for i in 0; do 78 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 79 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 80 | --root .cache/sensors-ninapro-universal-one-fold-intra-subject-$i \ 81 | --num-semg-row 1 --num-semg-col 10 \ 82 | --batch-size 1000 --decay-all --adabn --minibatch --dataset ninapro-db1/caputo \ 83 | --num-filter 64 \ 84 | --preprocess 'downsample-16' \ 85 | crossval --crossval-type universal-one-fold-intra-subject --fold $i 86 | done 87 | for i in $(seq 0 26 | shuf); do 88 | scripts/rundocker.sh python -m sigr.app exp --log log --snapshot model \ 89 | --num-epoch 28 --lr-step 16 --lr-step 24 --snapshot-period 28 \ 90 | --root .cache/sensors-ninapro-one-fold-intra-subject-$i \ 91 | --num-semg-row 1 --num-semg-col 10 \ 92 | --batch-size 1000 --decay-all --dataset ninapro-db1/caputo \ 93 | --num-filter 64 \ 94 | --params .cache/sensors-ninapro-universal-one-fold-intra-subject-0/model-0028.params \ 95 | --preprocess 'downsample-16' \ 96 | crossval --crossval-type one-fold-intra-subject --fold $i 97 | done 98 | -------------------------------------------------------------------------------- /sigr/__init__.py: -------------------------------------------------------------------------------- 1 | import mxnet as mx 2 | import numpy as np 3 | import random 4 | 5 | mx.random.seed(42) 6 | np.random.seed(43) 7 | random.seed(44) 8 | 9 | import os 10 | 11 | os.environ['JOBLIB_TEMP_FOLDER'] = '/tmp' 12 | 13 | ROOT = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) 14 | CACHE = os.path.join(ROOT, '.cache') 15 | 16 | from contextlib import contextmanager 17 | 18 | 19 | @contextmanager 20 | def Context(log=None, parallel=False, level=None): 21 | from .utils import logging_context 22 | with logging_context(log, level=level): 23 | if not parallel: 24 | yield 25 | else: 26 | import joblib as jb 27 | from multiprocessing import cpu_count 28 | with jb.Parallel(n_jobs=cpu_count()) as par: 29 | Context.parallel = par 30 | yield 31 | 32 | 33 | def _patch(func): 34 | func() 35 | return lambda: None 36 | 37 | 38 | @_patch 39 | def _patch_click(): 40 | import click 41 | orig = click.option 42 | 43 | def option(*args, **kargs): 44 | if 'help' in kargs and 'default' in kargs: 45 | kargs['help'] += ' (default {})'.format(kargs['default']) 46 | return orig(*args, **kargs) 47 | 48 | click.option = option 49 | 50 | 51 | from .data import s21 as data_s21 52 | 53 | 54 | __all__ = ['ROOT', 'CACHE', 'Context', 'data_s21'] 55 | -------------------------------------------------------------------------------- /sigr/app.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import click 3 | import mxnet as mx 4 | from logbook import Logger 5 | from pprint import pformat 6 | import os 7 | from .utils import packargs, Bunch 8 | from .module import Module 9 | from .data import Preprocess, Dataset 10 | from . import data_s21, Context, constant 11 | 12 | 13 | logger = Logger('sigr') 14 | 15 | 16 | @click.group() 17 | def cli(): 18 | pass 19 | 20 | 21 | @cli.group() 22 | @click.option('--downsample', type=int, default=0) 23 | @click.option('--num-semg-row', type=int, default=constant.NUM_SEMG_ROW, help='Rows of sEMG image') 24 | @click.option('--num-semg-col', type=int, default=constant.NUM_SEMG_COL, help='Cols of sEMG image') 25 | @click.option('--num-epoch', type=int, default=60, help='Maximum epoches') 26 | @click.option('--num-tzeng-batch', type=int, default=constant.NUM_TZENG_BATCH, 27 | help='Batch number of each Tzeng update, 2 means interleaved domain and label update') 28 | @click.option('--lr-step', type=int, multiple=True, default=[20, 40], help='Epoch numbers to decay learning rate') 29 | @click.option('--lr-factor', type=float, multiple=True) 30 | @click.option('--batch-size', type=int, default=1000, 31 | help='Batch size, should be 900 with --minibatch for s21 inter-subject experiment') 32 | @click.option('--lr', type=float, default=0.1, help='Base learning rate') 33 | @click.option('--wd', type=float, default=0.0001, help='Weight decay') 34 | @click.option('--subject-wd', type=float, help='Weight decay multiplier of the subject branch') 35 | @click.option('--gpu', type=int, multiple=True, default=[0]) 36 | @click.option('--gamma', type=float, default=constant.GAMMA, help='Gamma in RevGrad') 37 | @click.option('--log', type=click.Path(), help='Path of the logging file') 38 | @click.option('--snapshot', type=click.Path(), help='Snapshot prefix') 39 | @click.option('--root', type=click.Path(), help='Root path of the experiment, auto create if not exists') 40 | @click.option('--revgrad', is_flag=True, help='Use RevGrad') 41 | @click.option('--num-revgrad-batch', type=int, default=2, 42 | help=('Batch number of each RevGrad update, 2 means interleaved domain and label update, ' 43 | 'see "Adversarial Deep Averaging Networks for Cross-Lingual Sentiment Classification" for details')) 44 | @click.option('--tzeng', is_flag=True, help='Use Tzeng_ICCV_2015') 45 | @click.option('--confuse-conv', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on conv2') 46 | @click.option('--confuse-all', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on all layers') 47 | @click.option('--subject-loss-weight', type=float, default=1, help='Ganin et al. use 0.1 in their code') 48 | @click.option('--subject-confusion-loss-weight', type=float, default=1, 49 | help='Tzeng confusion loss weight, larger than 1 seems better') 50 | @click.option('--lambda-scale', type=float, default=constant.LAMBDA_SCALE, 51 | help='Global scale of lambda in RevGrad, 1 in their paper and 0.1 in their code') 52 | @click.option('--params', type=click.Path(exists=True), help='Inital weights') 53 | @click.option('--ignore-params', multiple=True, help='Ignore params in --params with regex') 54 | @click.option('--random-shift-fill', type=click.Choice(['zero', 'margin']), 55 | default=constant.RANDOM_SHIFT_FILL, help='Random shift filling value') 56 | @click.option('--random-shift-horizontal', type=int, default=0, help='Random shift input horizontally by x pixels') 57 | @click.option('--random-shift-vertical', type=int, default=0, help='Random shift input vertically by x pixels') 58 | @click.option('--random-scale', type=float, default=0, 59 | help='Random scale input data globally by 2^scale, and locally by 2^(scale/4)') 60 | @click.option('--random-bad-channel', type=float, multiple=True, default=[], 61 | help='Random (with a probability of 0.5 for each image) assign a pixel as specified value, usually [-1, 0, 1]') 62 | @click.option('--num-feature-block', type=int, default=constant.NUM_FEATURE_BLOCK, help='Number of FC layers in feature extraction part') 63 | @click.option('--num-gesture-block', type=int, default=constant.NUM_GESTURE_BLOCK, help='Number of FC layers in gesture branch') 64 | @click.option('--num-subject-block', type=int, default=constant.NUM_SUBJECT_BLOCK, help='Number of FC layers in subject branch') 65 | @click.option('--adabn', is_flag=True, help='AdaBN for model adaptation, must be used with --minibatch') 66 | @click.option('--num-adabn-epoch', type=int, default=constant.NUM_ADABN_EPOCH) 67 | @click.option('--num-pixel', type=int, default=constant.NUM_PIXEL, help='Pixelwise reduction layers') 68 | @click.option('--num-filter', type=int, default=constant.NUM_FILTER, help='Kernels of the conv layers') 69 | @click.option('--num-hidden', type=int, default=constant.NUM_HIDDEN, help='Kernels of the FC layers') 70 | @click.option('--num-bottleneck', type=int, default=constant.NUM_BOTTLENECK, help='Kernels of the bottleneck layer') 71 | @click.option('--dropout', type=float, default=constant.DROPOUT, help='Dropout ratio') 72 | @click.option('--window', type=int, default=1, help='Multi-frame as image channels') 73 | @click.option('--lstm-window', type=int) 74 | @click.option('--num-presnet', type=int, multiple=True, help='Deprecated') 75 | @click.option('--presnet-branch', type=int, multiple=True, help='Deprecated') 76 | @click.option('--drop-presnet', is_flag=True) 77 | @click.option('--bng', is_flag=True, help='Deprecated') 78 | @click.option('--minibatch', is_flag=True, help='Split data into minibatch by subject id') 79 | @click.option('--drop-branch', is_flag=True, help='Dropout after each FC in branches') 80 | @click.option('--pool', is_flag=True, help='Deprecated') 81 | @click.option('--fft', is_flag=True, help='Deprecaded. Perform FFT and use spectrum amplitude as image channels. Cannot be used on non-uniform (segment length) dataset like NinaPro') 82 | @click.option('--fft-append', is_flag=True, help='Append FFT feature to raw frames in channel axis') 83 | @click.option('--dual-stream', is_flag=True, help='Use raw frames and FFT feature as dual-stream') 84 | @click.option('--zscore/--no-zscore', default=True, help='Use z-score normalization on input') 85 | @click.option('--zscore-bng', is_flag=True, help='Use global BatchNorm as z-score normalization, for window > 1 or FFT') 86 | @click.option('--lstm', is_flag=True) 87 | @click.option('--num-lstm-hidden', type=int, default=constant.NUM_LSTM_HIDDEN, help='Kernels of the hidden layers in LSTM') 88 | @click.option('--num-lstm-layer', type=int, default=constant.NUM_LSTM_LAYER, help='Number of the hidden layers in LSTM') 89 | @click.option('--dense-window/--no-dense-window', default=True, help='Dense sampling of windows during training') 90 | @click.option('--lstm-last', type=int, default=0) 91 | @click.option('--lstm-dropout', type=float, default=constant.LSTM_DROPOUT, help='LSTM dropout ratio') 92 | @click.option('--lstm-shortcut', is_flag=True) 93 | @click.option('--lstm-bn/--no-lstm-bn', default=True, help='BatchNorm in LSTM') 94 | @click.option('--lstm-grad-scale/--no-lstm-grad-scale', default=True, help='Grad scale by the number of LSTM output') 95 | @click.option('--faug', type=float, default=0) 96 | @click.option('--faug-classwise', is_flag=True) 97 | @click.option('--num-eval-epoch', type=int, default=1) 98 | @click.option('--snapshot-period', type=int, default=1) 99 | @click.option('--gpu-x', type=int, default=0) 100 | @click.option('--drop-conv', is_flag=True) 101 | @click.option('--drop-pixel', type=int, multiple=True, default=(-1,)) 102 | @click.option('--drop-presnet-branch', is_flag=True) 103 | @click.option('--drop-presnet-proj', is_flag=True) 104 | @click.option('--fix-params', multiple=True) 105 | @click.option('--presnet-proj-type', type=click.Choice(['A', 'B']), default='A') 106 | @click.option('--decay-all', is_flag=True) 107 | @click.option('--presnet-promote', is_flag=True) 108 | @click.option('--pixel-reduce-loss-weight', type=float, default=0) 109 | @click.option('--fast-pixel-reduce/--no-fast-pixel-reduce', default=True) 110 | @click.option('--pixel-reduce-bias', is_flag=True) 111 | @click.option('--pixel-reduce-kernel', type=int, multiple=True, default=(1, 1)) 112 | @click.option('--pixel-reduce-stride', type=int, multiple=True, default=(1, 1)) 113 | @click.option('--pixel-reduce-pad', type=int, multiple=True, default=(0, 0)) 114 | @click.option('--pixel-reduce-norm', is_flag=True) 115 | @click.option('--pixel-reduce-reg-out', is_flag=True) 116 | @click.option('--num-pixel-reduce-filter', type=int, multiple=True, default=tuple(None for _ in range(constant.NUM_PIXEL))) 117 | @click.option('--num-conv', type=int, default=2) 118 | @click.option('--pixel-same-init', is_flag=True) 119 | @click.option('--presnet-dense', is_flag=True) 120 | @click.option('--conv-shortcut', is_flag=True) 121 | @click.option('--preprocess', callback=lambda ctx, param, value: Preprocess.parse(value)) 122 | @click.option('--bandstop', is_flag=True) 123 | @click.option('--dataset', type=click.Choice(['s21', 'csl', 124 | 'dba', 'dbb', 'dbc', 125 | 'ninapro-db1-matlab-lowpass', 126 | 'ninapro-db1/caputo', 127 | 'ninapro-db1', 128 | 'ninapro-db1/g53', 129 | 'ninapro-db1/g5', 130 | 'ninapro-db1/g8', 131 | 'ninapro-db1/g12']), required=True) 132 | @click.option('--balance-gesture', type=float, default=0) 133 | @click.option('--module', type=click.Choice(['convnet', 134 | 'knn', 135 | 'svm', 136 | 'random-forests', 137 | 'lda']), default='convnet') 138 | @click.option('--amplitude-weighting', is_flag=True) 139 | @packargs 140 | def exp(args): 141 | pass 142 | 143 | 144 | @exp.command() 145 | @click.option('--fold', type=int, required=True, help='Fold number of the crossval experiment') 146 | @click.option('--crossval-type', type=click.Choice(['intra-session', 147 | 'universal-intra-session', 148 | 'inter-session', 149 | 'universal-inter-session', 150 | 'intra-subject', 151 | 'universal-intra-subject', 152 | 'inter-subject', 153 | 'one-fold-intra-subject', 154 | 'universal-one-fold-intra-subject']), required=True) 155 | @packargs 156 | def crossval(args): 157 | if args.root: 158 | if args.log: 159 | args.log = os.path.join(args.root, args.log) 160 | if args.snapshot: 161 | args.snapshot = os.path.join(args.root, args.snapshot) 162 | 163 | if args.gpu_x: 164 | args.gpu = sum([list(args.gpu) for i in range(args.gpu_x)], []) 165 | 166 | if os.path.exists(args.log): 167 | click.echo('Found log {}, exit'.format(args.log)) 168 | return 169 | 170 | with Context(args.log, parallel=True): 171 | logger.info('Args:\n{}', pformat(args)) 172 | for i in range(args.num_epoch): 173 | path = args.snapshot + '-%04d.params' % (i + 1) 174 | if os.path.exists(path): 175 | logger.info('Found snapshot {}, exit', path) 176 | return 177 | 178 | dataset = Dataset.from_name(args.dataset) 179 | get_crossval_data = getattr(dataset, 'get_%s_data' % args.crossval_type.replace('-', '_')) 180 | train, val = get_crossval_data( 181 | batch_size=args.batch_size, 182 | fold=args.fold, 183 | preprocess=args.preprocess, 184 | adabn=args.adabn, 185 | minibatch=args.minibatch, 186 | balance_gesture=args.balance_gesture, 187 | amplitude_weighting=args.amplitude_weighting, 188 | random_shift_fill=args.random_shift_fill, 189 | random_shift_horizontal=args.random_shift_horizontal, 190 | random_shift_vertical=args.random_shift_vertical 191 | ) 192 | logger.info('Train samples: {}', train.num_sample) 193 | logger.info('Val samples: {}', val.num_sample) 194 | mod = Module.parse( 195 | args.module, 196 | revgrad=args.revgrad, 197 | num_revgrad_batch=args.num_revgrad_batch, 198 | tzeng=args.tzeng, 199 | num_tzeng_batch=args.num_tzeng_batch, 200 | num_gesture=train.num_gesture, 201 | num_subject=train.num_subject, 202 | subject_loss_weight=args.subject_loss_weight, 203 | lambda_scale=args.lambda_scale, 204 | adabn=args.adabn, 205 | num_adabn_epoch=args.num_adabn_epoch, 206 | random_scale=args.random_scale, 207 | dual_stream=args.dual_stream, 208 | lstm=args.lstm, 209 | num_lstm_hidden=args.num_lstm_hidden, 210 | num_lstm_layer=args.num_lstm_layer, 211 | for_training=True, 212 | faug=args.faug, 213 | faug_classwise=args.faug_classwise, 214 | num_eval_epoch=args.num_eval_epoch, 215 | snapshot_period=args.snapshot_period, 216 | pixel_same_init=args.pixel_same_init, 217 | symbol_kargs=dict( 218 | num_semg_row=args.num_semg_row, 219 | num_semg_col=args.num_semg_col, 220 | num_filter=args.num_filter, 221 | num_pixel=args.num_pixel, 222 | num_feature_block=args.num_feature_block, 223 | num_gesture_block=args.num_gesture_block, 224 | num_subject_block=args.num_subject_block, 225 | num_hidden=args.num_hidden, 226 | num_bottleneck=args.num_bottleneck, 227 | dropout=args.dropout, 228 | num_channel=train.num_channel // (args.lstm_window or 1), 229 | num_presnet=args.num_presnet, 230 | presnet_branch=args.presnet_branch, 231 | drop_presnet=args.drop_presnet, 232 | bng=args.bng, 233 | subject_confusion_loss_weight=args.subject_confusion_loss_weight, 234 | minibatch=args.minibatch, 235 | confuse_conv=args.confuse_conv, 236 | confuse_all=args.confuse_all, 237 | subject_wd=args.subject_wd, 238 | drop_branch=args.drop_branch, 239 | pool=args.pool, 240 | zscore=args.zscore, 241 | zscore_bng=args.zscore_bng, 242 | num_stream=2 if args.dual_stream else 1, 243 | lstm_last=args.lstm_last, 244 | lstm_dropout=args.lstm_dropout, 245 | lstm_shortcut=args.lstm_shortcut, 246 | lstm_bn=args.lstm_bn, 247 | lstm_window=args.lstm_window, 248 | lstm_grad_scale=args.lstm_grad_scale, 249 | drop_conv=args.drop_conv, 250 | drop_presnet_branch=args.drop_presnet_branch, 251 | drop_presnet_proj=args.drop_presnet_proj, 252 | presnet_proj_type=args.presnet_proj_type, 253 | presnet_promote=args.presnet_promote, 254 | pixel_reduce_loss_weight=args.pixel_reduce_loss_weight, 255 | pixel_reduce_bias=args.pixel_reduce_bias, 256 | pixel_reduce_kernel=args.pixel_reduce_kernel, 257 | pixel_reduce_stride=args.pixel_reduce_stride, 258 | pixel_reduce_pad=args.pixel_reduce_pad, 259 | pixel_reduce_norm=args.pixel_reduce_norm, 260 | pixel_reduce_reg_out=args.pixel_reduce_reg_out, 261 | num_pixel_reduce_filter=args.num_pixel_reduce_filter, 262 | fast_pixel_reduce=args.fast_pixel_reduce, 263 | drop_pixel=args.drop_pixel, 264 | num_conv=args.num_conv, 265 | presnet_dense=args.presnet_dense, 266 | conv_shortcut=args.conv_shortcut 267 | ), 268 | context=[mx.gpu(i) for i in args.gpu] 269 | ) 270 | mod.fit( 271 | train_data=train, 272 | eval_data=val, 273 | num_epoch=args.num_epoch, 274 | num_train=train.num_sample, 275 | batch_size=args.batch_size, 276 | lr_step=args.lr_step, 277 | lr_factor=args.lr_factor, 278 | lr=args.lr, 279 | wd=args.wd, 280 | gamma=args.gamma, 281 | snapshot=args.snapshot, 282 | params=args.params, 283 | ignore_params=args.ignore_params, 284 | fix_params=args.fix_params, 285 | decay_all=args.decay_all 286 | ) 287 | 288 | 289 | @exp.command() 290 | @packargs 291 | def general(args): 292 | if args.root: 293 | if args.log: 294 | args.log = os.path.join(args.root, args.log) 295 | if args.snapshot: 296 | args.snapshot = os.path.join(args.root, args.snapshot) 297 | 298 | if args.gpu_x: 299 | args.gpu = sum([list(args.gpu) for i in range(args.gpu_x)], []) 300 | 301 | with Context(args.log): 302 | logger.info('Args:\n{}', pformat(args)) 303 | for i in range(args.num_epoch): 304 | path = args.snapshot + '-%04d.params' % (i + 1) 305 | if os.path.exists(path): 306 | logger.info('Found snapshot {}, exit', path) 307 | return 308 | 309 | from .data import csl 310 | 311 | train, val = csl.get_general_data( 312 | batch_size=args.batch_size, 313 | adabn=args.adabn, 314 | minibatch=args.minibatch, 315 | downsample=args.downsample 316 | ) 317 | logger.info('Train samples: {}', train.num_sample) 318 | logger.info('Val samples: {}', val.num_sample) 319 | mod = Module( 320 | revgrad=args.revgrad, 321 | num_revgrad_batch=args.num_revgrad_batch, 322 | tzeng=args.tzeng, 323 | num_tzeng_batch=args.num_tzeng_batch, 324 | num_gesture=train.num_gesture, 325 | num_subject=train.num_subject, 326 | subject_loss_weight=args.subject_loss_weight, 327 | lambda_scale=args.lambda_scale, 328 | adabn=args.adabn, 329 | num_adabn_epoch=args.num_adabn_epoch, 330 | random_scale=args.random_scale, 331 | dual_stream=args.dual_stream, 332 | lstm=args.lstm, 333 | num_lstm_hidden=args.num_lstm_hidden, 334 | num_lstm_layer=args.num_lstm_layer, 335 | for_training=True, 336 | faug=args.faug, 337 | faug_classwise=args.faug_classwise, 338 | num_eval_epoch=args.num_eval_epoch, 339 | snapshot_period=args.snapshot_period, 340 | pixel_same_init=args.pixel_same_init, 341 | symbol_kargs=dict( 342 | num_semg_row=args.num_semg_row, 343 | num_semg_col=args.num_semg_col, 344 | num_filter=args.num_filter, 345 | num_pixel=args.num_pixel, 346 | num_feature_block=args.num_feature_block, 347 | num_gesture_block=args.num_gesture_block, 348 | num_subject_block=args.num_subject_block, 349 | num_hidden=args.num_hidden, 350 | num_bottleneck=args.num_bottleneck, 351 | dropout=args.dropout, 352 | num_channel=train.num_channel // (args.lstm_window or 1), 353 | num_presnet=args.num_presnet, 354 | presnet_branch=args.presnet_branch, 355 | drop_presnet=args.drop_presnet, 356 | bng=args.bng, 357 | subject_confusion_loss_weight=args.subject_confusion_loss_weight, 358 | minibatch=args.minibatch, 359 | confuse_conv=args.confuse_conv, 360 | confuse_all=args.confuse_all, 361 | subject_wd=args.subject_wd, 362 | drop_branch=args.drop_branch, 363 | pool=args.pool, 364 | zscore=args.zscore, 365 | zscore_bng=args.zscore_bng, 366 | num_stream=2 if args.dual_stream else 1, 367 | lstm_last=args.lstm_last, 368 | lstm_dropout=args.lstm_dropout, 369 | lstm_shortcut=args.lstm_shortcut, 370 | lstm_bn=args.lstm_bn, 371 | lstm_window=args.lstm_window, 372 | lstm_grad_scale=args.lstm_grad_scale, 373 | drop_conv=args.drop_conv, 374 | drop_presnet_branch=args.drop_presnet_branch, 375 | drop_presnet_proj=args.drop_presnet_proj, 376 | presnet_proj_type=args.presnet_proj_type, 377 | presnet_promote=args.presnet_promote, 378 | pixel_reduce_loss_weight=args.pixel_reduce_loss_weight, 379 | pixel_reduce_bias=args.pixel_reduce_bias, 380 | pixel_reduce_kernel=args.pixel_reduce_kernel, 381 | pixel_reduce_stride=args.pixel_reduce_stride, 382 | pixel_reduce_pad=args.pixel_reduce_pad, 383 | pixel_reduce_norm=args.pixel_reduce_norm, 384 | pixel_reduce_reg_out=args.pixel_reduce_reg_out, 385 | num_pixel_reduce_filter=args.num_pixel_reduce_filter, 386 | fast_pixel_reduce=args.fast_pixel_reduce, 387 | drop_pixel=args.drop_pixel, 388 | num_conv=args.num_conv, 389 | presnet_dense=args.presnet_dense, 390 | conv_shortcut=args.conv_shortcut 391 | ), 392 | context=[mx.gpu(i) for i in args.gpu] 393 | ) 394 | mod.fit( 395 | train_data=train, 396 | eval_data=val, 397 | num_epoch=args.num_epoch, 398 | num_train=train.num_sample, 399 | batch_size=args.batch_size, 400 | lr_step=args.lr_step, 401 | lr=args.lr, 402 | wd=args.wd, 403 | gamma=args.gamma, 404 | snapshot=args.snapshot, 405 | params=args.params, 406 | ignore_params=args.ignore_params, 407 | fix_params=args.fix_params, 408 | decay_all=args.decay_all 409 | ) 410 | 411 | 412 | @cli.command() 413 | @click.option('--num-semg-row', type=int, default=constant.NUM_SEMG_ROW, help='Rows of sEMG image') 414 | @click.option('--num-semg-col', type=int, default=constant.NUM_SEMG_COL, help='Cols of sEMG image') 415 | @click.option('--num-epoch', type=int, default=60, help='Maximum epoches') 416 | @click.option('--num-tzeng-batch', type=int, default=constant.NUM_TZENG_BATCH, 417 | help='Batch number of each Tzeng update, 2 means interleaved domain and label update') 418 | @click.option('--lr-step', type=int, multiple=True, default=[20, 40], help='Epoch numbers to decay learning rate') 419 | @click.option('--batch-size', type=int, default=1000, 420 | help='Batch size, should be 900 with --minibatch for s21 inter-subject experiment') 421 | @click.option('--lr', type=float, default=0.1, help='Base learning rate') 422 | @click.option('--wd', type=float, default=0.0001, help='Weight decay') 423 | @click.option('--subject-wd', type=float, help='Weight decay multiplier of the subject branch') 424 | @click.option('--gpu', type=int, multiple=True, default=[0]) 425 | @click.option('--gamma', type=float, default=constant.GAMMA, help='Gamma in RevGrad') 426 | @click.option('--log', type=click.Path(), help='Path of the logging file') 427 | @click.option('--snapshot', type=click.Path(), help='Snapshot prefix') 428 | @click.option('--root', type=click.Path(), help='Root path of the experiment, auto create if not exists') 429 | @click.option('--fold', type=int, required=True, help='Fold number of the inter-subject experiment') 430 | @click.option('--maxforce', is_flag=True, help='Use maxforce data of the target subject as calibration data') 431 | @click.option('--calib', is_flag=True, help='Use first repetition of the target subject as calibration data') 432 | @click.option('--only-calib', is_flag=True, help='Only use first repetition of the target subject as calibration data') 433 | @click.option('--target-binary', is_flag=True, help='Make binary prediction of subject and upsampling target dataset') 434 | @click.option('--revgrad', is_flag=True, help='Use RevGrad') 435 | @click.option('--num-revgrad-batch', type=int, default=2, 436 | help=('Batch number of each RevGrad update, 2 means interleaved domain and label update, ' 437 | 'see "Adversarial Deep Averaging Networks for Cross-Lingual Sentiment Classification" for details')) 438 | @click.option('--tzeng', is_flag=True, help='Use Tzeng_ICCV_2015') 439 | @click.option('--confuse-conv', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on conv2') 440 | @click.option('--confuse-all', is_flag=True, help='Domain confusion (for both RevGrad and Tzeng) on all layers') 441 | @click.option('--subject-loss-weight', type=float, default=1, help='Ganin et al. use 0.1 in their code') 442 | @click.option('--subject-confusion-loss-weight', type=float, default=1, 443 | help='Tzeng confusion loss weight, larger than 1 seems better') 444 | @click.option('--target-gesture-loss-weight', type=float, help='For --calib to emphasis calibration data') 445 | @click.option('--lambda-scale', type=float, default=constant.LAMBDA_SCALE, 446 | help='Global scale of lambda in RevGrad, 1 in their paper and 0.1 in their code') 447 | @click.option('--params', type=click.Path(exists=True), help='Inital weights') 448 | @click.option('--ignore-params', multiple=True, help='Ignore params in --params with regex') 449 | @click.option('--random-scale', type=float, default=0, 450 | help='Random scale input data globally by 2^scale, and locally by 2^(scale/4)') 451 | @click.option('--random-bad-channel', type=float, multiple=True, default=[], 452 | help='Random (with a probability of 0.5 for each image) assign a pixel as specified value, usually [-1, 0, 1]') 453 | @click.option('--num-feature-block', type=int, default=constant.NUM_FEATURE_BLOCK, help='Number of FC layers in feature extraction part') 454 | @click.option('--num-gesture-block', type=int, default=constant.NUM_GESTURE_BLOCK, help='Number of FC layers in gesture branch') 455 | @click.option('--num-subject-block', type=int, default=constant.NUM_SUBJECT_BLOCK, help='Number of FC layers in subject branch') 456 | @click.option('--adabn', is_flag=True, help='AdaBN for model adaptation, must be used with --minibatch') 457 | @click.option('--num-adabn-epoch', type=int, default=constant.NUM_ADABN_EPOCH) 458 | @click.option('--num-pixel', type=int, default=constant.NUM_PIXEL, help='Pixelwise reduction layers') 459 | @click.option('--num-filter', type=int, default=constant.NUM_FILTER, help='Kernels of the conv layers') 460 | @click.option('--num-hidden', type=int, default=constant.NUM_HIDDEN, help='Kernels of the FC layers') 461 | @click.option('--num-bottleneck', type=int, default=constant.NUM_BOTTLENECK, help='Kernels of the bottleneck layer') 462 | @click.option('--dropout', type=float, default=constant.DROPOUT, help='Dropout ratio') 463 | @click.option('--window', type=int, default=1, help='Multi-frame as image channels') 464 | @click.option('--lstm-window', type=int) 465 | @click.option('--num-presnet', type=int, multiple=True, help='Deprecated') 466 | @click.option('--presnet-branch', type=int, multiple=True, help='Deprecated') 467 | @click.option('--drop-presnet', is_flag=True) 468 | @click.option('--bng', is_flag=True, help='Deprecated') 469 | @click.option('--soft-label', is_flag=True, help='Tzeng soft-label for finetuning with calibration data') 470 | @click.option('--minibatch', is_flag=True, help='Split data into minibatch by subject id') 471 | @click.option('--drop-branch', is_flag=True, help='Dropout after each FC in branches') 472 | @click.option('--pool', is_flag=True, help='Deprecated') 473 | @click.option('--fft', is_flag=True, help='Deprecaded. Perform FFT and use spectrum amplitude as image channels. Cannot be used on non-uniform (segment length) dataset like NinaPro') 474 | @click.option('--fft-append', is_flag=True, help='Append FFT feature to raw frames in channel axis') 475 | @click.option('--dual-stream', is_flag=True, help='Use raw frames and FFT feature as dual-stream') 476 | @click.option('--zscore/--no-zscore', default=True, help='Use z-score normalization on input') 477 | @click.option('--zscore-bng', is_flag=True, help='Use global BatchNorm as z-score normalization, for window > 1 or FFT') 478 | @click.option('--lstm', is_flag=True) 479 | @click.option('--num-lstm-hidden', type=int, default=constant.NUM_LSTM_HIDDEN, help='Kernels of the hidden layers in LSTM') 480 | @click.option('--num-lstm-layer', type=int, default=constant.NUM_LSTM_LAYER, help='Number of the hidden layers in LSTM') 481 | @click.option('--dense-window/--no-dense-window', default=True, help='Dense sampling of windows during training') 482 | @click.option('--lstm-last', type=int, default=0) 483 | @click.option('--lstm-dropout', type=float, default=constant.LSTM_DROPOUT, help='LSTM dropout ratio') 484 | @click.option('--lstm-shortcut', is_flag=True) 485 | @click.option('--lstm-bn/--no-lstm-bn', default=True, help='BatchNorm in LSTM') 486 | @click.option('--lstm-grad-scale/--no-lstm-grad-scale', default=True, help='Grad scale by the number of LSTM output') 487 | @click.option('--faug', type=float, default=0) 488 | @click.option('--faug-classwise', is_flag=True) 489 | @click.option('--num-eval-epoch', type=int, default=1) 490 | @click.option('--snapshot-period', type=int, default=1) 491 | @click.option('--gpu-x', type=int, default=0) 492 | @click.option('--drop-conv', is_flag=True) 493 | @click.option('--drop-pixel', type=int, multiple=True, default=(-1,)) 494 | @click.option('--drop-presnet-branch', is_flag=True) 495 | @click.option('--drop-presnet-proj', is_flag=True) 496 | @click.option('--fix-params', multiple=True) 497 | @click.option('--presnet-proj-type', type=click.Choice(['A', 'B']), default='A') 498 | @click.option('--decay-all', is_flag=True) 499 | @click.option('--presnet-promote', is_flag=True) 500 | @click.option('--pixel-reduce-loss-weight', type=float, default=0) 501 | @click.option('--fast-pixel-reduce/--no-fast-pixel-reduce', default=True) 502 | @click.option('--pixel-reduce-bias', is_flag=True) 503 | @click.option('--pixel-reduce-kernel', type=int, multiple=True, default=(1, 1)) 504 | @click.option('--pixel-reduce-stride', type=int, multiple=True, default=(1, 1)) 505 | @click.option('--pixel-reduce-pad', type=int, multiple=True, default=(0, 0)) 506 | @click.option('--pixel-reduce-norm', is_flag=True) 507 | @click.option('--pixel-reduce-reg-out', is_flag=True) 508 | @click.option('--num-pixel-reduce-filter', type=int, multiple=True, default=(16, 16)) 509 | @click.option('--num-conv', type=int, default=2) 510 | @click.option('--pixel-same-init', is_flag=True) 511 | @click.option('--presnet-dense', is_flag=True) 512 | @click.option('--conv-shortcut', is_flag=True) 513 | @packargs 514 | def inter(args): 515 | '''Inter-subject experiment on S21 dataset''' 516 | if args.root: 517 | if args.log: 518 | args.log = os.path.join(args.root, args.log) 519 | if args.snapshot: 520 | args.snapshot = os.path.join(args.root, args.snapshot) 521 | 522 | if args.gpu_x: 523 | args.gpu = sum([list(args.gpu) for i in range(args.gpu_x)], []) 524 | 525 | with Context(args.log): 526 | logger.info('Args:\n{}', pformat(args)) 527 | for i in range(args.num_epoch): 528 | path = args.snapshot + '-%04d.params' % (i + 1) 529 | if os.path.exists(path): 530 | logger.info('Found snapshot {}, exit', path) 531 | return 532 | train, val = data_s21.get_inter_subject_data( 533 | '.cache/mat.s21.bandstop-45-55.s1000m.scale-01', 534 | fold=args.fold, 535 | batch_size=args.batch_size, 536 | maxforce=args.maxforce, 537 | calib=args.calib or args.only_calib, 538 | only_calib=args.only_calib, 539 | target_binary=args.target_binary, 540 | with_subject=args.revgrad or args.tzeng, 541 | with_target_gesture=args.target_gesture_loss_weight is not None, 542 | random_scale=args.random_scale, 543 | random_bad_channel=args.random_bad_channel, 544 | shuffle=True, 545 | adabn=args.adabn, 546 | window=args.window, 547 | dense_window=args.dense_window, 548 | soft_label=args.soft_label, 549 | minibatch=args.minibatch, 550 | fft=args.fft, 551 | fft_append=args.fft_append, 552 | dual_stream=args.dual_stream, 553 | lstm=args.lstm, 554 | lstm_window=args.lstm_window 555 | ) 556 | logger.info('Train samples: {}', train.num_sample) 557 | logger.info('Val samples: {}', val.num_sample) 558 | mod = Module( 559 | revgrad=args.revgrad, 560 | num_revgrad_batch=args.num_revgrad_batch, 561 | tzeng=args.tzeng, 562 | num_tzeng_batch=args.num_tzeng_batch, 563 | num_gesture=train.num_gesture, 564 | num_subject=train.num_subject, 565 | subject_loss_weight=args.subject_loss_weight, 566 | target_gesture_loss_weight=args.target_gesture_loss_weight, 567 | lambda_scale=args.lambda_scale, 568 | adabn=args.adabn, 569 | num_adabn_epoch=args.num_adabn_epoch, 570 | random_scale=args.random_scale, 571 | soft_label=args.soft_label, 572 | dual_stream=args.dual_stream, 573 | lstm=args.lstm, 574 | num_lstm_hidden=args.num_lstm_hidden, 575 | num_lstm_layer=args.num_lstm_layer, 576 | for_training=True, 577 | faug=args.faug, 578 | faug_classwise=args.faug_classwise, 579 | num_eval_epoch=args.num_eval_epoch, 580 | snapshot_period=args.snapshot_period, 581 | pixel_same_init=args.pixel_same_init, 582 | symbol_kargs=dict( 583 | num_semg_row=args.num_semg_row, 584 | num_semg_col=args.num_semg_col, 585 | num_filter=args.num_filter, 586 | num_pixel=args.num_pixel, 587 | num_feature_block=args.num_feature_block, 588 | num_gesture_block=args.num_gesture_block, 589 | num_subject_block=args.num_subject_block, 590 | num_hidden=args.num_hidden, 591 | num_bottleneck=args.num_bottleneck, 592 | dropout=args.dropout, 593 | num_channel=train.num_channel // (args.lstm_window or 1), 594 | num_presnet=args.num_presnet, 595 | presnet_branch=args.presnet_branch, 596 | drop_presnet=args.drop_presnet, 597 | bng=args.bng, 598 | subject_confusion_loss_weight=args.subject_confusion_loss_weight, 599 | minibatch=args.minibatch, 600 | confuse_conv=args.confuse_conv, 601 | confuse_all=args.confuse_all, 602 | subject_wd=args.subject_wd, 603 | drop_branch=args.drop_branch, 604 | pool=args.pool, 605 | zscore=args.zscore, 606 | zscore_bng=args.zscore_bng, 607 | num_stream=2 if args.dual_stream else 1, 608 | lstm_last=args.lstm_last, 609 | lstm_dropout=args.lstm_dropout, 610 | lstm_shortcut=args.lstm_shortcut, 611 | lstm_bn=args.lstm_bn, 612 | lstm_window=args.lstm_window, 613 | lstm_grad_scale=args.lstm_grad_scale, 614 | drop_conv=args.drop_conv, 615 | drop_presnet_branch=args.drop_presnet_branch, 616 | drop_presnet_proj=args.drop_presnet_proj, 617 | presnet_proj_type=args.presnet_proj_type, 618 | presnet_promote=args.presnet_promote, 619 | pixel_reduce_loss_weight=args.pixel_reduce_loss_weight, 620 | pixel_reduce_bias=args.pixel_reduce_bias, 621 | pixel_reduce_kernel=args.pixel_reduce_kernel, 622 | pixel_reduce_stride=args.pixel_reduce_stride, 623 | pixel_reduce_pad=args.pixel_reduce_pad, 624 | pixel_reduce_norm=args.pixel_reduce_norm, 625 | pixel_reduce_reg_out=args.pixel_reduce_reg_out, 626 | num_pixel_reduce_filter=args.num_pixel_reduce_filter, 627 | fast_pixel_reduce=args.fast_pixel_reduce, 628 | drop_pixel=args.drop_pixel, 629 | num_conv=args.num_conv, 630 | presnet_dense=args.presnet_dense, 631 | conv_shortcut=args.conv_shortcut 632 | ), 633 | context=[mx.gpu(i) for i in args.gpu] 634 | ) 635 | mod.fit( 636 | train_data=train, 637 | eval_data=val, 638 | num_epoch=args.num_epoch, 639 | num_train=train.num_sample, 640 | batch_size=args.batch_size, 641 | lr_step=args.lr_step, 642 | lr=args.lr, 643 | wd=args.wd, 644 | gamma=args.gamma, 645 | snapshot=args.snapshot, 646 | params=args.params, 647 | ignore_params=args.ignore_params, 648 | fix_params=args.fix_params, 649 | decay_all=args.decay_all 650 | ) 651 | 652 | 653 | @cli.command() 654 | @click.option('--num-epoch', type=int, default=150) 655 | @click.option('--lr-step', type=int, default=50) 656 | @click.option('--batch-size', type=int, default=2000) 657 | @click.option('--lr', type=float, default=0.1) 658 | @click.option('--gpu', type=int, multiple=True, default=[0]) 659 | @click.option('--log', type=click.Path()) 660 | @click.option('--snapshot', type=click.Path()) 661 | @click.option('--root', type=click.Path()) 662 | @click.option('--adapt', is_flag=True) 663 | @click.option('--gamma', type=float, default=10) 664 | @click.option('--subject-loss-weight', type=float, default=0.1) 665 | def _general( 666 | num_epoch, 667 | lr_step, 668 | batch_size, 669 | lr, 670 | gpu, 671 | log, 672 | snapshot, 673 | root, 674 | adapt, 675 | gamma, 676 | subject_loss_weight 677 | ): 678 | if root: 679 | if log: 680 | log = os.path.join(root, log) 681 | if snapshot: 682 | snapshot = os.path.join(root, snapshot) 683 | 684 | with Context(log): 685 | logger.info('Args:\n{}', pformat(locals())) 686 | mod = Module( 687 | adapt=adapt, 688 | num_gesture=8, 689 | num_subject=10, 690 | subject_loss_weight=subject_loss_weight, 691 | context=[mx.gpu(i) for i in gpu] 692 | ) 693 | train, val, num_train, _ = data_s21.get_general_data( 694 | '.cache/mat.s21.bandstop-45-55.s1000m.scale-01', 695 | batch_size=batch_size, 696 | adapt=adapt 697 | ) 698 | logger.info('Train samples: {}', num_train) 699 | mod.fit( 700 | train_data=train, 701 | eval_data=val, 702 | num_epoch=num_epoch, 703 | num_train=num_train, 704 | batch_size=batch_size, 705 | lr_step=lr_step, 706 | lr=lr, 707 | gamma=gamma, 708 | snapshot=snapshot 709 | ) 710 | 711 | 712 | @cli.command() 713 | def stats(): 714 | click.echo(data_s21.get_stats()) 715 | 716 | 717 | @cli.command() 718 | @click.option('--gpu', type=int, multiple=True, default=[0]) 719 | @click.option('--fold', type=int, required=True) 720 | @click.option('--batch-size', type=int, default=2000) 721 | def coral(gpu, fold, batch_size): 722 | with Context(): 723 | val = data_s21.get_inter_subject_val(fold=fold, batch_size=batch_size) 724 | 725 | mod = Module( 726 | num_gesture=8, 727 | coral=True, 728 | adabn=True, 729 | adabn_num_epoch=10, 730 | symbol_kargs=dict( 731 | num_filter=16, 732 | num_pixel=2, 733 | num_feature_block=2, 734 | num_gesture_block=0, 735 | num_hidden=512, 736 | num_bottleneck=128, 737 | dropout=0.5, 738 | num_channel=1 739 | ), 740 | context=[mx.gpu(i) for i in gpu] 741 | ) 742 | mod.init_coral( 743 | '.cache/sigr-inter-adabn-%d-v403/model-0060.params' % fold, 744 | [data_s21.get_coral([i], batch_size) for i in range(10) if i != fold], 745 | data_s21.get_coral([fold], batch_size) 746 | ) 747 | # mod.bind(data_shapes=val.provide_data, for_training=False) 748 | # mod.load_params('.cache/sigr-inter-%d-final/model-0060.params' % fold) 749 | 750 | metric = mx.metric.create('acc') 751 | mod.score(val, metric) 752 | logger.info('Fold {} accuracy: {}', fold, metric.get()[1]) 753 | 754 | 755 | if __name__ == '__main__': 756 | cli(obj=Bunch()) 757 | -------------------------------------------------------------------------------- /sigr/base_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | 3 | 4 | class Meta(type): 5 | 6 | impls = [] 7 | 8 | def __init__(cls, name, bases, fields): 9 | type.__init__(cls, name, bases, fields) 10 | Meta.impls.append(cls) 11 | 12 | 13 | class BaseModule(object): 14 | 15 | __metaclass__ = Meta 16 | 17 | @classmethod 18 | def parse(cls, text, **kargs): 19 | if cls is BaseModule: 20 | for impl in Meta.impls: 21 | if impl is not BaseModule: 22 | inst = impl.parse(text, **kargs) 23 | if inst is not None: 24 | return inst 25 | 26 | 27 | __all__ = ['BaseModule'] 28 | -------------------------------------------------------------------------------- /sigr/constant.py: -------------------------------------------------------------------------------- 1 | NUM_LSTM_HIDDEN = 128 2 | NUM_LSTM_LAYER = 1 3 | LSTM_DROPOUT = 0. 4 | NUM_SEMG_ROW = 16 5 | NUM_SEMG_COL = 8 6 | NUM_SEMG_POINT = NUM_SEMG_ROW * NUM_SEMG_COL 7 | NUM_FILTER = 16 8 | NUM_HIDDEN = 512 9 | NUM_BOTTLENECK = 128 10 | DROPOUT = 0.5 11 | GAMMA = 10 12 | NUM_FEATURE_BLOCK = 2 13 | NUM_GESTURE_BLOCK = 0 14 | NUM_SUBJECT_BLOCK = 0 15 | NUM_PIXEL = 2 16 | LAMBDA_SCALE = 1 17 | NUM_TZENG_BATCH = 2 18 | NUM_ADABN_EPOCH = 1 19 | RANDOM_SHIFT_FILL = 'zero' 20 | -------------------------------------------------------------------------------- /sigr/coral.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.linalg as splg 3 | 4 | 5 | def get_coral_params(ds, dt, lam=1e-3): 6 | ms = ds.mean(axis=0) 7 | ds = ds - ms 8 | mt = dt.mean(axis=0) 9 | dt = dt - mt 10 | cs = np.cov(ds.T) + lam * np.eye(ds.shape[1]) 11 | ct = np.cov(dt.T) + lam * np.eye(dt.shape[1]) 12 | sqrt = splg.sqrtm 13 | w = sqrt(ct).dot(np.linalg.inv(sqrt(cs))) 14 | b = mt - w.dot(ms.reshape(-1, 1)).ravel() 15 | return w, b 16 | -------------------------------------------------------------------------------- /sigr/data/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import mxnet as mx 3 | import os 4 | import scipy.io as spio 5 | import numpy as np 6 | from collections import namedtuple, OrderedDict 7 | from logbook import Logger 8 | from nose.tools import assert_equal 9 | from functools import partial 10 | from itertools import product, izip 11 | from .. import utils, constant 12 | 13 | 14 | logger = Logger('data') 15 | Combo = namedtuple('Combo', ['subject', 'gesture', 'trial'], verbose=False) 16 | Trial = namedtuple('Trial', ['data', 'gesture', 'subject'], verbose=False) 17 | 18 | 19 | def _register(impl): 20 | _register.impls.append(impl) 21 | 22 | 23 | _register.impls = [] 24 | 25 | 26 | class Dataset(object): 27 | 28 | class __metaclass__(type): 29 | 30 | def __init__(cls, name, bases, fields): 31 | type.__init__(cls, name, bases, fields) 32 | _register(cls) 33 | 34 | @property 35 | def num_trial(self): 36 | return len(self.trials) 37 | 38 | @property 39 | def num_gesture(self): 40 | return len(self.gestures) 41 | 42 | @property 43 | def num_subject(self): 44 | return len(self.subjects) 45 | 46 | @classmethod 47 | def from_name(cls, name): 48 | if name == 's21': 49 | from . import s21 50 | return s21 51 | if name == 'csl': 52 | from . import csl 53 | return csl 54 | inst = cls.parse(name) 55 | assert inst is not None, 'Unknown dataset {}'.format(name) 56 | return inst 57 | 58 | @classmethod 59 | def parse(cls, text): 60 | if cls is Dataset: 61 | for impl in _register.impls: 62 | if impl is not Dataset: 63 | inst = impl.parse(text) 64 | if inst is not None: 65 | return inst 66 | 67 | def get_combos(self, *args): 68 | for arg in args: 69 | if isinstance(arg, tuple): 70 | arg = [arg] 71 | for a in arg: 72 | yield Combo(*a) 73 | 74 | 75 | class SingleSessionMixin(object): 76 | 77 | def get_one_fold_intra_subject_trials(self): 78 | return self.trials[::2], self.trials[1::2] 79 | 80 | def get_inter_subject_data(self, fold, batch_size, preprocess, 81 | adabn, minibatch, **kargs): 82 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 83 | load = partial(get_data, 84 | root=self.root, 85 | last_batch_handle='pad', 86 | get_trial=get_trial, 87 | batch_size=batch_size, 88 | num_semg_row=self.num_semg_row, 89 | num_semg_col=self.num_semg_col) 90 | subject = self.subjects[fold] 91 | train = load( 92 | combos=self.get_combos(product([i for i in self.subjects if i != subject], 93 | self.gestures, self.trials)), 94 | adabn=adabn, 95 | # mini_batch_size=batch_size // (self.num_subject - 1 if minibatch else 1), 96 | mini_batch_size=10 if minibatch else 1, 97 | shuffle=True) 98 | val = load( 99 | combos=self.get_combos(product([subject], self.gestures, self.trials)), 100 | shuffle=False) 101 | return train, val 102 | 103 | def get_inter_subject_val(self, fold, batch_size, preprocess=None, **kargs): 104 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 105 | load = partial(get_data, 106 | root=self.root, 107 | last_batch_handle='pad', 108 | get_trial=get_trial, 109 | batch_size=batch_size, 110 | num_semg_row=self.num_semg_row, 111 | num_semg_col=self.num_semg_col) 112 | subject = self.subjects[fold] 113 | val = load( 114 | combos=self.get_combos(product([subject], self.gestures, self.trials)), 115 | shuffle=False) 116 | return val 117 | 118 | def get_intra_subject_data(self, fold, batch_size, preprocess, 119 | adabn, minibatch, **kargs): 120 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 121 | load = partial(get_data, 122 | root=self.root, 123 | last_batch_handle='pad', 124 | get_trial=get_trial, 125 | batch_size=batch_size, 126 | num_semg_row=self.num_semg_row, 127 | num_semg_col=self.num_semg_col) 128 | subject = self.subjects[fold // self.num_trial] 129 | trial = self.trials[fold % self.num_trial] 130 | train = load( 131 | combos=self.get_combos(product([subject], self.gestures, 132 | [i for i in self.trials if i != trial])), 133 | adabn=adabn, 134 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1), 135 | mini_batch_size=10 if minibatch else 1, 136 | shuffle=True) 137 | val = load( 138 | combos=self.get_combos(product([subject], self.gestures, [trial])), 139 | shuffle=False) 140 | return train, val 141 | 142 | def get_intra_subject_val(self, fold, batch_size, preprocess=None, **kargs): 143 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 144 | load = partial(get_data, 145 | root=self.root, 146 | last_batch_handle='pad', 147 | get_trial=get_trial, 148 | batch_size=batch_size, 149 | num_semg_row=self.num_semg_row, 150 | num_semg_col=self.num_semg_col) 151 | subject = self.subjects[fold // self.num_trial] 152 | trial = self.trials[fold % self.num_trial] 153 | val = load( 154 | combos=self.get_combos(product([subject], self.gestures, [trial])), 155 | shuffle=False) 156 | return val 157 | 158 | def get_universal_intra_subject_data(self, fold, batch_size, preprocess, 159 | adabn, minibatch, **kargs): 160 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 161 | load = partial(get_data, 162 | root=self.root, 163 | last_batch_handle='pad', 164 | get_trial=get_trial, 165 | batch_size=batch_size, 166 | num_semg_row=self.num_semg_row, 167 | num_semg_col=self.num_semg_col) 168 | trial = self.trials[fold] 169 | train = load( 170 | combos=self.get_combos(product(self.subjects, self.gestures, 171 | [i for i in self.trials if i != trial])), 172 | adabn=adabn, 173 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1), 174 | mini_batch_size=10 if minibatch else 1, 175 | shuffle=True) 176 | val = load( 177 | combos=self.get_combos(product(self.subjects, self.gestures, [trial])), 178 | shuffle=False) 179 | return train, val 180 | 181 | def get_one_fold_intra_subject_val(self, fold, batch_size, preprocess=None, **kargs): 182 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 183 | load = partial(get_data, 184 | root=self.root, 185 | last_batch_handle='pad', 186 | get_trial=get_trial, 187 | batch_size=batch_size, 188 | num_semg_row=self.num_semg_row, 189 | num_semg_col=self.num_semg_col) 190 | subject = self.subjects[fold] 191 | _, val_trials = self.get_one_fold_intra_subject_trials() 192 | val = load( 193 | combos=self.get_combos(product([subject], self.gestures, 194 | [i for i in val_trials])), 195 | shuffle=False) 196 | return val 197 | 198 | def get_one_fold_intra_subject_data(self, fold, batch_size, preprocess, 199 | adabn, minibatch, **kargs): 200 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 201 | load = partial(get_data, 202 | root=self.root, 203 | last_batch_handle='pad', 204 | get_trial=get_trial, 205 | batch_size=batch_size, 206 | num_semg_row=self.num_semg_row, 207 | num_semg_col=self.num_semg_col) 208 | subject = self.subjects[fold] 209 | train_trials, val_trials = self.get_one_fold_intra_subject_trials() 210 | train = load( 211 | combos=self.get_combos(product([subject], self.gestures, 212 | [i for i in train_trials])), 213 | adabn=adabn, 214 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1), 215 | mini_batch_size=10 if minibatch else 1, 216 | shuffle=True) 217 | val = load( 218 | combos=self.get_combos(product([subject], self.gestures, 219 | [i for i in val_trials])), 220 | shuffle=False) 221 | return train, val 222 | 223 | def get_universal_one_fold_intra_subject_data(self, fold, batch_size, preprocess, 224 | adabn, minibatch, **kargs): 225 | assert_equal(fold, 0) 226 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 227 | load = partial(get_data, 228 | root=self.root, 229 | last_batch_handle='pad', 230 | get_trial=get_trial, 231 | batch_size=batch_size, 232 | num_semg_row=self.num_semg_row, 233 | num_semg_col=self.num_semg_col) 234 | train_trials, val_trials = self.get_one_fold_intra_subject_trials() 235 | train = load( 236 | combos=self.get_combos(product(self.subjects, self.gestures, 237 | [i for i in train_trials])), 238 | adabn=adabn, 239 | # mini_batch_size=batch_size // (self.num_subject if minibatch else 1), 240 | mini_batch_size=10 if minibatch else 1, 241 | shuffle=True) 242 | val = load( 243 | combos=self.get_combos(product(self.subjects, self.gestures, 244 | [i for i in val_trials])), 245 | shuffle=False) 246 | return train, val 247 | 248 | 249 | def get_index(a): 250 | '''Convert label to 0 based index''' 251 | b = list(set(a)) 252 | return np.array([x if x < 0 else b.index(x) for x in a.ravel()]).reshape(a.shape) 253 | 254 | 255 | def get_path(root, combo): 256 | return os.path.join( 257 | root, 258 | '{0.subject:03d}', 259 | '{0.gesture:03d}', 260 | '{0.subject:03d}_{0.gesture:03d}_{0.trial:03d}.mat' 261 | ).format(combo) 262 | 263 | 264 | def label_to_gesture(label): 265 | '''Convert maxforce to -1''' 266 | return label if label < 100 else -1 267 | 268 | 269 | def _get_trial(root, combo): 270 | path = get_path(root, combo) 271 | mat = spio.loadmat(path) 272 | data = mat['data'].astype(np.float32) 273 | gesture = np.repeat(label_to_gesture(np.asscalar(mat['label'].astype(np.int))), len(data)) 274 | subject = np.repeat(np.asscalar(mat['subject'].astype(np.int)), len(data)) 275 | return Trial(data=data, gesture=gesture, subject=subject) 276 | 277 | 278 | def get_data( 279 | root, 280 | combos, 281 | num_semg_row, 282 | num_semg_col, 283 | mean=None, 284 | scale=None, 285 | with_subject=False, 286 | target_combos=None, 287 | target_binary=False, 288 | with_target_gesture=False, 289 | min_size=None, 290 | random_scale=False, 291 | random_bad_channel=[], 292 | shuffle=True, 293 | adabn=False, 294 | window=1, 295 | soft_label=False, 296 | fft=False, 297 | fft_append=False, 298 | dual_stream=False, 299 | num_ignore_per_segment=0, 300 | dense_window=True, 301 | faug=False, 302 | get_trial=None, 303 | balance_gesture=0, 304 | **kargs 305 | ): 306 | '''Get mxnet data iter''' 307 | if os.path.isdir(os.path.join(root, 'data')): 308 | root = os.path.join(root, 'data') 309 | 310 | combos = list(combos) 311 | if target_combos is not None: 312 | target_combos = list(target_combos) 313 | 314 | if get_trial is None: 315 | get_trial = _get_trial 316 | 317 | def try_scale(data): 318 | if mean is not None: 319 | data = data - mean 320 | if scale is not None: 321 | data = data * scale 322 | return data 323 | 324 | data = [] 325 | gesture = [] 326 | subject = [] 327 | segment = [] 328 | 329 | for combo in combos: 330 | trial = get_trial(root=root, combo=combo) 331 | data.append(try_scale(trial.data)) 332 | gesture.append(trial.gesture) 333 | subject.append(np.repeat(0, len(data[-1])) if target_binary else trial.subject) 334 | segment.append(np.repeat(len(segment), len(data[-1]))) 335 | 336 | if target_combos: 337 | for combo in target_combos: 338 | trial = get_trial(root=root, combo=combo) 339 | data.append(try_scale(data)) 340 | gesture.append(trial.gesture) 341 | subject.append(np.repeat(1, len(data[-1])) if target_binary else trial.subject) 342 | segment.append(np.repeat(len(segment), len(data[-1]))) 343 | 344 | # if window > 1: 345 | # data = [get_segments(seg, window) for seg in data] 346 | # gesture = [seg[window - 1:] for seg in gesture] 347 | # subject = [seg[window - 1:] for seg in subject] 348 | # for t in zip(data, gesture, subject): 349 | # for lhs, rhs in zip(t[:-1], t[1:]): 350 | # assert len(lhs) == len(rhs) 351 | 352 | logger.debug('MAT loaded') 353 | 354 | if not data: 355 | logger.warn('Empty data') 356 | return 357 | 358 | index = [] 359 | n = 0 360 | for seg in data: 361 | if dense_window: 362 | index.append(np.arange(n, n + len(seg) - window + 1 - num_ignore_per_segment)) 363 | else: 364 | index.append(np.arange(n, n + len(seg) - window + 1 - num_ignore_per_segment, window)) 365 | # Pad with the last value 366 | # index.append(np.repeat(n + len(seg) - window, window - 1)) 367 | n += len(seg) 368 | index = np.hstack(index) 369 | logger.debug('Index made') 370 | 371 | logger.debug('Segments: {}', len(data)) 372 | logger.debug('First segment shape: {}', data[0].shape) 373 | data = np.vstack(data).reshape(-1, 1, num_semg_row, num_semg_col) 374 | logger.debug('Data stacked') 375 | if min_size is not None: 376 | h = (min_size - num_semg_row) // 2 377 | w = (min_size - num_semg_col) // 2 378 | data = np.pad( 379 | data, 380 | ((0, 0), (0, 0), (h, h), (w, w)), 381 | 'constant', 382 | constant_values=0 383 | ) 384 | 385 | # data = np.tile(data, (1, 3, 1, 1)) 386 | gesture = get_index(np.hstack(gesture)) 387 | subject_orig = np.hstack(subject) 388 | subject = get_index(subject_orig) 389 | segment = np.hstack(segment) 390 | 391 | label = [] 392 | 393 | if soft_label is not False: 394 | label.append(('gesture_softmax_label', gesture)) 395 | label.append(('soft_label', soft_label[gesture])) 396 | else: 397 | label.append(('gesture_softmax_label', gesture)) 398 | 399 | if with_subject: 400 | label.append(('subject_softmax_label', subject)) 401 | # for i in range(gesture.max() + 1): 402 | # subset = subject.copy() 403 | # subset[gesture != i] = -1 404 | # label.append(('gesture%d_subject_softmax_label' % i, subset)) 405 | 406 | if with_target_gesture: 407 | if target_combos is not None: 408 | mask = np.in1d(subject_orig, list(set({combo.subject for combo in target_combos}))) 409 | target_gesture = gesture.copy() 410 | target_gesture[~mask, ...] = -1 411 | label.append(('target_gesture_softmax_label', target_gesture)) 412 | else: 413 | label.append(('target_gesture_softmax_label', gesture)) 414 | 415 | logger.debug('Make data iter') 416 | 417 | # important, use OrderedDict to ensure label order 418 | data = Data( 419 | data=OrderedDict([('data', data)]), 420 | label=OrderedDict(label), 421 | shuffle=shuffle, 422 | adabn=adabn, 423 | gesture=gesture.copy(), 424 | subject=subject.copy(), 425 | segment=segment.copy(), 426 | window=window, 427 | index=index, 428 | random_scale=random_scale, 429 | random_bad_channel=random_bad_channel, 430 | # num_sample=len(index), 431 | num_gesture=gesture.max() + 1, 432 | num_subject=subject.max() + 1, 433 | fft=fft, 434 | fft_append=fft_append, 435 | dual_stream=dual_stream, 436 | dense_window=dense_window, 437 | faug=faug, 438 | balance_gesture=balance_gesture, 439 | **kargs 440 | ) 441 | if not fft: 442 | data = Preload(data) 443 | return data 444 | 445 | 446 | class Preload(mx.io.PrefetchingIter): 447 | 448 | def __getattr__(self, name): 449 | if name != 'iters' and hasattr(self, 'iters') and hasattr(self.iters[0], name): 450 | return getattr(self.iters[0], name) 451 | raise AttributeError(name) 452 | 453 | def __setattr__(self, name, value): 454 | if name in ('shuffle', 'downsample', 'last_batch_handle'): 455 | return setattr(self.iters[0], name, value) 456 | return super(Preload, self).__setattr__(name, value) 457 | 458 | def iter_next(self): 459 | for e in self.data_ready: 460 | e.wait() 461 | if self.next_batch[0] is None: 462 | # for i in self.next_batch: 463 | # assert i is None, "Number of entry mismatches between iterators" 464 | return False 465 | else: 466 | # for batch in self.next_batch: 467 | # assert batch.pad == self.next_batch[0].pad, "Number of entry mismatches between iterators" 468 | self.current_batch = mx.io.DataBatch(sum([batch.data for batch in self.next_batch], []), 469 | sum([batch.label for batch in self.next_batch], []), 470 | self.next_batch[0].pad, 471 | self.next_batch[0].index) 472 | for e in self.data_ready: 473 | e.clear() 474 | for e in self.data_taken: 475 | e.set() 476 | return True 477 | 478 | 479 | class FaugData(mx.io.DataIter): 480 | 481 | def __init__(self, faug, batch_size, num_feature): 482 | super(FaugData, self).__init__() 483 | self.faug = faug 484 | self.batch_size = batch_size 485 | self.num_feature = num_feature 486 | 487 | @property 488 | def provide_data(self): 489 | return [('faug', (self.batch_size, self.num_feature))] 490 | 491 | @property 492 | def provide_label(self): 493 | return [] 494 | 495 | def iter_next(self): 496 | return True 497 | 498 | def getdata(self): 499 | if self.faug: 500 | return [mx.nd.array(self.faug * np.random.randn(self.batch_size, self.num_feature))] 501 | else: 502 | return [mx.nd.array(np.zeros((self.batch_size, self.num_feature)))] 503 | 504 | def getlabel(self): 505 | return [] 506 | 507 | 508 | class Data(mx.io.NDArrayIter): 509 | 510 | def __init__(self, *args, **kargs): 511 | self.random_shift_vertical = kargs.pop('random_shift_vertical', 0) 512 | self.random_shift_horizontal = kargs.pop('random_shift_horizontal', 0) 513 | self.random_shift_fill = kargs.pop('random_shift_fill', constant.RANDOM_SHIFT_FILL) 514 | self.framerate = kargs.pop('framerate', 1000) 515 | self.amplitude_weighting = kargs.pop('amplitude_weighting', False) 516 | self.amplitude_weighting_sort = kargs.pop('amplitude_weighting_sort', False) 517 | self.downsample = kargs.pop('downsample', None) 518 | self.dense_window = kargs.pop('dense_window') 519 | self.random_scale = kargs.pop('random_scale') 520 | self.random_bad_channel = kargs.pop('random_bad_channel') 521 | self.shuffle = kargs.pop('shuffle', False) 522 | self.adabn = kargs.pop('adabn', False) 523 | self._gesture = kargs.pop('gesture') 524 | self._subject = kargs.pop('subject') 525 | self._segment = kargs.pop('segment') 526 | self.window = kargs.pop('window') 527 | self._index_orig = kargs.pop('index') 528 | self._index = np.copy(self._index_orig) 529 | # self.num_sample = kargs.pop('num_sample') 530 | self.num_gesture = kargs.pop('num_gesture') 531 | self.num_subject = kargs.pop('num_subject') 532 | self.mini_batch_size = kargs.pop('mini_batch_size', kargs.get('batch_size')) 533 | self.random_state = kargs.pop('random_state', np.random) 534 | self.fft = kargs.pop('fft', False) 535 | self.fft_append = kargs.pop('fft_append', False) 536 | self.dual_stream = kargs.pop('dual_stream', False) 537 | self.faug = kargs.pop('faug', False) 538 | self.balance_gesture = kargs.pop('balance_gesture', 0) 539 | if not self.dual_stream: 540 | self.num_channel = self.window if not self.fft else self.window // 2 + (self.window if self.fft_append else 0) 541 | else: 542 | assert self.fft and not self.fft_append 543 | self.num_channel = [self.window, self.window // 2] 544 | 545 | super(Data, self).__init__(*args, **kargs) 546 | 547 | self.data = [(k, self._asnumpy(v)) for k, v in self.data] 548 | self.label = [(k, self._asnumpy(v)) for k, v in self.label] 549 | self.num_data = len(self._index) 550 | self.data_orig = self.data 551 | self.reset() 552 | # self.num_data = len(self._index) 553 | 554 | def _asnumpy(self, a): 555 | return a if not isinstance(a, mx.nd.NDArray) else a.asnumpy() 556 | 557 | @property 558 | def num_sample(self): 559 | return self.num_data 560 | 561 | @property 562 | def gesture(self): 563 | return self._gesture[self._index] 564 | 565 | @property 566 | def subject(self): 567 | return self._subject[self._index] 568 | 569 | @property 570 | def segment(self): 571 | return self._segment[self._index] 572 | 573 | @property 574 | def provide_data(self): 575 | if not self.dual_stream: 576 | res = [(k, tuple([self.batch_size, self.num_channel] + list(v.shape[2:]))) for k, v in self.data] 577 | else: 578 | assert_equal(len(self.data), 1) 579 | res = [('stream%d_' % i + self.data[0][0], tuple([self.batch_size, ch] + list(self.data[0][1].shape[2:]))) 580 | for i, ch in enumerate(self.num_channel)] 581 | if self.faug: 582 | res += [('faug', (self.batch_size, 16))] 583 | return res 584 | 585 | def _expand_index(self, index): 586 | return np.hstack([np.arange(i, i + self.window) for i in index]) 587 | 588 | def _reshape_data(self, data): 589 | return data.reshape(-1, self.window, *data.shape[2:]) 590 | 591 | def _get_fft(self, data): 592 | from .. import Context 593 | import joblib as jb 594 | res = [] 595 | for amp in Context.parallel(jb.delayed(_get_fft_aux)(sample, self.fft_append) for sample in data): 596 | res.append(amp[np.newaxis, ...]) 597 | return np.concatenate(res, axis=0) 598 | 599 | def _get_segments(self, a, index): 600 | b = mx.nd.empty((len(index), self.window) + a.shape[2:], dtype=a.dtype) 601 | for i, j in enumerate(index): 602 | b[i] = a[j:j + self.window].reshape(self.window, *a.shape[2:]) 603 | return b 604 | 605 | def _getdata(self, data_source): 606 | """Load data from underlying arrays, internal use only""" 607 | assert(self.cursor < self.num_data), "DataIter needs reset." 608 | 609 | if data_source is self.data and self.window > 1: 610 | if self.cursor + self.batch_size <= self.num_data: 611 | # res = [self._reshape_data(x[1][self._expand_index(self._index[self.cursor:self.cursor+self.batch_size])]) for x in data_source] 612 | res = [self._get_segments(x[1], self._index[self.cursor:self.cursor+self.batch_size]) for x in data_source] 613 | else: 614 | pad = self.batch_size - self.num_data + self.cursor 615 | res = [(np.concatenate((self._reshape_data(x[1][self._expand_index(self._index[self.cursor:])]), 616 | self._reshape_data(x[1][self._expand_index(self._index[:pad])])), axis=0)) for x in data_source] 617 | else: 618 | if self.cursor + self.batch_size <= self.num_data: 619 | res = [(x[1][self._index[self.cursor:self.cursor+self.batch_size]]) for x in data_source] 620 | else: 621 | pad = self.batch_size - self.num_data + self.cursor 622 | res = [(np.concatenate((x[1][self._index[self.cursor:]], x[1][self._index[:pad]]), axis=0)) for x in data_source] 623 | 624 | # if data_source is self.data: 625 | # for a in res: 626 | # assert np.all(np.isfinite(a)) and not np.all(a == 0) 627 | 628 | if data_source is self.data and self.fft: 629 | if not self.dual_stream: 630 | res = [self._get_fft(a.asnumpy() if isinstance(a, mx.nd.NDArray) else a) for a in res] 631 | else: 632 | res = res + [self._get_fft(a.asnumpy() if isinstance(a, mx.nd.NDArray) else a) for a in res] 633 | assert_equal(len(res), 2) 634 | 635 | if data_source is self.data and self.faug: 636 | res += [self.faug * self.random_state.randn(self.batch_size, 16)] 637 | 638 | res = [a if isinstance(a, mx.nd.NDArray) else mx.nd.array(a) for a in res] 639 | return res 640 | 641 | def _rand(self, smin, smax, shape): 642 | return (smax - smin) * self.random_state.rand(*shape) + smin 643 | 644 | def _do_shuffle(self): 645 | if not self.adabn or len(set(self._subject)) == 1: 646 | self.random_state.shuffle(self._index) 647 | else: 648 | batch_size = self.mini_batch_size 649 | # batch_size = self.batch_size 650 | # logger.info('AdaBN shuffle with a mini batch size of {}', batch_size) 651 | self.random_state.shuffle(self._index) 652 | subject_shuffled = self._subject[self._index] 653 | index_batch = [] 654 | for i in sorted(set(self._subject)): 655 | index = self._index[subject_shuffled == i] 656 | index = index[:len(index) // batch_size * batch_size] 657 | index_batch.append(index.reshape(-1, batch_size)) 658 | index_batch = np.vstack(index_batch) 659 | index = np.arange(len(index_batch)) 660 | self.random_state.shuffle(index) 661 | self._index = index_batch[index, :].ravel() 662 | # assert len(self._index) == len(set(self._index)) 663 | 664 | for i in range(0, len(self._subject), batch_size): 665 | # Make sure that the samples in one batch are from the same subject 666 | assert np.all(self._subject[self._index[i:i + batch_size - 1]] == 667 | self._subject[self._index[i + 1:i + batch_size]]) 668 | 669 | if batch_size != self.batch_size: 670 | assert self.batch_size % batch_size == 0 671 | # assert (self.batch_size // batch_size) % self.num_subject == 0 672 | self._index = self._index[:len(self._index) // self.batch_size * self.batch_size].reshape( 673 | -1, self.batch_size // batch_size, batch_size).transpose(0, 2, 1).ravel() 674 | 675 | def reset(self): 676 | self._reset() 677 | super(Data, self).reset() 678 | 679 | def _reset(self): 680 | # self._index.sort() 681 | self._index = np.copy(self._index_orig) 682 | 683 | if self.amplitude_weighting: 684 | assert np.all(self._index[:-1] < self._index[1:]) 685 | if not hasattr(self, 'amplitude_weight'): 686 | self.amplitude_weight = get_amplitude_weight( 687 | self.data[0][1], self._segment, self.framerate) 688 | if self.shuffle: 689 | random_state = self.random_state 690 | else: 691 | random_state = np.random.RandomState(677) 692 | self._index = random_state.choice( 693 | self._index, len(self._index), p=self.amplitude_weight) 694 | if self.amplitude_weighting_sort: 695 | logger.debug('Amplitude weighting sort') 696 | self._index.sort() 697 | 698 | if self.downsample: 699 | samples = np.arange(len(self._index)) 700 | np.random.RandomState(667).shuffle(samples) 701 | assert self.downsample > 0 and self.downsample <= 1 702 | samples = samples[:int(np.round(len(samples) * self.downsample))] 703 | assert len(samples) > 0 704 | self._index = self._index[samples] 705 | 706 | if self.balance_gesture: 707 | num_sample_per_gesture = int(np.round(self.balance_gesture * 708 | len(self._index) / self.num_gesture)) 709 | choice = [] 710 | for gesture in set(self.gesture): 711 | mask = self._gesture[self._index] == gesture 712 | choice.append(self.random_state.choice(np.where(mask)[0], 713 | num_sample_per_gesture)) 714 | choice = np.hstack(choice) 715 | self._index = self._index[choice] 716 | 717 | if self.shuffle: 718 | self._do_shuffle() 719 | 720 | if self.random_shift_horizontal or self.random_shift_vertical or self.random_scale or self.random_bad_channel: 721 | data = [(k, a.copy()) for k, a in self.data_orig] 722 | if self.random_shift_horizontal or self.random_shift_vertical: 723 | logger.info('shift {} {} {}', 724 | self.random_shift_fill, 725 | self.random_shift_horizontal, 726 | self.random_shift_vertical) 727 | hss = self.random_state.choice(1 + 2 * self.random_shift_horizontal, 728 | len(data[0][1])) - self.random_shift_horizontal 729 | vss = self.random_state.choice(1 + 2 * self.random_shift_vertical, 730 | len(data[0][1])) - self.random_shift_vertical 731 | # data = [(k, np.array([np.roll(row, s, axis=1) for row, s in izip(a, shift)])) 732 | # for k, a in data] 733 | data = [(k, np.array([_shift(row, hs, vs, self.random_shift_fill) 734 | for row, hs, vs in izip(a, hss, vss)])) 735 | for k, a in data] 736 | if self.random_scale: 737 | s = self.random_scale 738 | ss = s / 4 739 | data = [ 740 | (k, a * 2 ** (self._rand(-s, s, (a.shape[0], 1, 1, 1)) + self._rand(-ss, ss, a.shape))) 741 | for k, a in data 742 | ] 743 | if self.random_bad_channel: 744 | mask = self.random_state.choice(2, len(data[0][1])) > 0 745 | if mask.sum(): 746 | ch = self.random_state.choice(np.prod(data[0][1].shape[2:]), mask.sum()) 747 | row = ch // data[0][1].shape[3] 748 | col = ch % data[0][1].shape[3] 749 | val = self.random_state.choice(self.random_bad_channel, mask.sum()) 750 | val = np.tile(val.reshape(-1, 1), (1, data[0][1].shape[1])) 751 | for k, a in data: 752 | a[mask, :, row, col] = val 753 | self.data = data 754 | 755 | self.num_data = len(self._index) 756 | 757 | 758 | def _shift(a, hs, vs, fill): 759 | if fill == 'zero': 760 | b = np.zeros(a.shape, dtype=a.dtype) 761 | elif fill == 'margin': 762 | b = np.empty(a.shape, dtype=a.dtype) 763 | else: 764 | assert False, 'Known fill type: {}'.format(fill) 765 | 766 | s = a.shape 767 | if hs < 0: 768 | shb, she = -hs, s[2] 769 | thb, the = 0, s[2] + hs 770 | else: 771 | shb, she = 0, s[2] - hs 772 | thb, the = hs, s[2] 773 | if vs < 0: 774 | svb, sve = -vs, s[1] 775 | tvb, tve = 0, s[1] + vs 776 | else: 777 | svb, sve = 0, s[1] - vs 778 | tvb, tve = vs, s[1] 779 | b[:, tvb:tve, thb:the] = a[:, svb:sve, shb:she] 780 | 781 | if fill == 'margin': 782 | # Corners 783 | b[:, :tvb, :thb] = b[:, tvb, thb] 784 | b[:, tve:, :thb] = b[:, tve - 1, thb] 785 | b[:, tve:, the:] = b[:, tve - 1, the - 1] 786 | b[:, :tvb, the:] = b[:, tvb, the - 1] 787 | # Borders 788 | b[:, :tvb, thb:the] = b[:, tvb:tvb + 1, thb:the] 789 | b[:, tvb:tve, :thb] = b[:, tvb:tve, thb:thb + 1] 790 | b[:, tve:, thb:the] = b[:, tve - 1:tve, thb:the] 791 | b[:, tvb:tve, the:] = b[:, tvb:tve, the - 1:the] 792 | 793 | return b 794 | 795 | 796 | def _get_fft_aux(data, append): 797 | from ..fft import fft 798 | _, amp = fft(data.reshape(data.shape[0], -1).transpose(), 1000) 799 | amp = amp.transpose().reshape(-1, *data.shape[1:]) 800 | return amp if not append else np.concatenate([data, amp], axis=0) 801 | 802 | 803 | def get_amplitude_weight(data, segment, framerate): 804 | from .. import Context 805 | import joblib as jb 806 | indices = [np.where(segment == i)[0] for i in set(segment)] 807 | w = np.empty(len(segment), dtype=np.float) 808 | for i, ret in zip( 809 | indices, 810 | Context.parallel(jb.delayed(get_amplitude_weight_aux)(data[i], framerate) 811 | for i in indices) 812 | ): 813 | w[i] = ret 814 | return w / max(w.sum(), 1e-8) 815 | 816 | 817 | def get_amplitude_weight_aux(data, framerate): 818 | return _get_amplitude_weight_aux(data, framerate) 819 | 820 | 821 | @utils.cached 822 | def _get_amplitude_weight_aux(data, framerate): 823 | # High-Density Electromyography and Motor Skill Learning for Robust Long-Term Control of a 7-DoF Robot Arm 824 | lowpass = utils.butter_lowpass_filter 825 | shape = data.shape 826 | data = np.abs(data.reshape(shape[0], -1)) 827 | data = np.transpose([lowpass(ch, 3, framerate, 4, zero_phase=True) for ch in data.T]) 828 | data = data.mean(axis=1) 829 | data -= data.min() 830 | data /= max(data.max(), 1e-8) 831 | return data 832 | 833 | 834 | from .preprocess import Preprocess 835 | from . import capgmyo, ninapro 836 | assert capgmyo, ninapro 837 | 838 | 839 | __all__ = ['Dataset', 'Preprocess', 'get_data'] 840 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | from itertools import product 4 | import numpy as np 5 | import scipy.io as sio 6 | from logbook import Logger 7 | from ... import utils, CACHE 8 | from .. import Dataset as Base, Combo, Trial, SingleSessionMixin 9 | 10 | 11 | TRIALS = list(range(1, 11)) 12 | NUM_TRIAL = len(TRIALS) 13 | NUM_SEMG_ROW = 16 14 | NUM_SEMG_COL = 8 15 | FRAMERATE = 1000 16 | PREPROCESS_KARGS = dict( 17 | framerate=FRAMERATE, 18 | num_semg_row=NUM_SEMG_ROW, 19 | num_semg_col=NUM_SEMG_COL 20 | ) 21 | 22 | logger = Logger(__name__) 23 | 24 | 25 | class GetTrial(object): 26 | 27 | def __init__(self, gestures, trials, preprocess=None): 28 | self.preprocess = preprocess 29 | self.memo = {} 30 | self.gesture_and_trials = list(product(gestures, trials)) 31 | 32 | def get_path(self, root, combo): 33 | return os.path.join( 34 | root, 35 | '{c.subject:03d}-{c.gesture:03d}-{c.trial:03d}.mat'.format(c=combo)) 36 | 37 | def __call__(self, root, combo): 38 | path = self.get_path(root, combo) 39 | if path not in self.memo: 40 | logger.debug('Load subject {}', combo.subject) 41 | paths = [self.get_path(root, Combo(combo.subject, gesture, trial)) 42 | for gesture, trial in self.gesture_and_trials] 43 | self.memo.update({path: data for path, data in 44 | zip(paths, _get_data(paths, self.preprocess))}) 45 | data = self.memo[path] 46 | data = data.copy() 47 | gesture = np.repeat(combo.gesture, len(data)) 48 | subject = np.repeat(combo.subject, len(data)) 49 | return Trial(data=data, gesture=gesture, subject=subject) 50 | 51 | 52 | @utils.cached 53 | def _get_data(paths, preprocess): 54 | # return list(Context.parallel( 55 | # jb.delayed(_get_data_aux)(path, preprocess) for path in paths)) 56 | return [_get_data_aux(path, preprocess) for path in paths] 57 | 58 | 59 | def _get_data_aux(path, preprocess): 60 | data = sio.loadmat(path)['data'].astype(np.float32) 61 | if preprocess: 62 | data = preprocess(data, **PREPROCESS_KARGS) 63 | return data 64 | 65 | 66 | class Dataset(SingleSessionMixin, Base): 67 | 68 | framerate = FRAMERATE 69 | num_semg_row = NUM_SEMG_ROW 70 | num_semg_col = NUM_SEMG_COL 71 | trials = TRIALS 72 | 73 | def __init__(self, root): 74 | self.root = root 75 | 76 | def get_trial_func(self, *args, **kargs): 77 | return GetTrial(*args, **kargs) 78 | 79 | @classmethod 80 | def parse(cls, text): 81 | if cls is not Dataset and text == cls.name: 82 | return cls(root=os.path.join(CACHE, cls.name.split('/')[0])) 83 | 84 | 85 | from . import dba, dbb, dbc 86 | assert dba and dbb and dbc 87 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/dba.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'dba' 8 | subjects = list(range(1, 19)) 9 | gestures = list(range(1, 9)) 10 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/dbb.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from functools import partial 3 | from itertools import product 4 | from logbook import Logger 5 | from . import Dataset as Base 6 | from .. import get_data 7 | from ... import constant 8 | 9 | 10 | logger = Logger(__name__) 11 | 12 | 13 | class Dataset(Base): 14 | 15 | name = 'dbb' 16 | subjects = list(range(2, 21, 2)) 17 | gestures = list(range(1, 9)) 18 | num_session = 2 19 | sessions = [1, 2] 20 | 21 | def get_universal_inter_session_data(self, fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs): 22 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 23 | load = partial(get_data, 24 | framerate=self.framerate, 25 | root=self.root, 26 | last_batch_handle='pad', 27 | get_trial=get_trial, 28 | batch_size=batch_size, 29 | num_semg_row=self.num_semg_row, 30 | num_semg_col=self.num_semg_col) 31 | session = fold + 1 32 | subjects = list(range(1, 11)) 33 | num_subject = 10 34 | train = load(combos=self.get_combos(product([self.encode_subject_and_session(s, i) for s, i in 35 | product(subjects, [i for i in self.sessions if i != session])], 36 | self.gestures, self.trials)), 37 | adabn=adabn, 38 | mini_batch_size=batch_size // (num_subject * (self.num_session - 1) if minibatch else 1), 39 | balance_gesture=balance_gesture, 40 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 41 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 42 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 43 | shuffle=True) 44 | logger.debug('Training set loaded') 45 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(s, session) for s in subjects], 46 | self.gestures, self.trials)), 47 | adabn=adabn, 48 | mini_batch_size=batch_size // (num_subject if minibatch else 1), 49 | shuffle=False) 50 | logger.debug('Test set loaded') 51 | return train, val 52 | 53 | def get_inter_session_data(self, fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs): 54 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 55 | load = partial(get_data, 56 | framerate=self.framerate, 57 | root=self.root, 58 | last_batch_handle='pad', 59 | get_trial=get_trial, 60 | batch_size=batch_size, 61 | num_semg_row=self.num_semg_row, 62 | num_semg_col=self.num_semg_col) 63 | subject = fold // self.num_session + 1 64 | session = fold % self.num_session + 1 65 | train = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, i) for i in self.sessions if i != session], 66 | self.gestures, self.trials)), 67 | adabn=adabn, 68 | mini_batch_size=batch_size // (self.num_session - 1 if minibatch else 1), 69 | balance_gesture=balance_gesture, 70 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 71 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 72 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 73 | shuffle=True) 74 | logger.debug('Training set loaded') 75 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, session)], 76 | self.gestures, self.trials)), 77 | shuffle=False) 78 | logger.debug('Test set loaded') 79 | return train, val 80 | 81 | def get_inter_session_val(self, fold, batch_size, preprocess=None, **kargs): 82 | get_trial = self.get_trial_func(self.gestures, self.trials, preprocess=preprocess) 83 | load = partial(get_data, 84 | framerate=self.framerate, 85 | root=self.root, 86 | last_batch_handle='pad', 87 | get_trial=get_trial, 88 | batch_size=batch_size, 89 | num_semg_row=self.num_semg_row, 90 | num_semg_col=self.num_semg_col) 91 | subject = fold // self.num_session + 1 92 | session = fold % self.num_session + 1 93 | val = load(combos=self.get_combos(product([self.encode_subject_and_session(subject, session)], 94 | self.gestures, self.trials)), 95 | shuffle=False) 96 | logger.debug('Test set loaded') 97 | return val 98 | 99 | def encode_subject_and_session(self, subject, session): 100 | return (subject - 1) * self.num_session + session 101 | -------------------------------------------------------------------------------- /sigr/data/capgmyo/dbc.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'dbc' 8 | subjects = list(range(1, 11)) 9 | gestures = list(range(1, 13)) 10 | -------------------------------------------------------------------------------- /sigr/data/csl.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import get_data, Combo, Trial 3 | from .. import ROOT, Context, constant 4 | import os 5 | from itertools import product 6 | from functools import partial 7 | import scipy.io as sio 8 | import numpy as np 9 | from logbook import Logger 10 | import joblib as jb 11 | from ..utils import cached 12 | from nose.tools import assert_is_not_none 13 | 14 | 15 | ROOT = os.path.join(ROOT, '.cache/csl') 16 | NUM_TRIAL = 10 17 | SUBJECTS = list(range(1, 6)) 18 | SESSIONS = list(range(1, 6)) 19 | NUM_SESSION = len(SESSIONS) 20 | NUM_SUBJECT = len(SUBJECTS) 21 | NUM_SUBJECT_AND_SESSION = len(SUBJECTS) * NUM_SESSION 22 | SUBJECT_AND_SESSIONS = list(range(1, NUM_SUBJECT_AND_SESSION + 1)) 23 | GESTURES = list(range(27)) 24 | REST_TRIALS = [x - 1 for x in [2, 4, 7, 8, 11, 13, 19, 25, 26, 30]] 25 | NUM_SEMG_ROW = 24 26 | NUM_SEMG_COL = 7 27 | FRAMERATE = 2048 28 | framerate = FRAMERATE 29 | TRIALS = list(range(NUM_TRIAL)) 30 | PREPROCESS_KARGS = dict( 31 | framerate=FRAMERATE, 32 | num_semg_row=NUM_SEMG_ROW, 33 | num_semg_col=NUM_SEMG_COL 34 | ) 35 | 36 | logger = Logger('csl') 37 | 38 | 39 | def get_general_data(batch_size, adabn, minibatch, downsample, **kargs): 40 | get_trial = GetTrial(downsample=downsample) 41 | load = partial(get_data, 42 | framerate=FRAMERATE, 43 | root=ROOT, 44 | last_batch_handle='pad', 45 | get_trial=get_trial, 46 | batch_size=batch_size, 47 | num_semg_row=NUM_SEMG_ROW, 48 | num_semg_col=NUM_SEMG_COL) 49 | train = load(combos=get_combos(product(SUBJECT_AND_SESSIONS, GESTURES[1:], range(0, NUM_TRIAL, 2)), 50 | product(SUBJECT_AND_SESSIONS, GESTURES[:1], REST_TRIALS[0::2])), 51 | adabn=adabn, 52 | shuffle=True, 53 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 54 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 55 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 56 | mini_batch_size=batch_size // (NUM_SUBJECT_AND_SESSION if minibatch else 1)) 57 | logger.debug('Training set loaded') 58 | val = load(combos=get_combos(product(SUBJECT_AND_SESSIONS, GESTURES[1:], range(1, NUM_TRIAL, 2)), 59 | product(SUBJECT_AND_SESSIONS, GESTURES[:1], REST_TRIALS[1::2])), 60 | shuffle=False) 61 | logger.debug('Test set loaded') 62 | return train, val 63 | 64 | 65 | def get_intra_session_val(fold, batch_size, preprocess, **kargs): 66 | get_trial = GetTrial(preprocess=preprocess) 67 | load = partial(get_data, 68 | amplitude_weighting=kargs.get('amplitude_weighting', False), 69 | amplitude_weighting_sort=kargs.get('amplitude_weighting_sort', False), 70 | framerate=FRAMERATE, 71 | root=ROOT, 72 | last_batch_handle='pad', 73 | get_trial=get_trial, 74 | batch_size=batch_size, 75 | num_semg_row=NUM_SEMG_ROW, 76 | num_semg_col=NUM_SEMG_COL, 77 | random_state=np.random.RandomState(42)) 78 | subject = fold // (NUM_SESSION * NUM_TRIAL) + 1 79 | session = fold // NUM_TRIAL % NUM_SESSION + 1 80 | fold = fold % NUM_TRIAL 81 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)], 82 | GESTURES[1:], [fold]), 83 | product([encode_subject_and_session(subject, session)], 84 | GESTURES[:1], REST_TRIALS[fold:fold + 1])), 85 | shuffle=False) 86 | return val 87 | 88 | 89 | def get_universal_intra_session_data(fold, batch_size, preprocess, balance_gesture, **kargs): 90 | get_trial = GetTrial(preprocess=preprocess) 91 | load = partial(get_data, 92 | amplitude_weighting=kargs.get('amplitude_weighting', False), 93 | amplitude_weighting_sort=kargs.get('amplitude_weighting_sort', False), 94 | framerate=FRAMERATE, 95 | root=ROOT, 96 | last_batch_handle='pad', 97 | get_trial=get_trial, 98 | batch_size=batch_size, 99 | num_semg_row=NUM_SEMG_ROW, 100 | num_semg_col=NUM_SEMG_COL) 101 | trial = fold 102 | train = load(combos=get_combos(product(SUBJECT_AND_SESSIONS, 103 | GESTURES[1:], [i for i in range(NUM_TRIAL) if i != trial]), 104 | product(SUBJECT_AND_SESSIONS, 105 | GESTURES[:1], [REST_TRIALS[i] for i in range(NUM_TRIAL) if i != trial])), 106 | balance_gesture=balance_gesture, 107 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 108 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 109 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 110 | shuffle=True) 111 | assert_is_not_none(train) 112 | logger.debug('Training set loaded') 113 | val = load(combos=get_combos(product(SUBJECT_AND_SESSIONS, 114 | GESTURES[1:], [trial]), 115 | product(SUBJECT_AND_SESSIONS, 116 | GESTURES[:1], REST_TRIALS[trial:trial + 1])), 117 | shuffle=False) 118 | logger.debug('Test set loaded') 119 | assert_is_not_none(val) 120 | return train, val 121 | 122 | 123 | def get_intra_session_data(fold, batch_size, preprocess, balance_gesture, **kargs): 124 | get_trial = GetTrial(preprocess=preprocess) 125 | load = partial(get_data, 126 | amplitude_weighting=kargs.get('amplitude_weighting', False), 127 | amplitude_weighting_sort=kargs.get('amplitude_weighting_sort', False), 128 | framerate=FRAMERATE, 129 | root=ROOT, 130 | last_batch_handle='pad', 131 | get_trial=get_trial, 132 | batch_size=batch_size, 133 | num_semg_row=NUM_SEMG_ROW, 134 | num_semg_col=NUM_SEMG_COL) 135 | subject = fold // (NUM_SESSION * NUM_TRIAL) + 1 136 | session = fold // NUM_TRIAL % NUM_SESSION + 1 137 | fold = fold % NUM_TRIAL 138 | train = load(combos=get_combos(product([encode_subject_and_session(subject, session)], 139 | GESTURES[1:], [f for f in range(NUM_TRIAL) if f != fold]), 140 | product([encode_subject_and_session(subject, session)], 141 | GESTURES[:1], [REST_TRIALS[f] for f in range(NUM_TRIAL) if f != fold])), 142 | balance_gesture=balance_gesture, 143 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 144 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 145 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 146 | shuffle=True) 147 | assert_is_not_none(train) 148 | logger.debug('Training set loaded') 149 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)], 150 | GESTURES[1:], [fold]), 151 | product([encode_subject_and_session(subject, session)], 152 | GESTURES[:1], REST_TRIALS[fold:fold + 1])), 153 | shuffle=False) 154 | logger.debug('Test set loaded') 155 | assert_is_not_none(val) 156 | return train, val 157 | 158 | 159 | def get_inter_session_data(fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs): 160 | # TODO: calib 161 | get_trial = GetTrial(preprocess=preprocess) 162 | load = partial(get_data, 163 | framerate=FRAMERATE, 164 | root=ROOT, 165 | last_batch_handle='pad', 166 | get_trial=get_trial, 167 | batch_size=batch_size, 168 | num_semg_row=NUM_SEMG_ROW, 169 | num_semg_col=NUM_SEMG_COL) 170 | subject = fold // NUM_SESSION + 1 171 | session = fold % NUM_SESSION + 1 172 | train = load(combos=get_combos(product([encode_subject_and_session(subject, i) for i in SESSIONS if i != session], 173 | GESTURES[1:], TRIALS), 174 | product([encode_subject_and_session(subject, i) for i in SESSIONS if i != session], 175 | GESTURES[:1], REST_TRIALS)), 176 | adabn=adabn, 177 | mini_batch_size=batch_size // (NUM_SESSION - 1 if minibatch else 1), 178 | balance_gesture=balance_gesture, 179 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 180 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 181 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 182 | shuffle=True) 183 | logger.debug('Training set loaded') 184 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)], 185 | GESTURES[1:], TRIALS), 186 | product([encode_subject_and_session(subject, session)], 187 | GESTURES[:1], REST_TRIALS)), 188 | shuffle=False) 189 | logger.debug('Test set loaded') 190 | return train, val 191 | 192 | 193 | def get_inter_session_val(fold, batch_size, preprocess, **kargs): 194 | # TODO: calib 195 | get_trial = GetTrial(preprocess=preprocess) 196 | load = partial(get_data, 197 | framerate=FRAMERATE, 198 | root=ROOT, 199 | last_batch_handle='pad', 200 | get_trial=get_trial, 201 | batch_size=batch_size, 202 | num_semg_row=NUM_SEMG_ROW, 203 | num_semg_col=NUM_SEMG_COL, 204 | random_state=np.random.RandomState(42)) 205 | subject = fold // NUM_SESSION + 1 206 | session = fold % NUM_SESSION + 1 207 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session)], 208 | GESTURES[1:], TRIALS), 209 | product([encode_subject_and_session(subject, session)], 210 | GESTURES[:1], REST_TRIALS)), 211 | shuffle=False) 212 | return val 213 | 214 | 215 | def get_universal_inter_session_data(fold, batch_size, preprocess, adabn, minibatch, balance_gesture, **kargs): 216 | # TODO: calib 217 | get_trial = GetTrial(preprocess=preprocess) 218 | load = partial(get_data, 219 | framerate=FRAMERATE, 220 | root=ROOT, 221 | last_batch_handle='pad', 222 | get_trial=get_trial, 223 | batch_size=batch_size, 224 | num_semg_row=NUM_SEMG_ROW, 225 | num_semg_col=NUM_SEMG_COL) 226 | session = fold + 1 227 | train = load(combos=get_combos(product([encode_subject_and_session(s, i) for s, i in 228 | product(SUBJECTS, [i for i in SESSIONS if i != session])], 229 | GESTURES[1:], TRIALS), 230 | product([encode_subject_and_session(s, i) for s, i in 231 | product(SUBJECTS, [i for i in SESSIONS if i != session])], 232 | GESTURES[:1], REST_TRIALS)), 233 | adabn=adabn, 234 | mini_batch_size=batch_size // (NUM_SUBJECT * (NUM_SESSION - 1) if minibatch else 1), 235 | balance_gesture=balance_gesture, 236 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 237 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 238 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 239 | shuffle=True) 240 | logger.debug('Training set loaded') 241 | val = load(combos=get_combos(product([encode_subject_and_session(s, session) for s in SUBJECTS], 242 | GESTURES[1:], TRIALS), 243 | product([encode_subject_and_session(s, session) for s in SUBJECTS], 244 | GESTURES[:1], REST_TRIALS)), 245 | adabn=adabn, 246 | mini_batch_size=batch_size // (NUM_SUBJECT if minibatch else 1), 247 | shuffle=False) 248 | logger.debug('Test set loaded') 249 | return train, val 250 | 251 | 252 | def get_intra_subject_data(fold, batch_size, cut, bandstop, adabn, minibatch, **kargs): 253 | get_trial = GetTrial(cut=cut, bandstop=bandstop) 254 | load = partial(get_data, 255 | framerate=FRAMERATE, 256 | root=ROOT, 257 | last_batch_handle='pad', 258 | get_trial=get_trial, 259 | batch_size=batch_size, 260 | num_semg_row=NUM_SEMG_ROW, 261 | num_semg_col=NUM_SEMG_COL) 262 | subject = fold // NUM_TRIAL + 1 263 | fold = fold % NUM_TRIAL 264 | train = load(combos=get_combos(product([encode_subject_and_session(subject, session) for session in SESSIONS], 265 | GESTURES[1:], [f for f in range(NUM_TRIAL) if f != fold]), 266 | product([encode_subject_and_session(subject, session) for session in SESSIONS], 267 | GESTURES[:1], [REST_TRIALS[f] for f in range(NUM_TRIAL) if f != fold])), 268 | adabn=adabn, 269 | mini_batch_size=batch_size // (NUM_SESSION if minibatch else 1), 270 | random_shift_fill=kargs.get('random_shift_fill', constant.RANDOM_SHIFT_FILL), 271 | random_shift_horizontal=kargs.get('random_shift_horizontal', 0), 272 | random_shift_vertical=kargs.get('random_shift_vertical', 0), 273 | shuffle=True) 274 | logger.debug('Training set loaded') 275 | val = load(combos=get_combos(product([encode_subject_and_session(subject, session) for session in SESSIONS], 276 | GESTURES[1:], [fold]), 277 | product([encode_subject_and_session(subject, session) for session in SESSIONS], 278 | GESTURES[:1], REST_TRIALS[fold:fold + 1])), 279 | shuffle=False) 280 | logger.debug('Test set loaded') 281 | return train, val 282 | 283 | 284 | class GetTrial(object): 285 | 286 | def __init__(self, preprocess=None): 287 | self.preprocess = preprocess 288 | self.memo = {} 289 | 290 | def __call__(self, root, combo): 291 | subject, session = decode_subject_and_session(combo.subject) 292 | path = os.path.join(root, 293 | 'subject%d' % subject, 294 | 'session%d' % session, 295 | 'gest%d.mat' % combo.gesture) 296 | if path not in self.memo: 297 | data = _get_data(path, self.preprocess) 298 | self.memo[path] = data 299 | logger.debug('{}', path) 300 | else: 301 | data = self.memo[path] 302 | assert combo.trial < len(data), str(combo) 303 | data = data[combo.trial].copy() 304 | gesture = np.repeat(combo.gesture, len(data)) 305 | subject = np.repeat(combo.subject, len(data)) 306 | return Trial(data=data, gesture=gesture, subject=subject) 307 | 308 | 309 | @cached 310 | def _get_data(path, preprocess): 311 | data = sio.loadmat(path)['gestures'] 312 | data = [np.transpose(np.delete(segment.astype(np.float32), np.s_[7:192:8], 0)) 313 | for segment in data.flat] 314 | if preprocess: 315 | data = list(Context.parallel(jb.delayed(preprocess)(segment, **PREPROCESS_KARGS) 316 | for segment in data)) 317 | return data 318 | 319 | 320 | # @cached 321 | # def _get_data(path, bandstop, cut, downsample): 322 | # data = sio.loadmat(path)['gestures'] 323 | # data = [np.transpose(np.delete(segment.astype(np.float32), np.s_[7:192:8], 0)) 324 | # for segment in data.flat] 325 | # if bandstop: 326 | # data = list(Context.parallel(jb.delayed(get_bandstop)(segment) for segment in data)) 327 | # if cut is not None: 328 | # data = list(Context.parallel(jb.delayed(cut)(segment, framerate=FRAMERATE) for segment in data)) 329 | # if downsample > 1: 330 | # data = [segment[::downsample].copy() for segment in data] 331 | # return data 332 | 333 | 334 | def decode_subject_and_session(ss): 335 | return (ss - 1) // NUM_SESSION + 1, (ss - 1) % NUM_SESSION + 1 336 | 337 | 338 | def encode_subject_and_session(subject, session): 339 | return (subject - 1) * NUM_SESSION + session 340 | 341 | 342 | def get_bandstop(data): 343 | from ..utils import butter_bandstop_filter 344 | return np.array([butter_bandstop_filter(ch, 45, 55, 2048, 2) for ch in data]) 345 | 346 | 347 | def get_combos(*args): 348 | for arg in args: 349 | if isinstance(arg, tuple): 350 | arg = [arg] 351 | for a in arg: 352 | combo = Combo(*a) 353 | if ignore_missing(combo): 354 | continue 355 | yield combo 356 | 357 | 358 | def ignore_missing(combo): 359 | return combo.subject == 19 and combo.gesture in (8, 9) and combo.trial == 9 360 | -------------------------------------------------------------------------------- /sigr/data/ninapro/__init__.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | from itertools import product 4 | import numpy as np 5 | import scipy.io as sio 6 | from logbook import Logger 7 | from ... import utils, CACHE 8 | from .. import Dataset as Base, Combo, Trial, SingleSessionMixin 9 | 10 | 11 | NUM_SEMG_ROW = 1 12 | NUM_SEMG_COL = 10 13 | FRAMERATE = 100 14 | PREPROCESS_KARGS = dict( 15 | framerate=FRAMERATE, 16 | num_semg_row=NUM_SEMG_ROW, 17 | num_semg_col=NUM_SEMG_COL 18 | ) 19 | 20 | logger = Logger(__name__) 21 | 22 | 23 | class Dataset(SingleSessionMixin, Base): 24 | 25 | framerate = FRAMERATE 26 | num_semg_row = NUM_SEMG_ROW 27 | num_semg_col = NUM_SEMG_COL 28 | subjects = list(range(27)) 29 | gestures = list(range(53)) 30 | trials = list(range(10)) 31 | 32 | def __init__(self, root): 33 | self.root = root 34 | 35 | def get_one_fold_intra_subject_trials(self): 36 | return [0, 2, 3, 5, 7, 8, 9], [1, 4, 6] 37 | 38 | def get_trial_func(self, *args, **kargs): 39 | return GetTrial(*args, **kargs) 40 | 41 | @classmethod 42 | def parse(cls, text): 43 | if cls is not Dataset and text == cls.name: 44 | return cls(root=os.path.join(CACHE, cls.name.split('/')[0], 'data')) 45 | 46 | 47 | class GetTrial(object): 48 | 49 | def __init__(self, gestures, trials, preprocess=None): 50 | self.preprocess = preprocess 51 | self.memo = {} 52 | self.gesture_and_trials = list(product(gestures, trials)) 53 | 54 | def get_path(self, root, combo): 55 | return os.path.join( 56 | root, 57 | '{c.subject:03d}', 58 | '{c.gesture:03d}', 59 | '{c.subject:03d}_{c.gesture:03d}_{c.trial:03d}.mat').format(c=combo) 60 | 61 | def __call__(self, root, combo): 62 | path = self.get_path(root, combo) 63 | if path not in self.memo: 64 | logger.debug('Load subject {}', combo.subject) 65 | paths = [self.get_path(root, Combo(combo.subject, gesture, trial)) 66 | for gesture, trial in self.gesture_and_trials] 67 | self.memo.update({path: data for path, data in 68 | zip(paths, _get_data(paths, self.preprocess))}) 69 | data = self.memo[path] 70 | data = data.copy() 71 | gesture = np.repeat(combo.gesture, len(data)) 72 | subject = np.repeat(combo.subject, len(data)) 73 | return Trial(data=data, gesture=gesture, subject=subject) 74 | 75 | 76 | @utils.cached 77 | def _get_data(paths, preprocess): 78 | # return list(Context.parallel( 79 | # jb.delayed(_get_data_aux)(path, preprocess) for path in paths)) 80 | return [_get_data_aux(path, preprocess) for path in paths] 81 | 82 | 83 | def _get_data_aux(path, preprocess): 84 | data = sio.loadmat(path)['data'].astype(np.float32) 85 | if preprocess: 86 | data = preprocess(data, **PREPROCESS_KARGS) 87 | return data 88 | 89 | 90 | from . import db1, db1_g53, db1_g5, db1_g8, db1_g12, caputo, db1_matlab_lowpass 91 | assert db1 and db1_g53 and db1_g5 and db1_g8 and db1_g12 and caputo and db1_matlab_lowpass 92 | -------------------------------------------------------------------------------- /sigr/data/ninapro/caputo.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/caputo' 8 | gestures = list(range(1, 53)) 9 | 10 | def get_one_fold_intra_subject_trials(self): 11 | return [i - 1 for i in [1, 3, 4, 5, 9]], [i - 1 for i in [2, 6, 7, 8, 10]] 12 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1' 8 | gestures = list(range(1, 53)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g12.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g12' 8 | gestures = list(range(1, 13)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g5.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g5' 8 | gestures = list(range(25, 30)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g53.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g53' 8 | gestures = list(range(0, 53)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_g8.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Dataset(Base): 6 | 7 | name = 'ninapro-db1/g8' 8 | gestures = list(range(13, 21)) 9 | -------------------------------------------------------------------------------- /sigr/data/ninapro/db1_matlab_lowpass.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from . import Dataset as Base 3 | 4 | 5 | class Database(Base): 6 | 7 | name = 'ninapro-db1-matlab-lowpass' 8 | gestures = list(range(1, 53)) 9 | -------------------------------------------------------------------------------- /sigr/data/preprocess.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import re 3 | import numpy as np 4 | from nose.tools import assert_less_equal 5 | from ..utils import cached, butter_lowpass_filter as lowpass 6 | from scipy.ndimage.filters import median_filter 7 | 8 | 9 | class Preprocess(object): 10 | 11 | class __metaclass__(type): 12 | 13 | def __init__(cls, name, bases, fields): 14 | type.__init__(cls, name, bases, fields) 15 | if name != 'Preprocess': 16 | Preprocess.register(cls) 17 | 18 | impls = [] 19 | 20 | def __call__(self, data, **kargs): 21 | return data 22 | 23 | @classmethod 24 | def parse(cls, text): 25 | if not text: 26 | return None 27 | if cls is Preprocess: 28 | for impl in cls.impls: 29 | inst = impl.parse(text) 30 | if inst is not None: 31 | return inst 32 | 33 | @classmethod 34 | def register(cls, impl): 35 | cls.impls.append(impl) 36 | 37 | 38 | class Sequence(Preprocess): 39 | 40 | @classmethod 41 | def parse(cls, text): 42 | matched = re.search('\((.+)\)', text) 43 | if matched: 44 | return cls([Preprocess.parse(stage) for stage 45 | in matched.group(1).split(',')]) 46 | 47 | def __init__(self, stages): 48 | self.stages = stages 49 | 50 | def __call__(self, data, **kargs): 51 | for stage in self.stages: 52 | data = stage(data, **kargs) 53 | return data 54 | 55 | def __repr__(self): 56 | return 'Sequence(%s)' % ','.join(str(stage) for stage in self.stages) 57 | 58 | 59 | class Bandstop(Preprocess): 60 | 61 | @classmethod 62 | def parse(cls, text): 63 | if re.search('bandstop', text): 64 | return cls() 65 | 66 | def __call__(self, data, framerate, **kargs): 67 | from ..utils import butter_bandstop_filter as bandstop 68 | return np.transpose([bandstop(ch, 45, 55, framerate, 2) for ch in data.T]) 69 | 70 | def __repr__(self): 71 | return 'Bandstop()' 72 | 73 | 74 | class CSLBandpass(Preprocess): 75 | 76 | @classmethod 77 | def parse(cls, text): 78 | if re.search('csl-bandpass', text): 79 | return cls() 80 | 81 | def __call__(self, data, framerate, **kargs): 82 | from ..utils import butter_bandpass_filter as bandpass 83 | return np.transpose([bandpass(ch, 20, 400, framerate, 4) for ch in data.T]) 84 | 85 | def __repr__(self): 86 | return 'CSLBandpass()' 87 | 88 | 89 | class NinaProLowpass(Preprocess): 90 | 91 | @classmethod 92 | def parse(cls, text): 93 | if re.search('ninapro-lowpass', text): 94 | return cls() 95 | 96 | def __call__(self, data, framerate, **kargs): 97 | return np.transpose([lowpass(ch, 1, framerate, 1, zero_phase=True) for ch in data.T]) 98 | 99 | def __repr__(self): 100 | return 'NinaProLowpass()' 101 | 102 | 103 | class Downsample(Preprocess): 104 | 105 | @classmethod 106 | def parse(cls, text): 107 | matched = re.search('downsample-(\d+)', text) 108 | if matched: 109 | return cls(int(matched.group(1))) 110 | 111 | def __init__(self, step): 112 | self.step = step 113 | 114 | def __call__(self, data, **kargs): 115 | return data[::self.step].copy() 116 | 117 | def __repr__(self): 118 | return 'Downsample(step=%d)' % self.step 119 | 120 | 121 | class Median3x3(Preprocess): 122 | 123 | @classmethod 124 | def parse(cls, text): 125 | if re.search('median', text): 126 | return cls() 127 | 128 | def __call__(self, data, num_semg_row, num_semg_col, **kargs): 129 | return np.array([median_filter(image, 3).ravel() for image 130 | in data.reshape(-1, num_semg_row, num_semg_col)]) 131 | 132 | def __repr__(self): 133 | return 'Median3x3()' 134 | 135 | 136 | class Abs(Preprocess): 137 | 138 | @classmethod 139 | def parse(cls, text): 140 | if re.search('abs', text): 141 | return cls() 142 | 143 | def __call__(self, data, **kargs): 144 | return np.abs(data) 145 | 146 | def __repr__(self): 147 | return 'Abs()' 148 | 149 | 150 | class RMS(Preprocess): 151 | 152 | @classmethod 153 | def parse(cls, text): 154 | matched = re.search('rms-(\d+)', text) 155 | if matched: 156 | return cls(int(matched.group(1))) 157 | 158 | def __init__(self, window): 159 | self.window = window 160 | 161 | def __call__(self, data, **kargs): 162 | window = min(self.window, len(data)) 163 | return np.transpose([moving_rms(ch, window) for ch in data.T]) 164 | 165 | def __repr__(self): 166 | return 'RMS(window=%d)' % self.window 167 | 168 | 169 | class Cut(Preprocess): 170 | pass 171 | 172 | 173 | class MiddleCut(Cut): 174 | 175 | @classmethod 176 | def parse(cls, text): 177 | matched = re.search('mid-(\d+)', text) 178 | if matched: 179 | return cls(int(matched.group(1))) 180 | 181 | def __init__(self, window): 182 | self.window = window 183 | 184 | def __call__(self, data, **kargs): 185 | if len(data) < self.window: 186 | return data 187 | begin = (len(data) - self.window) // 2 188 | return data[begin:begin + self.window].copy() 189 | 190 | def __repr__(self): 191 | return 'MiddleCut(window=%d)' % self.window 192 | 193 | 194 | class PeakCut(Cut): 195 | 196 | @classmethod 197 | def parse(cls, text): 198 | matched = re.search('^peak-(\d+)$', text) 199 | if matched: 200 | return cls(int(matched.group(1))) 201 | 202 | def __init__(self, window): 203 | self.window = window 204 | 205 | def __call__(self, data, framerate, num_semg_row, num_semg_col, **kargs): 206 | if len(data) < self.window: 207 | return data 208 | 209 | begin = np.argmax(_get_amp(data, framerate, num_semg_row, num_semg_col) 210 | [self.window // 2:-(self.window - self.window // 2 - 1)]) 211 | assert_less_equal(begin + self.window, len(data)) 212 | return data[begin:begin + self.window] 213 | 214 | def __repr__(self): 215 | return 'PeakCut(window=%d)' % self.window 216 | 217 | 218 | class NinaProPeakCut(Cut): 219 | 220 | @classmethod 221 | def parse(cls, text): 222 | matched = re.search('^ninapro-peak-(\d+)$', text) 223 | if matched: 224 | return cls(int(matched.group(1))) 225 | 226 | def __init__(self, window): 227 | self.window = window 228 | 229 | def __call__(self, data, framerate, **kargs): 230 | if len(data) < self.window: 231 | return data 232 | 233 | begin = np.argmax(_get_ninapro_amp(data, framerate) 234 | [self.window // 2:-(self.window - self.window // 2 - 1)]) 235 | assert_less_equal(begin + self.window, len(data)) 236 | return data[begin:begin + self.window] 237 | 238 | def __repr__(self): 239 | return 'NinaProPeakCut(window=%d)' % self.window 240 | 241 | 242 | class CSLCut(Cut): 243 | 244 | @classmethod 245 | def parse(cls, text): 246 | if re.search('csl-cut', text): 247 | return cls() 248 | 249 | def __call__(self, data, framerate, **kargs): 250 | begin, end = _csl_cut(data, framerate) 251 | return data[begin:end] 252 | 253 | def __repr__(self): 254 | return 'CSLCut()' 255 | 256 | 257 | def _csl_cut(data, framerate): 258 | window = int(np.round(150 * framerate / 2048)) 259 | data = data[:len(data) // window * window].reshape(-1, 150, data.shape[1]) 260 | rms = np.sqrt(np.mean(np.square(data), axis=1)) 261 | rms = [median_filter(image, 3).ravel() for image in rms.reshape(-1, 24, 7)] 262 | rms = np.mean(rms, axis=1) 263 | threshold = np.mean(rms) 264 | mask = rms > threshold 265 | for i in range(1, len(mask) - 1): 266 | if not mask[i] and mask[i - 1] and mask[i + 1]: 267 | mask[i] = True 268 | from .. import utils 269 | begin, end = max(utils.continuous_segments(mask), 270 | key=lambda s: (mask[s[0]], s[1] - s[0])) 271 | return begin * window, end * window 272 | 273 | 274 | @cached 275 | def _get_amp(data, framerate, num_semg_row, num_semg_col): 276 | data = np.abs(data) 277 | data = np.transpose([lowpass(ch, 2, framerate, 4, zero_phase=True) for ch in data.T]) 278 | return [median_filter(image, 3).mean() for image in data.reshape(-1, num_semg_row, num_semg_col)] 279 | 280 | 281 | def _get_ninapro_amp(data, framerate): 282 | data = np.abs(data) 283 | data = np.transpose([lowpass(ch, 2, framerate, 4, zero_phase=True) for ch in data.T]) 284 | return data.mean(axis=1) 285 | 286 | 287 | def moving_rms(a, window): 288 | a2 = np.square(a) 289 | window = np.ones(window) / window 290 | return np.sqrt(np.convolve(a2, window, 'valid')) 291 | -------------------------------------------------------------------------------- /sigr/data/s21.py: -------------------------------------------------------------------------------- 1 | from itertools import product, starmap 2 | from . import get_data, Combo 3 | from .. import ROOT 4 | import os 5 | import numpy as np 6 | 7 | 8 | ROOT = os.path.join(ROOT, '.cache/mat.s21.bandstop-45-55.s1000m.scale-01') 9 | 10 | 11 | def get_coral(folds, batch_size): 12 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 13 | return get_data( 14 | root=ROOT, 15 | # combos=get_combos(product([subjects[fold] for fold in folds], [100, 101], [0])), 16 | combos=get_combos(product([subjects[fold] for fold in folds], range(1, 9), [0])), 17 | mean=0.5, 18 | scale=2, 19 | batch_size=2000, 20 | last_batch_handle='pad', 21 | shuffle=False, 22 | adabn=True 23 | ) 24 | 25 | 26 | def get_combos(prods): 27 | return list(starmap(Combo, prods)) 28 | 29 | 30 | def get_stats(): 31 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 32 | load = lambda subject: get_data( 33 | root=ROOT, 34 | combos=get_combos(product([subject], range(1, 9), range(10))), 35 | mean=0.5, 36 | scale=2, 37 | batch_size=1000, 38 | last_batch_handle='roll_over' 39 | ) 40 | stats = [] 41 | for subject in subjects: 42 | batch = next(load(subject)[0]) 43 | data = batch.data[0].asnumpy() 44 | stats.append({ 45 | 'std': data.std() 46 | }) 47 | import pandas as pd 48 | return pd.DataFrame(stats, index=range(10)) 49 | 50 | 51 | def get_general_data(root, batch_size, with_subject): 52 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 53 | load = lambda **kargs: get_data( 54 | root=root, 55 | mean=0.5, 56 | scale=2, 57 | with_subject=with_subject, 58 | batch_size=batch_size, 59 | last_batch_handle='roll_over', 60 | **kargs 61 | ) 62 | val, num_val = load(combos=get_combos(product(subjects, range(1, 9), range(1, 10, 2)))) 63 | train, num_train = load(combos=get_combos(product(subjects, range(1, 9), range(0, 10, 2)))) 64 | return train, val, num_train, num_val 65 | 66 | 67 | def get_inter_subject_data( 68 | root, 69 | fold, 70 | batch_size, 71 | maxforce, 72 | target_binary, 73 | calib, 74 | with_subject, 75 | with_target_gesture, 76 | random_scale, 77 | random_bad_channel, 78 | shuffle, 79 | adabn, 80 | window, 81 | only_calib, 82 | soft_label, 83 | minibatch, 84 | fft, 85 | fft_append, 86 | dual_stream, 87 | lstm, 88 | dense_window, 89 | lstm_window 90 | ): 91 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 92 | 93 | num_subject = 10 if maxforce or calib else 9 94 | if minibatch: 95 | assert batch_size % num_subject == 0, '%d %% %d' % (batch_size, num_subject) 96 | mini_batch_size = batch_size // num_subject 97 | else: 98 | mini_batch_size = batch_size 99 | 100 | load = lambda **kargs: get_data( 101 | root=root, 102 | mean=0.5, 103 | scale=2, 104 | with_subject=with_subject, 105 | target_binary=target_binary, 106 | batch_size=batch_size, 107 | with_target_gesture=with_target_gesture, 108 | fft=fft, 109 | fft_append=fft_append, 110 | dual_stream=dual_stream, 111 | **kargs 112 | ) 113 | val_subject = subjects[fold] 114 | del subjects[fold] 115 | val = load( 116 | combos=get_combos(product([val_subject], range(1, 9), range(1, 10) if calib else range(10))), 117 | last_batch_handle='pad', 118 | shuffle=False, 119 | window=(window // (lstm_window or window)) if lstm else window, 120 | num_ignore_per_segment=window - 1 if lstm else 0, 121 | dense_window=dense_window 122 | ) 123 | 124 | if maxforce and calib: 125 | target_combos = get_combos(product([val_subject], list(range(1, 9)) * 10 + [100, 101], [0] * (9 if target_binary else 1))) 126 | elif maxforce: 127 | target_combos = get_combos(product([val_subject], [100, 101], [0] * 41 * (9 if target_binary else 1))) 128 | elif only_calib: 129 | target_combos = get_combos(product([val_subject], list(range(1, 9)), [0])) 130 | elif calib: 131 | target_combos = get_combos(product([val_subject], list(range(1, 9)) * 10, [0] * (9 if target_binary else 1))) 132 | else: 133 | target_combos = None 134 | 135 | if only_calib: 136 | combos = [] 137 | else: 138 | combos = get_combos(product(subjects, range(1, 9), range(10))) 139 | if maxforce: 140 | combos += get_combos(product(subjects, [100, 101], [0])) 141 | 142 | if soft_label: 143 | import pandas as pd 144 | soft_label = pd.DataFrame.from_csv(os.path.join(os.path.dirname(__file__), 's21_soft_label.scv')) 145 | 146 | train = load( 147 | combos=combos, 148 | target_combos=target_combos, 149 | random_scale=random_scale, 150 | random_bad_channel=random_bad_channel, 151 | last_batch_handle='pad', 152 | shuffle=shuffle, 153 | mini_batch_size=mini_batch_size, 154 | soft_label=False if soft_label is False else soft_label[soft_label['fold'] == fold][[str(i) for i in range(8)]].as_matrix(), 155 | adabn=adabn, 156 | window=window, 157 | dense_window=dense_window 158 | ) 159 | return train, val 160 | 161 | 162 | def get_inter_subject_val(fold, batch_size, calib, **kargs): 163 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 164 | return get_data( 165 | combos=get_combos(product([subjects[fold]], range(1, 9), range(1, 10) if calib else range(10))), 166 | root=ROOT, 167 | mean=0.5, 168 | scale=2, 169 | batch_size=batch_size, 170 | last_batch_handle='pad', 171 | shuffle=False, 172 | random_state=np.random.RandomState(42), 173 | **kargs 174 | ) 175 | 176 | 177 | def get_inter_subject_train(fold, batch_size): 178 | subjects = [1, 3, 6, 8, 10, 12, 14, 16, 18, 20] 179 | return get_data( 180 | combos=get_combos(product([subjects[i] for i in range(10) if i != fold], range(1, 9), range(10))), 181 | root=ROOT, 182 | mean=0.5, 183 | scale=2, 184 | batch_size=batch_size, 185 | last_batch_handle='pad', 186 | shuffle=False 187 | ) 188 | -------------------------------------------------------------------------------- /sigr/data/s21_soft_label.scv: -------------------------------------------------------------------------------- 1 | ,0,1,2,3,4,5,6,7,fold 2 | 0,0.585914433002,0.0113508105278,0.0469612777233,0.0103245843202,0.0297896899283,0.213723227382,0.0498025380075,0.0520934313536,0 3 | 1,0.0177394840866,0.260400772095,0.0862425193191,0.217843294144,0.0818059891462,0.0307203009725,0.155910596251,0.149291470647,0 4 | 2,0.0565337799489,0.0951329991221,0.290711790323,0.0840612798929,0.18585036695,0.0402426086366,0.122267112136,0.125134795904,0 5 | 3,0.0208049118519,0.22623719275,0.0898522436619,0.244077846408,0.120171234012,0.0297798700631,0.127505093813,0.141515702009,0 6 | 4,0.0588904172182,0.0689897313714,0.238068774343,0.0831308886409,0.332722753286,0.038050621748,0.0563105903566,0.123781606555,0 7 | 5,0.375521868467,0.0143311182037,0.0225172676146,0.00740715721622,0.0129349566996,0.436831176281,0.100549437106,0.0298644062132,0 8 | 6,0.0699726864696,0.116337053478,0.117743805051,0.0787186548114,0.0549033097923,0.118118651211,0.306120842695,0.138010010123,0 9 | 7,0.0656202509999,0.132219433784,0.132582515478,0.126813665032,0.14836679399,0.0508881472051,0.15999814868,0.183442443609,0 10 | 8,0.573849737644,0.0113045303151,0.0458538196981,0.0101971579716,0.0256233215332,0.230556309223,0.0508873090148,0.0516888573766,1 11 | 9,0.0176304485649,0.238844901323,0.0921256542206,0.208853840828,0.100514553487,0.0301328171045,0.161045968533,0.150802791119,1 12 | 10,0.0456261076033,0.0941201895475,0.301260918379,0.0844209119678,0.202144488692,0.035298217088,0.11817112565,0.118890993297,1 13 | 11,0.0206406675279,0.207882523537,0.0887452363968,0.239738628268,0.123213484883,0.0303926169872,0.147305071354,0.142024502158,1 14 | 12,0.0543091744184,0.0716122165322,0.208353236318,0.0926068499684,0.354343175888,0.0366435796022,0.0551674477756,0.126913920045,1 15 | 13,0.390921235085,0.0146044613793,0.0273387394845,0.00665758550167,0.0111129777506,0.439626246691,0.086997166276,0.022702537477,1 16 | 14,0.0779526233673,0.114770486951,0.122850477695,0.0789507627487,0.0550688132644,0.135418519378,0.289047718048,0.125866964459,1 17 | 15,0.0544591732323,0.142568826675,0.129290759563,0.14990568161,0.124165035784,0.0469007156789,0.163525983691,0.189116105437,1 18 | 16,0.594108045101,0.00999377202243,0.0427821725607,0.00983097590506,0.0300427172333,0.21718133986,0.0437685139477,0.0522559508681,2 19 | 17,0.0147701213136,0.252291023731,0.086970448494,0.229896858335,0.0909300968051,0.0276916641742,0.151407673955,0.145995393395,2 20 | 18,0.0548012703657,0.0855624973774,0.321182370186,0.0757946372032,0.200975820422,0.0397630445659,0.109111316502,0.112741105258,2 21 | 19,0.0160632822663,0.23853699863,0.0873838663101,0.254317045212,0.0998763814569,0.0285355579108,0.140727058053,0.134509548545,2 22 | 20,0.0488464124501,0.071256428957,0.250549972057,0.0880576819181,0.333701938391,0.0352295488119,0.0502944551408,0.122012011707,2 23 | 21,0.368753939867,0.0141330743209,0.0257930252701,0.00734513904899,0.0160015933216,0.426667273045,0.109375782311,0.0318860970438,2 24 | 22,0.0769981369376,0.0936448574066,0.119277991354,0.0723595842719,0.0508595369756,0.158597052097,0.304699063301,0.123484656215,2 25 | 23,0.0539519712329,0.140293493867,0.141128987074,0.132432863116,0.142665907741,0.0458317175508,0.16619721055,0.177431449294,2 26 | 24,0.5591365695,0.0114038847387,0.0458650290966,0.0102882077917,0.030391799286,0.239412352443,0.0506078414619,0.0528544560075,3 27 | 25,0.0177513454109,0.238689228892,0.100602254272,0.223592862487,0.101336151361,0.0273324083537,0.143142953515,0.147503301501,3 28 | 26,0.0509831905365,0.0920915007591,0.306010752916,0.0830404087901,0.199058055878,0.0363924279809,0.116912446916,0.115444153547,3 29 | 27,0.0209863614291,0.230153664947,0.0899317339063,0.233842685819,0.108670607209,0.0305126570165,0.147977411747,0.13786932826,3 30 | 28,0.0588953457773,0.0688157305121,0.240384683013,0.0937875658274,0.307589739561,0.038122843951,0.0595138818026,0.132834300399,3 31 | 29,0.34596735239,0.015977114439,0.0254017412663,0.00788462907076,0.0162124875933,0.443999558687,0.11173632741,0.0327737107873,3 32 | 30,0.0714117065072,0.113147959113,0.118221767247,0.0778153985739,0.0507748536766,0.149462670088,0.300601005554,0.11848885566,3 33 | 31,0.0650929734111,0.132121101022,0.133752852678,0.142429187894,0.129983246326,0.0497390404344,0.163882493973,0.182930201292,3 34 | 32,0.594033658504,0.00978412944824,0.0448851883411,0.00877121277153,0.0301306284964,0.217805102468,0.0442418269813,0.050309818238,4 35 | 33,0.0125299263746,0.260252594948,0.0838051810861,0.231062918901,0.0862522274256,0.0271735414863,0.147024214268,0.151855185628,4 36 | 34,0.0581892468035,0.0873780623078,0.314817845821,0.0790030509233,0.185462743044,0.041338711977,0.113018415868,0.120728157461,4 37 | 35,0.0210476107895,0.202843770385,0.089477263391,0.260362178087,0.116318069398,0.0304090902209,0.129763320088,0.149725064635,4 38 | 36,0.0577154792845,0.0685700327158,0.220444232225,0.0952667221427,0.335331767797,0.0371781699359,0.0557963885367,0.129641205072,4 39 | 37,0.350155651569,0.0145853841677,0.0262643638998,0.00586827797815,0.0150276897475,0.449112892151,0.10923538357,0.0297034103423,4 40 | 38,0.0629346594214,0.110558472574,0.113621123135,0.0723159685731,0.048894032836,0.153379887342,0.309411644936,0.128806352615,4 41 | 39,0.0646151080728,0.129451319575,0.131009638309,0.148646071553,0.128791987896,0.0500863455236,0.167172878981,0.180158615112,4 42 | 40,0.580878973007,0.0113963577896,0.0468598306179,0.0103343445808,0.0308610480279,0.219456076622,0.0469783619046,0.0531973131001,5 43 | 41,0.0178357362747,0.249909639359,0.0978484898806,0.218870550394,0.0988143011928,0.0303349476308,0.154420286417,0.131914392114,5 44 | 42,0.0603491105139,0.0798718780279,0.32292303443,0.0799543261528,0.213207960129,0.0414940938354,0.10795687139,0.0941794067621,5 45 | 43,0.0202580057085,0.226799473166,0.0851363390684,0.259252399206,0.0973977297544,0.0303331054747,0.152081489563,0.128690332174,5 46 | 44,0.0585519187152,0.0609935373068,0.246153384447,0.0856251418591,0.341543257236,0.0378966443241,0.0550683364272,0.114118672907,5 47 | 45,0.381124228239,0.0156234931201,0.0270085260272,0.00779242208228,0.0163024608046,0.412237107754,0.107251346111,0.0326209925115,5 48 | 46,0.0739379227161,0.10848467797,0.0726554319263,0.0708635002375,0.043027702719,0.160984665155,0.334530264139,0.135450929403,5 49 | 47,0.0556856766343,0.146907523274,0.128168582916,0.148990258574,0.123327106237,0.0457626357675,0.17281241715,0.178279042244,5 50 | 48,0.564357459545,0.0114045888186,0.0459825992584,0.0102657750249,0.0303347632289,0.232964366674,0.051675580442,0.0529765896499,6 51 | 49,0.0178332515061,0.243779942393,0.0952674150467,0.221842601895,0.0975822508335,0.0296472813934,0.143671065569,0.150326281786,6 52 | 50,0.0530138872564,0.0937199220061,0.303352326155,0.079468511045,0.18704906106,0.0393490642309,0.123327203095,0.12065808475,6 53 | 51,0.0210853293538,0.225017100573,0.0950988605618,0.219580188394,0.125873163342,0.0306841302663,0.138054862618,0.144554525614,6 54 | 52,0.0484657548368,0.0729737207294,0.22917728126,0.0949310436845,0.33986890316,0.0320509634912,0.0581087581813,0.124372884631,6 55 | 53,0.357948482037,0.0155431739986,0.0275333896279,0.00730467308313,0.0162500869483,0.441128909588,0.103582292795,0.030662054196,6 56 | 54,0.0759399980307,0.0977061539888,0.108938999474,0.0703660771251,0.0512497872114,0.153688088059,0.313845336437,0.128190472722,6 57 | 55,0.0589478500187,0.145432949066,0.14507548511,0.136068463326,0.147310584784,0.0499232001603,0.152261927724,0.164912343025,6 58 | 56,0.60515910387,0.00981649104506,0.0380278304219,0.00969350151718,0.0275855381042,0.221200808883,0.0429150685668,0.0455675348639,7 59 | 57,0.0123052867129,0.264212399721,0.0848530232906,0.221839383245,0.0756716877222,0.0268339943141,0.153965786099,0.160277932882,7 60 | 58,0.0582231655717,0.0821022167802,0.329458266497,0.0659338235855,0.178078427911,0.0419608466327,0.120090350509,0.12409196049,7 61 | 59,0.0133151542395,0.223469093442,0.0751332044601,0.256989300251,0.117942109704,0.0219468865544,0.138660281897,0.15249787271,7 62 | 60,0.0427861995995,0.0686382204294,0.232943952084,0.0817427933216,0.359221041203,0.0316239818931,0.0524121262133,0.130583316088,7 63 | 61,0.39994981885,0.00861005764455,0.0213297940791,0.00698478939012,0.0145348263904,0.437821269035,0.0879363417625,0.0227968432009,7 64 | 62,0.0704860463738,0.100654803216,0.115903668106,0.0651987493038,0.0495304837823,0.147507175803,0.323221951723,0.127423748374,7 65 | 63,0.0522875934839,0.1378274858,0.133011072874,0.138035595417,0.145129650831,0.0384888760746,0.167901203036,0.187253862619,7 66 | 64,0.622319102287,0.00547091430053,0.0352219045162,0.0057081039995,0.0194392669946,0.246698439121,0.0414409972727,0.023681294173,8 67 | 65,0.0157551020384,0.265378654003,0.0883843973279,0.225549280643,0.0939034298062,0.0188524145633,0.136794626713,0.15533824265,8 68 | 66,0.0573872178793,0.0854949876666,0.3361107409,0.0717113688588,0.202803596854,0.0280384868383,0.0982796773314,0.120111130178,8 69 | 67,0.0179295912385,0.229140669107,0.0825557112694,0.260068267584,0.121700055897,0.0170418191701,0.133098810911,0.138416275382,8 70 | 68,0.0514223910868,0.0662609562278,0.24324285984,0.0915150269866,0.352055311203,0.0198261514306,0.0505900233984,0.125039443374,8 71 | 69,0.317708104849,0.0155081218109,0.0258607566357,0.00671677058563,0.0152378566563,0.472864151001,0.114921763539,0.031142629683,8 72 | 70,0.0631428137422,0.108594641089,0.101839073002,0.0715390816331,0.0431883595884,0.146862730384,0.332589954138,0.132170587778,8 73 | 71,0.0608781836927,0.137492239475,0.121944501996,0.143452703953,0.148526206613,0.0384837388992,0.165029957891,0.184127256274,8 74 | 72,0.591446340084,0.0108800856397,0.0416322499514,0.00790138915181,0.0240155700594,0.232482612133,0.0463338270783,0.0452779754996,9 75 | 73,0.0167807787657,0.242463931441,0.0944532901049,0.218917831779,0.0947048291564,0.0288841463625,0.153221786022,0.15052652359,9 76 | 74,0.0518501289189,0.0852868705988,0.313529372215,0.0755426958203,0.20850302279,0.040258616209,0.106483064592,0.118484579027,9 77 | 75,0.0191303230822,0.214621096849,0.0793804228306,0.251748353243,0.115605682135,0.029749520123,0.143106788397,0.146608382463,9 78 | 76,0.0519566200674,0.0685137063265,0.235343664885,0.0947128608823,0.343081116676,0.0370681136847,0.0540818944573,0.115194141865,9 79 | 77,0.359065055847,0.0153089584783,0.0246519688517,0.00715514039621,0.0132492622361,0.442372530699,0.1072749421,0.0308763105422,9 80 | 78,0.0754193663597,0.0997704938054,0.1212580055,0.0742529407144,0.0518601499498,0.146741092205,0.307116210461,0.12350846082,9 81 | 79,0.0604774132371,0.143125548959,0.125325471163,0.141259744763,0.139200344682,0.0475078225136,0.166975483298,0.176064044237,9 82 | -------------------------------------------------------------------------------- /sigr/evaluation.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import os 3 | import numpy as np 4 | from functools import partial 5 | from .parse_log import parse_log 6 | from . import utils 7 | from . import module 8 | from logbook import Logger 9 | from copy import deepcopy 10 | import mxnet as mx 11 | 12 | 13 | Exp = utils.Bunch 14 | 15 | logger = Logger(__name__) 16 | 17 | 18 | @utils.cached(ignore=['context']) 19 | def _crossval_predict_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None): 20 | Mod = deepcopy(Mod) 21 | Mod.update(context=context) 22 | mod = module.RuntimeModule(**Mod) 23 | Val = partial( 24 | get_crossval_val, 25 | fold=fold, 26 | batch_size=self.batch_size, 27 | window=mod.num_channel, 28 | **(dataset_args or {}) 29 | ) 30 | return mod.predict(utils.LazyProxy(Val)) 31 | 32 | 33 | @utils.cached(ignore=['context']) 34 | def _crossval_predict_proba_aux(self, Mod, get_crossval_val, fold, context, dataset_args=None): 35 | Mod = deepcopy(Mod) 36 | Mod.update(context=context) 37 | mod = module.RuntimeModule(**Mod) 38 | Val = partial( 39 | get_crossval_val, 40 | fold=fold, 41 | batch_size=self.batch_size, 42 | window=mod.num_channel, 43 | **(dataset_args or {}) 44 | ) 45 | return mod.predict_proba(utils.LazyProxy(Val)) 46 | 47 | 48 | def _crossval_predict(self, **kargs): 49 | proba = kargs.pop('proba', False) 50 | fold = int(kargs.pop('fold')) 51 | Mod = kargs.pop('Mod') 52 | Mod = deepcopy(Mod) 53 | Mod.update(params=self.format_params(Mod['params'], fold)) 54 | context = Mod.pop('context', [mx.gpu(0)]) 55 | # import pickle 56 | # d = kargs.copy() 57 | # d.update(Mod=Mod, fold=fold) 58 | # print(pickle.dumps(d)) 59 | 60 | # Ensure load from disk. 61 | # Otherwise following cached methods like vote will have two caches, 62 | # one for the first computation, 63 | # and the other for the cached one. 64 | func = _crossval_predict_aux if not proba else _crossval_predict_proba_aux 65 | return func.call_and_shelve(self, Mod=Mod, fold=fold, context=context, **kargs).get() 66 | 67 | 68 | class Evaluation(object): 69 | 70 | def __init__(self, batch_size=None): 71 | self.batch_size = batch_size 72 | 73 | 74 | class CrossValEvaluation(Evaluation): 75 | 76 | def __init__(self, **kargs): 77 | self.crossval_type = kargs.pop('crossval_type') 78 | super(CrossValEvaluation, self).__init__(**kargs) 79 | 80 | def get_crossval_val_func(self, dataset): 81 | return getattr(dataset, 'get_%s_val' % self.crossval_type.replace('-', '_')) 82 | 83 | def format_params(self, params, fold): 84 | try: 85 | return params % fold 86 | except: 87 | return params 88 | 89 | def transform(self, Mod, dataset, fold, dataset_args=None): 90 | get_crossval_val = self.get_crossval_val_func(dataset) 91 | pred, true, _ = _crossval_predict( 92 | self, 93 | proba=True, 94 | Mod=Mod, 95 | get_crossval_val=get_crossval_val, 96 | fold=fold, 97 | dataset_args=dataset_args) 98 | return pred, true 99 | 100 | def accuracy_mod(self, Mod, dataset, fold, 101 | vote=False, 102 | dataset_args=None, 103 | balance=False): 104 | get_crossval_val = self.get_crossval_val_func(dataset) 105 | pred, true, segment = _crossval_predict( 106 | self, 107 | Mod=Mod, 108 | get_crossval_val=get_crossval_val, 109 | fold=fold, 110 | dataset_args=dataset_args) 111 | if vote: 112 | from .vote import vote as do 113 | return do(true, pred, segment, vote, balance) 114 | return (true == pred).sum() / true.size 115 | 116 | def accuracy_exp(self, exp, fold): 117 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 118 | return self.accuracy_mod(Mod=exp.Mod, 119 | dataset=exp.dataset, 120 | fold=fold, 121 | vote=exp.get('vote', False), 122 | dataset_args=exp.get('dataset_args')) 123 | else: 124 | try: 125 | return parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1] 126 | except: 127 | return np.nan 128 | 129 | def accuracy(self, **kargs): 130 | if 'exp' in kargs: 131 | return self.accuracy_exp(**kargs) 132 | elif 'Mod' in kargs: 133 | return self.accuracy_mod(**kargs) 134 | else: 135 | assert False 136 | 137 | def accuracies(self, exps, folds): 138 | acc = [] 139 | for exp in exps: 140 | for fold in folds: 141 | acc.append(self.accuracy(exp=exp, fold=fold)) 142 | return np.array(acc).reshape(len(exps), len(folds)) 143 | 144 | def compare(self, exps, fold): 145 | acc = [] 146 | for exp in exps: 147 | if hasattr(exp, 'Mod') and hasattr(exp, 'dataset'): 148 | acc.append(self.accuracy(Mod=exp.Mod, 149 | dataset=exp.dataset, 150 | fold=fold, 151 | vote=exp.get('vote', False), 152 | dataset_args=exp.get('dataset_args'))) 153 | else: 154 | try: 155 | acc.append(parse_log(os.path.join(exp.root % fold, 'log')).val.iloc[-1]) 156 | except: 157 | acc.append(np.nan) 158 | return acc 159 | 160 | def vote_accuracy_curves(self, exps, folds, windows, balance=False): 161 | acc = [] 162 | for exp in exps: 163 | for fold in folds: 164 | acc.append(self.vote_accuracy_curve( 165 | Mod=exp.Mod, 166 | dataset=exp.dataset, 167 | fold=int(fold), 168 | windows=windows, 169 | dataset_args=exp.get('dataset_args'), 170 | balance=balance)) 171 | return np.array(acc).reshape(len(exps), len(folds), len(windows)) 172 | 173 | def vote_accuracy_curve(self, Mod, dataset, fold, windows, 174 | dataset_args=None, 175 | balance=False): 176 | get_crossval_val = self.get_crossval_val_func(dataset) 177 | pred, true, segment = _crossval_predict( 178 | self, 179 | Mod=Mod, 180 | get_crossval_val=get_crossval_val, 181 | fold=fold, 182 | dataset_args=dataset_args) 183 | from .vote import get_vote_accuracy_curve as do 184 | return do(true, pred, segment, windows, balance)[1] 185 | 186 | 187 | def get_crossval_accuracies(crossval_type, exps, folds, batch_size=1000): 188 | acc = [] 189 | evaluation = CrossValEvaluation( 190 | crossval_type=crossval_type, 191 | batch_size=batch_size 192 | ) 193 | for fold in folds: 194 | acc.append(evaluation.compare(exps, fold)) 195 | return acc 196 | -------------------------------------------------------------------------------- /sigr/fft.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | 4 | 5 | def fft(data, fs): 6 | n = data.shape[-1] 7 | window = np.hanning(n) 8 | windowed = data * window 9 | spectrum = np.fft.fft(windowed) 10 | freq = np.fft.fftfreq(n, 1 / fs) 11 | half_n = np.ceil(n / 2) 12 | spectrum_half = (2 / n) * spectrum[..., :half_n] 13 | freq_half = freq[:half_n] 14 | return freq_half, np.abs(spectrum_half) 15 | -------------------------------------------------------------------------------- /sigr/lstm.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | import mxnet as mx 3 | 4 | 5 | LSTMState = namedtuple("LSTMState", ["c", "h"]) 6 | # LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_bias", "h2h_weight", "h2h_bias"]) 7 | LSTMParam = namedtuple("LSTMParam", ["i2h_weight", "i2h_gamma", "h2h_weight", "h2h_gamma", 8 | "beta", "c_gamma", "c_beta"]) 9 | LSTMModel = namedtuple("LSTMModel", ["rnn_exec", "symbol", 10 | "init_states", "last_states", 11 | "seq_data", "seq_labels", "seq_outputs", 12 | "param_blocks"]) 13 | 14 | 15 | class LSTM(object): 16 | 17 | def lstm_orig(self, prefix, num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 18 | """LSTM Cell symbol""" 19 | if dropout > 0.: 20 | indata = mx.sym.Dropout(data=indata, p=dropout) 21 | i2h = mx.sym.FullyConnected(data=indata, 22 | weight=param.i2h_weight, 23 | bias=param.i2h_bias, 24 | num_hidden=num_hidden * 4, 25 | name=prefix + "t%d_l%d_i2h" % (seqidx, layeridx)) 26 | h2h = mx.sym.FullyConnected(data=prev_state.h, 27 | weight=param.h2h_weight, 28 | bias=param.h2h_bias, 29 | num_hidden=num_hidden * 4, 30 | name=prefix + "t%d_l%d_h2h" % (seqidx, layeridx)) 31 | gates = i2h + h2h 32 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 33 | name=prefix + "t%d_l%d_slice" % (seqidx, layeridx)) 34 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 35 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 36 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 37 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 38 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 39 | next_h = out_gate * mx.sym.Activation(next_c, act_type="tanh") 40 | return LSTMState(c=next_c, h=next_h) 41 | 42 | def lstm_not_share_beta_gamma(self, prefix, num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 43 | """LSTM Cell symbol""" 44 | if dropout > 0.: 45 | indata = mx.sym.Dropout(data=indata, p=dropout) 46 | i2h = mx.sym.FullyConnected(data=indata, 47 | weight=param.i2h_weight, 48 | bias=param.i2h_bias, 49 | num_hidden=num_hidden * 4, 50 | name=prefix + "t%d_l%d_i2h" % (seqidx, layeridx)) 51 | i2h = mx.sym.BatchNorm( 52 | name=prefix + "t%d_l%d_i2h_bn" % (seqidx, layeridx), 53 | data=i2h, 54 | fix_gamma=False, 55 | momentum=0.9, 56 | attr={'wd_mult': '0'} 57 | ) 58 | h2h = mx.sym.FullyConnected(data=prev_state.h, 59 | weight=param.h2h_weight, 60 | bias=param.h2h_bias, 61 | num_hidden=num_hidden * 4, 62 | name=prefix + "t%d_l%d_h2h" % (seqidx, layeridx)) 63 | h2h = mx.sym.BatchNorm( 64 | name=prefix + "t%d_l%d_h2h_bn" % (seqidx, layeridx), 65 | data=h2h, 66 | fix_gamma=False, 67 | momentum=0.9, 68 | attr={'wd_mult': '0'} 69 | ) 70 | gates = i2h + h2h 71 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 72 | name=prefix + "t%d_l%d_slice" % (seqidx, layeridx)) 73 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 74 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 75 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 76 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 77 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 78 | next_h = out_gate * mx.sym.Activation( 79 | mx.symbol.BatchNorm( 80 | name=prefix + 't%d_l%d_c_bn' % (seqidx, layeridx), 81 | data=next_c, 82 | fix_gamma=False, 83 | momentum=0.9, 84 | attr={'wd_mult': '0'} 85 | ), 86 | act_type="tanh" 87 | ) 88 | return LSTMState(c=next_c, h=next_h) 89 | 90 | def lstm(self, prefix, num_hidden, indata, prev_state, param, seqidx, layeridx, dropout=0.): 91 | """LSTM Cell symbol""" 92 | if dropout > 0.: 93 | indata = mx.sym.Dropout(data=indata, p=dropout) 94 | i2h = mx.sym.FullyConnected(data=indata, 95 | weight=param.i2h_weight, 96 | num_hidden=num_hidden * 4, 97 | no_bias=True, 98 | name=prefix + "t%d_l%d_i2h" % (seqidx, layeridx)) 99 | i2h = self.BatchNorm( 100 | name=prefix + "t%d_l%d_i2h_bn" % (seqidx, layeridx), 101 | data=i2h, 102 | gamma=param.i2h_gamma, 103 | num_channel=num_hidden * 4 104 | ) 105 | h2h = mx.sym.FullyConnected(data=prev_state.h, 106 | weight=param.h2h_weight, 107 | num_hidden=num_hidden * 4, 108 | no_bias=True, 109 | name=prefix + "t%d_l%d_h2h" % (seqidx, layeridx)) 110 | h2h = self.BatchNorm( 111 | name=prefix + "t%d_l%d_h2h_bn" % (seqidx, layeridx), 112 | data=h2h, 113 | gamma=param.h2h_gamma, 114 | beta=param.beta, 115 | num_channel=num_hidden * 4 116 | ) 117 | gates = i2h + h2h 118 | slice_gates = mx.sym.SliceChannel(gates, num_outputs=4, 119 | name=prefix + "t%d_l%d_slice" % (seqidx, layeridx)) 120 | in_gate = mx.sym.Activation(slice_gates[0], act_type="sigmoid") 121 | in_transform = mx.sym.Activation(slice_gates[1], act_type="tanh") 122 | forget_gate = mx.sym.Activation(slice_gates[2], act_type="sigmoid") 123 | out_gate = mx.sym.Activation(slice_gates[3], act_type="sigmoid") 124 | next_c = (forget_gate * prev_state.c) + (in_gate * in_transform) 125 | next_h = out_gate * mx.sym.Activation( 126 | self.BatchNorm( 127 | name=prefix + 't%d_l%d_c_bn' % (seqidx, layeridx), 128 | data=next_c, 129 | gamma=param.c_gamma, 130 | beta=param.c_beta, 131 | num_channel=num_hidden 132 | ), 133 | act_type="tanh" 134 | ) 135 | return LSTMState(c=next_c, h=next_h) 136 | 137 | def BatchNorm(self, name, data, gamma, beta=None, **kargs): 138 | net = data 139 | 140 | if not self.bn: 141 | return net 142 | 143 | if self.minibatch: 144 | num_channel = kargs.pop('num_channel') 145 | net = mx.symbol.Reshape(net, shape=(-1, self.num_subject * num_channel)) 146 | net = mx.symbol.BatchNorm( 147 | name=name + '_norm', 148 | data=net, 149 | fix_gamma=True, 150 | momentum=0.9, 151 | attr={'wd_mult': '0', 'lr_mult': '0'} 152 | ) 153 | net = mx.symbol.Reshape(data=net, shape=(-1, num_channel)) 154 | else: 155 | net = mx.symbol.BatchNorm( 156 | name=name + '_norm', 157 | data=net, 158 | fix_gamma=True, 159 | momentum=0.9, 160 | attr={'wd_mult': '0', 'lr_mult': '0'} 161 | ) 162 | net = mx.symbol.broadcast_mul(net, gamma) 163 | if beta is not None: 164 | net = mx.symbol.broadcast_plus(net, beta) 165 | return net 166 | 167 | def __init__( 168 | self, 169 | prefix, 170 | data, 171 | num_lstm_layer, 172 | seq_len, 173 | num_hidden, 174 | dropout=0., 175 | minibatch=False, 176 | num_subject=0, 177 | bn=True, 178 | ): 179 | self.bn = bn 180 | self.minibatch = minibatch 181 | self.num_subject = num_subject 182 | if self.minibatch: 183 | assert self.num_subject > 0 184 | 185 | prefix += 'lstm_' 186 | 187 | param_cells = [] 188 | last_states = [] 189 | for i in range(num_lstm_layer): 190 | param_cells.append(LSTMParam(i2h_weight=mx.sym.Variable(prefix + "l%d_i2h_weight" % i), 191 | # i2h_bias=mx.sym.Variable(prefix + "l%d_i2h_bias" % i), 192 | h2h_weight=mx.sym.Variable(prefix + "l%d_h2h_weight" % i), 193 | # h2h_bias=mx.sym.Variable(prefix + "l%d_h2h_bias" % i))) 194 | i2h_gamma=mx.symbol.Variable(prefix + 'l%d_i2h_gamma' % i, shape=(1, num_hidden * 4), attr={'wd_mult': '0'}), 195 | h2h_gamma=mx.symbol.Variable(prefix + 'l%d_h2h_gamma' % i, shape=(1, num_hidden * 4), attr={'wd_mult': '0'}), 196 | beta=mx.symbol.Variable(prefix + 'l%d_beta' % i, shape=(1, num_hidden * 4), attr={'wd_mult': '0'}), 197 | c_gamma=mx.symbol.Variable(prefix + 'l%d_c_gamma' % i, shape=(1, num_hidden), attr={'wd_mult': '0'}), 198 | c_beta=mx.symbol.Variable(prefix + 'l%d_c_beta' % i, shape=(1, num_hidden), attr={'wd_mult': '0'}))) 199 | state = LSTMState(c=mx.sym.Variable(prefix + "l%d_init_c" % i, attr={'lr_mult': '0'}), 200 | h=mx.sym.Variable(prefix + "l%d_init_h" % i, attr={'lr_mult': '0'})) 201 | last_states.append(state) 202 | assert(len(last_states) == num_lstm_layer) 203 | 204 | wordvec = mx.sym.SliceChannel(data=data, num_outputs=seq_len, squeeze_axis=1) 205 | 206 | hidden_all = [] 207 | for seqidx in range(seq_len): 208 | hidden = wordvec[seqidx] 209 | 210 | # stack LSTM 211 | for i in range(num_lstm_layer): 212 | if i == 0: 213 | dp_ratio = 0. 214 | else: 215 | dp_ratio = dropout 216 | next_state = self.lstm(prefix, num_hidden, indata=hidden, 217 | prev_state=last_states[i], 218 | param=param_cells[i], 219 | seqidx=seqidx, layeridx=i, dropout=dp_ratio) 220 | hidden = next_state.h 221 | last_states[i] = next_state 222 | 223 | # decoder 224 | if dropout > 0.: 225 | hidden = mx.sym.Dropout(data=hidden, p=dropout) 226 | hidden_all.append(hidden) 227 | 228 | self.net = hidden_all 229 | # return mx.sym.Concat(*hidden_all, dim=1) 230 | # return mx.sym.Pooling(mx.sym.Concat(*[mx.sym.Reshape(h, shape=(0, 0, 1, 1)) for h in hidden_all], dim=2), kernel=(1, 1), global_pool=True, pool_type='max') 231 | # return mx.sym.Pooling(mx.sym.Concat(*[mx.sym.Reshape(h, shape=(0, 0, 1, 1)) for h in hidden_all], dim=2), kernel=(1, 1), global_pool=True, pool_type='avg') 232 | 233 | 234 | def lstm_unroll(**kargs): 235 | return LSTM(**kargs).net 236 | -------------------------------------------------------------------------------- /sigr/parse_log.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import re 3 | import numpy as np 4 | 5 | 6 | def div(up, down): 7 | try: 8 | return up / down 9 | except: 10 | return np.nan 11 | 12 | 13 | def parse_log(path): 14 | with open(path, 'r') as f: 15 | lines = f.readlines() 16 | 17 | res = [re.compile('.*Epoch\[(\d+)\] Train-accuracy(?:\[g\])?=([.\d]+)'), 18 | re.compile('.*Epoch\[(\d+)\] Validation-accuracy(?:\[g\])?=([.\d]+)'), 19 | re.compile('.*Epoch\[(\d+)\] Time.*=([.\d]+)')] 20 | 21 | data = {} 22 | for l in lines: 23 | i = 0 24 | for r in res: 25 | m = r.match(l) 26 | if m is not None: 27 | break 28 | i += 1 29 | if m is None: 30 | continue 31 | 32 | assert len(m.groups()) == 2 33 | epoch = int(m.groups()[0]) 34 | val = float(m.groups()[1]) 35 | 36 | if epoch not in data: 37 | data[epoch] = [0] * len(res) * 2 38 | 39 | data[epoch][i*2] += val 40 | data[epoch][i*2+1] += 1 41 | 42 | df = [] 43 | for k, v in data.items(): 44 | try: 45 | df.append({ 46 | 'epoch': k + 1, 47 | 'train': div(v[0], v[1]), 48 | 'val': div(v[2], v[3]), 49 | 'time': div(v[4], v[5]) 50 | }) 51 | except: 52 | pass 53 | try: 54 | import pandas as pd 55 | return pd.DataFrame(df) 56 | except: 57 | return df 58 | -------------------------------------------------------------------------------- /sigr/sklearn_module.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | from nose.tools import assert_equal 3 | import mxnet as mx 4 | import numpy as np 5 | from logbook import Logger 6 | import joblib as jb 7 | from .base_module import BaseModule 8 | 9 | 10 | logger = Logger('sigr') 11 | 12 | 13 | class SklearnModule(BaseModule): 14 | 15 | def _get_data_label(self, data_iter): 16 | data = [] 17 | label = [] 18 | for batch in data_iter: 19 | data.append(batch.data[0].asnumpy().reshape( 20 | batch.data[0].shape[0], -1)) 21 | label.append(batch.label[0].asnumpy()) 22 | if batch.pad: 23 | data[-1] = data[-1][:-batch.pad] 24 | label[-1] = label[-1][:-batch.pad] 25 | data = np.vstack(data) 26 | label = np.hstack(label) 27 | assert_equal(len(data), len(label)) 28 | return data, label 29 | 30 | def fit(self, train_data, eval_data, eval_metric='acc', **kargs): 31 | snapshot = kargs.pop('snapshot') 32 | self.clf.fit(*self._get_data_label(train_data)) 33 | jb.dump(self.clf, snapshot + '-0001.params') 34 | 35 | if not isinstance(eval_metric, mx.metric.EvalMetric): 36 | eval_metric = mx.metric.create(eval_metric) 37 | data, label = self._get_data_label(eval_data) 38 | pred = self.clf.predict(data).astype(np.int64) 39 | prob = np.zeros((len(pred), pred.max() + 1)) 40 | prob[np.arange(len(prob)), pred] = 1 41 | eval_metric.update([mx.nd.array(label)], [mx.nd.array(prob)]) 42 | for name, val in eval_metric.get_name_value(): 43 | logger.info('Epoch[0] Validation-{}={}', name, val) 44 | 45 | 46 | class KNNModule(SklearnModule): 47 | 48 | def __init__(self): 49 | from sklearn.neighbors import KNeighborsClassifier as KNN 50 | self.clf = KNN() 51 | 52 | @classmethod 53 | def parse(cls, text, **kargs): 54 | if text == 'knn': 55 | return cls() 56 | 57 | 58 | class SVMModule(SklearnModule): 59 | 60 | def __init__(self): 61 | from sklearn.svm import LinearSVC 62 | self.clf = LinearSVC() 63 | 64 | @classmethod 65 | def parse(cls, text, **kargs): 66 | if text == 'svm': 67 | return cls() 68 | 69 | 70 | class RandomForestsModule(SklearnModule): 71 | 72 | def __init__(self): 73 | from sklearn.ensemble import RandomForestClassifier as RandomForests 74 | self.clf = RandomForests() 75 | 76 | @classmethod 77 | def parse(cls, text, **kargs): 78 | if text == 'random-forests': 79 | return cls() 80 | 81 | 82 | class LDAModule(SklearnModule): 83 | 84 | def __init__(self): 85 | from sklearn.discriminant_analysis import LinearDiscriminantAnalysis as LDA 86 | self.clf = LDA() 87 | 88 | @classmethod 89 | def parse(cls, text, **kargs): 90 | if text == 'lda': 91 | return cls() 92 | -------------------------------------------------------------------------------- /sigr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from contextlib import contextmanager 2 | import os 3 | import numpy as np 4 | from .proxy import LazyProxy 5 | assert LazyProxy 6 | 7 | 8 | @contextmanager 9 | def logging_context(path=None, level=None): 10 | from logbook import StderrHandler, FileHandler 11 | from logbook.compat import redirected_logging 12 | with StderrHandler(level=level or 'INFO').applicationbound(): 13 | if path: 14 | if not os.path.isdir(os.path.dirname(path)): 15 | os.makedirs(os.path.dirname(path)) 16 | with FileHandler(path, bubble=True).applicationbound(): 17 | with redirected_logging(): 18 | yield 19 | else: 20 | with redirected_logging(): 21 | yield 22 | 23 | 24 | def return_list(func): 25 | import inspect 26 | from functools import wraps 27 | assert inspect.isgeneratorfunction(func) 28 | 29 | @wraps(func) 30 | def wrapped(*args, **kargs): 31 | return list(func(*args, **kargs)) 32 | 33 | return wrapped 34 | 35 | 36 | @return_list 37 | def continuous_segments(label): 38 | label = np.asarray(label) 39 | 40 | if not len(label): 41 | return 42 | 43 | breaks = list(np.where(label[:-1] != label[1:])[0] + 1) 44 | for begin, end in zip([0] + breaks, breaks + [len(label)]): 45 | assert begin < end 46 | yield begin, end 47 | 48 | 49 | def cached(*args, **kargs): 50 | import joblib as jb 51 | from .. import CACHE 52 | memo = getattr(cached, 'memo', None) 53 | if memo is None: 54 | cached.memo = memo = jb.Memory(CACHE, verbose=0) 55 | return memo.cache(*args, **kargs) 56 | 57 | 58 | def get_segments(data, window): 59 | return windowed_view( 60 | data.flat, 61 | window * data.shape[1], 62 | (window - 1) * data.shape[1] 63 | ) 64 | 65 | 66 | def windowed_view(arr, window, overlap): 67 | from numpy.lib.stride_tricks import as_strided 68 | arr = np.asarray(arr) 69 | window_step = window - overlap 70 | new_shape = arr.shape[:-1] + ((arr.shape[-1] - overlap) // window_step, 71 | window) 72 | new_strides = (arr.strides[:-1] + (window_step * arr.strides[-1],) + 73 | arr.strides[-1:]) 74 | return as_strided(arr, shape=new_shape, strides=new_strides) 75 | 76 | 77 | class Bunch(dict): 78 | 79 | def __getattr__(self, key): 80 | if key in self: 81 | return self[key] 82 | raise AttributeError(key) 83 | 84 | def __setattr__(self, key, value): 85 | self[key] = value 86 | 87 | 88 | def _packargs(func): 89 | from functools import wraps 90 | import inspect 91 | 92 | @wraps(func) 93 | def wrapped(ctx_or_args, **kargs): 94 | if isinstance(ctx_or_args, Bunch): 95 | args = ctx_or_args 96 | else: 97 | args = ctx_or_args.obj 98 | ignore = inspect.getargspec(func).args 99 | args.update({key: kargs.pop(key) for key in list(kargs) 100 | if key not in ignore and key not in args}) 101 | return func(ctx_or_args, **kargs) 102 | return wrapped 103 | 104 | 105 | def packargs(func): 106 | import click 107 | return click.pass_obj(_packargs(func)) 108 | 109 | 110 | def butter_bandpass_filter(data, lowcut, highcut, fs, order): 111 | from scipy.signal import butter, lfilter 112 | 113 | nyq = 0.5 * fs 114 | low = lowcut / nyq 115 | high = highcut / nyq 116 | 117 | b, a = butter(order, [low, high], btype='bandpass') 118 | y = lfilter(b, a, data) 119 | return y 120 | 121 | 122 | def butter_bandstop_filter(data, lowcut, highcut, fs, order): 123 | from scipy.signal import butter, lfilter 124 | 125 | nyq = 0.5 * fs 126 | low = lowcut / nyq 127 | high = highcut / nyq 128 | 129 | b, a = butter(order, [low, high], btype='bandstop') 130 | y = lfilter(b, a, data) 131 | return y 132 | 133 | 134 | def butter_lowpass_filter(data, cut, fs, order, zero_phase=False): 135 | from scipy.signal import butter, lfilter, filtfilt 136 | 137 | nyq = 0.5 * fs 138 | cut = cut / nyq 139 | 140 | b, a = butter(order, cut, btype='low') 141 | y = (filtfilt if zero_phase else lfilter)(b, a, data) 142 | return y 143 | -------------------------------------------------------------------------------- /sigr/utils/proxy.py: -------------------------------------------------------------------------------- 1 | class LazyProxy(object): 2 | 3 | def __init__(self, make): 4 | self._make = make 5 | 6 | def __getattr__(self, name): 7 | if name == '_inst': 8 | self._inst = self._make() 9 | return self._inst 10 | return getattr(self._inst, name) 11 | 12 | def __setattr__(self, name, value): 13 | if name in ('_make', '_inst'): 14 | return super(LazyProxy, self).__setattr__(name, value) 15 | return setattr(self._inst, name, value) 16 | 17 | def __getstate__(self): 18 | return self._make 19 | 20 | def __setstate__(self, make): 21 | self._make = make 22 | 23 | def __hash__(self): 24 | return hash(self._make) 25 | 26 | def __iter__(self): 27 | return self._inst.__iter__() 28 | -------------------------------------------------------------------------------- /sigr/vote.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import numpy as np 3 | import joblib as jb 4 | from nose.tools import assert_greater 5 | from .utils import return_list, cached 6 | from . import Context 7 | 8 | 9 | def get_vote_accuracy_curve(labels, predictions, segments, windows, balance=False): 10 | if len(set(segments)) < len(windows): 11 | func = get_vote_accuracy_curve_aux 12 | else: 13 | func = get_vote_accuracy_curve_aux_few_windows 14 | return func(np.asarray(labels), 15 | np.asarray(predictions), 16 | np.asarray(segments), 17 | np.asarray(windows), 18 | balance) 19 | 20 | 21 | @cached 22 | def get_vote_accuracy_curve_aux(labels, predictions, segments, windows, balance): 23 | segment_labels = partial_vote(labels, segments) 24 | return ( 25 | np.asarray(windows), 26 | np.array(list(Context.parallel( 27 | jb.delayed(get_vote_accuracy_curve_step)( 28 | segment_labels, 29 | predictions, 30 | segments, 31 | window, 32 | balance 33 | ) for window in windows 34 | ))) 35 | ) 36 | 37 | 38 | @cached 39 | def get_vote_accuracy_curve_aux_few_windows(labels, predictions, segments, windows, balance): 40 | segment_labels = partial_vote(labels, segments) 41 | return ( 42 | np.asarray(windows), 43 | np.array([ 44 | get_vote_accuracy_curve_step( 45 | segment_labels, 46 | predictions, 47 | segments, 48 | window, 49 | balance, 50 | parallel=True 51 | ) for window in windows 52 | ]) 53 | ) 54 | 55 | 56 | def get_vote_accuracy(labels, predictions, segments, window, balance): 57 | _, y = get_vote_accuracy_curve(labels, predictions, segments, [window], balance) 58 | return y[0] 59 | 60 | 61 | vote = get_vote_accuracy 62 | 63 | 64 | def get_segment_vote_accuracy(segment_label, segment_predictions, window): 65 | def gen(): 66 | count = { 67 | label: np.hstack([[0], np.cumsum(segment_predictions == label)]) 68 | for label in set(segment_predictions) 69 | } 70 | tmp = window 71 | if tmp == -1: 72 | tmp = len(segment_predictions) 73 | tmp = min(tmp, len(segment_predictions)) 74 | for begin in range(len(segment_predictions) - tmp + 1): 75 | yield segment_label == max( 76 | count, 77 | key=lambda label: count[label][begin + tmp] - count[label][begin] 78 | ), segment_label 79 | return list(gen()) 80 | 81 | 82 | def get_vote_accuracy_curve_step(segment_labels, predictions, segments, window, 83 | balance, 84 | parallel=False): 85 | def gen(): 86 | # assert_greater(window, 0) 87 | assert window > 0 or window == -1 88 | if not parallel: 89 | for segment_label, segment_predictions in zip(segment_labels, split(predictions, segments)): 90 | for ret in get_segment_vote_accuracy(segment_label, segment_predictions, window): 91 | yield ret 92 | else: 93 | for rets in Context.parallel( 94 | jb.delayed(get_segment_vote_accuracy)(segment_label, segment_predictions, window) 95 | for segment_label, segment_predictions in zip(segment_labels, split(predictions, segments)) 96 | ): 97 | for ret in rets: 98 | yield ret 99 | 100 | good, labels = zip(*list(gen())) 101 | good = np.asarray(good) 102 | 103 | if not balance: 104 | return np.sum(good) / len(good) 105 | else: 106 | acc = [] 107 | for label in set(labels): 108 | mask = [labels == label] 109 | acc.append(np.sum(good[mask]) / np.sum(mask)) 110 | return np.mean(acc) 111 | 112 | 113 | @return_list 114 | def partial_vote(labels, segments, length=None): 115 | for part in split(labels, segments): 116 | part = list(part) 117 | 118 | if length is not None: 119 | part = part[:length] 120 | 121 | assert_greater(len(part), 0) 122 | yield max([(part.count(label), label) for label in set(part)])[1] 123 | 124 | 125 | def split(labels, segments): 126 | return [labels[segments == segment] for segment in sorted(set(segments))] 127 | --------------------------------------------------------------------------------