├── .gitignore
├── LICENSE
├── README.md
├── blstm_model.py
├── blstm_model_run.py
├── example_data
├── README.md
├── features.arff
├── input.arff
├── model.h5
└── output.arff
├── feature_extraction
├── AnnotateData.m
├── AnnotateDataAll.m
├── README.md
└── arff_utils
│ ├── AddAttArff.m
│ ├── GetAcceleration.m
│ ├── GetAttPositionArff.m
│ ├── GetNomAttValue.m
│ ├── GetVelocity.m
│ ├── IsNomAttribute.m
│ ├── LoadArff.m
│ └── SaveArff.m
├── figures
├── blstm_vs_lstm.png
├── dense_units.png
├── network.png
├── num_blstm.png
├── num_conv.png
├── num_dense.png
├── performance_all.png
├── performance_ours.png
└── performance_vs_context.png
└── requirements.txt
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | *.egg-info/
24 | .installed.cfg
25 | *.egg
26 | MANIFEST
27 |
28 | # PyInstaller
29 | # Usually these files are written by a python script from a template
30 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
31 | *.manifest
32 | *.spec
33 |
34 | # Installer logs
35 | pip-log.txt
36 | pip-delete-this-directory.txt
37 |
38 | # Unit test / coverage reports
39 | htmlcov/
40 | .tox/
41 | .coverage
42 | .coverage.*
43 | .cache
44 | nosetests.xml
45 | coverage.xml
46 | *.cover
47 | .hypothesis/
48 | .pytest_cache/
49 |
50 | # Translations
51 | *.mo
52 | *.pot
53 |
54 | # Django stuff:
55 | *.log
56 | local_settings.py
57 | db.sqlite3
58 |
59 | # Flask stuff:
60 | instance/
61 | .webassets-cache
62 |
63 | # Scrapy stuff:
64 | .scrapy
65 |
66 | # Sphinx documentation
67 | docs/_build/
68 |
69 | # PyBuilder
70 | target/
71 |
72 | # Jupyter Notebook
73 | .ipynb_checkpoints
74 |
75 | # pyenv
76 | .python-version
77 |
78 | # celery beat schedule file
79 | celerybeat-schedule
80 |
81 | # SageMath parsed files
82 | *.sage.py
83 |
84 | # Environments
85 | .env
86 | .venv
87 | env/
88 | venv/
89 | ENV/
90 | env.bak/
91 | venv.bak/
92 |
93 | # Spyder project settings
94 | .spyderproject
95 | .spyproject
96 |
97 | # Rope project settings
98 | .ropeproject
99 |
100 | # mkdocs documentation
101 | /site
102 |
103 | # mypy
104 | .mypy_cache/
105 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep eye movement (EM) classifier: a 1D CNN-BLSTM model
2 |
3 | This is the implementation of the deep learning approach for eye movement classification from the "1D CNN with BLSTM for automated classification of fixations, saccades, and smooth pursuits" paper. If you use this code, please cite it as
4 |
5 | @Article{startsev2018cnn,
6 | author="Startsev, Mikhail
7 | and Agtzidis, Ioannis
8 | and Dorr, Michael",
9 | title="1D CNN with BLSTM for automated classification of fixations, saccades, and smooth pursuits",
10 | journal="Behavior Research Methods",
11 | year="2018",
12 | month="Nov",
13 | day="08",
14 | issn="1554-3528",
15 | doi="10.3758/s13428-018-1144-2",
16 | url="https://doi.org/10.3758/s13428-018-1144-2"
17 | }
18 |
19 | The full paper is freely accessible via a [SharedIt link](https://rdcu.be/bbMo3).
20 |
21 | Authors:
22 |
23 | Mikhail Startsev, Ioannis Agtzidis, Michael Dorr
24 |
25 | For feedback and collaboration you can contact Mikhail Startsev via mikhail.startsev@tum.de, or any of the authors above at\
26 | `< firstname.lastname > @tum.de`.
27 |
28 |
29 | # DESCRIPTION
30 |
31 | The model implemented here is a combination of one-dimensional (i.e. temporal, in our case) convolutions, fully-connected and bidirectional LSTM layers. The model is trained to classify eye movements in the eye tracking signal.
32 |
33 | Before being processed by the model, the signal (x and y coordinates over time; we used 250 Hz recordings of the GazeCom data set -- see [here](http://michaeldorr.de/smoothpursuit/) for its data, ca. 4.5h of eye tracking recordings in total) is pre-processed to extract speed, direction, and acceleration features. This is not a necessary step, since our architecture works well utilising just the `xy` features, but performance improves with some straightforward feature extraction.
34 |
35 | Here is our `default` architecture (as introduced in the paper):
36 |
37 | 
38 |
39 | We improved the architecture for ECVP'18 (see presentation slides [here](http://michaeldorr.de/smoothpursuit/ECVP2018_presentation_slides.pdf)) to include 4 convolutional layers instead of 3, no dense layers, and 2 BLSTM layers instead of 1 (both pre-trained models provided via a [link](https://drive.google.com/drive/folders/1SPGTwUKnvZCUFJO05CnYTqakv-Akdth-?usp=sharing) below, together with the data for this research).
40 |
41 | Our approach delivers state-of-the-art performance in terms of fixation, saccade, and smooth pursuit detection in the eye tracking data. It is also the first deep learning approach for eye movement classification that accounts for smooth pursuits (as well as one of the very first dynamic deep learning models for eye movement detection overall).
42 |
43 | 
44 |
45 |
46 | For more details, evaluation protocol, and results -- see our paper.
47 |
48 | # DEPENDENCIES
49 |
50 | To make use of this software, you need to first install the [sp_tool](https://github.com/MikhailStartsev/sp_tool/). For its installation instructions see respective README!
51 |
52 | If you want to use blstm_model.py script (to train/test models on GazeCom -- data to be found [here](http://michaeldorr.de/smoothpursuit/)), provide the correct path to the sp_tool folder via the `--sp-tool-folder /path/to/sp_tool/` argument.
53 |
54 |
55 | **You will also need to download and unzip the data archive from [here](https://drive.google.com/drive/folders/1SPGTwUKnvZCUFJO05CnYTqakv-Akdth-?usp=sharing).** In particular,
56 |
57 | * The files in `data/models` contain the pre-trained models of two different architectures: the "standard" architecture with 3 Conv1D layers, 1 dense layer, and 1 BLSTM layer (described in the main paper from above) and the improved ("final") architecture that was presented at ECVP'18 (4 Conv1D layers, 2 BLSTM layers) presentation slides can be found [here](http://michaeldorr.de/smoothpursuit/ECVP2018_presentation_slides.pdf)). Note that the "improved" architecture performs better (see paper for the evaluation of the standard model or [the project page](http://michaeldorr.de/smoothpursuit/) for the final one).
58 | * The files in `data/inputs` need to be unzipped, if you want to use GazeCom data as input to the model
59 | * The files in `data/outputs` need to be unzipped, if you want to examine the outputs of our models without having to run it
60 |
61 |
62 | ## Standard package dependencies
63 |
64 | See `requirements.txt`, or use `pip install -r requirements.txt` directly.
65 |
66 | # USAGE
67 |
68 | The files in this repository provide the interface to our 1D CNN-BLSTM eye movement classification model.
69 |
70 | `blstm_model.py` is the main script that contains all the necessary tools to train and test the model (currently -- on GazeCom).
71 |
72 | ## Testing the model on your data
73 |
74 | If you simply want to run our best (so far) model(s) on external data, this is the pipeline you need to follow (using the `blstm_model_run.py` script, which is a reduced-interface version of the `blstm_model.py`, see below). This is, perhaps, not the final version of the code, but the pipeline (steps 1 to 3) outlined below has been tested.
75 |
76 |
77 | 1. Convert your data into .arff file format (see `example_data/input.arff` for an example file). Your .arff files need to contain some metadata about the eye tracking experiment (the values are to be set in accordance to your particular data set!):
78 |
79 | %@METADATA width_px 1280.0
80 | %@METADATA height_px 720.0
81 | %@METADATA width_mm 400.0
82 | %@METADATA height_mm 225.0
83 | %@METADATA distance_mm 450.0
84 |
85 | The fields indicate
86 | * the width and height of the stimulus or monitor (in pixels),
87 | * the width and height of the stimulus or monitor in millimeters,
88 | * the distance between the subjects' eyes and the stimulus or monitor (in millimeters).
89 |
90 | The data in this file needs just 4 columns:
91 | * time (in microseconds)
92 | * x and y - the on-screen (or relative to the stimulus) coordinates (in pixels; will be converted to degrees of visual angle automatically - that is why we need the metadata)
93 | * confidence - the eye tracker confidence for the tracking of the subjects' eyes (1 means normal tracking, 0 means lost tracking). If you do not have this information, set to 1 everywhere.
94 |
95 | The format is better described in https://ieeexplore.ieee.org/abstract/document/7851169/
96 |
97 | 2. Such files can then be processed by the Matlab scripts in the `feature_extraction` folder to produce .arff files with features of the gaze (see an example file in example_data/features.arff). For single files, use AnnotateData.m (see respective README; usage: `AnnotateData('path/to/input/file.arff', 'path/to/output/file.arff')` ). You can alternatively use the model that only utilises x and y as feature, in which case you can skip this step.
98 |
99 | 3. Call `python blstm_model_run.py --input path/to/extracted/feature/file.arff --output folder/where/to/put/the/result.arff --model path/to/model/folder/or/file --feat `.
100 |
101 | Feature groups refer to the groups of features that are present in the path/to/extracted/feature/file.arff file. Can be the following: xy, speed, direction, acceleration. Any number of feature groups can be listed (use space as a separator, see example below). We found that using speed and direction as features performed best.
102 |
103 | Example command:
104 | > python blstm_model_run.py --feat speed direction --model example_data/model.h5 --in example_data/features.arff --out example_data/output_reproduced.arff
105 |
106 | More info about the arguments can be obtained via
107 |
108 | > python blstm_model_run.py --help
109 |
110 | ## Training and testing a new model
111 |
112 | For this purpose, use the `blstm_model.py` script. You can find out all its options by running
113 |
114 | > python blstm_model.py --help
115 |
116 | If you wish to train on another data set or to use some new features, please take note of the comments in `blstm_model.py`, in particular in the beginning of the `run()` function.
117 |
118 | During training (after each cross-validation fold has been processed), the script will output some sample-level performance statistics (F1 scores) for the detection of fixation, saccade, pursuit, and noise samples. *These results are intermediate and serve the monitoring purposes only!* The full evaluation can (and will) only be undertaken when all the cross-validation folds are processed.
119 |
120 | ### Multi-instance training (recommended)
121 |
122 | For training, we usually used the `--run-once` option, which will only train one model (i.e. run thorough only one fold of the Leave-One-Video-Out cross-validation process), since some GPU memory is likely not perfectly freed, and this allows for simultaneous training on several program instance or machines (for different folds of our cross-validation), provided that they have access to some common folder for synchronisation.
123 |
124 | Each run will first create a placeholder .h5 model file and then run the training. The model training for the folds with already existing corresponding .h5 files is omitted. With this option the script needs to be run 18 times (same as the number of videos in GazeCom), so a simplest bash-wrapper would be handy. No additional parameters, which specify the fold of the cross-validation is to be processed, are necessary (although it can be specified through `--run-once-video`).
125 |
126 | If you are running the training on several machines or script instances, you can run the 18-iteration loop of the `--run-once`-enabled commands, and the next free machine will start processing the next needed cross-validation fold.
127 |
128 | One thing to note here is that is you interrupt the training process, you will end up with one or more empty .h5 files in the corresponding model folder. **These need to be deleted before the training is resumed**, since the training on these folds will be skipped otherwise.
129 |
130 | #### NB. Right now the path to storing models (`data/models`) and outputs (`data/outputs`) are set for convenient local training. Set model paths to a location that is accessible to all instances of the program that you plan to run by setting an appropriate `--model-root-path`. This folder might need to be created beforehand.
131 |
132 | ### Other important options
133 |
134 | * `--final-run` is for inference run only. This disables the data pre-cleaning when loading it. This operation mode is intended for getting the final classification results with a set of already trained models
135 | * `--output-folder` is the output folder, where the labelled .arff files will be saved. Set to `auto` if you wish for the folder to be selected automatically (the folder name will include the model structure descriptor). **If no argument is provided, no outputs will be created!**
136 | * `--model-root-path` is the path, where all the models will be stored. The script will create a sub-folder for each model architecture (the name will contain the model descriptor), and this sub-folder will then contain individual .h5 trained model files for each of the cross-validation folds when the respective iteration finishes.
137 | * `--batch-size`, `--num-epochs`, `--initial-epoch` (useful for fine-tuning a model), `--training-samples` (**mind the help if adjusting this option**: you might need to adjust `--overlap`, too) are all fairly standard parameters
138 | * since the model deals with temporal windows of data (we tested windows up to 257 samples), the parameters for such windows can be specified via `--window-size` (number of samples) and `--window-overlap` (overlap between the sampled windows -- don't necessarily want to have windows that are only shifted by 1 sample, could lead to overfitting). Generally, using larger window sizes leads to better performance (larger context = better classification):
139 |
140 | 
141 |
142 | * `--features` is another important parameter. It lists the gaze coordinates' features that you want to use for training. It supports the following options: `xy`, `speed`, `direction`, `acceleration`, and `movement`, the latter referring to the speed, direction, and acceleration features combined. Features can be specified in combination, e.g. `--features speed acceleration`. In our tests, using acceleration decreases the overall performance of the model, especially for smooth pursuit.
143 |
144 |
145 | ### Architecture-defining options
146 |
147 | You can configure the architecture by passing appropriate parameters via console arguments. Any architecture that can be achieved by the use of these options will consist of 3 blocks (which can also be omitted):
148 | * Convolutional
149 | * `--num-conv` will set the number of convolutional layers (default: 3, **recommended: 4**)
150 |
151 | 
152 |
153 | * `--conv-padding-mode` sets the padding mode (valid or same)
154 | * `--conv-units` can be used to set the number of convolutional filters that are learned on each layer. This parameter accepts a list of values (e.g. `--conv-units 32 16 8`). If the list is longer than `--num-conv`, it will be truncated. If it is shorter -- the last element is repeated as many times as necessary, so passing `--conv-units 32 16 8` together with `--num-conv 5` will result in 5 convolutional layers, with 32, 16, 8, 8, and 8 filters, respectively
155 |
156 | * Dense (fully-connected)
157 | * `--num-dense` sets the number of dense layers (default: 1, **recommended: 0**)
158 | * `--dense-units` acts much like `--conv-units`, but for the number of dense units in respective layers
159 |
160 | Not using any dense layers proved to be a better choice:
161 |
162 |

163 |
164 | * BLSTM
165 | * `--num-blstm` sets the number of BLSTM layers (default: 1, **recommended: 2**)
166 |
167 | 
168 |
169 | * `--blstm-units` acts just like `--conv-units`, but for the number of BLSTM units in respective layers
170 | * `--no-bidirectional` will force the model to use LSTM instead of BLSTM (leasd to poorer performance, but could be used in an online detection set-up). The plot below represents the training loss value (categorical cross-entropy) for BLSTM vs 2 stacked uni-directional LSTMs (to roughly match the number of parameters) models:
171 |
172 | 
173 |
174 | Here is the comparison between the achieved F1 scores of our "default" architecture and the "recommended" final one:
175 |
176 | 
177 |
178 | If you want to just create the architecture and see the number of trainable parameters or other details, use `--dry-run`.
179 |
--------------------------------------------------------------------------------
/blstm_model.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # coding: utf-8
3 | from datetime import datetime
4 |
5 | import glob
6 | import pickle
7 | import json
8 |
9 | import keras
10 | from keras.models import Sequential
11 | from keras.layers import LSTM, Dense, Conv1D, TimeDistributed, Flatten, Activation, Dropout, Bidirectional
12 | from keras.layers.pooling import MaxPooling1D
13 | from keras.callbacks import History, TensorBoard, Callback
14 | import keras.initializers as KI
15 | from keras.layers import BatchNormalization
16 |
17 | from keras import backend as K
18 | import numpy as np
19 | import os
20 | from argparse import ArgumentParser
21 | import math
22 | from copy import deepcopy
23 | import itertools
24 | import warnings
25 |
26 | from sp_tool import util as sp_util
27 | from sp_tool.arff_helper import ArffHelper
28 | from sp_tool.evaluate import CORRESPONDENCE_TO_HAND_LABELLING_VALUES
29 | from sp_tool import recording_processor as sp_processor
30 |
31 | # If need to set a limit on how much GPU memory the model is allowed to use
32 |
33 | # import tensorflow as tf
34 | # from keras.backend.tensorflow_backend import set_session
35 | # config = tf.ConfigProto()
36 | # config.gpu_options.per_process_gpu_memory_fraction = 0.5
37 | # config.gpu_options.visible_device_list = "0"
38 | # set_session(tf.Session(config=config))
39 |
40 |
41 | def zip_equal(*args):
42 | """
43 | Iterate the zip-ed combination of @args, making sure they have the same length
44 | :param args: iterables to zip
45 | :return: yields what a usual zip would
46 | """
47 | fallback = object()
48 | for combination in itertools.zip_longest(*args, fillvalue=fallback):
49 | if any((c is fallback for c in combination)):
50 | raise ValueError('zip_equals arguments have different length')
51 | yield combination
52 |
53 |
54 | def categorical_f1_score_for_class(y_true, y_pred, class_i, dtype=None):
55 | """
56 | A generic function for computing sample-level F1 score for a class @class_i (i.e. for the classification problem
57 | "is @class_i" vs "is not @class_i").
58 | :param y_true: one-hot encoded true labels for a set of samples
59 | :param y_pred: predicted probabilities for all classes for the same set of samples
60 | :param class_i: which class to consider for the F1 score computation
61 | :param dtype: an optional intermediate data type parameter; unless some relevant exception is raised (type mismatch,
62 | for example), no need to pass anything; in some cases 'float64' had to be passed, instead of the
63 | default 'float32'
64 | :return: one floating-point value, the F1 score for the binary @class_i vs not-@class_i problem
65 | """
66 | pred_labels = K.argmax(y_pred, axis=-1)
67 | tp = K.sum(y_true[:, :, class_i] * K.cast(K.equal(pred_labels, class_i), 'float32' if not dtype else dtype))
68 | all_p_detections = K.sum(K.cast(K.equal(pred_labels, class_i), 'float32' if not dtype else dtype))
69 | all_p_true = K.sum(y_true[:, :, class_i])
70 |
71 | precision = tp / (all_p_detections + K.epsilon())
72 | recall = tp / (all_p_true + K.epsilon())
73 | f_score = 2 * precision * recall / (precision + recall + K.epsilon())
74 |
75 | return f_score
76 |
77 |
78 | # A set of F1-score functions for the three major eye movement types that we use to monitor model training on both
79 | # the training and the validation set. Almost the same signature as the "master"
80 | # function categorical_f1_score_for_class() above.
81 |
82 | def f1_FIX(y_true, y_pred, dtype=None):
83 | return categorical_f1_score_for_class(y_true, y_pred, 1, dtype)
84 |
85 |
86 | def f1_SACC(y_true, y_pred, dtype=None):
87 | return categorical_f1_score_for_class(y_true, y_pred, 2, dtype)
88 |
89 |
90 | def f1_SP(y_true, y_pred, dtype=None):
91 | return categorical_f1_score_for_class(y_true, y_pred, 3, dtype)
92 |
93 |
94 | def create_model(num_classes, batch_size, train_data_shape, dropout_rate=0.3,
95 | padding_mode='valid',
96 | num_conv_layers=3, conv_filter_counts=(32, 16, 8),
97 | num_dense_layers=1, dense_units_count=(32,),
98 | num_blstm_layers=1, blstm_unit_counts=(16,),
99 | unroll_blstm=False,
100 | stateful=False,
101 | no_bidirectional=True):
102 | """
103 | Create a 1D CNN-BLSTM model that contains 3 blocks of layers: Conv1D block, Dense block, and BLSTM block.
104 | Each of these is configurable via the parameters of this function; only the convolutional block cannot be entirely
105 | skipped.
106 |
107 | Each layer in the Conv1D block has the filter size of 3, and is followed by BatchNormalization and ReLU activation.
108 | Every layer in this block, starting from the second one, is preceded by Dropout.
109 |
110 | In the Dense block, every layer uses a TimeDistributed wrapper and is preceded by Dropout, followed by ReLU
111 | activation.
112 |
113 | All [B]LSTM layer(s) use tanh activation.
114 |
115 | After the BLSTM block, the model contains a single (time-distributed) Dense layer with softmax activation that
116 | has the @num_classes units.
117 |
118 |
119 | :param num_classes: number of classes to be classified
120 | :param batch_size: batch size
121 | :param train_data_shape: shape of the training data array (will infer sequence length -- @train_data_shape[1] -- and
122 | number of features -- train_data_shape[2] -- here, @train_data_shape[0] is ignored, and
123 | @batch_size is used instead).
124 | :param dropout_rate: each convolutional (except for the first one) and dense layer is preceded by Dropout with this
125 | rate.
126 | :param padding_mode: convolution padding mode; can be 'valid' (default), 'same', or 'causal'; the latter can be
127 | useful, if a modification into a realtime-like model is desired:
128 | https://keras.io/layers/convolutional/#conv1d
129 | :param num_conv_layers: number of convolutional layers in the Conv1D block
130 | :param conv_filter_counts: number of filters in each respective Conv1D layer; has to be of length of at least
131 | @num_conv_layers - will use the first @num_conv_layers elements
132 | :param num_dense_layers: number of dense layers in the Dense block
133 | :param dense_units_count: number of units in each respective Dense layer; has to be of length of at least
134 | @num_dense_layers - will use the first @num_dense_layers elements
135 | :param num_blstm_layers: number of dense layers in the BLSTM block
136 | :param blstm_unit_counts: number of units in each respective [B]LSTM layer; has to be of length of at least
137 | @num_blstm_layers - will use the first @num_blstm_layers elements
138 | :param unroll_blstm: whether to unroll the [B]LSTM(s), see https://keras.io/layers/recurrent/#lstm
139 | :param stateful: whether to make the [B]LSTM(s) stateful; not used yet
140 | :param no_bidirectional: if True, will use traditional LSTMs, not the Bidirectional wrapper,
141 | see https://keras.io/layers/wrappers/#bidirectional
142 | :return: a keras.models.Sequential() model
143 | """
144 | assert num_conv_layers <= len(conv_filter_counts)
145 | if len(conv_filter_counts) != num_conv_layers:
146 | warnings.warn('@num_conv_layers={} is shorter than @conv_filter_counts={}, so the last {} elements of the '
147 | 'latter will be ignored. Might be incorrectly passed arguments!'.format(
148 | num_conv_layers, conv_filter_counts, len(conv_filter_counts) - num_conv_layers))
149 |
150 | assert num_dense_layers <= len(dense_units_count)
151 | if len(dense_units_count) != num_dense_layers:
152 | warnings.warn('@num_dense_layers={} is shorter than @dense_unit_counts={}, so the last {} elements of the '
153 | 'latter will be ignored. Might be incorrectly passed arguments!'.format(
154 | num_dense_layers, dense_units_count, len(dense_units_count) - num_dense_layers))
155 |
156 | assert num_blstm_layers <= len(blstm_unit_counts)
157 | if len(blstm_unit_counts) != num_blstm_layers:
158 | warnings.warn('@num_blstm_layers={} is shorter than @blstm_unit_counts={}, so the last {} elements of the '
159 | 'latter will be ignored. Might be incorrectly passed arguments!'.format(
160 | num_conv_layers, conv_filter_counts, len(blstm_unit_counts) - num_blstm_layers))
161 |
162 | model = Sequential()
163 |
164 | for conv_layer_id in range(num_conv_layers):
165 | if conv_layer_id != 0:
166 | model.add(Dropout(dropout_rate))
167 |
168 | conv_layer_args = {
169 | 'filters': conv_filter_counts[conv_layer_id],
170 | 'kernel_size': 3,
171 | 'padding': padding_mode,
172 | 'kernel_initializer': KI.RandomNormal(),
173 | 'bias_initializer': KI.Ones()
174 | }
175 | # special args for the first layer
176 | if conv_layer_id == 0:
177 | conv_layer_args['batch_input_shape'] = (batch_size, train_data_shape[1], train_data_shape[2])
178 |
179 | model.add(Conv1D(**conv_layer_args))
180 | model.add(BatchNormalization(axis=-1))
181 | model.add(Activation('relu'))
182 |
183 | model.add(TimeDistributed(Flatten()))
184 |
185 | for dense_layer_id in range(num_dense_layers):
186 | model.add(Dropout(dropout_rate))
187 | model.add(TimeDistributed(Dense(dense_units_count[dense_layer_id], activation='relu',
188 | kernel_initializer=KI.RandomNormal(),
189 | bias_initializer=KI.Ones())))
190 |
191 | for blstm_layer_id in range(num_blstm_layers):
192 | if not no_bidirectional:
193 | model.add(Bidirectional(LSTM(blstm_unit_counts[blstm_layer_id],
194 | return_sequences=True, stateful=stateful,
195 | unroll=unroll_blstm)))
196 | else:
197 | model.add(LSTM(blstm_unit_counts[blstm_layer_id],
198 | return_sequences=True, stateful=stateful,
199 | unroll=unroll_blstm))
200 |
201 | model.add(TimeDistributed(Dense(num_classes, activation='softmax',
202 | kernel_initializer=KI.RandomNormal(),
203 | bias_initializer=KI.Ones())))
204 |
205 | model.compile(loss='categorical_crossentropy',
206 | optimizer='rmsprop', metrics=['accuracy',
207 | f1_SP,
208 | f1_FIX,
209 | f1_SACC])
210 | model.summary()
211 | return model
212 |
213 |
214 | def extract_windows(X, y, window_length,
215 | padding_features=0,
216 | downsample=1, temporal_padding=False):
217 | """
218 | Extract fixed-sized (@window_length) windows from arbitrary-length sequences (in X and y),
219 | padding them, if necessary (mirror-padding is used).
220 |
221 | :param X: input data; list of arrays, each shaped like (NUM_SAMPLES, NUM_FEATURES);
222 | each list item corresponds to one eye tracking recording (one observer & one stimulus clip)
223 | :param y: corresponding labels; list of arrays, each shaped like (NUM_SAMPLES,);
224 | each list element corresponds to sample-level eye movement class labels in the respective sequence;
225 | list elements in X and y are assumed to be matching.
226 | :param window_length: the length of resulting windows; this is the input "context size" in the paper, in samples.
227 | :param padding_features: how many extra samples to take in the feature (X) space on each side
228 | (resulting X will have sequence length longer than resulting Y, by 2 * padding_features,
229 | while Y will have sample length of @window_length);
230 | this is necessary due to the use of valid padding in convolution operations in the model;
231 | if all convolutions are of size 3, and if they all use valid padding, @padding_features
232 | should be set to the number of convolution layers.
233 | :param downsample: take each @downsample'th window; if equal to @window_length, no overlap between windows;
234 | by default, all possible windows with the shift of 1 sample between them will be created,
235 | resulting in NUM_SAMPLES-1 overlap; if overlap of K samples is desired, should set
236 | @downsample=(NUM_SAMPLES-K)
237 | :param temporal_padding: whether to pad the entire sequences, so that the first window is centered around the
238 | first sample of the real sequence (i.e. the sequence of recorded eye tracking samples);
239 | not used
240 | :return: two lists of numpy arrays:
241 | (1) a list of windows corresponding to input data (features),
242 | (2) a list of windows corresponding to labels we will predict.
243 | These can be used as input to network training procedures, for example.
244 | """
245 | res_X = []
246 | res_Y = []
247 | # iterate through each file in this subset of videos
248 | for x_item, y_item in zip_equal(X, y):
249 | # pad for all windows
250 | padding_size_x = padding_features
251 | padding_size_y = 0
252 | if temporal_padding:
253 | padding_size_x += window_length // 2
254 | padding_size_y += window_length // 2
255 |
256 | padded_x = np.pad(x_item, ((padding_size_x, padding_size_x), (0, 0)), 'reflect')
257 | padded_y = np.pad(y_item, ((padding_size_y, padding_size_y), (0, 0)), 'reflect')
258 |
259 | #
260 | # Extract all valid windows in @padded_x, with given downsampling and size.
261 | # @res_X will have windows of size @window_length + 2*@padding_features
262 | window_length_x = window_length + 2 * padding_features
263 | res_X += [padded_x[i:i + window_length_x, :] for i in
264 | range(0, padded_x.shape[0] - window_length_x + 1, downsample)]
265 | # @res_Y will have windows of size @window_length, central to the ones in @res_X
266 | res_Y += [padded_y[i:i + window_length, :] for i in
267 | range(0, padded_y.shape[0] - window_length + 1, downsample)]
268 | return res_X, res_Y
269 |
270 |
271 | def evaluate_test(model, X, y=None,
272 | keys_to_subtract_start_indices=(),
273 | correct_for_unknown_class=True,
274 | padding_features=0,
275 | temporal_padding=False,
276 | split_by_items=False):
277 | """
278 | Predict (with a trained @model) labels for full sequences in @X and compare those to @y. The function returns both
279 | the predicted and true labels ("raw" results), and the results of several metrics
280 | (accuracy, sample-level F1 scores for classes fixation, saccade, pursuit, and noise) - "aggregated" results.
281 | If @y is not provided, only the predicted per-class probabilities are returned
282 |
283 | The context window width for prediction will be inferred from the @model.
284 |
285 | If predicting on data without labels, set @correct_for_unknown_class to False (see below)! If @y is not passed, it
286 | will not be used either way, but if a "fake" @y is provided (e.g. filled with UNKNOWN labels), make sure to change
287 | this option.
288 |
289 | :param model: a trained deep model (has to have a .predict_proba() method)
290 | :param X: input data; list of arrays, each shaped like (NUM_SAMPLES, NUM_FEATURES);
291 | each list item corresponds to one eye tracking recording (one observer & one stimulus clip)
292 | :param y: corresponding labels; needed only to evaluate the model predictions;
293 | the list of arrays, each shaped like (NUM_SAMPLES,);
294 | each list element corresponds to sample-level eye movement class labels in the respective sequence;
295 | list elements in X and y are assumed to be matching.
296 | :param keys_to_subtract_start_indices: indices of the keys, the starting value of which should be subtracted from
297 | all samples (inside of each window that is passed to the model)
298 | :param correct_for_unknown_class: if True, will assign "correct" probabilities (1 for class 0, 0 for other classes)
299 | to the samples that have an UNKNOWN label, since we don't want them influencing
300 | the evaluation scores.
301 | :param padding_features: how many extra samples to take in the feature (X) space on each side
302 | (resulting X will have sequence length longer than resulting Y, by 2 * padding_features,
303 | while Y will have sample length of @window_length);
304 | this is necessary due to the use of valid padding in convolution operations in the model;
305 | if all convolutions are of size 3, and if they all use valid padding, @padding_features
306 | should be set to the number of convolution layers.
307 | :param temporal_padding: whether to pad the entire sequences, so that the first window is centered around the
308 | first sample of the real sequence (i.e. the sequence of recorded eye tracking samples);
309 | False by default
310 | :param split_by_items: whether to split "raw" results by individual sequences in @X and @y
311 | (necessary when want to create output .arff files that match all input .arff files)
312 | :return: two dictionaries:
313 | (1) "raw_results", with keys "pred" (per-class probabilities predicted by the network, for every
314 | sample in @X) and "true" (one-hot encoded class labels, for every sample in @X -- taken from @y,
315 | if @y was provided).
316 | If @split_by_items is set to True, these will be lists of numpy arrays that correspond to
317 | *sequences* (not samples!) in @X. Otherwise, these are themselves numpy arrays.
318 | (2) "results", with keys "accuracy", "F1-FIX", "F1-SACC", "F1-SP", and "F1-NOISE" (if @y is provided)
319 | """
320 | res_Y = [] if y is not None else None
321 | res_X = []
322 | if y is None:
323 | # if no true labels are provided, fake the sequence
324 | y = [None] * len(X)
325 |
326 | window_length = model.output_shape[1] # output_shape is (batch_size, NUM_SAMPLES, NUM_CLASSES)
327 | downsample = window_length # no overlap here
328 |
329 | # Keep track of where each individual recording (1 observer, 1 stimulus clip) starts and ends -- will need this
330 | # to produce final automatically labelled .arff files
331 | items_start = []
332 | items_end = []
333 |
334 | for x_item, y_item in zip_equal(X, y):
335 | items_start.append(len(res_X))
336 |
337 | # how much padding is needed additionally, for full-window size
338 | target_length = int(math.ceil(x_item.shape[0] / float(window_length)) * window_length)
339 | padding_size_x = [0, target_length - x_item.shape[0]]
340 | padding_size_y = [0, target_length - y_item.shape[0]] if y_item is not None else [0, 0]
341 |
342 | # pad features (@x_item) for all windows
343 | padding_size_x = [elem + padding_features for elem in padding_size_x]
344 | # no @y_item-padding required
345 |
346 | if temporal_padding:
347 | padding_size_x = [elem + window_length // 2 for elem in padding_size_x]
348 | padding_size_y = [elem + window_length // 2 for elem in padding_size_y]
349 |
350 | padded_x = np.pad(x_item, (padding_size_x, (0, 0)), 'reflect') # x is padded with reflections to limit artifacts
351 | if y_item is not None:
352 | padded_y = np.pad(y_item, (padding_size_y, (0, 0)), 'constant') # y is zero-padded to ignore those labels
353 | # set to UNKNOWN class
354 | has_no_label_mask = padded_y.sum(axis=1) == 0
355 | padded_y[has_no_label_mask, 0] = 1
356 |
357 | #
358 | # Extract all valid windows in @padded_x, with given downsampling and size.
359 | # @res_X will have windows of size @window_length + 2*@padding_features
360 | window_length_x = window_length + 2 * padding_features
361 | res_X += [padded_x[i:i + window_length_x, :] for i in
362 | range(0, padded_x.shape[0] - window_length_x + 1, downsample)]
363 | if y_item is not None:
364 | # @res_Y will have windows of size @window_length, central to the ones in @res_X
365 | res_Y += [padded_y[i:i + window_length, :] for i in
366 | range(0, padded_y.shape[0] - window_length + 1, downsample)]
367 |
368 | items_end.append(len(res_X))
369 |
370 | res_X = np.array(res_X)
371 | if res_Y is not None:
372 | res_Y = np.array(res_Y)
373 |
374 | # Subtract the first value of each feature (for which it is necessary) inside each window
375 | for col_ind in keys_to_subtract_start_indices:
376 | res_X[:, :, col_ind] -= res_X[:, 0, col_ind].reshape(-1, 1)
377 |
378 | # augment to fit batch size
379 | batch_size = model.get_config()['layers'][0]['config']['batch_input_shape'][0]
380 | original_len = res_X.shape[0]
381 | target_len = int(np.ceil(float(original_len) / batch_size)) * batch_size
382 | res_X = np.pad(res_X,
383 | pad_width=((0, target_len - res_X.shape[0]), (0, 0), (0, 0)),
384 | mode='constant')
385 | # take only the needed predictions
386 | res_proba = model.predict(res_X, batch_size=batch_size)[:original_len]
387 |
388 | results = {}
389 |
390 | # The unknown labels in @res_Y should play no role, set respective probabilities in @res_proba:
391 | # If @y was provided, set "correct" probabilities for the UNKNOWN class, since we cannot properly evaluate against
392 | # an undefined label
393 | if correct_for_unknown_class and res_Y is not None:
394 | unknown_class_mask = res_Y[:, :, 0] == 1
395 | res_proba[unknown_class_mask] = 0.0
396 | res_proba[unknown_class_mask, 0] = 1.0
397 | res = np.argmax(res_proba, axis=-1)
398 |
399 | raw_results = {'true': res_Y, 'pred': res_proba}
400 |
401 | # cannot evaluate unless the labels @y were provided
402 | if res_Y is not None:
403 | results['accuracy'] = (np.mean(res[np.logical_not(unknown_class_mask)] ==
404 | np.argmax(res_Y[np.logical_not(unknown_class_mask)], axis=-1)))
405 |
406 | for stat_i, stat_name in zip(list(range(1, 5)), ['FIX', 'SACC', 'SP', 'NOISE']):
407 | results['F1-{}'.format(stat_name)] = K.eval(categorical_f1_score_for_class(res_Y, res_proba, stat_i,
408 | 'float64'))
409 |
410 | if split_by_items:
411 | # split into individual sequences
412 | split_res_true = [] if res_Y is not None else None
413 | split_res_pred = []
414 | for individual_start, individual_end in zip_equal(items_start, items_end):
415 | if res_Y is not None:
416 | split_res_true.append(raw_results['true'][individual_start:individual_end])
417 | split_res_pred.append(raw_results['pred'][individual_start:individual_end])
418 | raw_results = {'true': split_res_true, 'pred': split_res_pred}
419 |
420 | return raw_results, results
421 |
422 |
423 | def get_architecture_descriptor(args):
424 | """
425 | Generate the descriptor of the model
426 | :param args: command line arguments
427 | :return: descriptor string
428 | """
429 | # (convert lists to tuples to avoid [] in the descriptor, which later confuse glob.glob())
430 | return '{numC}x{padC}C@{unitsC}_{numD}xD@{unitsD}_{numB}x{typeB}@{unitsB}'.format(
431 | numC=args.num_conv, padC=args.conv_padding_mode[0], unitsC=tuple(args.conv_units)[:args.num_conv],
432 | numD=args.num_dense, unitsD=tuple(args.dense_units)[:args.num_dense],
433 | numB=args.num_blstm, typeB=('B' if not args.no_bidirectional else 'L'),
434 | unitsB=tuple(args.blstm_units)[:args.num_blstm])
435 |
436 |
437 | def get_feature_descriptor(args):
438 | """
439 | Generate the descriptor of the feature set that is used
440 | :param args: command line arguments
441 | :return: descriptor string
442 | """
443 | feature_postfix = []
444 | features_to_name = args.features[:] # copy list
445 | naming_priority = ['movement', 'speed', 'acc', 'direction']
446 | naming_priority += sorted(set(args.features).difference(naming_priority)) # add the rest in alphabetical order
447 | for n in naming_priority:
448 | if n == 'movement':
449 | # to add "movement" to the descriptor, check that all 3 parts of these features are present
450 | if 'speed' in features_to_name and \
451 | 'acc' in features_to_name and \
452 | 'direction' in features_to_name:
453 | feature_postfix.append('movement')
454 |
455 | features_to_name.remove('speed')
456 | features_to_name.remove('acc')
457 | features_to_name.remove('direction')
458 | else:
459 | continue
460 | if n in features_to_name:
461 | feature_postfix.append(n)
462 | features_to_name.remove(n)
463 | feature_postfix = '_'.join(feature_postfix)
464 |
465 | # if we want to limit the number of temporal scales of the features, mark it in the signature
466 | if args.num_feature_scales < 5:
467 | feature_postfix += '_{}_temp_scales'.format(args.num_feature_scales)
468 |
469 | return feature_postfix
470 |
471 |
472 | def get_full_model_descriptor(args):
473 | """
474 | Get a full descriptor of the model, which includes the architecture descriptor, feature descriptor, and info
475 | about context window size and overlap.
476 | :param args:
477 | :return:
478 | """
479 | return '{mdl}_{feat}_WINDOW_{win}{overlap}/'.format(mdl=get_architecture_descriptor(args),
480 | feat=get_feature_descriptor(args),
481 | win=args.window_size,
482 | overlap='_overlap_{}'.format(args.overlap)
483 | if args.overlap > 0 else '')
484 |
485 |
486 | def get_arff_attributes_to_keep(args):
487 | keys_to_keep = []
488 | if 'xy' in args.features:
489 | keys_to_keep += ['x', 'y']
490 |
491 | if 'speed' in args.features:
492 | keys_to_keep += ['speed_{}'.format(i) for i in (1, 2, 4, 8, 16)[:args.num_feature_scales]]
493 | if 'direction' in args.features:
494 | keys_to_keep += ['direction_{}'.format(i) for i in (1, 2, 4, 8, 16)[:args.num_feature_scales]]
495 | if 'acc' in args.features:
496 | keys_to_keep += ['acceleration_{}'.format(i) for i in (1, 2, 4, 8, 16)[:args.num_feature_scales]]
497 |
498 | return keys_to_keep
499 |
500 |
501 | def run(args):
502 | """
503 | Run model training/testing, depending on @args. See description of parse_args() for more information!
504 | :param args: terminal arguments to the program. See parse_args() help, or run with --help.
505 | :return: if args.dry_run is set to True, returns a model (created or loaded)
506 | """
507 |
508 | # For every numeric label we will need a categorical (human-readable) value to easier interpret results.
509 | CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE = {v: k for k, v in
510 | CORRESPONDENCE_TO_HAND_LABELLING_VALUES.items()}
511 | print(CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE)
512 | num_classes = 5 # 0 = UNKNOWN, 1 = FIX, 2 = SP, 3 = SACC, 4 = NOISE
513 |
514 | # Paths, where to store models and generated outputs (.arff files with sample-level labels).
515 | # These are just the "root" folders for this data, subfolders will be automatically created
516 | # for each model configuration.
517 | #
518 | # Now these are stored locally, but if distributed training is desired, make sure that especially MODELS_DIR
519 | # points to a shared location that is accessible for reading and writing to all training nodes!
520 | # Pass the appropriate --model-root-path argument!
521 | MODELS_DIR = 'data/models/' if args.model_root_path is None else args.model_root_path
522 | OUT_DIR = 'data/outputs/'
523 |
524 | # by default, do training; --final-run will initiate testing mode
525 | TRAINING_MODE = True
526 | if args.final_run:
527 | TRAINING_MODE = False
528 |
529 | # NB: Some training parameters are set for the GazeCom data set! If you wish to *train* this model on
530 | # another data set, you will need to adjust these:
531 | # - CLEAN_TIME_LIMIT -- during training, all samples (for all clips) with a timestamp exceeding this threshold
532 | # will be disregarded. Set LOAD_CLEAN_DATA to False instead of "=TRAINING_MODE", if this is
533 | # not desired.
534 | # - certain .arff column names ("handlabeller_final" for the ground truth labels, feature names - e.g. "speed_1")
535 | # would have to be adjusted, if respective data is located in other columns
536 | # - keys_to_subtract_start -- if you add more features that need to be zeroed in the beginning of each window
537 | # that is given to the network, you need to add them to this set; by default, only
538 | # features "x" and "y" are treated this way.
539 | # - num_classes above, if more classes are labelled in the data set (5 by default, 0 = UNKNOWN, the rest are
540 | # actual labels (1 = FIXATION, 2 = SACCADE, 3 = SMOOTH PURSUIT, 4 = NOISE);
541 | # also, CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE will have to be updated then!
542 | # - time_scale -- in GazeCom, time is given in microseconds, so it has to be multiplied by 1e-6 to convert to
543 | # seconds; either convert your data set's time to microseconds as well, or change the @time_scale
544 |
545 | # If you only want to test a pre-trained model on a new data set, you have to convert the data set to .arff file
546 | # format with the fields 'time', 'x', 'y', and maybe 'handlabeller_final', if manual annotations are available.
547 | # These fields will be preserved in final outputs. Also you have to provide fields that correspond to all features
548 | # that the model utilises (e.g. for speed and direction, feature names are speed_1, speed_2, speed_4, speed_8,
549 | # and speed_16 (speed in degrees of visual angle per second, extracted at corresponding scales, in samples),
550 | # same for direction (in radians relative to the horizontal vector from left to right). See the `feature_extraction`
551 | # folder and its scripts for implementation details, as well as the corresponding paper.
552 |
553 | CLEAN_TIME_LIMIT = 21 * 1e6 # in microseconds; 21 seconds
554 |
555 | # During training, some data purification is performed ("clean" data is loaded):
556 | # - samples after CLEAN_TIME_LIMIT microseconds are ignored
557 | # - files SSK_*.arff are re-sampled by taking every second gaze sample
558 | # (in GazeCom, these contain 500Hz recordings, compared to 250Hz of the rest of the data set)
559 | # Set LOAD_CLEAN_DATA to False, if this is not desired.
560 | LOAD_CLEAN_DATA = TRAINING_MODE
561 | if LOAD_CLEAN_DATA:
562 | print('Loading clean data')
563 | else:
564 | print('Loading raw data')
565 |
566 | print('Feature description:', get_feature_descriptor(args))
567 | print('Architecture description:', get_architecture_descriptor(args))
568 |
569 | # Where to find feature files. "{}" in the template is where a clip name will be inserted.
570 | # The "root" folder is by default 'data/inputs/GazeCom_all_features'
571 | files_template = args.feature_files_folder + '/{}/*.arff'
572 |
573 | # When extracting window of fixed size with fixed overlap, this formula defines the downsampling factor
574 | # (i.e. will take every @downsample'th window of @args.window_size in length, which will result in overlap of
575 | # exactly @args.overlap samples between subsequent windows)
576 | downsample = args.window_size - args.overlap
577 | # All convolutions are of size 3, so each Conv1D layer (with valid padding) requires 1 sample padded on each side.
578 | padding_features = args.num_conv if args.conv_padding_mode == 'valid' else 0
579 |
580 | # Set paths to model files and output files' directory, which include model parameters,
581 | # such as feature signatures, window size, overlap size.
582 | MODEL_PATH = '{dir}/LOO_{descr}/'.format(dir=MODELS_DIR,
583 | descr=get_full_model_descriptor(args))
584 | OUT_PATH = '{dir}/output_LOO_{descr}/'.format(dir=OUT_DIR,
585 | descr=get_full_model_descriptor(args))
586 | # The record .model_name in @args overrides the MODEL_PATH set above.
587 | if args.model_name is not None:
588 | MODEL_PATH = '{}/{}/'.format(MODELS_DIR, args.model_name)
589 |
590 | print('Selected model path:', MODEL_PATH)
591 | if not args.dry_run:
592 | if not os.path.exists(MODEL_PATH):
593 | os.mkdir(MODEL_PATH)
594 |
595 | # Load all pre-computed video parameters just to get video names (for cross-video validation, LOVO)
596 | # Only need the 'video_names' key in the dictionary in this json file (in case it needs to be adjusted for a
597 | # different data set).
598 | gc = json.load(open('data/inputs/GazeCom_video_parameters.json'))
599 | all_video_names = gc['video_names']
600 |
601 | # To clear the used memory, we will normally run this script as many times as there are different clips in GazeCom.
602 | # Not to load and pre-process raw input sequences this often, we dump them as an .h5 file, or load it, if it
603 | # already exists.
604 | raw_data_set_fname = 'data/cached/GazeCom_data_{feat}{is_clean}.h5'.format(feat=get_feature_descriptor(args),
605 | is_clean='' if LOAD_CLEAN_DATA
606 | else '_not_clean')
607 |
608 | # During a final run, we want to create output files that have fewer fields: just the ones describing the raw data
609 | # (no features) and the ground truth label.
610 | if args.final_run:
611 | source_fnames = []
612 | source_objs = []
613 | source_keys_to_keep = ['time', 'x', 'y', 'confidence', 'handlabeller_final']
614 |
615 | # Depending on the command line arguments, certain keys are used as features (@keys_to_keep here)
616 | keys_to_keep = get_arff_attributes_to_keep(args)
617 |
618 | print('Using the following features:', keys_to_keep)
619 |
620 | data_X = []
621 | data_Y = []
622 | data_Y_one_hot = []
623 | # If no file with pre-processed features and labels exists yet, create it.
624 | if not os.path.exists(raw_data_set_fname):
625 | if not os.path.exists(os.path.split(raw_data_set_fname)[0]):
626 | os.makedirs(os.path.split(raw_data_set_fname)[0])
627 |
628 | # Will convert to degrees the following keys: x, y, and all speed and acceleration features.
629 | keys_to_convert_to_degrees = ['x', 'y'] + [k for k in keys_to_keep if 'speed_' in k or 'acceleration_' in k]
630 | keys_to_convert_to_degrees = sorted(set(keys_to_convert_to_degrees).intersection(keys_to_keep))
631 | # Conversion is carried out by dividing by pixels-per-degree value (PPD)
632 | print('Will divide by PPD the following keys', keys_to_convert_to_degrees)
633 |
634 | time_scale = 1e-6 # time is originally in microseconds; scale to seconds
635 |
636 | total_files = 0
637 | for video_name in gc['video_names']:
638 | print('For {} using files from {}'.format(video_name, files_template.format(video_name)))
639 | fnames = sorted(glob.glob(files_template.format(video_name)))
640 | total_files += len(fnames)
641 |
642 | data_X.append([])
643 | data_Y.append([])
644 | data_Y_one_hot.append([])
645 | if args.final_run:
646 | source_fnames.append(list(fnames))
647 | source_objs.append([])
648 |
649 | for f in fnames:
650 | o = ArffHelper.load(open(f))
651 | if LOAD_CLEAN_DATA and 'SSK_' in f:
652 | # the one observer with 500HZ instead of 250Hz
653 | o['data'] = o['data'][::2]
654 | if LOAD_CLEAN_DATA:
655 | o['data'] = o['data'][o['data']['time'] <= CLEAN_TIME_LIMIT]
656 |
657 | if args.final_run: # make a copy of the data before any further modifications! record to @source_objs
658 | source_obj = deepcopy(o)
659 | # only the essential keys in array
660 | source_obj['data'] = source_obj['data'][source_keys_to_keep]
661 | # in attributes too
662 | attribute_names = [n for n, _ in source_obj['attributes']]
663 | source_obj['attributes'] = [source_obj['attributes'][attribute_names.index(attr)]
664 | for attr in source_keys_to_keep]
665 | source_objs[-1].append(source_obj)
666 |
667 | # normalize coordinates in @o by dividing by @ppd_f -- the pixels-per-degree value of this file @f
668 | ppd_f = sp_util.calculate_ppd(o)
669 | for k in keys_to_convert_to_degrees:
670 | o['data'][k] /= ppd_f
671 |
672 | # add to respective data sets (only the features to be used and the true labels)
673 | data_X[-1].append(np.hstack([np.reshape(o['data'][key], (-1, 1)) for key in keys_to_keep]).astype(np.float64))
674 | assert data_X[-1][-1].dtype == np.float64
675 | if 'time' in keys_to_keep:
676 | data_X[-1][-1][:, keys_to_keep.index('time')] *= time_scale
677 | data_Y[-1].append(o['data']['handlabeller_final']) # "true" labels
678 | data_Y_one_hot[-1].append(np.eye(num_classes)[data_Y[-1][-1]]) # convert numeric labels to one-hot form
679 |
680 | if total_files > 0:
681 | print('Loaded a total of {} files'.format(total_files))
682 | else:
683 | raise ValueError('No input files found! Check that the directory "{}" exists '
684 | 'and is accessible for reading, or provide a different value for the '
685 | '--feature-files-folder argument. Make sure you extracted the respective archive, '
686 | 'if the data was provided in this way. This folder must contain '
687 | '18 subfolders with names corresponding to '
688 | 'GazeCom clip names.'.format(args.feature_files_folder))
689 |
690 | # As mentioned above, preserve the pre-processed data to not re-do this again, at least on the same system.
691 | # This creates files that are dependent of the features that are preserved and whether @LOAD_CLEAN_DATA is set.
692 | if not args.final_run:
693 | pickle.dump({'data_X': data_X, 'data_Y': data_Y, 'data_Y_one_hot': data_Y_one_hot},
694 | open(raw_data_set_fname, 'w'))
695 | else:
696 | pickle.dump({'data_X': data_X, 'data_Y': data_Y, 'data_Y_one_hot': data_Y_one_hot,
697 | 'source_fnames': source_fnames, 'source_objs': source_objs},
698 | open(raw_data_set_fname, 'w'))
699 | print('Written to', raw_data_set_fname)
700 | else:
701 | # If the raw file already exists, just load it
702 | print('Loading from', raw_data_set_fname)
703 | loaded_data = pickle.load(open(raw_data_set_fname))
704 | data_X, data_Y, data_Y_one_hot = loaded_data['data_X'], loaded_data['data_Y'], loaded_data['data_Y_one_hot']
705 | if args.final_run:
706 | source_fnames, source_objs = loaded_data['source_fnames'], loaded_data['source_objs']
707 |
708 | # Each record of the @windows_X and @windows_Y lists corresponds to one "fold" of the used cross-validation
709 | # procedure
710 | windows_X = []
711 | windows_Y = []
712 |
713 | # Will subtract the initial value from the following keys (only the changes in these keys should matter),
714 | # not to overfit, for example, for spatial location of the eye movement: x and y coordinates.
715 | keys_to_subtract_start = sorted({'x', 'y'}.intersection(keys_to_keep))
716 | print('Will subtract the starting values of the following features:', keys_to_subtract_start)
717 | keys_to_subtract_start_indices = [i for i, key in enumerate(keys_to_keep) if key in keys_to_subtract_start]
718 |
719 | for subset_index in range(len(data_X)):
720 | x, y = extract_windows(data_X[subset_index], data_Y_one_hot[subset_index],
721 | window_length=args.window_size,
722 | downsample=downsample,
723 | padding_features=padding_features)
724 | windows_X.append(np.array(x))
725 | windows_Y.append(np.array(y))
726 |
727 | # Subtract the first value of each feature inside each window
728 | for col_ind in keys_to_subtract_start_indices:
729 | windows_X[-1][:, :, col_ind] -= windows_X[-1][:, 0, col_ind].reshape(-1, 1)
730 |
731 | # TensorBoard logs directory to supervise model training
732 | log_root_dir = 'data/tensorboard_logs/{}'
733 | logname = 'logs_{}_{}'.format(datetime.now().strftime("%Y-%m-%d_%H-%M"),
734 | os.path.split(MODEL_PATH.rstrip('/'))[1]) # add model description to filename
735 |
736 | callbacks_list = [History(), TensorBoard(batch_size=args.batch_size, log_dir=log_root_dir.format(logname),
737 | write_images=False)] # , histogram_freq=100, write_grads=True)]
738 |
739 | # All results are stored as lists - one element for one left-out group (one fold of the cross-validation,
740 | # in this case)
741 | results = {
742 | 'accuracy': [],
743 | 'F1-SP': [],
744 | 'F1-FIX': [],
745 | 'F1-SACC': [],
746 | 'F1-NOISE': []
747 | }
748 | raw_results = {'true': [], 'pred': []}
749 | training_histories = [] # histories for all folds
750 |
751 | # LOVO-CV (Leave-One-Video-Out; Leave-n-Observers-Out overestimates model performance, see paper)
752 | num_training_runs = 0 # count the actual training runs, in case we need to stop after just one
753 | # (if @args.run_once == True)
754 | # iterate through stimulus video clips, leaving each one out in turn
755 | for i, video_name in enumerate(all_video_names):
756 | if args.run_once and (args.run_once_video is not None and video_name != args.run_once_video):
757 | # --run-once is enabled and the --run-once-video argument has been passed. If this video name does not
758 | # match the @video_name, skip it
759 | continue
760 | # Check if final (trained) model already exists.
761 | # Thanks to creating an empty file in the "else" clause below, can run several training procedures at the same
762 | # time (preferably on different machines or docker instances, with MODEL_PATH set to some location that is
763 | # accessible to all the machines/etc.) - the other training procedures will skip training the model for this
764 | # particular fold of the cross-validation procedure.
765 | model_fname = MODEL_PATH + '/Conv_sample_windows_epochs_{}_without_{}.h5'.format(args.num_epochs, video_name)
766 | if os.path.exists(model_fname):
767 | ALREADY_TRAINED = True
768 | print('Skipped training, file in', model_fname, 'exists')
769 | if not args.final_run:
770 | # if no need to generate output .arff files, just skip this cross-validation fold
771 | continue
772 | else:
773 | # have not begun training yet
774 | ALREADY_TRAINED = False
775 | print('Creating an empty file in {}'.format(model_fname))
776 | os.system('touch "{}"'.format(model_fname))
777 |
778 | if not ALREADY_TRAINED:
779 | # need to train the model
780 |
781 | r = np.random.RandomState(0)
782 | # here, we ignore the "left out" video clip @i, but aggregate over all others
783 | train_set_len = sum([len(windows_X[j]) for j in range(len(windows_X)) if j != i])
784 | print('Total amount of windows:', train_set_len)
785 | # will shuffle all samples according to this permutation, and will keep only the necessary amount
786 | # (by default, 50,000 for training)
787 | perm = r.permutation(train_set_len)[:args.num_training_samples]
788 | train_X = []
789 | train_Y = []
790 | indices_range_low = 0
791 | for j in range(len(windows_X)):
792 | if j == i:
793 | continue # skip the test set
794 | indices_range_high = indices_range_low + len(windows_X[j])
795 | # permutation indices [indices_range_low; indices_range_high) are referring
796 | # to the windows in @windows_X[j]
797 | local_indices = perm[(perm >= indices_range_low) * (perm < indices_range_high)]
798 | # re-map in "global" indices (inside @perm) onto "local" indices (inside @windows_X[j])
799 | local_indices -= indices_range_low
800 | train_X.append(windows_X[j][local_indices])
801 | train_Y.append(windows_Y[j][local_indices])
802 |
803 | indices_range_low = indices_range_high
804 |
805 | # assemble the entire training set
806 | train_X = np.concatenate(train_X)
807 | train_Y = np.concatenate(train_Y)
808 |
809 | # if fine-tuning, load a pre-trained model; if not - create a model from scratch
810 | if args.initial_epoch:
811 | # We have to pass our metrics (f1_SP and so on) as custom_objects here and below, since it won't load
812 | # otherwise
813 | model = keras.models.load_model(MODEL_PATH + '/Conv_sample_windows_epochs_{}_without_{}.h5'.format(args.initial_epoch, video_name),
814 | custom_objects={'f1_SP': f1_SP, 'f1_SACC': f1_SACC, 'f1_FIX': f1_FIX})
815 | print('Loaded model from', MODEL_PATH + '/Conv_sample_windows_epochs_{}_without_{}.h5'.format(args.initial_epoch, video_name))
816 | else:
817 | model = create_model(num_classes=num_classes, batch_size=args.batch_size,
818 | train_data_shape=train_X.shape,
819 | dropout_rate=0.3,
820 | padding_mode=args.conv_padding_mode,
821 | num_conv_layers=args.num_conv, conv_filter_counts=args.conv_units,
822 | num_dense_layers=args.num_dense, dense_units_count=args.dense_units,
823 | num_blstm_layers=args.num_blstm, blstm_unit_counts=args.blstm_units,
824 | unroll_blstm=False,
825 | no_bidirectional=args.no_bidirectional)
826 | else:
827 | model = keras.models.load_model(model_fname, custom_objects={'f1_SP': f1_SP,
828 | 'f1_SACC': f1_SACC,
829 | 'f1_FIX': f1_FIX})
830 | print('Skipped training, loaded model from', model_fname)
831 |
832 | if args.dry_run:
833 | return model
834 |
835 | # need to run training now?
836 | if TRAINING_MODE and not ALREADY_TRAINED:
837 | # Store model training history. Make sure (again?) that training is always performed with
838 | # @num_training_samples sequences
839 | assert train_X.shape[0] == train_Y.shape[0]
840 | assert train_X.shape[0] == args.num_training_samples, 'Not enough training samples (might need to ' \
841 | 'increase --overlap, decrease --window, or ' \
842 | 'decrease --num-training-samples in the command ' \
843 | 'line arguments.'
844 |
845 | training_histories.append(model.fit(train_X, train_Y,
846 | epochs=args.num_epochs,
847 | batch_size=args.batch_size,
848 | shuffle=True,
849 | callbacks=callbacks_list,
850 | validation_split=0.1,
851 | verbose=1,
852 | initial_epoch=args.initial_epoch))
853 | model.save(model_fname)
854 | # Completed a cross-validation fold with actual training
855 | num_training_runs += 1
856 |
857 | # Test model on the left-out cross-validation group (@i is the index corresponding to the test set on this run).
858 | # @raw will contain predicated class probabilities as well as true labels.
859 | # @processed will contain performance statistics of this fold
860 | raw, processed = evaluate_test(model, data_X[i], data_Y_one_hot[i],
861 | keys_to_subtract_start_indices=keys_to_subtract_start_indices,
862 | padding_features=padding_features,
863 | split_by_items=args.final_run) # keep sequences of different observers separate,
864 | # if we still need to write output .arff files
865 |
866 | # store all results (raw and processed)
867 | for k in list(raw_results.keys()):
868 | if args.final_run:
869 | # On the "final" run, @raw[k] is split into several sequences each, need to concatenate here
870 | # (in order to maintain the same format of the .pickle files)
871 | raw_results[k].append(np.concatenate(raw[k], axis=0))
872 | else:
873 | raw_results[k].append(raw[k])
874 | for k in list(results.keys()):
875 | results[k].append(processed[k])
876 |
877 | print('Evaluating for video', video_name)
878 | for stat_name in ['FIX', 'SACC', 'SP', 'NOISE']:
879 | print('F1-{}'.format(stat_name), results['F1-{}'.format(stat_name)][-1])
880 |
881 | if args.final_run and args.output_folder is not None:
882 | if args.output_folder == 'auto':
883 | args.output_folder = OUT_PATH
884 | print('Creating the detection outputs in', args.output_folder)
885 | # Generate actual ARFF outputs
886 | # Iterate through file names, original objects (from input .arff's), ground truth labels,
887 | # and predicted labels:
888 | for source_fname, source_obj, labels_true, labels_pred in \
889 | zip_equal(source_fnames[i], source_objs[i],
890 | raw['true'], raw['pred']):
891 | full_folder, suffix = os.path.split(source_fname)
892 | folder_name = os.path.split(full_folder)[1] # subfolder name = video name
893 | out_fname = '{}/{}/{}'.format(args.output_folder, folder_name, suffix)
894 |
895 | # create folders that might not exist yet
896 | if not os.path.exists(args.output_folder):
897 | os.mkdir(args.output_folder)
898 | if not os.path.exists('{}/{}'.format(args.output_folder, folder_name)):
899 | os.mkdir('{}/{}'.format(args.output_folder, folder_name))
900 |
901 | # get labels from probabilities for each label
902 | labels_true = np.argmax(labels_true, axis=-1)
903 | labels_pred = np.argmax(labels_pred, axis=-1)
904 | known_class_mask = labels_true != 0 # in case there were some unassigned labels in the ground truth
905 | labels_true = labels_true[known_class_mask]
906 | labels_pred = labels_pred[known_class_mask]
907 |
908 | # sanity check: "true" labels must match the "handlabeller_final" column of the input .arff files
909 | assert (labels_true == source_obj['data']['handlabeller_final']).all()
910 |
911 | # add a column containing predicted labels
912 | source_obj = ArffHelper.add_column(source_obj,
913 | name=sp_processor.EM_TYPE_ATTRIBUTE_NAME,
914 | dtype=sp_processor.EM_TYPE_ARFF_DATA_TYPE,
915 | default_value=sp_processor.EM_TYPE_DEFAULT_VALUE)
916 | # fill in with categorical values instead of numerical ones
917 | # (use @CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE for conversion)
918 | source_obj['data'][sp_processor.EM_TYPE_ATTRIBUTE_NAME] = \
919 | [CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE[x] for x in labels_pred]
920 |
921 | ArffHelper.dump(source_obj, open(out_fname, 'w'))
922 |
923 | if args.run_once and num_training_runs >= 1:
924 | break
925 |
926 | # get statistics over all splits that are already processed
927 | if not ALREADY_TRAINED or args.final_run:
928 |
929 | raw_results['true'] = np.concatenate(raw_results['true'])
930 | raw_results['pred'] = np.concatenate(raw_results['pred'])
931 |
932 | mask = np.argmax(raw_results['true'], axis=-1) != 0
933 | print('Found', np.logical_not(mask).sum(), 'UNKNOWN samples in the raw ``true\'\' predictions ' \
934 | '(including the artificially padded parts of the last windows ' \
935 | 'in each sequence, in order to match window width)')
936 |
937 | print(raw_results['true'].shape, raw_results['pred'].shape)
938 |
939 | unknown_class_mask = raw_results['true'][:, :, 0] == 1 # count "unknown"s in the one-hot-encoded true labels
940 |
941 | print('Overall classification scores per class:')
942 | for stat_i, stat_name in zip(list(range(1, 5)), ['FIX', 'SACC', 'SP', 'NOISE']):
943 | results['overall-F1-{}'.format(stat_name)] = K.eval(categorical_f1_score_for_class(raw_results['true'],
944 | raw_results['pred'],
945 | stat_i,
946 | 'float64'))
947 | print('F1-{}'.format(stat_name), results['overall-F1-{}'.format(stat_name)])
948 |
949 | results['overall-acc'] = np.mean(np.argmax(raw_results['true'][np.logical_not(unknown_class_mask)], axis=-1) ==
950 | np.argmax(raw_results['pred'][np.logical_not(unknown_class_mask)], axis=-1))
951 |
952 | # how many samples did the network leave "UNKNOWN"
953 | results['overall-UNKNOWN-samples'] = (np.argmax(raw_results['pred'], axis=-1) == 0).sum()
954 | print('Sample left UNKNOWN:', results['overall-UNKNOWN-samples'], '(including the UNKNOWN samples matching ' \
955 | 'the window-padded ``true\'\' labels from ' \
956 | 'above)')
957 |
958 | # Run the full evaluation.
959 | # Need a downloaded and installed sp_tool package for this! See http://michaeldorr.de/smoothpursuit/sp_tool.zip
960 | # See README for more information on how to do this.
961 | if args.final_run and args.output_folder:
962 | print('Running sp_tool eval --> {out_dir}/eval.json'.format(out_dir=args.output_folder))
963 | cmd = 'python {sp_tool_dir}/examples/evaluate_on_gazecom.py ' \
964 | '--in "{out_dir}" ' \
965 | '--hand "{gt_dir}" ' \
966 | '--pretty > ' \
967 | '"{out_dir}/eval.json"'.format(sp_tool_dir=args.sp_tool_folder,
968 | out_dir=args.output_folder,
969 | gt_dir=args.ground_truth_folder)
970 | print('Running command:\n', cmd)
971 |
972 | if os.path.exists('{}/examples/'.format(args.sp_tool_folder)):
973 | cmd_res = os.system(cmd)
974 | if cmd_res != 0:
975 | print('Something failed during the sp_tool evaluation run. Check the command above and run it ' \
976 | 'manually, if necessary! Make sure you also *installed* the sp_tool framework, not just ' \
977 | 'downloaded it (see sp_tool/README for details). Also, make sure both the output folder ' \
978 | '"{}" and the ground truth folder "{}" exist (e.g. were extracted from the respective ' \
979 | 'archives).'.format(args.output_folder, args.ground_truth_folder))
980 |
981 | else:
982 | print('\n\nCould not run final evaluation! sp_tool folder could not be found in', args.sp_tool_folder)
983 | print('Pass the --sp-tool-folder argument that points to the correct location (relative or absolute) ' \
984 | 'of the sp_tool folder, or download the full deep_eye_movement_classification.zip archive from' \
985 | 'http://michaeldorr.de/smoothpursuit again.')
986 |
987 |
988 | def parse_args(dry_run=False):
989 | """
990 | Parse command line arguments, or just create the parser (if @dry_run == True)
991 | :param dry_run: if True, will return the argument parser itself, and not the parsed arguments
992 | :return: either parser arguments' Namespace object (if @dry_run == False),
993 | or the argument parser itself (if @dry_run == True)
994 | """
995 | parser = ArgumentParser('Sequence-to-sequence eye movement classification')
996 | parser.add_argument('--final-run', '--final', '-f', dest='final_run', action='store_true',
997 | help='Final run with testing only, but on full data (not "clean" data).')
998 | parser.add_argument('--folder', '--output-folder', dest='output_folder', default=None,
999 | help='Only for --final-run: write prediction results as ARFF files here.'
1000 | 'Can be set to "auto" to select automatically.')
1001 | parser.add_argument('--model-root-path', default='data/models',
1002 | help='The path which will contain all trained models. If you are running the testing, set this '
1003 | 'argument to the same value as was used during training, so that the models can be '
1004 | 'automatically detected.')
1005 |
1006 | parser.add_argument('--feature-files-folder', '--feature-folder', '--feat-folder',
1007 | default='data/inputs/GazeCom_all_features/',
1008 | help='Folder containing the files with already extracted features.')
1009 | parser.add_argument('--ground-truth-folder', '--gt-folder', default='data/inputs/GazeCom_ground_truth/',
1010 | help='Folder containing the ground truth.')
1011 |
1012 | parser.add_argument('--run-once', '--once', '-o', dest='run_once', action='store_true',
1013 | help='Run one step of the LOO-CV run-through and exit (helps with memory consumption,'
1014 | 'then run manually 18 times for GazeCom, for 18 videos).')
1015 | parser.add_argument('--run-once-video', default=None,
1016 | help='Run one step of the LOO-CV run-through on the video with *this* name and exit. '
1017 | 'Used for partial testing of the models.')
1018 |
1019 | parser.add_argument('--batch-size', dest='batch_size', default=5000, type=int,
1020 | help='Batch size for training')
1021 | parser.add_argument('--num-epochs', '--epochs', dest='num_epochs', default=1000, type=int,
1022 | help='Number of epochs')
1023 | parser.add_argument('--initial-epoch', dest='initial_epoch', default=0, type=int,
1024 | help='Start training from this epoch')
1025 | parser.add_argument('--training-samples', dest='num_training_samples', default=50000, type=int,
1026 | help='Number of training samples. The default value is appropriate for windows of 65 samples, '
1027 | 'no overlap between the windows. If window size is increased, need to set --overlap to '
1028 | 'something greater than zero, ideally - maintain a similar number of windows. For example,'
1029 | ' if we use windows of size 257, we set overlap to (257-65) = 192, so that there would be '
1030 | 'as many windows as with a window size of 65, but without overlap.\n\n'
1031 | 'If you decide to increase the number of training samples, you will most likely have to '
1032 | 'adjust --window-size and --overlap values!')
1033 | parser.add_argument('--window-size', '--window', dest='window_size', default=65, type=int,
1034 | help='Window size for classifying')
1035 | parser.add_argument('--window-overlap', '--overlap', dest='overlap', default=0, type=int,
1036 | help='Windows overlap for training data generation')
1037 |
1038 | parser.add_argument('--model-name', '--model', dest='model_name', default=None,
1039 | help='Model name. This allows for naming your models as you wish, BUT '
1040 | 'it will override the model and feature descriptors that are included in the '
1041 | 'automatically generated model names, so use with caution!')
1042 |
1043 | parser.add_argument('--features', '--feat', choices=['movement', # = speed + direction + acceleration
1044 | 'speed', 'acc', 'direction',
1045 | 'xy'],
1046 | nargs='+', default=['speed', 'direction'],
1047 | help='All of the features that are to be used, can be listed without separators, e.g. '
1048 | '"--features speed direction". "acc" stands for acceleration; "movement" is a combination '
1049 | 'of all movement features -- speed, direction, acceleration.')
1050 |
1051 | parser.add_argument('--num-feature-scales', type=int, default=5,
1052 | help='Number of temporal scales for speed/direction/acceleration features (max is 5, min is 1).'
1053 | ' Recommended to leave default.'
1054 | ' The actual scales are 1, 2, 4, 8, and 16 samples, and the first @num_feature_scales '
1055 | 'from this list will be used.')
1056 |
1057 | parser.add_argument('--num-conv', default=3, type=int,
1058 | help='Number of convolutional layers before dense layers and BLSTM')
1059 | parser.add_argument('--num-dense', default=1, type=int,
1060 | help='Number of dense layers before BLSTM')
1061 | parser.add_argument('--num-blstm', default=1, type=int,
1062 | help='Number of BLSTM layers')
1063 |
1064 | parser.add_argument('--conv-padding-mode', default='valid', choices=['valid', 'same'],
1065 | help='Conv1D padding type (applied to all convolutional layers)')
1066 |
1067 | parser.add_argument('--conv-units', nargs='+', default=[32, 16, 8], type=int,
1068 | help='Number of filters in respective 1D convolutional layers. If not enough is provided '
1069 | 'for all the layers, the last layer\'s number of filters will be re-used. '
1070 | 'Should pass as, for example, "--conv-units 32 16 8 4".')
1071 | parser.add_argument('--dense-units', nargs='+', default=[32], type=int,
1072 | help='Number of units in the dense layer (before BLSTM). If not enough is provided '
1073 | 'for all the layers, the last layer\'s number of units will be re-used.'
1074 | 'Should pass as, for example, "--dense-units 32 32".')
1075 | parser.add_argument('--blstm-units', nargs='+', default=[16], type=int,
1076 | help='Number of units in BLSTM layers. If not enough is provided '
1077 | 'for all the layers, the last layer\'s number of units will be re-used.'
1078 | 'Should pass as, for example, "--blstm-units 16 16 16".')
1079 | parser.add_argument('--no-bidirectional', '--no-bi', action='store_true',
1080 | help='Use conventional LSTMs, no bi-directional wrappers.')
1081 |
1082 | parser.add_argument('--dry-run', action='store_true',
1083 | help='Do not train or test anything, just create a model and show the architecture and '
1084 | 'number of trainable parameters.')
1085 |
1086 | parser.add_argument('--sp-tool-folder', default='../sp_tool/',
1087 | help='Folder containing the sp_tool framework. Can be downloaded and installed from '
1088 | 'http://michaeldorr.de/smoothpursuit/sp_tool.zip as a stand-alone package, or as '
1089 | 'part of the deep_eye_movement_classification.zip archive via '
1090 | 'http://michaeldorr.de/smoothpursuit.')
1091 |
1092 | if dry_run:
1093 | return parser
1094 |
1095 | args = parser.parse_args()
1096 | if 'movement' in args.features:
1097 | args.features.remove('movement')
1098 | args.features += ['speed', 'acc', 'direction']
1099 | args.features = list(sorted(set(args.features)))
1100 |
1101 | if not 1 <= args.num_feature_scales <= 5:
1102 | raise ValueError('--num-feature-scales can be between 1 and 5')
1103 |
1104 | if len(args.conv_units) < args.num_conv:
1105 | args.conv_units += [args.conv_units[-1]] * (args.num_conv - len(args.conv_units))
1106 | warnings.warn('Not enough --conv-units passed, repeating the last one. Resulting filter '
1107 | 'counts: {}'.format(args.conv_units))
1108 | if len(args.dense_units) < args.num_dense:
1109 | args.dense_units += [args.dense_units[-1]] * (args.num_dense - len(args.dense_units))
1110 | warnings.warn('Not enough --dense-units passed, repeating the last one. Resulting filter '
1111 | 'counts: {}'.format(args.dense_units))
1112 | if len(args.blstm_units) < args.num_blstm:
1113 | args.blstm_units += [args.blstm_units[-1]] * (args.num_blstm - len(args.blstm_units))
1114 | warnings.warn('Not enough --blstm-units passed, repeating the last one. Resulting unit '
1115 | 'counts: {}'.format(args.blstm_units))
1116 |
1117 | return args
1118 |
1119 |
1120 | def __main__():
1121 | args = parse_args()
1122 | run(args)
1123 |
1124 | if __name__ == '__main__':
1125 | __main__()
1126 |
--------------------------------------------------------------------------------
/blstm_model_run.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import sys
4 | import glob
5 | import tempfile
6 | import numpy as np
7 |
8 | import keras
9 |
10 | from sp_tool.arff_helper import ArffHelper
11 | import sp_tool.util as sp_util
12 | from sp_tool.evaluate import CORRESPONDENCE_TO_HAND_LABELLING_VALUES
13 | from sp_tool import recording_processor as sp_processor
14 |
15 | import blstm_model
16 | from blstm_model import zip_equal
17 |
18 |
19 | def run(args):
20 | """
21 | Run prediction for a trained model on a set of .arff files (with features already extracted).
22 | See feature_extraction folder for the code to compute appropriate features.
23 | :param args: command line arguments
24 | :return: a list of tuples (corresponding to all processed files) that consist of
25 | - the path to an outputted file
26 | - predicted per-class probabilities
27 | """
28 | subfolders_and_fnames = find_all_subfolder_prefixes_and_input_files(args)
29 | out_fnames = get_corresponding_output_paths(subfolders_and_fnames, args)
30 |
31 | print('Processing {} file(s) from "{}" into "{}"'.format(len(out_fnames),
32 | args.input,
33 | args.output))
34 |
35 | arff_objects = [ArffHelper.load(open(fname)) for _, fname in subfolders_and_fnames]
36 |
37 | keys_to_keep = blstm_model.get_arff_attributes_to_keep(args)
38 | print('Will look for the following keys in all .arff files: {}. ' \
39 | 'If any of these are missing, an error will follow!'.format(keys_to_keep))
40 | all_features = [get_features_columns(obj, args) for obj in arff_objects]
41 |
42 | model = keras.models.load_model(args.model_path,
43 | custom_objects={'f1_SP': blstm_model.f1_SP,
44 | 'f1_SACC': blstm_model.f1_SACC,
45 | 'f1_FIX': blstm_model.f1_FIX})
46 |
47 | # Guess the padding size from model input and output size
48 | window_length = model.output_shape[1] # (batch size, window size, number of classes)
49 | padded_window_shape = model.input_shape[1] # (batch size, window size, number of features)
50 | padding_features = (padded_window_shape - window_length) // 2
51 | print('Will pad the feature sequences with {} samples on each side.'.format(padding_features))
52 |
53 | keys_to_subtract_start = sorted({'x', 'y'}.intersection(keys_to_keep))
54 | if len(keys_to_subtract_start) > 0:
55 | print('Will subtract the starting values of the following features:', keys_to_subtract_start)
56 | keys_to_subtract_start_indices = [i for i, key in enumerate(keys_to_keep) if key in keys_to_subtract_start]
57 |
58 | predictions, _ = blstm_model.evaluate_test(model=model,
59 | X=all_features,
60 | y=None, # no ground truth available or needed
61 | keys_to_subtract_start_indices=keys_to_subtract_start_indices,
62 | correct_for_unknown_class=False,
63 | padding_features=padding_features,
64 | split_by_items=True)
65 |
66 | CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE = {v: k for k, v in
67 | CORRESPONDENCE_TO_HAND_LABELLING_VALUES.items()}
68 | print('Class names:', CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE)
69 |
70 | for original_obj, out_fname, predicted_labels in zip_equal(arff_objects, out_fnames, predictions['pred']):
71 | # Create folders that might not exist yet
72 | containing_folder = os.path.split(out_fname)[0]
73 | if not os.path.exists(containing_folder):
74 | os.makedirs(containing_folder)
75 |
76 | # Get labels from probabilities for each label
77 | labels_pred = np.argmax(predicted_labels, axis=-1)
78 | # We get outputs as windows of labels, so now need to assemble one whole sequence.
79 | # Also need to cut the result, since it contains only whole windows of data and was respectively mirror-padded
80 | labels_pred = np.concatenate(labels_pred)[:original_obj['data'].shape[0]]
81 |
82 | # Add a column containing predicted labels
83 | original_obj = ArffHelper.add_column(original_obj,
84 | name=sp_processor.EM_TYPE_ATTRIBUTE_NAME,
85 | dtype=sp_processor.EM_TYPE_ARFF_DATA_TYPE,
86 | default_value=sp_processor.EM_TYPE_DEFAULT_VALUE)
87 | # Fill in with categorical values instead of numerical ones
88 | # (use @CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE for conversion)
89 | original_obj['data'][sp_processor.EM_TYPE_ATTRIBUTE_NAME] = \
90 | [CORRESPONDENCE_TO_HAND_LABELLING_VALUES_REVERSE[x] for x in labels_pred]
91 |
92 | ArffHelper.dump(original_obj, open(out_fname, 'w'))
93 |
94 | print('Prediction and file operations finished, check {} for outputs!'.format(args.output))
95 |
96 | return zip_equal(out_fnames, predictions['pred'])
97 |
98 |
99 | def parse_args():
100 | # Will keep most of the arguments, but suppress others
101 | base_parser = blstm_model.parse_args(dry_run=True)
102 | # Inherit all arguments, but retain the possibility to add the same args, but suppress them
103 | parser = argparse.ArgumentParser(parents=[base_parser], add_help=False, conflict_handler='resolve')
104 |
105 | # List all arguments (as lists of all ways to address each) that are to be eradicated
106 | args_to_suppress = [
107 | ['--model-name', '--model'], # will add a more intuitive --model-path argument below
108 | # no need for the following when training is completed already
109 | ['--initial-epoch'],
110 | ['--batch-size'],
111 | ['--run-once', '--once', '-o'],
112 | ['--run-once-video'],
113 | ['--ground-truth-folder', '--gt-folder'], # no need for ground truth
114 | ['--final-run', '--final', '-f'], # it's always a "final" run here
115 | ['--folder', '--output-folder'], # will override
116 | ['--training-samples'],
117 | ['--sp-tool-folder']
118 | ]
119 |
120 | for arg_group in args_to_suppress:
121 | parser.add_argument(*arg_group, help=argparse.SUPPRESS)
122 |
123 | parser.add_argument('--input', '--in', required=True,
124 | help='Path to input data. Can be either a single .arff file, or a whole directory. '
125 | 'In the latter case, this directory will be scanned for .arff files, and all of them will '
126 | 'be used as inputs, generating corresponding labelled files.')
127 |
128 | # rewrite the help
129 | parser.add_argument('--output', '--output-folder', '--out', dest='output', default=None,
130 | help='Write prediction results as ARFF file(s) here. Will mimic the structure of the --input '
131 | 'folder, or just create a single file, if --input itself points to an .arff file. '
132 | 'Can be a path to the desired output .arff file, in case --input is also just one file. '
133 | 'If not provided, will create a temporary folder and write the outputs there.')
134 |
135 | parser.add_argument('--model-path', '--model', default=None,
136 | help='Path to a particular model (an .h5 file), which is to be used, or a folder containing '
137 | 'all 18 models that are trained in the Leave-One-Video-Out cross-validation procedure '
138 | 'on GazeCom. If this argument is '
139 | 'provided, it overrides all the architecture- and model-defining parameters. The '
140 | 'provided .h5 file will be loaded instead. \n\nIf --model-path is not provided, will '
141 | 'generate a model descriptor from architecture parameters and so on, and look for it '
142 | 'in the respective subfolder of ``data/models/''. Will then (or if --model-path contains '
143 | 'a path to a folder, and not to an .h5 file) take the model that was '
144 | 'trained on all data except for `bridge_1`, since this video has no "true" smooth '
145 | 'pursuit, so we will this way maximise the amount of this relatively rare class in the '
146 | 'used training set.')
147 |
148 | args = parser.parse_args()
149 |
150 | if args.model_path is None:
151 | model_descriptor = blstm_model.get_full_model_descriptor(args)
152 | args.model_path = 'data/models/LOO_{descr}/'.format(descr=model_descriptor)
153 |
154 | # If it is a path to a directory, find the model trained for the ``bridge_1'' clip.
155 | # Otherwise, we just assume that the path points to a model file.
156 | if os.path.isdir(args.model_path):
157 | all_model_candidates = sorted(glob.glob('{}/*_without_bridge_1*.h5'.format(args.model_path)))
158 | if len(all_model_candidates) == 0:
159 | raise ValueError('No model in the "{dir}" folder has ``without_bride_1\'\' in its name. Either pass '
160 | 'a path to an exact .h5 model file in --model-path, or make sure you have the right model '
161 | 'in the aforementioned folder.'.format(dir=args.model_path))
162 | elif len(all_model_candidates) > 1:
163 | raise ValueError('More than one model with ``without_bride_1\'\' in its name has been found in the "{dir}" '
164 | 'folder: {candidates}. Either pass a path to an exact .h5 model file in --model-path, '
165 | 'or make sure you have only one model trained without the clip ``bridge_1\'\' in the '
166 | 'aforementioned folder.'.format(dir=args.model_path,
167 | candidates=all_model_candidates))
168 | args.model_path = all_model_candidates[0] # since there has to be just one
169 |
170 | return args
171 |
172 |
173 | def find_all_subfolder_prefixes_and_input_files(args):
174 | """
175 | Extract a matching set of paths to .arff files and additional folders between the --input folder and the files
176 | themselves (so that we will be able to replicate the structure later on)
177 | :param args: command line arguments
178 | :return: a list of tuples, where the first element is the sub-folder prefix and the second one is the full path
179 | to each .arff file
180 | """
181 | if os.path.isfile(args.input):
182 | return [('', args.input)]
183 | assert os.path.isdir(args.input), '--input is neither a file nor a folder'
184 |
185 | res = []
186 | for dirpath, dirnames, filenames in os.walk(args.input):
187 | filenames = [x for x in filenames if x.lower().endswith('.arff')]
188 | if filenames:
189 | dirpath_suffix = dirpath[len(args.input):].strip('/')
190 | res += [(dirpath_suffix, os.path.join(dirpath, fname)) for fname in filenames]
191 | return res
192 |
193 |
194 | def get_corresponding_output_paths(subfolders_and_full_input_filenames, args):
195 | """
196 | Create a list that will contain output paths for all the @subfolders_and_full_input_filenames
197 | (the output of find_all_subfolder_prefixes_and_input_files() function) in the output folder.
198 | :param subfolders_and_full_input_filenames: subfolder prefixes,
199 | returned by find_all_subfolder_prefixes_and_input_files()
200 | :param args: command line arguments
201 | :return:
202 | """
203 | if args.output is None:
204 | args.output = tempfile.mkdtemp(prefix='blstm_model_output_')
205 | print('No --output was provided, creating a folder in', args.output, file=sys.stderr)
206 |
207 | if args.output.lower().endswith('.arff'):
208 | assert len(subfolders_and_full_input_filenames) == 1, 'If --output is just one file, cannot have more than ' \
209 | 'one input file! Consider providing a folder as the ' \
210 | '--output.'
211 | return [args.output]
212 |
213 | res = []
214 | for subfolder, full_name in subfolders_and_full_input_filenames:
215 | res.append(os.path.join(args.output, subfolder, os.path.split(full_name)[-1]))
216 | return res
217 |
218 |
219 | def get_features_columns(arff_obj, args):
220 | """
221 | Extracting features from the .arff file (reading the file, getting the relevant columns
222 | :param arff_obj: a loaded .arff file
223 | :param args: command line arguments
224 | :return:
225 | """
226 | keys_to_keep = blstm_model.get_arff_attributes_to_keep(args)
227 |
228 | keys_to_convert_to_degrees = ['x', 'y'] + [k for k in keys_to_keep if 'speed_' in k or 'acceleration_' in k]
229 | keys_to_convert_to_degrees = sorted(set(keys_to_convert_to_degrees).intersection(keys_to_keep))
230 | # Conversion is carried out by dividing by pixels-per-degree value (PPD)
231 | if get_features_columns.run_count == 0:
232 | if len(keys_to_convert_to_degrees) > 0:
233 | print('Will divide by PPD the following features', keys_to_convert_to_degrees)
234 | get_features_columns.run_count += 1
235 |
236 | # normalize coordinates in @o by dividing by @ppd_f -- the pixels-per-degree value of the @arff_obj
237 | ppd_f = sp_util.calculate_ppd(arff_obj)
238 | for k in keys_to_convert_to_degrees:
239 | arff_obj['data'][k] /= ppd_f
240 |
241 | # add to respective data sets (only the features to be used and the true labels)
242 | return np.hstack([np.reshape(arff_obj['data'][key], (-1, 1)) for key in keys_to_keep]).astype(np.float64)
243 |
244 |
245 | get_features_columns.run_count = 0
246 |
247 | if __name__ == '__main__':
248 | run(parse_args())
249 |
--------------------------------------------------------------------------------
/example_data/README.md:
--------------------------------------------------------------------------------
1 | The .arff files correspond to the input file `GazeCom/holsten_gaze/GGM_holsten_gate.arff` (all input files can be downloaded [here](https://drive.google.com/drive/folders/1SPGTwUKnvZCUFJO05CnYTqakv-Akdth-?usp=sharing)).
2 |
3 | The model corresponds to the "final architecture" (descriptor "4xvC@(32, 16, 8, 8)_0xD@()_2xB@(16, 16)",
4 | speed and direction used as features, context size 257 samples), which was trained **without** the data of the `holsten_gate` video (more information on Leave-One-Video-Out cross-validation in the [corresponding paper](https://rdcu.be/bbMo3)).
5 |
6 | You can reproduce this output example by running
7 |
8 | > python blstm_model_run.py --feat speed direction --model example_data/model.h5 --in example_data/features.arff --out example_data/output_reproduced.arff
9 |
--------------------------------------------------------------------------------
/example_data/model.h5:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MikhailStartsev/deep_em_classifier/9a345a37aab47ac6780ce0d4b5798cc15291c75b/example_data/model.h5
--------------------------------------------------------------------------------
/feature_extraction/AnnotateData.m:
--------------------------------------------------------------------------------
1 | % AnnotateData.m
2 | %
3 | % This function gets ARFF data as input and annotates them with with speed, acceleration, and direction
4 | % attributes at window sizes of 1, 2, 4, 8 and 16 samples width
5 | %
6 | % input:
7 | % arffFile - file containing arff data
8 | % outputFile - fiel to store data
9 | function AnnotateData(arffFile, outputFile)
10 | addpath('arff_utils')
11 |
12 | windowsSize = [1 2 4 8 16];
13 |
14 | [data, metadata, attributes, relation, comments] = LoadArff(arffFile);
15 | comments{end+1} = 'The number after speed, direction denotes the step size that was used for the calculation.';
16 | comments{end+1} = 'Acceleration was calculated between two adjacent samples of the already low pass filtered velocity';
17 |
18 | for i=1:length(windowsSize)
19 | step = windowsSize(i);
20 |
21 | [speed, direction] = GetVelocity(data, attributes, step);
22 |
23 | speedAttName = ['speed_' num2str(step)];
24 | [data, attributes] = AddAttArff(data, attributes, speed, speedAttName, 'numeric');
25 |
26 | dirAttName = ['direction_' num2str(step)];
27 | [data, attributes] = AddAttArff(data, attributes, direction, dirAttName, 'numeric');
28 |
29 | acceleration = GetAcceleration(data, attributes, speedAttName, dirAttName, 1);
30 | accAttName = ['acceleration_' num2str(step)];
31 | [data, attributes] = AddAttArff(data, attributes, acceleration, accAttName, 'numeric');
32 | end
33 |
34 | SaveArff(outputFile, data, metadata, attributes, relation, comments);
35 | end
36 |
--------------------------------------------------------------------------------
/feature_extraction/AnnotateDataAll.m:
--------------------------------------------------------------------------------
1 | % AnnotateDataAll.m
2 | %
3 | % This function annotates all the gazsamples with velocity and acceleration data
4 | % input:
5 | % arffBasepath - folder containing the input .arff files (with gaze data)
6 | % outBasepath - path to the folder where the resulting extracted
7 | % feature .arff files will be written; does not have to
8 | % exist already.
9 | % Both arguments can be omitted, will then default to the folders intended
10 | % to be used with the repository (see global README).
11 | function AnnotateDataAll(arffBasepath, outBasepath)
12 | if nargin < 1
13 | arffBasepath = '../data/inputs/GazeCom_ground_truth';
14 | end
15 | if nargin < 2
16 | outBasepath = '../data/inputs/GazeCom_features'; % already generated features should appear in ../data/inputs/GazeCom_all_features
17 | end
18 |
19 | dirList = glob([arffBasepath '/*']);
20 | for i=1:size(dirList,1)
21 | pos = strfind(dirList{i}, '/');
22 | name = dirList{i}(pos(end-1)+1:pos(end)-1);
23 | outputDir = [outBasepath '/' name];
24 |
25 | if (exist(outputDir) ~= 7)
26 | mkdir(outputDir);
27 | end
28 |
29 | arffFiles = glob([arffBasepath '/' name '/*.arff']);
30 |
31 | for arffInd=1:size(arffFiles)
32 | arffFile = arffFiles{arffInd,1};
33 | [arffDir, arffName, ext] = fileparts(arffFile);
34 | outputFile = [outputDir '/' arffName '.arff'];
35 |
36 | disp(['Processing ' arffFile]);
37 | AnnotateData(arffFile, outputFile);
38 | end
39 | end
40 | end
41 |
--------------------------------------------------------------------------------
/feature_extraction/README.md:
--------------------------------------------------------------------------------
1 | To extract features from GazeCom data, run AnnotateDataAll.
2 | If some alterations are needed, make sure to change paths in that file (or pass appropriate arguments) not to override the existing .arff feature files!
3 | You can call this function as `AnnotateDataAll(path_to_input_arff_files_folder, path_to_resulting_features_folder)`.
4 |
5 | If you want to extract features from another data set, you can manually run the AnnotateData function in a loop for all
6 | the needed files (run as `AnnotateData('path/to/input/file.arff', 'path/to/output/file.arff')`).
7 |
8 | The .arff files to be processed need to have at least 4 columns:
9 | * x, the x coordinate of gaze location (in pixels, in case of GazeCom)
10 | * y, the x coordinate of gaze location (in pixels, in case of GazeCom)
11 | * time, the timestamp of the respective sample (in microseconds)
12 | * confidence the tracking confidence value (from the eye tracker).
13 |
14 | In the feature extraction scripts, the "acceptable" confidence level is set to 0.75:
15 | every sample with confidence value below this threshold will be discarded (features
16 | set to zero). In GazeCom, the confidence values are either 0 or 1 (0 meaning lost tracking,
17 | 1 meaning normal tracking). If your data does not have such information, set confidence
18 | to 1 for all samples.
19 |
20 |
21 | ### REQUIRED PACKAGES: ###
22 |
23 | - glob: Download the archive from https://www.mathworks.com/matlabcentral/fileexchange/40149-expand-wildcards-for-files-and-directory-names,
24 | extract the glob.m file and add its folder to your MATLAB path
25 | - add arff_utils to your MATLAB path if you want to use them outside the provided scripts (if you use AnnotateData, this folder will be automatically added to your path)
26 |
27 | ### ACKNOWLEDGEMENTS ###
28 |
29 | The MATLAB utilities for handling ARFF files (in arff_utils/) are from the https://web.gin.g-node.org/ioannis.agtzidis/matlab_utils repository, which contains
30 | many more functions (only the ones necessary for feature extraction scripts were copied here).
31 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/AddAttArff.m:
--------------------------------------------------------------------------------
1 | % function AddAttArff:
2 | %
3 | % This function adds the data and the name of the new attribute to the initial
4 | % data as a new column.
5 | %
6 | % input:
7 | % data - data of the initial arff file
8 | % attributes - attributes of the initial arff file
9 | % attData - attribute data to append at the data. When nominal attributes
10 | % are appended the attribute values should corespond to the enumeration
11 | % equivalent
12 | % attName - attribute name
13 | % attType - attribute type (Integer, Numeric or nominal in the form '{val1,val2}')
14 | %
15 | % output:
16 | % newData - data after addition of the new column
17 | % newAttributes - attributes containing the addition of the new attribute
18 |
19 | function [newData, newAttributes] = AddAttArff(data, attributes, attData, attName, attType)
20 | % are data and new attribute smae size
21 | assert(size(data,1)==size(attData,1), 'Provided attribute does not have same number of entries as initial data');
22 |
23 | % check if attribute already exists
24 | for i=1:size(attributes,1)
25 | if (strcmpi(attributes{i,1}, attName))
26 | error(['Attributes "' attName '" already exists. Cannot add it.']);
27 | end
28 | end
29 |
30 | % merge returned attributes
31 | newAttributes = attributes;
32 | index = size(attributes,1)+1;
33 | newAttributes{index,1} = attName;
34 | newAttributes{index,2} = attType;
35 |
36 | % concatenate attribute to the returned data
37 | newData = zeros(size(data,1), size(data,2)+1);
38 | newData(:,1:end-1) = data(:,:);
39 | newData(:,end) = attData(:);
40 | end
41 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/GetAcceleration.m:
--------------------------------------------------------------------------------
1 | % GetAcceleration.m
2 | %
3 | % This function calculates the acceleration from the precomputed velocity.
4 | %
5 | % input:
6 | % data - arff data
7 | % attributes - attributes describing the data
8 | % attSpeed - attribute name holding the speed
9 | % attDir - attribute name holding the speed direction
10 | % windowWidth - step size
11 | %
12 | % output:
13 | % acceleration - computed acceleration
14 |
15 | function acceleration = GetAcceleration(data, attributes, attSpeed, attDir, windowWidth)
16 | c_minConf = 0.75;
17 | step = ceil(windowWidth/2);
18 |
19 | acceleration = zeros(size(data,1),1);
20 |
21 | timeInd = GetAttPositionArff(attributes, 'time');
22 | confInd = GetAttPositionArff(attributes, 'confidence');
23 | speedInd = GetAttPositionArff(attributes, attSpeed);
24 | dirInd = GetAttPositionArff(attributes, attDir);
25 |
26 | for i=1:size(data,1)
27 | if (data(i,confInd) < c_minConf)
28 | continue;
29 | end
30 |
31 | % get initial interval
32 | if (step == windowWidth)
33 | startPos = i - step;
34 | endPos = i;
35 | else
36 | startPos = i -step;
37 | endPos = i + step;
38 | end
39 |
40 | % fine tune intervals
41 | if (startPos < 1 || data(startPos,confInd) < c_minConf)
42 | startPos = i;
43 | end
44 | if (endPos > size(data,1) || data(endPos,confInd) < c_minConf)
45 | endPos = i;
46 | end
47 |
48 | % invalid interval
49 | if (startPos == endPos)
50 | continue;
51 | end
52 |
53 | velStartX = data(startPos,speedInd)*cos(data(startPos,dirInd));
54 | velStartY = data(startPos,speedInd)*sin(data(startPos,dirInd));
55 |
56 | velEndX = data(endPos,speedInd)*cos(data(endPos,dirInd));
57 | velEndY = data(endPos,speedInd)*sin(data(endPos,dirInd));
58 |
59 | deltaT = (data(endPos,timeInd)-data(startPos,timeInd))/1000000;
60 |
61 | accX = (velEndX-velStartX)/deltaT;
62 | accY = (velEndY-velStartY)/deltaT;
63 |
64 | acceleration(i) = sqrt(accX^2 + accY^2);
65 | end
66 | end
67 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/GetAttPositionArff.m:
--------------------------------------------------------------------------------
1 | % function GetAttPositionArff:
2 | %
3 | % Gets a list of attributes returned from LoadArff and an attribute name to
4 | % search. If it finds the attribute returns its index otherwise it can raise
5 | % an error.
6 | %
7 | % input:
8 | % arffAttributes - attribute list returned from LoadArff
9 | % attribute - attribute to search
10 | % check - (optional) boolean to check if attribute exists. Default is true
11 | %
12 | % output:
13 | % attIndex - index attribute of the attribute in the list if it was found.
14 | % Returns 0 if it wasn't found
15 |
16 | function [attIndex] = GetAttPositionArff(arffAttributes, attribute, check)
17 | if (nargin < 3)
18 | check = true;
19 | end
20 | attIndex = 0;
21 |
22 | for i=1:size(arffAttributes,1)
23 | if (strcmpi(arffAttributes{i,1}, attribute) == 1)
24 | attIndex = i;
25 | end
26 | end
27 |
28 | % check index
29 | if (check)
30 | assert(attIndex>0, ['Attribute "' attribute '" not found']);
31 | end
32 | end
33 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/GetNomAttValue.m:
--------------------------------------------------------------------------------
1 | % GetNomAttValue.m
2 | %
3 | % This function returns the value of a nominal attribute in its correct form without
4 | % spaces.
5 | %
6 | % input:
7 | % attDatatype - the part that describes the attribute after its name
8 | %
9 | % output:
10 | % attValue - nominal attribute in its correct form
11 |
12 | function [attValue] = GetNomAttValue(attDatatype)
13 | openCurl = strfind(attDatatype, '{');
14 | closeCurl = strfind(attDatatype, '}');
15 |
16 | if (isempty(openCurl) && isempty(closeCurl))
17 | isNom = false;
18 | nominalMap = containers.Map;
19 | numericMap = containers.Map;
20 | return;
21 | end
22 |
23 | assert(length(openCurl) == 1, ['Invalid attribute datatype ' attDatatype]);
24 | assert(length(closeCurl) == 1, ['Invalid attribute datatype ' attDatatype]);
25 | attValue = attDatatype(openCurl:closeCurl);
26 |
27 | % remove spaces from nominal
28 | attValue = attValue(~isspace(attValue));
29 | end
30 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/GetVelocity.m:
--------------------------------------------------------------------------------
1 | % GetVelocity.m
2 | %
3 | % This funciton calcualtes the speed and direction for the given data and step size
4 | %
5 | % input:
6 | % data - arff data
7 | % attributes - attrbitues describing the data
8 | % windowWidth - step size
9 | %
10 | % output:
11 | % speed - vector containing speed in pixels per second
12 | % direction - direction of movement in rands from -pi to pi
13 |
14 | function [speed, direction] = GetVelocity(data, attributes, windowWidth)
15 | c_minConf = 0.75;
16 | step = ceil(windowWidth/2);
17 |
18 | speed = zeros(size(data,1),1);
19 | direction = zeros(size(data,1),1);
20 |
21 | timeInd = GetAttPositionArff(attributes, 'time');
22 | xInd = GetAttPositionArff(attributes, 'x');
23 | yInd = GetAttPositionArff(attributes, 'y');
24 | confInd = GetAttPositionArff(attributes, 'confidence');
25 |
26 | for i=1:size(data,1)
27 | if (data(i,confInd) < c_minConf)
28 | continue;
29 | end
30 |
31 | % get initial interval
32 | if (step == windowWidth)
33 | startPos = i - step;
34 | endPos = i;
35 | else
36 | startPos = i -step;
37 | endPos = i + step;
38 | end
39 |
40 | % fine tune intervals
41 | if (startPos < 1 || data(startPos,confInd) < c_minConf)
42 | startPos = i;
43 | end
44 | if (endPos > size(data,1) || data(endPos,confInd) < c_minConf)
45 | endPos = i;
46 | end
47 |
48 | % invalid interval
49 | if (startPos == endPos)
50 | continue;
51 | end
52 |
53 | ampl = sqrt((data(endPos,xInd)-data(startPos,xInd))^2 + (data(endPos,yInd)-data(startPos,yInd))^2);
54 | time = (data(endPos,timeInd) - data(startPos,timeInd))/1000000;
55 | speed(i) = ampl/time;
56 |
57 | direction(i) = atan2(data(endPos,yInd)-data(startPos,yInd), data(endPos,xInd)-data(startPos,xInd));
58 | end
59 | end
60 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/IsNomAttribute.m:
--------------------------------------------------------------------------------
1 | % IsNomAttribute.m
2 | %
3 | % This function checks if an attribute is of nominal type and returns true along
4 | % with nominal and numeric maps. Otherwise it returns false.
5 | %
6 | % input:
7 | % attDatatype - the part that describes the attribute after its name
8 | %
9 | % output:
10 | % isNom - boolean value denoting if nominal
11 | % nominalMap - mapping of nominal values to doubles as in an C++ enumeration
12 | % numericMap - mapping of doubles to nominal values
13 |
14 | function [isNom, nominalMap, numericMap] = IsNomAttribute(attDatatype)
15 | openCurl = strfind(attDatatype, '{');
16 | closeCurl = strfind(attDatatype, '}');
17 |
18 | if (isempty(openCurl) && isempty(closeCurl))
19 | isNom = false;
20 | nominalMap = containers.Map;
21 | numericMap = containers.Map;
22 | return;
23 | end
24 |
25 | assert(length(openCurl) == 1, ['Invalid attribute datatype ' attDatatype]);
26 | assert(length(closeCurl) == 1, ['Invalid attribute datatype ' attDatatype]);
27 | attDatatype = attDatatype(openCurl+1:closeCurl-1);
28 |
29 | % remove spaces from nominal
30 | attDatatype = attDatatype(~isspace(attDatatype));
31 |
32 | keys = split(attDatatype, ',');
33 | values = 0:length(keys)-1;
34 |
35 | nominalMap = containers.Map(keys, values);
36 |
37 | % convert to simple when we have single key. Otherwise the type is invalid for map creation
38 | if (length(keys) == 1)
39 | keys = string(keys);
40 | end
41 | numericMap = containers.Map(values, keys);
42 | isNom = true;
43 | end
44 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/LoadArff.m:
--------------------------------------------------------------------------------
1 | % LoadArff.m
2 | %
3 | % Thi funciton loads data from an ARFF file and returns the data, metadata,
4 | % attributes, relation and comments. All returned strings are lower case.
5 | %
6 | % input:
7 | % arffFile - path to ARFF file to read
8 | %
9 | % output:
10 | % data - data stored in the ARFF file
11 | % metadata - structure holding metadta in the form: metadata.{width_px, height_px, width_mm, height_mm, distance_mm} -1 if not available. Extra metadata are stored in metadata.extra, which is an nx2 cell array holding name-value pairs
12 | % attributes - nx2 cell array with attribute names and types, where n is the number of attributes
13 | % relation - relation described in ARFF
14 | % comments - nx1 cell array containing one comment line per cell
15 |
16 | function [data, metadata, attributes, relation, comments] = LoadArff(arffFile)
17 | % initialize data
18 | data = [];
19 | % initialize metadata
20 | metadata.width_px = -1;
21 | metadata.height_px = -1;
22 | metadata.width_mm = -1;
23 | metadata.height_mm = -1;
24 | metadata.distance_mm = -1;
25 | metadata.extra = {};
26 | attributes = {};
27 | relation = '';
28 | comments = {};
29 |
30 | % nominal attribute handling
31 | nomMat = logical([]);
32 | nomMaps = {};
33 |
34 | % read header
35 | numOfHeaderLines = 1;
36 | fid = fopen(arffFile, 'r');
37 | fline = fgetl(fid);
38 | while (ischar(fline))
39 | % split lines into words
40 | words = strsplit(fline,' ');
41 | % check for relation
42 | if (size(words,2)>1 && strcmpi(words{1,1},'@relation')==1)
43 | relation = lower(words{1,2});
44 | % check for width_px
45 | elseif (size(words,2)>2 && strcmpi(words{1,1},'%@metadata')==1 && strcmpi(words{1,2},'width_px')==1)
46 | metadata.width_px = str2num(words{1,3});
47 | % check for height_px
48 | elseif (size(words,2)>2 && strcmpi(words{1,1},'%@metadata')==1 && strcmpi(words{1,2},'height_px')==1)
49 | metadata.height_px = str2num(words{1,3});
50 | % check for width_mm
51 | elseif (size(words,2)>2 && strcmpi(words{1,1},'%@metadata')==1 && strcmpi(words{1,2},'width_mm')==1)
52 | metadata.width_mm = str2num(words{1,3});
53 | % check for height_mm
54 | elseif (size(words,2)>2 && strcmpi(words{1,1},'%@metadata')==1 && strcmpi(words{1,2},'height_mm')==1)
55 | metadata.height_mm = str2num(words{1,3});
56 | % check for distance_mm
57 | elseif (size(words,2)>2 && strcmpi(words{1,1},'%@metadata')==1 && strcmpi(words{1,2},'distance_mm')==1)
58 | metadata.distance_mm = str2num(words{1,3});
59 | % process the rest of the metadata
60 | elseif (size(words,2)>2 && strcmpi(words{1,1},'%@metadata')==1)
61 | pos = size(metadata.extra,1)+1;
62 | metadata.extra{pos,1} = words{1,2};
63 | metadata.extra{pos,2} = words{1,3};
64 | % check for attributes
65 | elseif (size(words,2)>2 && strcmpi(words{1,1},'@attribute')==1)
66 | index = size(attributes,1)+1;
67 | attributes{index,1} = lower(words{1,2});
68 | attributes{index,2} = words{1,3};
69 | [isNom, nominalMap] = IsNomAttribute(fline);
70 | nomMat = [nomMat; isNom];
71 | if (isNom)
72 | nomMaps = [nomMaps; {nominalMap}];
73 | attributes{index,2} = GetNomAttValue(fline);
74 | else
75 | nomMaps = [nomMaps; {[]}];
76 | end
77 | % check if it is a comment
78 | elseif (length(fline>0) && fline(1) == '%')
79 | comments{end+1} = fline;
80 | % check if data has been reached
81 | elseif (size(words,2)>0 && strcmpi(words{1,1},'@data')==1)
82 | break;
83 | end
84 |
85 | fline = fgetl(fid);
86 | numOfHeaderLines = numOfHeaderLines+1;
87 | end
88 |
89 | numAtts = size(attributes,1);
90 | readFormat = '';
91 | for ind=1:numAtts
92 | if (nomMat(ind))
93 | readFormat = [readFormat '%s '];
94 | else
95 | readFormat = [readFormat '%f '];
96 | end
97 | end
98 | lines = textscan(fid, readFormat, 'Delimiter', ',');
99 |
100 | nomIndices = find(nomMat);
101 | for nomInd=nomIndices'
102 | if (isempty(nomInd))
103 | break;
104 | end
105 |
106 | for ind=1:size(lines{1,nomInd},1)
107 | lines{1,nomInd}{ind} = nomMaps{nomInd,1}(lines{1,nomInd}{ind});
108 | end
109 | lines{1,nomInd} = cell2mat(lines{1,nomInd});
110 | end
111 |
112 | data = cell2mat(lines);
113 |
114 | fclose(fid);
115 | end
116 |
--------------------------------------------------------------------------------
/feature_extraction/arff_utils/SaveArff.m:
--------------------------------------------------------------------------------
1 | % SaveArff.m
2 | %
3 | % Function to save ARFF data to file.
4 | %
5 | % input:
6 | % arffFile - name of the file to save data
7 | % data - data to write in arff file
8 | % metadata - metadata struct in
9 | % attributes - nx2 cell array holding the attribute names
10 | % relation - relation described in the file
11 | % comments - (optional) nx1 cell array containing one comment line per cell
12 |
13 | function SaveArff(arffFile, data, metadata, attributes, relation, comments)
14 | if (nargin < 6)
15 | comments = {};
16 | end
17 | % check input
18 | assert(isfield(metadata,'width_px'), 'metadata should contain "width_px" field');
19 | assert(isfield(metadata,'height_px'), 'metadata should contain "height_px" field');
20 | assert(isfield(metadata,'width_mm'), 'metadata should contain "width_mm" field');
21 | assert(isfield(metadata,'height_mm'), 'metadata should contain "height_mm" field');
22 | assert(isfield(metadata,'distance_mm'), 'metadata should contain "distance_mm" field');
23 | assert(size(relation,2)>0, 'relation should not be empty');
24 | assert(size(attributes,1)==size(data,2), 'attribute number should be the same with data');
25 |
26 | % start writing
27 | fid = fopen(arffFile, 'w+');
28 |
29 | % write relation
30 | fprintf(fid, '@RELATION %s\n\n', relation);
31 |
32 | % write metadata
33 | fprintf(fid, '%%@METADATA width_px %d\n', metadata.width_px);
34 | fprintf(fid, '%%@METADATA height_px %d\n', metadata.height_px);
35 | fprintf(fid, '%%@METADATA width_mm %.2f\n', metadata.width_mm);
36 | fprintf(fid, '%%@METADATA height_mm %.2f\n', metadata.height_mm);
37 | fprintf(fid, '%%@METADATA distance_mm %.2f\n\n', metadata.distance_mm);
38 |
39 | % write metadata extras. Those are data that vary between experiments
40 | for i=1:size(metadata.extra,1)
41 | fprintf(fid, '%%@METADATA %s %s\n', metadata.extra{i,1}, metadata.extra{i,2});
42 | end
43 | % print an empty line
44 | fprintf(fid, '\n');
45 |
46 | % write attributes and get their type
47 | % 1 = integer
48 | % 2 = numeric
49 | % 3 = nominal
50 | % -1 = other
51 | numAtts = size(attributes,1);
52 | attType = -1*ones(numAtts,1);
53 | numMaps = cell(numAtts,1);
54 | for i=1:numAtts
55 | fprintf(fid, '@ATTRIBUTE %s %s\n', attributes{i,1}, attributes{i,2});
56 | [isNom, ~, numericMap] = IsNomAttribute(attributes{i,2});
57 |
58 | % get type
59 | if (strcmpi(attributes{i,2},'integer')==1)
60 | attType(i) = 1;
61 | elseif (strcmpi(attributes{i,2},'numeric')==1)
62 | attType(i) = 2;
63 | elseif (isNom)
64 | attType(i) = 3;
65 | numMaps{i,1} = numericMap;
66 | end
67 | end
68 |
69 | % write comments if they exist
70 | if (~isempty(comments))
71 | fprintf(fid, '\n');
72 | for i=1:length(comments)
73 | comment = comments{i};
74 | % check if % is the first character
75 | if (length(comment)>0 && comment(1)~='%')
76 | comment = ['%' comment];
77 | end
78 |
79 | fprintf(fid, '%s\n', comment);
80 | end
81 | end
82 |
83 | % write data keyword
84 | fprintf(fid,'\n@DATA\n');
85 |
86 | numEntries = size(data,1);
87 | % transpose data in order to allow one line writing because fprintf handles
88 | % matrices column wise when writing in file
89 | data = num2cell(data');
90 | nomIndices = find(attType==3);
91 | for nomInd=nomIndices'
92 | if (isempty(nomInd))
93 | break;
94 | end
95 |
96 | % convert numbers to nominal values
97 | for ind=1:numEntries
98 | data{nomInd, ind} = numMaps{nomInd,1}(data{nomInd, ind});
99 | end
100 | end
101 |
102 | writeFormat = '';
103 | for ind=1:numAtts
104 | if (attType(ind) == 1)
105 | writeFormat = [writeFormat '%d'];
106 | elseif (attType(ind) == 2)
107 | writeFormat = [writeFormat '%.2f'];
108 | elseif (attType(ind) == 3)
109 | writeFormat = [writeFormat '%s'];
110 | else
111 | error(['Attribute type "' num2str(attType(ind)) '" is not recognised']);
112 | end
113 |
114 | if (ind