├── .dockerignore
├── .gitignore
├── Dockerfile
├── LICENSE
├── Plots.ipynb
├── assets
└── architecture.svg
├── config
├── conf.yaml
├── features
│ ├── degree.yaml
│ ├── method_attributes.yaml
│ └── method_summary.yaml
└── logger
│ └── wandb.yaml
├── core
├── __init__.py
├── callbacks.py
├── data_module.py
├── dataset.py
├── model.py
└── utils.py
├── data
├── README.md
├── test_new.sha256
├── test_old.sha256
├── train_new.sha256
└── train_old.sha256
├── malware-learning.def
├── metadata
└── api.list
├── notebooks
├── 0-Preliminary Analysis.ipynb
├── 1-AGfeatures.ipynb
├── 2-GFeatures.ipynb
├── 3-APIFeatures.ipynb
├── 4-APIFeatures-Binary.ipynb
└── 5-AGfeatures-Binary.ipynb
├── readme.md
├── requirements.txt
├── scripts
├── __init__.py
├── plot_callgraph.py
├── process_dataset.py
└── split_dataset.py
└── train_model.py
/.dockerignore:
--------------------------------------------------------------------------------
1 | data/
2 | temp/
3 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | data/
2 | models/
3 | notebooks/lightning_logs
4 | .idea
5 | output*/
6 | metadata/
7 | # Dataset list
8 | *.list
9 | # Pickle
10 | *.pkl
11 | # Java files
12 | *.jar
13 | # Singularity Image
14 | *.sif
15 | # Nano saved
16 | *.save
17 | # Byte-compiled / optimized / DLL files
18 | __pycache__/
19 | *.py[cod]
20 | *$py.class
21 |
22 | # C extensions
23 | *.so
24 |
25 | # Distribution / packaging
26 | .Python
27 | build/
28 | develop-eggs/
29 | dist/
30 | downloads/
31 | eggs/
32 | .eggs/
33 | lib/
34 | lib64/
35 | parts/
36 | sdist/
37 | var/
38 | wheels/
39 | pip-wheel-metadata/
40 | share/python-wheels/
41 | *.egg-info/
42 | .installed.cfg
43 | *.egg
44 | MANIFEST
45 |
46 | # PyInstaller
47 | # Usually these files are written by a python script from a template
48 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
49 | *.manifest
50 | *.spec
51 |
52 | # Installer logs
53 | pip-log.txt
54 | pip-delete-this-directory.txt
55 |
56 | # Unit test / coverage reports
57 | htmlcov/
58 | .tox/
59 | .nox/
60 | .coverage
61 | .coverage.*
62 | .cache
63 | nosetests.xml
64 | coverage.xml
65 | *.cover
66 | *.py,cover
67 | .hypothesis/
68 | .pytest_cache/
69 |
70 | # Translations
71 | *.mo
72 | *.pot
73 |
74 | # Django stuff:
75 | *.log
76 | local_settings.py
77 | db.sqlite3
78 | db.sqlite3-journal
79 |
80 | # Flask stuff:
81 | instance/
82 | .webassets-cache
83 |
84 | # Scrapy stuff:
85 | .scrapy
86 |
87 | # Sphinx documentation
88 | docs/_build/
89 |
90 | # PyBuilder
91 | target/
92 |
93 | # Jupyter Notebook
94 | .ipynb_checkpoints
95 | *ipynb
96 |
97 | # IPython
98 | profile_default/
99 | ipython_config.py
100 |
101 | # pyenv
102 | .python-version
103 |
104 | # pipenv
105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
108 | # install all needed dependencies.
109 | #Pipfile.lock
110 |
111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
112 | __pypackages__/
113 |
114 | # Celery stuff
115 | celerybeat-schedule
116 | celerybeat.pid
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .env
123 | .venv
124 | env/
125 | venv/
126 | ENV/
127 | env.bak/
128 | venv.bak/
129 |
130 | # Spyder project settings
131 | .spyderproject
132 | .spyproject
133 |
134 | # Rope project settings
135 | .ropeproject
136 |
137 | # mkdocs documentation
138 | /site
139 |
140 | # mypy
141 | .mypy_cache/
142 | .dmypy.json
143 | dmypy.json
144 |
145 | # Pyre type checker
146 | .pyre/
--------------------------------------------------------------------------------
/Dockerfile:
--------------------------------------------------------------------------------
1 | FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel
2 | COPY requirements.txt /mnt/
3 | RUN apt-get update && apt-get install -y git graphviz graphviz-dev && rm -rf /var/lib/apt/lists/*
4 | RUN pip install -r /mnt/requirements.txt
5 | RUN git clone https://github.com/androguard/androguard.git && cd androguard && python setup.py install
6 | VOLUME /model
7 | WORKDIR /model
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | GNU GENERAL PUBLIC LICENSE
2 | Version 3, 29 June 2007
3 |
4 | Copyright (C) 2007 Free Software Foundation, Inc.
5 | Everyone is permitted to copy and distribute verbatim copies
6 | of this license document, but changing it is not allowed.
7 |
8 | Preamble
9 |
10 | The GNU General Public License is a free, copyleft license for
11 | software and other kinds of works.
12 |
13 | The licenses for most software and other practical works are designed
14 | to take away your freedom to share and change the works. By contrast,
15 | the GNU General Public License is intended to guarantee your freedom to
16 | share and change all versions of a program--to make sure it remains free
17 | software for all its users. We, the Free Software Foundation, use the
18 | GNU General Public License for most of our software; it applies also to
19 | any other work released this way by its authors. You can apply it to
20 | your programs, too.
21 |
22 | When we speak of free software, we are referring to freedom, not
23 | price. Our General Public Licenses are designed to make sure that you
24 | have the freedom to distribute copies of free software (and charge for
25 | them if you wish), that you receive source code or can get it if you
26 | want it, that you can change the software or use pieces of it in new
27 | free programs, and that you know you can do these things.
28 |
29 | To protect your rights, we need to prevent others from denying you
30 | these rights or asking you to surrender the rights. Therefore, you have
31 | certain responsibilities if you distribute copies of the software, or if
32 | you modify it: responsibilities to respect the freedom of others.
33 |
34 | For example, if you distribute copies of such a program, whether
35 | gratis or for a fee, you must pass on to the recipients the same
36 | freedoms that you received. You must make sure that they, too, receive
37 | or can get the source code. And you must show them these terms so they
38 | know their rights.
39 |
40 | Developers that use the GNU GPL protect your rights with two steps:
41 | (1) assert copyright on the software, and (2) offer you this License
42 | giving you legal permission to copy, distribute and/or modify it.
43 |
44 | For the developers' and authors' protection, the GPL clearly explains
45 | that there is no warranty for this free software. For both users' and
46 | authors' sake, the GPL requires that modified versions be marked as
47 | changed, so that their problems will not be attributed erroneously to
48 | authors of previous versions.
49 |
50 | Some devices are designed to deny users access to install or run
51 | modified versions of the software inside them, although the manufacturer
52 | can do so. This is fundamentally incompatible with the aim of
53 | protecting users' freedom to change the software. The systematic
54 | pattern of such abuse occurs in the area of products for individuals to
55 | use, which is precisely where it is most unacceptable. Therefore, we
56 | have designed this version of the GPL to prohibit the practice for those
57 | products. If such problems arise substantially in other domains, we
58 | stand ready to extend this provision to those domains in future versions
59 | of the GPL, as needed to protect the freedom of users.
60 |
61 | Finally, every program is threatened constantly by software patents.
62 | States should not allow patents to restrict development and use of
63 | software on general-purpose computers, but in those that do, we wish to
64 | avoid the special danger that patents applied to a free program could
65 | make it effectively proprietary. To prevent this, the GPL assures that
66 | patents cannot be used to render the program non-free.
67 |
68 | The precise terms and conditions for copying, distribution and
69 | modification follow.
70 |
71 | TERMS AND CONDITIONS
72 |
73 | 0. Definitions.
74 |
75 | "This License" refers to version 3 of the GNU General Public License.
76 |
77 | "Copyright" also means copyright-like laws that apply to other kinds of
78 | works, such as semiconductor masks.
79 |
80 | "The Program" refers to any copyrightable work licensed under this
81 | License. Each licensee is addressed as "you". "Licensees" and
82 | "recipients" may be individuals or organizations.
83 |
84 | To "modify" a work means to copy from or adapt all or part of the work
85 | in a fashion requiring copyright permission, other than the making of an
86 | exact copy. The resulting work is called a "modified version" of the
87 | earlier work or a work "based on" the earlier work.
88 |
89 | A "covered work" means either the unmodified Program or a work based
90 | on the Program.
91 |
92 | To "propagate" a work means to do anything with it that, without
93 | permission, would make you directly or secondarily liable for
94 | infringement under applicable copyright law, except executing it on a
95 | computer or modifying a private copy. Propagation includes copying,
96 | distribution (with or without modification), making available to the
97 | public, and in some countries other activities as well.
98 |
99 | To "convey" a work means any kind of propagation that enables other
100 | parties to make or receive copies. Mere interaction with a user through
101 | a computer network, with no transfer of a copy, is not conveying.
102 |
103 | An interactive user interface displays "Appropriate Legal Notices"
104 | to the extent that it includes a convenient and prominently visible
105 | feature that (1) displays an appropriate copyright notice, and (2)
106 | tells the user that there is no warranty for the work (except to the
107 | extent that warranties are provided), that licensees may convey the
108 | work under this License, and how to view a copy of this License. If
109 | the interface presents a list of user commands or options, such as a
110 | menu, a prominent item in the list meets this criterion.
111 |
112 | 1. Source Code.
113 |
114 | The "source code" for a work means the preferred form of the work
115 | for making modifications to it. "Object code" means any non-source
116 | form of a work.
117 |
118 | A "Standard Interface" means an interface that either is an official
119 | standard defined by a recognized standards body, or, in the case of
120 | interfaces specified for a particular programming language, one that
121 | is widely used among developers working in that language.
122 |
123 | The "System Libraries" of an executable work include anything, other
124 | than the work as a whole, that (a) is included in the normal form of
125 | packaging a Major Component, but which is not part of that Major
126 | Component, and (b) serves only to enable use of the work with that
127 | Major Component, or to implement a Standard Interface for which an
128 | implementation is available to the public in source code form. A
129 | "Major Component", in this context, means a major essential component
130 | (kernel, window system, and so on) of the specific operating system
131 | (if any) on which the executable work runs, or a compiler used to
132 | produce the work, or an object code interpreter used to run it.
133 |
134 | The "Corresponding Source" for a work in object code form means all
135 | the source code needed to generate, install, and (for an executable
136 | work) run the object code and to modify the work, including scripts to
137 | control those activities. However, it does not include the work's
138 | System Libraries, or general-purpose tools or generally available free
139 | programs which are used unmodified in performing those activities but
140 | which are not part of the work. For example, Corresponding Source
141 | includes interface definition files associated with source files for
142 | the work, and the source code for shared libraries and dynamically
143 | linked subprograms that the work is specifically designed to require,
144 | such as by intimate data communication or control flow between those
145 | subprograms and other parts of the work.
146 |
147 | The Corresponding Source need not include anything that users
148 | can regenerate automatically from other parts of the Corresponding
149 | Source.
150 |
151 | The Corresponding Source for a work in source code form is that
152 | same work.
153 |
154 | 2. Basic Permissions.
155 |
156 | All rights granted under this License are granted for the term of
157 | copyright on the Program, and are irrevocable provided the stated
158 | conditions are met. This License explicitly affirms your unlimited
159 | permission to run the unmodified Program. The output from running a
160 | covered work is covered by this License only if the output, given its
161 | content, constitutes a covered work. This License acknowledges your
162 | rights of fair use or other equivalent, as provided by copyright law.
163 |
164 | You may make, run and propagate covered works that you do not
165 | convey, without conditions so long as your license otherwise remains
166 | in force. You may convey covered works to others for the sole purpose
167 | of having them make modifications exclusively for you, or provide you
168 | with facilities for running those works, provided that you comply with
169 | the terms of this License in conveying all material for which you do
170 | not control copyright. Those thus making or running the covered works
171 | for you must do so exclusively on your behalf, under your direction
172 | and control, on terms that prohibit them from making any copies of
173 | your copyrighted material outside their relationship with you.
174 |
175 | Conveying under any other circumstances is permitted solely under
176 | the conditions stated below. Sublicensing is not allowed; section 10
177 | makes it unnecessary.
178 |
179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
180 |
181 | No covered work shall be deemed part of an effective technological
182 | measure under any applicable law fulfilling obligations under article
183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or
184 | similar laws prohibiting or restricting circumvention of such
185 | measures.
186 |
187 | When you convey a covered work, you waive any legal power to forbid
188 | circumvention of technological measures to the extent such circumvention
189 | is effected by exercising rights under this License with respect to
190 | the covered work, and you disclaim any intention to limit operation or
191 | modification of the work as a means of enforcing, against the work's
192 | users, your or third parties' legal rights to forbid circumvention of
193 | technological measures.
194 |
195 | 4. Conveying Verbatim Copies.
196 |
197 | You may convey verbatim copies of the Program's source code as you
198 | receive it, in any medium, provided that you conspicuously and
199 | appropriately publish on each copy an appropriate copyright notice;
200 | keep intact all notices stating that this License and any
201 | non-permissive terms added in accord with section 7 apply to the code;
202 | keep intact all notices of the absence of any warranty; and give all
203 | recipients a copy of this License along with the Program.
204 |
205 | You may charge any price or no price for each copy that you convey,
206 | and you may offer support or warranty protection for a fee.
207 |
208 | 5. Conveying Modified Source Versions.
209 |
210 | You may convey a work based on the Program, or the modifications to
211 | produce it from the Program, in the form of source code under the
212 | terms of section 4, provided that you also meet all of these conditions:
213 |
214 | a) The work must carry prominent notices stating that you modified
215 | it, and giving a relevant date.
216 |
217 | b) The work must carry prominent notices stating that it is
218 | released under this License and any conditions added under section
219 | 7. This requirement modifies the requirement in section 4 to
220 | "keep intact all notices".
221 |
222 | c) You must license the entire work, as a whole, under this
223 | License to anyone who comes into possession of a copy. This
224 | License will therefore apply, along with any applicable section 7
225 | additional terms, to the whole of the work, and all its parts,
226 | regardless of how they are packaged. This License gives no
227 | permission to license the work in any other way, but it does not
228 | invalidate such permission if you have separately received it.
229 |
230 | d) If the work has interactive user interfaces, each must display
231 | Appropriate Legal Notices; however, if the Program has interactive
232 | interfaces that do not display Appropriate Legal Notices, your
233 | work need not make them do so.
234 |
235 | A compilation of a covered work with other separate and independent
236 | works, which are not by their nature extensions of the covered work,
237 | and which are not combined with it such as to form a larger program,
238 | in or on a volume of a storage or distribution medium, is called an
239 | "aggregate" if the compilation and its resulting copyright are not
240 | used to limit the access or legal rights of the compilation's users
241 | beyond what the individual works permit. Inclusion of a covered work
242 | in an aggregate does not cause this License to apply to the other
243 | parts of the aggregate.
244 |
245 | 6. Conveying Non-Source Forms.
246 |
247 | You may convey a covered work in object code form under the terms
248 | of sections 4 and 5, provided that you also convey the
249 | machine-readable Corresponding Source under the terms of this License,
250 | in one of these ways:
251 |
252 | a) Convey the object code in, or embodied in, a physical product
253 | (including a physical distribution medium), accompanied by the
254 | Corresponding Source fixed on a durable physical medium
255 | customarily used for software interchange.
256 |
257 | b) Convey the object code in, or embodied in, a physical product
258 | (including a physical distribution medium), accompanied by a
259 | written offer, valid for at least three years and valid for as
260 | long as you offer spare parts or customer support for that product
261 | model, to give anyone who possesses the object code either (1) a
262 | copy of the Corresponding Source for all the software in the
263 | product that is covered by this License, on a durable physical
264 | medium customarily used for software interchange, for a price no
265 | more than your reasonable cost of physically performing this
266 | conveying of source, or (2) access to copy the
267 | Corresponding Source from a network server at no charge.
268 |
269 | c) Convey individual copies of the object code with a copy of the
270 | written offer to provide the Corresponding Source. This
271 | alternative is allowed only occasionally and noncommercially, and
272 | only if you received the object code with such an offer, in accord
273 | with subsection 6b.
274 |
275 | d) Convey the object code by offering access from a designated
276 | place (gratis or for a charge), and offer equivalent access to the
277 | Corresponding Source in the same way through the same place at no
278 | further charge. You need not require recipients to copy the
279 | Corresponding Source along with the object code. If the place to
280 | copy the object code is a network server, the Corresponding Source
281 | may be on a different server (operated by you or a third party)
282 | that supports equivalent copying facilities, provided you maintain
283 | clear directions next to the object code saying where to find the
284 | Corresponding Source. Regardless of what server hosts the
285 | Corresponding Source, you remain obligated to ensure that it is
286 | available for as long as needed to satisfy these requirements.
287 |
288 | e) Convey the object code using peer-to-peer transmission, provided
289 | you inform other peers where the object code and Corresponding
290 | Source of the work are being offered to the general public at no
291 | charge under subsection 6d.
292 |
293 | A separable portion of the object code, whose source code is excluded
294 | from the Corresponding Source as a System Library, need not be
295 | included in conveying the object code work.
296 |
297 | A "User Product" is either (1) a "consumer product", which means any
298 | tangible personal property which is normally used for personal, family,
299 | or household purposes, or (2) anything designed or sold for incorporation
300 | into a dwelling. In determining whether a product is a consumer product,
301 | doubtful cases shall be resolved in favor of coverage. For a particular
302 | product received by a particular user, "normally used" refers to a
303 | typical or common use of that class of product, regardless of the status
304 | of the particular user or of the way in which the particular user
305 | actually uses, or expects or is expected to use, the product. A product
306 | is a consumer product regardless of whether the product has substantial
307 | commercial, industrial or non-consumer uses, unless such uses represent
308 | the only significant mode of use of the product.
309 |
310 | "Installation Information" for a User Product means any methods,
311 | procedures, authorization keys, or other information required to install
312 | and execute modified versions of a covered work in that User Product from
313 | a modified version of its Corresponding Source. The information must
314 | suffice to ensure that the continued functioning of the modified object
315 | code is in no case prevented or interfered with solely because
316 | modification has been made.
317 |
318 | If you convey an object code work under this section in, or with, or
319 | specifically for use in, a User Product, and the conveying occurs as
320 | part of a transaction in which the right of possession and use of the
321 | User Product is transferred to the recipient in perpetuity or for a
322 | fixed term (regardless of how the transaction is characterized), the
323 | Corresponding Source conveyed under this section must be accompanied
324 | by the Installation Information. But this requirement does not apply
325 | if neither you nor any third party retains the ability to install
326 | modified object code on the User Product (for example, the work has
327 | been installed in ROM).
328 |
329 | The requirement to provide Installation Information does not include a
330 | requirement to continue to provide support service, warranty, or updates
331 | for a work that has been modified or installed by the recipient, or for
332 | the User Product in which it has been modified or installed. Access to a
333 | network may be denied when the modification itself materially and
334 | adversely affects the operation of the network or violates the rules and
335 | protocols for communication across the network.
336 |
337 | Corresponding Source conveyed, and Installation Information provided,
338 | in accord with this section must be in a format that is publicly
339 | documented (and with an implementation available to the public in
340 | source code form), and must require no special password or key for
341 | unpacking, reading or copying.
342 |
343 | 7. Additional Terms.
344 |
345 | "Additional permissions" are terms that supplement the terms of this
346 | License by making exceptions from one or more of its conditions.
347 | Additional permissions that are applicable to the entire Program shall
348 | be treated as though they were included in this License, to the extent
349 | that they are valid under applicable law. If additional permissions
350 | apply only to part of the Program, that part may be used separately
351 | under those permissions, but the entire Program remains governed by
352 | this License without regard to the additional permissions.
353 |
354 | When you convey a copy of a covered work, you may at your option
355 | remove any additional permissions from that copy, or from any part of
356 | it. (Additional permissions may be written to require their own
357 | removal in certain cases when you modify the work.) You may place
358 | additional permissions on material, added by you to a covered work,
359 | for which you have or can give appropriate copyright permission.
360 |
361 | Notwithstanding any other provision of this License, for material you
362 | add to a covered work, you may (if authorized by the copyright holders of
363 | that material) supplement the terms of this License with terms:
364 |
365 | a) Disclaiming warranty or limiting liability differently from the
366 | terms of sections 15 and 16 of this License; or
367 |
368 | b) Requiring preservation of specified reasonable legal notices or
369 | author attributions in that material or in the Appropriate Legal
370 | Notices displayed by works containing it; or
371 |
372 | c) Prohibiting misrepresentation of the origin of that material, or
373 | requiring that modified versions of such material be marked in
374 | reasonable ways as different from the original version; or
375 |
376 | d) Limiting the use for publicity purposes of names of licensors or
377 | authors of the material; or
378 |
379 | e) Declining to grant rights under trademark law for use of some
380 | trade names, trademarks, or service marks; or
381 |
382 | f) Requiring indemnification of licensors and authors of that
383 | material by anyone who conveys the material (or modified versions of
384 | it) with contractual assumptions of liability to the recipient, for
385 | any liability that these contractual assumptions directly impose on
386 | those licensors and authors.
387 |
388 | All other non-permissive additional terms are considered "further
389 | restrictions" within the meaning of section 10. If the Program as you
390 | received it, or any part of it, contains a notice stating that it is
391 | governed by this License along with a term that is a further
392 | restriction, you may remove that term. If a license document contains
393 | a further restriction but permits relicensing or conveying under this
394 | License, you may add to a covered work material governed by the terms
395 | of that license document, provided that the further restriction does
396 | not survive such relicensing or conveying.
397 |
398 | If you add terms to a covered work in accord with this section, you
399 | must place, in the relevant source files, a statement of the
400 | additional terms that apply to those files, or a notice indicating
401 | where to find the applicable terms.
402 |
403 | Additional terms, permissive or non-permissive, may be stated in the
404 | form of a separately written license, or stated as exceptions;
405 | the above requirements apply either way.
406 |
407 | 8. Termination.
408 |
409 | You may not propagate or modify a covered work except as expressly
410 | provided under this License. Any attempt otherwise to propagate or
411 | modify it is void, and will automatically terminate your rights under
412 | this License (including any patent licenses granted under the third
413 | paragraph of section 11).
414 |
415 | However, if you cease all violation of this License, then your
416 | license from a particular copyright holder is reinstated (a)
417 | provisionally, unless and until the copyright holder explicitly and
418 | finally terminates your license, and (b) permanently, if the copyright
419 | holder fails to notify you of the violation by some reasonable means
420 | prior to 60 days after the cessation.
421 |
422 | Moreover, your license from a particular copyright holder is
423 | reinstated permanently if the copyright holder notifies you of the
424 | violation by some reasonable means, this is the first time you have
425 | received notice of violation of this License (for any work) from that
426 | copyright holder, and you cure the violation prior to 30 days after
427 | your receipt of the notice.
428 |
429 | Termination of your rights under this section does not terminate the
430 | licenses of parties who have received copies or rights from you under
431 | this License. If your rights have been terminated and not permanently
432 | reinstated, you do not qualify to receive new licenses for the same
433 | material under section 10.
434 |
435 | 9. Acceptance Not Required for Having Copies.
436 |
437 | You are not required to accept this License in order to receive or
438 | run a copy of the Program. Ancillary propagation of a covered work
439 | occurring solely as a consequence of using peer-to-peer transmission
440 | to receive a copy likewise does not require acceptance. However,
441 | nothing other than this License grants you permission to propagate or
442 | modify any covered work. These actions infringe copyright if you do
443 | not accept this License. Therefore, by modifying or propagating a
444 | covered work, you indicate your acceptance of this License to do so.
445 |
446 | 10. Automatic Licensing of Downstream Recipients.
447 |
448 | Each time you convey a covered work, the recipient automatically
449 | receives a license from the original licensors, to run, modify and
450 | propagate that work, subject to this License. You are not responsible
451 | for enforcing compliance by third parties with this License.
452 |
453 | An "entity transaction" is a transaction transferring control of an
454 | organization, or substantially all assets of one, or subdividing an
455 | organization, or merging organizations. If propagation of a covered
456 | work results from an entity transaction, each party to that
457 | transaction who receives a copy of the work also receives whatever
458 | licenses to the work the party's predecessor in interest had or could
459 | give under the previous paragraph, plus a right to possession of the
460 | Corresponding Source of the work from the predecessor in interest, if
461 | the predecessor has it or can get it with reasonable efforts.
462 |
463 | You may not impose any further restrictions on the exercise of the
464 | rights granted or affirmed under this License. For example, you may
465 | not impose a license fee, royalty, or other charge for exercise of
466 | rights granted under this License, and you may not initiate litigation
467 | (including a cross-claim or counterclaim in a lawsuit) alleging that
468 | any patent claim is infringed by making, using, selling, offering for
469 | sale, or importing the Program or any portion of it.
470 |
471 | 11. Patents.
472 |
473 | A "contributor" is a copyright holder who authorizes use under this
474 | License of the Program or a work on which the Program is based. The
475 | work thus licensed is called the contributor's "contributor version".
476 |
477 | A contributor's "essential patent claims" are all patent claims
478 | owned or controlled by the contributor, whether already acquired or
479 | hereafter acquired, that would be infringed by some manner, permitted
480 | by this License, of making, using, or selling its contributor version,
481 | but do not include claims that would be infringed only as a
482 | consequence of further modification of the contributor version. For
483 | purposes of this definition, "control" includes the right to grant
484 | patent sublicenses in a manner consistent with the requirements of
485 | this License.
486 |
487 | Each contributor grants you a non-exclusive, worldwide, royalty-free
488 | patent license under the contributor's essential patent claims, to
489 | make, use, sell, offer for sale, import and otherwise run, modify and
490 | propagate the contents of its contributor version.
491 |
492 | In the following three paragraphs, a "patent license" is any express
493 | agreement or commitment, however denominated, not to enforce a patent
494 | (such as an express permission to practice a patent or covenant not to
495 | sue for patent infringement). To "grant" such a patent license to a
496 | party means to make such an agreement or commitment not to enforce a
497 | patent against the party.
498 |
499 | If you convey a covered work, knowingly relying on a patent license,
500 | and the Corresponding Source of the work is not available for anyone
501 | to copy, free of charge and under the terms of this License, through a
502 | publicly available network server or other readily accessible means,
503 | then you must either (1) cause the Corresponding Source to be so
504 | available, or (2) arrange to deprive yourself of the benefit of the
505 | patent license for this particular work, or (3) arrange, in a manner
506 | consistent with the requirements of this License, to extend the patent
507 | license to downstream recipients. "Knowingly relying" means you have
508 | actual knowledge that, but for the patent license, your conveying the
509 | covered work in a country, or your recipient's use of the covered work
510 | in a country, would infringe one or more identifiable patents in that
511 | country that you have reason to believe are valid.
512 |
513 | If, pursuant to or in connection with a single transaction or
514 | arrangement, you convey, or propagate by procuring conveyance of, a
515 | covered work, and grant a patent license to some of the parties
516 | receiving the covered work authorizing them to use, propagate, modify
517 | or convey a specific copy of the covered work, then the patent license
518 | you grant is automatically extended to all recipients of the covered
519 | work and works based on it.
520 |
521 | A patent license is "discriminatory" if it does not include within
522 | the scope of its coverage, prohibits the exercise of, or is
523 | conditioned on the non-exercise of one or more of the rights that are
524 | specifically granted under this License. You may not convey a covered
525 | work if you are a party to an arrangement with a third party that is
526 | in the business of distributing software, under which you make payment
527 | to the third party based on the extent of your activity of conveying
528 | the work, and under which the third party grants, to any of the
529 | parties who would receive the covered work from you, a discriminatory
530 | patent license (a) in connection with copies of the covered work
531 | conveyed by you (or copies made from those copies), or (b) primarily
532 | for and in connection with specific products or compilations that
533 | contain the covered work, unless you entered into that arrangement,
534 | or that patent license was granted, prior to 28 March 2007.
535 |
536 | Nothing in this License shall be construed as excluding or limiting
537 | any implied license or other defenses to infringement that may
538 | otherwise be available to you under applicable patent law.
539 |
540 | 12. No Surrender of Others' Freedom.
541 |
542 | If conditions are imposed on you (whether by court order, agreement or
543 | otherwise) that contradict the conditions of this License, they do not
544 | excuse you from the conditions of this License. If you cannot convey a
545 | covered work so as to satisfy simultaneously your obligations under this
546 | License and any other pertinent obligations, then as a consequence you may
547 | not convey it at all. For example, if you agree to terms that obligate you
548 | to collect a royalty for further conveying from those to whom you convey
549 | the Program, the only way you could satisfy both those terms and this
550 | License would be to refrain entirely from conveying the Program.
551 |
552 | 13. Use with the GNU Affero General Public License.
553 |
554 | Notwithstanding any other provision of this License, you have
555 | permission to link or combine any covered work with a work licensed
556 | under version 3 of the GNU Affero General Public License into a single
557 | combined work, and to convey the resulting work. The terms of this
558 | License will continue to apply to the part which is the covered work,
559 | but the special requirements of the GNU Affero General Public License,
560 | section 13, concerning interaction through a network will apply to the
561 | combination as such.
562 |
563 | 14. Revised Versions of this License.
564 |
565 | The Free Software Foundation may publish revised and/or new versions of
566 | the GNU General Public License from time to time. Such new versions will
567 | be similar in spirit to the present version, but may differ in detail to
568 | address new problems or concerns.
569 |
570 | Each version is given a distinguishing version number. If the
571 | Program specifies that a certain numbered version of the GNU General
572 | Public License "or any later version" applies to it, you have the
573 | option of following the terms and conditions either of that numbered
574 | version or of any later version published by the Free Software
575 | Foundation. If the Program does not specify a version number of the
576 | GNU General Public License, you may choose any version ever published
577 | by the Free Software Foundation.
578 |
579 | If the Program specifies that a proxy can decide which future
580 | versions of the GNU General Public License can be used, that proxy's
581 | public statement of acceptance of a version permanently authorizes you
582 | to choose that version for the Program.
583 |
584 | Later license versions may give you additional or different
585 | permissions. However, no additional obligations are imposed on any
586 | author or copyright holder as a result of your choosing to follow a
587 | later version.
588 |
589 | 15. Disclaimer of Warranty.
590 |
591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
599 |
600 | 16. Limitation of Liability.
601 |
602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
610 | SUCH DAMAGES.
611 |
612 | 17. Interpretation of Sections 15 and 16.
613 |
614 | If the disclaimer of warranty and limitation of liability provided
615 | above cannot be given local legal effect according to their terms,
616 | reviewing courts shall apply local law that most closely approximates
617 | an absolute waiver of all civil liability in connection with the
618 | Program, unless a warranty or assumption of liability accompanies a
619 | copy of the Program in return for a fee.
620 |
621 | END OF TERMS AND CONDITIONS
622 |
623 | How to Apply These Terms to Your New Programs
624 |
625 | If you develop a new program, and you want it to be of the greatest
626 | possible use to the public, the best way to achieve this is to make it
627 | free software which everyone can redistribute and change under these terms.
628 |
629 | To do so, attach the following notices to the program. It is safest
630 | to attach them to the start of each source file to most effectively
631 | state the exclusion of warranty; and each file should have at least
632 | the "copyright" line and a pointer to where the full notice is found.
633 |
634 |
635 | Copyright (C)
636 |
637 | This program is free software: you can redistribute it and/or modify
638 | it under the terms of the GNU General Public License as published by
639 | the Free Software Foundation, either version 3 of the License, or
640 | (at your option) any later version.
641 |
642 | This program is distributed in the hope that it will be useful,
643 | but WITHOUT ANY WARRANTY; without even the implied warranty of
644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
645 | GNU General Public License for more details.
646 |
647 | You should have received a copy of the GNU General Public License
648 | along with this program. If not, see .
649 |
650 | Also add information on how to contact you by electronic and paper mail.
651 |
652 | If the program does terminal interaction, make it output a short
653 | notice like this when it starts in an interactive mode:
654 |
655 | Copyright (C)
656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
657 | This is free software, and you are welcome to redistribute it
658 | under certain conditions; type `show c' for details.
659 |
660 | The hypothetical commands `show w' and `show c' should show the appropriate
661 | parts of the General Public License. Of course, your program's commands
662 | might be different; for a GUI interface, you would use an "about box".
663 |
664 | You should also get your employer (if you work as a programmer) or school,
665 | if any, to sign a "copyright disclaimer" for the program, if necessary.
666 | For more information on this, and how to apply and follow the GNU GPL, see
667 | .
668 |
669 | The GNU General Public License does not permit incorporating your program
670 | into proprietary programs. If your program is a subroutine library, you
671 | may consider it more useful to permit linking proprietary applications with
672 | the library. If this is what you want to do, use the GNU Lesser General
673 | Public License instead of this License. But first, please read
674 | .
675 |
--------------------------------------------------------------------------------
/config/conf.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | train_dir: ${env:PWD}/data/train
3 | test_dir: ${env:PWD}/data/test
4 | batch_size: 16
5 | pin_memory: false
6 | num_workers: 6
7 | split_train_val: true
8 | split_ratios: [0.75, 0.25]
9 | consider_features: ${features.attributes}
10 |
11 | model:
12 | convolution_count: 0
13 | convolution_algorithm: GraphConv # Can be one of GraphConv, SAGEConv, TAGConv, SGConv, DotGatConv
14 | input_dimension: ${features.size}
15 |
16 | trainer:
17 | max_epochs: 100
18 | gpus: null
19 |
20 | hydra:
21 | run:
22 | dir: output/${model.convolution_algorithm}/${features.name}-conv_count=${model.convolution_count}
23 | sweep:
24 | dir: output/${model.convolution_algorithm}
25 | subdir: ${features.name}-conv_count=${model.convolution_count}
26 |
27 | defaults:
28 | - features: degree
29 | - logger: wandb
30 |
--------------------------------------------------------------------------------
/config/features/degree.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: degree
3 | attributes: []
4 | size: 1
--------------------------------------------------------------------------------
/config/features/method_attributes.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: method
3 | attributes:
4 | - external
5 | - native
6 | - public
7 | - static
8 | - codesize
9 | size: 5
--------------------------------------------------------------------------------
/config/features/method_summary.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: hybrid
3 | attributes:
4 | - api
5 | - user
6 | size: 247
7 |
--------------------------------------------------------------------------------
/config/logger/wandb.yaml:
--------------------------------------------------------------------------------
1 | # @package _group_
2 | name: wandb
3 | args:
4 | project: malware_homo
5 | log_model: true
6 | hparams:
7 | convolution_algorithm: ${model.convolution_algorithm}
8 | features: ${features.name}
9 | convolution_count: ${model.convolution_count}
--------------------------------------------------------------------------------
/core/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinayakakv/android-malware-detection/1aab288ec599a3958982866ce989311a96cbffd9/core/__init__.py
--------------------------------------------------------------------------------
/core/callbacks.py:
--------------------------------------------------------------------------------
1 | import numbers
2 | from typing import Tuple, List, Union
3 |
4 | import dgl
5 | import matplotlib.pyplot as plt
6 | import numpy as np
7 | import pytorch_lightning as pl
8 | import torch.nn
9 | import wandb
10 | from pytorch_lightning.callbacks import Callback
11 | from pytorch_lightning.metrics.metric import Metric
12 |
13 | from core.model import MalwareDetector
14 | from core.utils import plot_confusion_matrix, plot_curve
15 |
16 |
17 | class InputMonitor(Callback):
18 | """
19 | Plots the histogram of input labels
20 | """
21 |
22 | def __init__(self):
23 | pass
24 |
25 | def on_train_batch_start(
26 | self,
27 | trainer: pl.Trainer,
28 | pl_module: pl.LightningModule,
29 | batch: Tuple[dgl.DGLHeteroGraph, torch.Tensor],
30 | batch_idx: int,
31 | dataloader_idx: int
32 | ):
33 | samples, labels = batch
34 | trainer.logger.experiment.log({
35 | 'train_data_histogram': wandb.Histogram(labels.detach().cpu().numpy())
36 | }, commit=False)
37 |
38 |
39 | class BestModelTagger(Callback):
40 | """
41 | Logs the "best_epoch" and the metric value corresponding to that to the logger
42 | Inspired from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/early_stopping.py
43 | """
44 |
45 | def __init__(self, monitor: str = 'val_loss', mode: str = 'min'):
46 | self.monitor = monitor
47 | if mode not in ['min', 'max']:
48 | raise ValueError(f"Invalid mode {mode}. Must be one of 'min' or 'max'")
49 | self.mode = mode
50 | self.monitor_op = torch.lt if mode == 'min' else torch.gt
51 | torch_inf = torch.tensor(np.Inf)
52 | self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf
53 |
54 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule):
55 | logs = trainer.callback_metrics
56 | monitor_val = logs.get(self.monitor)
57 | if monitor_val is None:
58 | raise RuntimeError(f"{self.monitor} was supposed to be logged from model. Could not find that")
59 | if monitor_val is not None:
60 | if isinstance(monitor_val, Metric):
61 | monitor_val = monitor_val.compute()
62 | elif isinstance(monitor_val, numbers.Number):
63 | monitor_val = torch.tensor(monitor_val, device=pl_module.device, dtype=torch.float)
64 | if self.monitor_op(monitor_val, self.best_score):
65 | self.best_score = monitor_val
66 | trainer.logger.experiment.log({
67 | f'{self.mode}_{self.monitor}': monitor_val.cpu().numpy(),
68 | 'best_epoch': trainer.current_epoch
69 | }, commit=False)
70 |
71 |
72 | class MetricsLogger(Callback):
73 |
74 | def __init__(self, stages: Union[List[str], str]):
75 | valid_stages = {'train', 'val', 'test'}
76 | if stages == 'all':
77 | self.stages = valid_stages
78 | else:
79 | for stage in stages:
80 | if stage not in valid_stages:
81 | raise ValueError(f"Stage {stage} is not valid. Must be one of {valid_stages}")
82 | self.stages = set(stages) & valid_stages
83 |
84 | @staticmethod
85 | def _plot_metrics(trainer: pl.Trainer, pl_module: MalwareDetector, stage: str):
86 | confusion_matrix = pl_module.test_outputs['confusion_matrix'].compute().cpu().numpy()
87 | plot_confusion_matrix(
88 | confusion_matrix,
89 | group_names=['TN', 'FP', 'FN', 'TP'],
90 | categories=['Benign', 'Malware'],
91 | cmap='binary'
92 | )
93 | trainer.logger.experiment.log({
94 | f'{stage}_confusion_matrix': wandb.Image(plt)
95 | }, commit=False)
96 | if stage != 'test':
97 | return
98 | roc = pl_module.test_outputs['roc'].compute()
99 | figure = plot_curve(roc[0].cpu(), roc[1].cpu(), 'roc')
100 | trainer.logger.experiment.log({
101 | f'ROC': figure
102 | }, commit=False)
103 | prc = pl_module.test_outputs['prc'].compute()
104 | figure = plot_curve(prc[1].cpu(), prc[0].cpu(), 'prc')
105 | trainer.logger.experiment.log({
106 | f'PRC': figure
107 | }, commit=False)
108 |
109 | @staticmethod
110 | def compute_metrics(pl_module: MalwareDetector, stage: str):
111 | metrics = {}
112 | if stage == 'train':
113 | metric_dict = pl_module.train_metrics
114 | elif stage == 'val':
115 | metric_dict = pl_module.val_metrics
116 | elif stage == 'test':
117 | metric_dict = pl_module.test_metrics
118 | else:
119 | raise ValueError(f"Invalid stage: {stage}")
120 | for metric_name, metric in metric_dict.items():
121 | metrics[metric_name] = metric.compute()
122 | return metrics
123 |
124 | def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: MalwareDetector, outputs):
125 | if 'train' not in self.stages:
126 | return
127 | trainer.logger.experiment.log(self.compute_metrics(pl_module, 'train'), commit=False)
128 |
129 | def on_validation_end(self, trainer: pl.Trainer, pl_module: MalwareDetector):
130 | if 'val' not in self.stages or trainer.running_sanity_check:
131 | return
132 | trainer.logger.experiment.log(self.compute_metrics(pl_module, 'val'), commit=False)
133 |
134 | def on_test_end(self, trainer: pl.Trainer, pl_module: MalwareDetector):
135 | if 'test' not in self.stages:
136 | return
137 | trainer.logger.experiment.log(self.compute_metrics(pl_module, 'test'), commit=False)
138 | self._plot_metrics(trainer, pl_module, 'test')
139 |
--------------------------------------------------------------------------------
/core/data_module.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Dict, Tuple, Union
3 |
4 | import dgl
5 | import pytorch_lightning as pl
6 | import torch
7 | from sklearn.model_selection import StratifiedShuffleSplit
8 | from torch.utils.data import DataLoader
9 |
10 | from core.dataset import MalwareDataset
11 |
12 |
13 | def stratified_split_dataset(samples: List[str],
14 | labels: Dict[str, int],
15 | ratios: Tuple[float, float]) -> Tuple[List[str], List[str]]:
16 | """
17 | Split the dataset into train and validation datasets based on the given ratio
18 | :param samples: List of file names
19 | :param labels: Mapping from file name to label
20 | :param ratios: Training ratio, validation ratio
21 | :return: List of file names in training and validation split
22 | """
23 | if sum(ratios) != 1:
24 | raise Exception("Invalid ratios provided")
25 | train_ratio, val_ratio = ratios
26 | sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=0)
27 | train_idx, val_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]
28 | train_list = [samples[x] for x in train_idx]
29 | val_list = [samples[x] for x in val_idx]
30 | return train_list, val_list
31 |
32 |
33 | @torch.no_grad()
34 | def collate(samples: List[Tuple[dgl.DGLGraph, int]]) -> (dgl.DGLGraph, torch.Tensor):
35 | """
36 | Batches several graphs into one
37 | :param samples: Tuple containing graph and its label
38 | :return: Batched graph, and labels concatenated into a tensor
39 | """
40 | graphs, labels = map(list, zip(*samples))
41 | batched_graph = dgl.batch(graphs)
42 | labels = torch.tensor(labels)
43 | return batched_graph, labels.float()
44 |
45 |
46 | class MalwareDataModule(pl.LightningDataModule):
47 | """
48 | Handler class for data loading, splitting and initializing datasets and dataloaders.
49 | """
50 |
51 | def __init__(
52 | self,
53 | train_dir: Union[str, Path],
54 | test_dir: Union[str, Path],
55 | batch_size: int,
56 | split_ratios: Tuple[float, float],
57 | consider_features: List[str],
58 | num_workers: int,
59 | pin_memory: bool,
60 | split_train_val: bool,
61 | ):
62 | """
63 | Creates the MalwareDataModule
64 | :param train_dir: The directory containing the training samples
65 | :param test_dir: The directory containing the testing samples
66 | :param batch_size: Number of graphs in a batch
67 | :param split_ratios: Tuple containing training and validation split
68 | :param consider_features: Features types to consider
69 | :param num_workers: Number of processes to
70 | :param pin_memory: If True, said to be speeding up GPU data transfer
71 | :param split_train_val: If true, split the train dataset into train and validation,
72 | else use test dataset for validation
73 | """
74 | super().__init__()
75 | self.train_dir = Path(train_dir)
76 | if not self.train_dir.exists():
77 | raise FileNotFoundError(f"Train directory {train_dir} does not exist. Could not read from it.")
78 | self.test_dir = Path(test_dir)
79 | if not self.test_dir.exists():
80 | raise FileNotFoundError(f"Test directory {test_dir} does not exist. Could not read from it.")
81 | self.dataloader_kwargs = {
82 | 'num_workers': num_workers,
83 | 'batch_size': batch_size,
84 | 'pin_memory': pin_memory,
85 | 'collate_fn': collate,
86 | 'drop_last': True
87 | }
88 | self.split_ratios = split_ratios
89 | self.split = split_train_val
90 | self.splitter = stratified_split_dataset
91 | self.consider_features = consider_features
92 |
93 | @staticmethod
94 | def get_samples(path: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]:
95 | """
96 | Get samples and labels from the given path
97 | :param path: The directory containing graphs
98 | :return: The file list, and their label mapping
99 | """
100 | base_path = Path(path)
101 | if not base_path.exists():
102 | raise FileNotFoundError(f'{base_path} does not exist')
103 | apk_list = sorted([x for x in base_path.iterdir()])
104 | samples = []
105 | labels = {}
106 | for apk in apk_list:
107 | samples.append(apk.name)
108 | labels[apk.name] = int("Benig" not in apk.name)
109 | return samples, labels
110 |
111 | def setup(self, stage=None):
112 | samples, labels = self.get_samples(self.train_dir)
113 | test_samples, test_labels = self.get_samples(self.test_dir)
114 | if self.split:
115 | train_samples, val_samples = self.splitter(samples, labels, self.split_ratios)
116 | val_dir = self.train_dir
117 | val_labels = labels
118 | else:
119 | train_samples = samples
120 | val_dir = self.test_dir
121 | val_samples, val_labels = test_samples, test_labels
122 | self.train_dataset = MalwareDataset(
123 | source_dir=self.train_dir,
124 | samples=train_samples,
125 | labels=labels,
126 | consider_features=self.consider_features
127 | )
128 | self.val_dataset = MalwareDataset(
129 | source_dir=val_dir,
130 | samples=val_samples,
131 | labels=val_labels,
132 | consider_features=self.consider_features
133 | )
134 | self.test_dataset = MalwareDataset(
135 | source_dir=self.test_dir,
136 | samples=test_samples,
137 | labels=test_labels,
138 | consider_features=self.consider_features
139 | )
140 |
141 | def train_dataloader(self):
142 | return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs)
143 |
144 | def val_dataloader(self):
145 | return DataLoader(self.val_dataset, **self.dataloader_kwargs)
146 |
147 | def test_dataloader(self):
148 | return DataLoader(self.test_dataset, **self.dataloader_kwargs)
149 |
--------------------------------------------------------------------------------
/core/dataset.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import List, Dict, Tuple, Union
3 |
4 | import dgl
5 | import torch
6 | from torch.utils.data import Dataset
7 |
8 | attributes = {'external', 'entrypoint', 'native', 'public', 'static', 'codesize'}
9 |
10 |
11 | class MalwareDataset(Dataset):
12 | def __init__(
13 | self,
14 | source_dir: Union[str, Path],
15 | samples: List[str],
16 | labels: Dict[str, int],
17 | consider_features: List[str],
18 | ):
19 | self.source_dir = Path(source_dir)
20 | self.samples = samples
21 | self.labels = labels
22 | self.consider_features = consider_features
23 |
24 | def __len__(self) -> int:
25 | """Denotes the total number of samples"""
26 | return len(self.samples)
27 |
28 | @staticmethod
29 | def _process_node_attributes(g: dgl.DGLGraph):
30 | for attribute in attributes & set(g.ndata.keys()):
31 | g.ndata[attribute] = g.ndata[attribute].view(-1, 1)
32 | return g
33 |
34 | def __getitem__(self, index: int) -> Tuple[dgl.DGLGraph, int]:
35 | """Generates one sample of data"""
36 | name = self.samples[index]
37 | graphs, _ = dgl.data.utils.load_graphs(str(self.source_dir / name))
38 | graph: dgl.DGLGraph = dgl.add_self_loop(graphs[0])
39 | g = self._process_node_attributes(graph)
40 | if len(g.ndata.keys()) > 0:
41 | features = torch.cat([g.ndata[x] for x in self.consider_features], dim=1).float()
42 | else:
43 | features = (g.in_degrees() + g.out_degrees()).view(-1, 1).float()
44 | g.ndata.clear()
45 | g.ndata['features'] = features
46 | return g, self.labels[name]
47 |
--------------------------------------------------------------------------------
/core/model.py:
--------------------------------------------------------------------------------
1 | from typing import Mapping
2 | from typing import Tuple, Optional, Dict
3 |
4 | import dgl
5 | import dgl.nn.pytorch as graph_nn
6 | import pytorch_lightning as pl
7 | import pytorch_lightning.metrics as metrics
8 | import torch
9 | import torch.nn.functional as F
10 | from dgl.nn import Sequential
11 | from pytorch_lightning.metrics import Metric
12 | from torch import nn
13 |
14 |
15 | class MalwareDetector(pl.LightningModule):
16 | def __init__(
17 | self,
18 | input_dimension: int,
19 | convolution_algorithm: str,
20 | convolution_count: int,
21 | ):
22 | super().__init__()
23 | supported_algorithms = ['GraphConv', 'SAGEConv', 'TAGConv', 'DotGatConv']
24 | if convolution_algorithm not in supported_algorithms:
25 | raise ValueError(
26 | f"{convolution_algorithm} is not supported. Supported algorithms are {supported_algorithms}")
27 | self.save_hyperparameters()
28 | self.convolution_layers = []
29 | convolution_dimensions = [64, 32, 16]
30 | for dimension in convolution_dimensions[:convolution_count]:
31 | self.convolution_layers.append(self._get_convolution_layer(
32 | name=convolution_algorithm,
33 | input_dimension=input_dimension,
34 | output_dimension=dimension
35 | ))
36 | input_dimension = dimension
37 | self.convolution_layers = Sequential(*self.convolution_layers)
38 | self.last_dimension = input_dimension
39 | self.classify = nn.Linear(input_dimension, 1)
40 | # Metrics
41 | self.loss_func = nn.BCEWithLogitsLoss()
42 | self.train_metrics = self._get_metric_dict('train')
43 | self.val_metrics = self._get_metric_dict('val')
44 | self.test_metrics = self._get_metric_dict('test')
45 | self.test_outputs = nn.ModuleDict({
46 | 'confusion_matrix': metrics.ConfusionMatrix(num_classes=2),
47 | 'prc': metrics.PrecisionRecallCurve(compute_on_step=False),
48 | 'roc': metrics.ROC(compute_on_step=False)
49 | })
50 |
51 | @staticmethod
52 | def _get_convolution_layer(
53 | name: str,
54 | input_dimension: int,
55 | output_dimension: int
56 | ) -> Optional[nn.Module]:
57 | return {
58 | "GraphConv": graph_nn.GraphConv(
59 | input_dimension,
60 | output_dimension,
61 | activation=F.relu
62 | ),
63 | "SAGEConv": graph_nn.SAGEConv(
64 | input_dimension,
65 | output_dimension,
66 | activation=F.relu,
67 | aggregator_type='mean',
68 | norm=F.normalize
69 | ),
70 | "DotGatConv": graph_nn.DotGatConv(
71 | input_dimension,
72 | output_dimension,
73 | num_heads=1
74 | ),
75 | "TAGConv": graph_nn.TAGConv(
76 | input_dimension,
77 | output_dimension,
78 | k=4
79 | )
80 | }.get(name, None)
81 |
82 | @staticmethod
83 | def _get_metric_dict(stage: str) -> Mapping[str, Metric]:
84 | return nn.ModuleDict({
85 | f'{stage}_accuracy': metrics.Accuracy(),
86 | f'{stage}_precision': metrics.Precision(num_classes=1),
87 | f'{stage}_recall': metrics.Recall(num_classes=1),
88 | f'{stage}_f1': metrics.FBeta(num_classes=1)
89 | })
90 |
91 | def forward(self, g: dgl.DGLGraph) -> torch.Tensor:
92 | with g.local_scope():
93 | h = g.ndata['features']
94 | h = self.convolution_layers(g, h)
95 | g.ndata['h'] = h if len(self.convolution_layers) > 0 else h[0]
96 | # Calculate graph representation by averaging all the node representations.
97 | hg = dgl.mean_nodes(g, 'h')
98 | return self.classify(hg).squeeze()
99 |
100 | def training_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int) -> torch.Tensor:
101 | bg, label = batch
102 | logits = self.forward(bg)
103 | loss = self.loss_func(logits, label)
104 | prediction = torch.sigmoid(logits)
105 | for metric_name, metric in self.train_metrics.items():
106 | metric.update(prediction, label)
107 | self.log('train_loss', loss, on_step=True, on_epoch=True)
108 | return loss
109 |
110 | def validation_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int):
111 | bg, label = batch
112 | logits = self.forward(bg)
113 | loss = self.loss_func(logits, label)
114 | prediction = torch.sigmoid(logits)
115 | for metric_name, metric in self.val_metrics.items():
116 | metric.update(prediction, label)
117 | self.log('val_loss', loss, on_step=False, on_epoch=True)
118 | return loss
119 |
120 | def test_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int):
121 | bg, label = batch
122 | logits = self.forward(bg)
123 | prediction = torch.sigmoid(logits)
124 | loss = self.loss_func(logits, label)
125 | for metric_name, metric in self.test_metrics.items():
126 | metric.update(prediction, label)
127 | for metric_name, metric in self.test_outputs.items():
128 | metric.update(prediction, label)
129 | self.log('test_loss', loss, on_step=False, on_epoch=True)
130 | return loss
131 |
132 | def configure_optimizers(self) -> torch.optim.Adam:
133 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
134 | return optimizer
135 |
--------------------------------------------------------------------------------
/core/utils.py:
--------------------------------------------------------------------------------
1 | import matplotlib.pyplot as plt
2 | import numpy as np
3 | import plotly.express as px
4 | import pytorch_lightning.metrics.functional as M
5 | import seaborn as sns
6 |
7 |
8 | def plot_curve(x, y, curve_type):
9 | """
10 | Plots ROC or PRC
11 | Inspired from https://plotly.com/python/roc-and-pr-curves/
12 | :param x: The x co-ordinates
13 | :param y: The y co-ordinates
14 | :param curve_type: one of 'roc' or 'prc'
15 | :return: Plotly figure
16 | """
17 | auc = M.classification.auc(x, y)
18 | x, y = x.numpy(), y.numpy()
19 | if curve_type == 'roc':
20 | title = f"ROC, AUC = {auc}"
21 | labels = dict(x='FPR', y='TPR')
22 | elif curve_type == 'prc':
23 | title = f"PRC, mAP = {auc}"
24 | labels = dict(x='Recall', y='Precision')
25 | else:
26 | raise ValueError(f"Invalid curve type - {curve_type}. Must be one of 'roc' or 'prc'.")
27 | fig = px.area(x=x, y=y, labels=labels, title=title)
28 | fig.update_yaxes(scaleanchor="x", scaleratio=1)
29 | fig.update_xaxes(constrain='domain')
30 | return fig
31 |
32 |
33 | def plot_confusion_matrix(cf,
34 | group_names=None,
35 | categories='auto',
36 | count=True,
37 | percent=True,
38 | cbar=True,
39 | xyticks=True,
40 | xyplotlabels=True,
41 | sum_stats=True,
42 | fig_size=None,
43 | cmap='Blues',
44 | title=None):
45 | '''
46 | From https://github.com/DTrimarchi10/confusion_matrix/blob/master/cf_matrix.py
47 | Blog https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea
48 | This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization.
49 | Arguments
50 | ---------
51 | cf: confusion matrix to be passed in
52 | group_names: List of strings that represent the labels row by row to be shown in each square.
53 | categories: List of strings containing the categories to be displayed on the x,y axis. Default is 'auto'
54 | count: If True, show the raw number in the confusion matrix. Default is True.
55 | normalize: If True, show the proportions for each category. Default is True.
56 | cbar: If True, show the color bar. The cbar values are based off the values in the confusion matrix.
57 | Default is True.
58 | xyticks: If True, show x and y ticks. Default is True.
59 | xyplotlabels: If True, show 'True Label' and 'Predicted Label' on the figure. Default is True.
60 | sum_stats: If True, display summary statistics below the figure. Default is True.
61 | fig_size: Tuple representing the figure size. Default will be the matplotlib rcParams value.
62 | cmap: Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues'
63 | See http://matplotlib.org/examples/color/colormaps_reference.html
64 | title: Title for the heatmap. Default is None.
65 | '''
66 | plt.clf()
67 | # CODE TO GENERATE TEXT INSIDE EACH SQUARE
68 | blanks = ['' for i in range(cf.size)]
69 |
70 | if group_names and len(group_names) == cf.size:
71 | group_labels = ["{}\n".format(value) for value in group_names]
72 | else:
73 | group_labels = blanks
74 |
75 | if count:
76 | group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
77 | else:
78 | group_counts = blanks
79 |
80 | if percent:
81 | group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)]
82 | else:
83 | group_percentages = blanks
84 |
85 | box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)]
86 | box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1])
87 |
88 | # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
89 | if sum_stats:
90 | # Accuracy is sum of diagonal divided by total observations
91 | accuracy = np.trace(cf) / float(np.sum(cf))
92 |
93 | # if it is a binary confusion matrix, show some more stats
94 | if len(cf) == 2:
95 | # Metrics for Binary Confusion Matrices
96 | precision = cf[1, 1] / sum(cf[:, 1])
97 | recall = cf[1, 1] / sum(cf[1, :])
98 | f1_score = 2 * precision * recall / (precision + recall)
99 | stats_text = "\n\nAccuracy={:0.4f}\nPrecision={:0.4f}\nRecall={:0.4f}\nF1 Score={:0.4f}".format(
100 | accuracy, precision, recall, f1_score)
101 | else:
102 | stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
103 | else:
104 | stats_text = ""
105 |
106 | # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
107 | if fig_size is None:
108 | # Get default figure size if not set
109 | fig_size = plt.rcParams.get('figure.figsize')
110 |
111 | if not xyticks:
112 | # Do not show categories if xyticks is False
113 | categories = False
114 |
115 | # MAKE THE HEATMAP VISUALIZATION
116 | plt.figure(figsize=fig_size)
117 | sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, xticklabels=categories, yticklabels=categories)
118 |
119 | if xyplotlabels:
120 | plt.ylabel('True label')
121 | plt.xlabel('Predicted label' + stats_text)
122 | else:
123 | plt.xlabel(stats_text)
124 |
125 | if title:
126 | plt.title(title)
127 |
128 | plt.tight_layout()
129 | plt.savefig("CM.png")
130 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Datasets
2 |
3 | To conduct the experiments, 2 datasets were used.
4 |
5 | 1. [`MalDroid2020`](https://www.unb.ca/cic/datasets/maldroid-2020.html)
6 | 2. [`AndroZoo`](https://androzoo.uni.lu/)
7 |
8 | While `MalDroid2020` was used as a base, `AndroZoo` was used to collect new benign APKs.
9 |
10 | ## Format
11 | This directory contains the hashes of APKs in `*.sha256` files.
12 | Each line in the file consists of `sha256 name` pairs.
13 | There are 4 such files
14 |
15 | 1. `train_old.sha256` - The training samples from `MalDroid2020`
16 | 2. `train_new.sha256` - The training samples from `AndroZoo`
17 | 3. `test_old.sha256` - The testing samples from `MalDroid2020`
18 | 4. `test_new.sha256` - The testing samples from `AndroZoo`
--------------------------------------------------------------------------------
/malware-learning.def:
--------------------------------------------------------------------------------
1 | Bootstrap: docker
2 | From: pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel
3 |
4 | %files
5 | requirements.txt /mnt/requirements.txt
6 | dgl-0.6-cp38-cp38-linux_x86_64.whl /mnt/dgl-0.6-cp38-cp38-linux_x86_64.whl
7 |
8 | %post
9 | apt-get update && apt-get install -y git graphviz graphviz-dev && rm -rf /var/lib/apt/lists/*
10 | pip install -r /mnt/requirements.txt
11 | pip install /mnt/dgl-0.6-cp38-cp38-linux_x86_64.whl
12 | cd /mnt
13 | git clone --depth 1 https://github.com/androguard/androguard.git && cd androguard && python setup.py install
--------------------------------------------------------------------------------
/metadata/api.list:
--------------------------------------------------------------------------------
1 | android
2 | android.accessibilityservice
3 | android.accounts
4 | android.animation
5 | android.annotation
6 | android.app
7 | android.app.admin
8 | android.app.assist
9 | android.app.backup
10 | android.app.blob
11 | android.app.job
12 | android.app.role
13 | android.app.slice
14 | android.app.usage
15 | android.appwidget
16 | android.bluetooth
17 | android.bluetooth.le
18 | android.companion
19 | android.content
20 | android.content.pm
21 | android.content.res
22 | android.content.res.loader
23 | android.database
24 | android.database.sqlite
25 | android.drm
26 | android.gesture
27 | android.graphics
28 | android.graphics.drawable
29 | android.graphics.drawable.shapes
30 | android.graphics.fonts
31 | android.graphics.pdf
32 | android.graphics.text
33 | android.hardware
34 | android.hardware.biometrics
35 | android.hardware.camera2
36 | android.hardware.camera2.params
37 | android.hardware.display
38 | android.hardware.fingerprint
39 | android.hardware.input
40 | android.hardware.usb
41 | android.icu.lang
42 | android.icu.math
43 | android.icu.number
44 | android.icu.text
45 | android.icu.util
46 | android.inputmethodservice
47 | android.location
48 | android.media
49 | android.media.audiofx
50 | android.media.browse
51 | android.media.effect
52 | android.media.midi
53 | android.media.projection
54 | android.media.session
55 | android.media.tv
56 | android.mtp
57 | android.net
58 | android.net.http
59 | android.net.nsd
60 | android.net.rtp
61 | android.net.sip
62 | android.net.ssl
63 | android.net.wifi
64 | android.net.wifi.aware
65 | android.net.wifi.hotspot2
66 | android.net.wifi.hotspot2.omadm
67 | android.net.wifi.hotspot2.pps
68 | android.net.wifi.p2p
69 | android.net.wifi.p2p.nsd
70 | android.net.wifi.rtt
71 | android.nfc
72 | android.nfc.cardemulation
73 | android.nfc.tech
74 | android.opengl
75 | android.os
76 | android.os.health
77 | android.os.storage
78 | android.os.strictmode
79 | android.preference
80 | android.print
81 | android.print.pdf
82 | android.printservice
83 | android.provider
84 | android.renderscript
85 | android.sax
86 | android.se.omapi
87 | android.security
88 | android.security.identity
89 | android.security.keystore
90 | android.service.autofill
91 | android.service.carrier
92 | android.service.chooser
93 | android.service.controls
94 | android.service.controls.actions
95 | android.service.controls.templates
96 | android.service.dreams
97 | android.service.media
98 | android.service.notification
99 | android.service.quickaccesswallet
100 | android.service.quicksettings
101 | android.service.restrictions
102 | android.service.textservice
103 | android.service.voice
104 | android.service.vr
105 | android.service.wallpaper
106 | android.speech
107 | android.speech.tts
108 | android.system
109 | android.telecom
110 | android.telephony
111 | android.telephony.cdma
112 | android.telephony.data
113 | android.telephony.emergency
114 | android.telephony.euicc
115 | android.telephony.gsm
116 | android.telephony.ims
117 | android.telephony.ims.feature
118 | android.telephony.mbms
119 | android.test
120 | android.test.mock
121 | android.test.suitebuilder
122 | android.test.suitebuilder.annotation
123 | android.text
124 | android.text.format
125 | android.text.method
126 | android.text.style
127 | android.text.util
128 | android.transition
129 | android.util
130 | android.util.proto
131 | android.view
132 | android.view.accessibility
133 | android.view.animation
134 | android.view.autofill
135 | android.view.contentcapture
136 | android.view.inputmethod
137 | android.view.inspector
138 | android.view.textclassifier
139 | android.view.textservice
140 | android.webkit
141 | android.widget
142 | android.widget.inline
143 | com.google.android.collect
144 | com.google.android.gles_jni
145 | com.google.android.util
146 | dalvik.annotation
147 | dalvik.bytecode
148 | dalvik.system
149 | java.awt.font
150 | java.beans
151 | java.io
152 | java.lang
153 | java.lang.annotation
154 | java.lang.invoke
155 | java.lang.ref
156 | java.lang.reflect
157 | java.math
158 | java.net
159 | java.nio
160 | java.nio.channels
161 | java.nio.channels.spi
162 | java.nio.charset
163 | java.nio.charset.spi
164 | java.nio.file
165 | java.nio.file.attribute
166 | java.nio.file.spi
167 | java.security
168 | java.security.acl
169 | java.security.cert
170 | java.security.interfaces
171 | java.security.spec
172 | java.sql
173 | java.text
174 | java.time
175 | java.time.chrono
176 | java.time.format
177 | java.time.temporal
178 | java.time.zone
179 | java.util
180 | java.util.concurrent
181 | java.util.concurrent.atomic
182 | java.util.concurrent.locks
183 | java.util.function
184 | java.util.jar
185 | java.util.logging
186 | java.util.prefs
187 | java.util.regex
188 | java.util.stream
189 | java.util.zip
190 | javax.crypto
191 | javax.crypto.interfaces
192 | javax.crypto.spec
193 | javax.microedition.khronos.egl
194 | javax.microedition.khronos.opengles
195 | javax.net
196 | javax.net.ssl
197 | javax.security.auth
198 | javax.security.auth.callback
199 | javax.security.auth.login
200 | javax.security.auth.x500
201 | javax.security.cert
202 | javax.sql
203 | javax.xml
204 | javax.xml.datatype
205 | javax.xml.namespace
206 | javax.xml.parsers
207 | javax.xml.transform
208 | javax.xml.transform.dom
209 | javax.xml.transform.sax
210 | javax.xml.transform.stream
211 | javax.xml.validation
212 | javax.xml.xpath
213 | junit.framework
214 | junit.runner
215 | org.apache.http.conn
216 | org.apache.http.conn.scheme
217 | org.apache.http.conn.ssl
218 | org.apache.http.params
219 | org.json
220 | org.w3c.dom
221 | org.w3c.dom.ls
222 | org.xml.sax
223 | org.xml.sax.ext
224 | org.xml.sax.helpers
225 | org.xmlpull.v1
226 | org.xmlpull.v1.sax2
--------------------------------------------------------------------------------
/notebooks/2-GFeatures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "Using backend: pytorch\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "import re\n",
18 | "import dgl\n",
19 | "import torch\n",
20 | "from torch.utils.data import DataLoader\n",
21 | "\n",
22 | "import torch.nn as nn\n",
23 | "import torch.nn.functional as F\n",
24 | "import networkx as nx\n",
25 | "\n",
26 | "from pathlib import Path\n",
27 | "from androguard.misc import AnalyzeAPK\n",
28 | "import pickle\n",
29 | "import pytorch_lightning as pl\n",
30 | "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
31 | "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n",
32 | "import sklearn.metrics as M\n",
33 | "\n",
34 | "from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv\n",
35 | "from sklearn.model_selection import StratifiedShuffleSplit\n",
36 | "\n",
37 | "import joblib as J"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "#%xmode verbose"
47 | ]
48 | },
49 | {
50 | "cell_type": "markdown",
51 | "metadata": {},
52 | "source": [
53 | "## Params"
54 | ]
55 | },
56 | {
57 | "cell_type": "code",
58 | "execution_count": 3,
59 | "metadata": {},
60 | "outputs": [],
61 | "source": [
62 | "model_kwargs = {'in_dim': 15, 'hidden_dim': 30, 'n_classes': 5 }"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 4,
68 | "metadata": {},
69 | "outputs": [],
70 | "source": [
71 | "train = False"
72 | ]
73 | },
74 | {
75 | "cell_type": "code",
76 | "execution_count": 5,
77 | "metadata": {},
78 | "outputs": [],
79 | "source": [
80 | "extract = False"
81 | ]
82 | },
83 | {
84 | "cell_type": "markdown",
85 | "metadata": {},
86 | "source": [
87 | "## Dataset"
88 | ]
89 | },
90 | {
91 | "cell_type": "code",
92 | "execution_count": 6,
93 | "metadata": {},
94 | "outputs": [],
95 | "source": [
96 | "def get_samples(base_path):\n",
97 | " base_path = Path(base_path)\n",
98 | " labels_dict = {x:i for i,x in enumerate(sorted([\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"]))}\n",
99 | " if not base_path.exists():\n",
100 | " raise Exception(f'{base_path} does not exist')\n",
101 | " apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])\n",
102 | " samples = []\n",
103 | " labels = {}\n",
104 | " for apk in apk_list:\n",
105 | " samples.append(apk.name)\n",
106 | " labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]\n",
107 | " return samples, labels"
108 | ]
109 | },
110 | {
111 | "cell_type": "code",
112 | "execution_count": 7,
113 | "metadata": {},
114 | "outputs": [],
115 | "source": [
116 | "samples, labels = get_samples('../data/large/raw')"
117 | ]
118 | },
119 | {
120 | "cell_type": "code",
121 | "execution_count": 8,
122 | "metadata": {},
123 | "outputs": [
124 | {
125 | "data": {
126 | "text/plain": [
127 | "'Adware0000.apk'"
128 | ]
129 | },
130 | "execution_count": 8,
131 | "metadata": {},
132 | "output_type": "execute_result"
133 | }
134 | ],
135 | "source": [
136 | "samples[0]"
137 | ]
138 | },
139 | {
140 | "cell_type": "code",
141 | "execution_count": 9,
142 | "metadata": {},
143 | "outputs": [],
144 | "source": [
145 | "raw_prefix = Path('../data/large/raw')\n",
146 | "processed_prefix = Path('../data/large/G-feat')"
147 | ]
148 | },
149 | {
150 | "cell_type": "code",
151 | "execution_count": 10,
152 | "metadata": {},
153 | "outputs": [],
154 | "source": [
155 | "def process(file):\n",
156 | " _, _, dx = AnalyzeAPK(raw_prefix/file)\n",
157 | " cg = dx.get_call_graph()\n",
158 | " opcodes = {}\n",
159 | " for node in cg.nodes():\n",
160 | " sequence = [0] * 15\n",
161 | " if not node.is_external():\n",
162 | " for instr in node.get_method().get_instructions():\n",
163 | " value = instr.get_op_value()\n",
164 | " if value == 0x00: # nop\n",
165 | " sequence[0] = 1\n",
166 | " elif value >= 0x01 and value <= 0x0D: # mov\n",
167 | " sequence[1] = 1\n",
168 | " elif value >= 0x0E and value <= 0x11: # return\n",
169 | " sequence[2] = 1\n",
170 | " elif value == 0x1D or value == 0x1E: # monitor\n",
171 | " sequence[3] = 1\n",
172 | " elif value >= 0x32 and value <= 0x3D: # if\n",
173 | " sequence[4] = 1\n",
174 | " elif value == 0x27: # throw\n",
175 | " sequence[5] = 1\n",
176 | " elif value == 0x28 or value == 0x29: #goto\n",
177 | " sequence[6] = 1\n",
178 | " elif value >= 0x2F and value <= 0x31: # compare\n",
179 | " sequence[7] = 1\n",
180 | " elif value >= 0x7F and value <= 0x8F: # unop\n",
181 | " sequence[8] = 1\n",
182 | " elif value >=90 and value <= 0xE2: # binop\n",
183 | " sequence[9] = 1\n",
184 | " elif value == 0x21 or (value >= 0x23 and value <= 0x26) or (value >= 0x44 and value <= 0x51): # aop\n",
185 | " sequence[10] = 1\n",
186 | " elif (value >= 0x52 and value <= 0x5F) or (value >= 0xF2 and value <= 0xF7): # instanceop\n",
187 | " sequence[11] = 1\n",
188 | " elif (value >= 0x60 and value <= 0x6D): # staticop\n",
189 | " sequence[12] = 1\n",
190 | " elif (value >= 0x6E and value <= 0x72) and (value >= 0x74 and value <= 0x78) and (value >= 0xF9 and value <= 0xFB):\n",
191 | " sequence[13] = 1\n",
192 | " elif (value >= 0x22 and value <= 0x25):\n",
193 | " sequence[14] = 1\n",
194 | " opcodes[node] = {'sequence': sequence}\n",
195 | " nx.set_node_attributes(cg, opcodes)\n",
196 | " labels = {x: {'name': x.full_name} for x in cg.nodes()}\n",
197 | " nx.set_node_attributes(cg, labels)\n",
198 | " cg = nx.convert_node_labels_to_integers(cg)\n",
199 | " torch.save(cg, processed_prefix/ (file.split('.')[0]+'.graph'))"
200 | ]
201 | },
202 | {
203 | "cell_type": "code",
204 | "execution_count": 11,
205 | "metadata": {},
206 | "outputs": [],
207 | "source": [
208 | "if extract:\n",
209 | " J.Parallel(n_jobs=40)(J.delayed(process)(x) for x in samples);"
210 | ]
211 | },
212 | {
213 | "cell_type": "code",
214 | "execution_count": 12,
215 | "metadata": {},
216 | "outputs": [],
217 | "source": [
218 | "class MalwareDataset(torch.utils.data.Dataset):\n",
219 | " def __init__(self, save_dir, list_IDs, labels):\n",
220 | " self.save_dir = Path(save_dir)\n",
221 | " self.list_IDs = list_IDs\n",
222 | " self.labels = labels\n",
223 | " self.cache = {}\n",
224 | "\n",
225 | " def __len__(self):\n",
226 | " 'Denotes the total number of samples'\n",
227 | " return len(self.list_IDs)\n",
228 | "\n",
229 | " def __getitem__(self, index):\n",
230 | " 'Generates one sample of data'\n",
231 | " # Select sample\n",
232 | " if index not in self.cache:\n",
233 | " ID = self.list_IDs[index]\n",
234 | " graph_path = self.save_dir / (ID.split('.')[0] + '.graph')\n",
235 | " cg = torch.load(graph_path)\n",
236 | " dg = dgl.from_networkx(cg, node_attrs=['sequence'], edge_attrs=['offset'])\n",
237 | " dg = dgl.add_self_loop(dg)\n",
238 | " self.cache[index] = (dg, self.labels[ID])\n",
239 | " return self.cache[index]"
240 | ]
241 | },
242 | {
243 | "cell_type": "markdown",
244 | "metadata": {},
245 | "source": [
246 | "## Data Loading"
247 | ]
248 | },
249 | {
250 | "cell_type": "code",
251 | "execution_count": 13,
252 | "metadata": {},
253 | "outputs": [],
254 | "source": [
255 | "def split_dataset(samples, labels, ratios):\n",
256 | " if sum(ratios) != 1:\n",
257 | " raise Exception(\"Invalid ratios provided\")\n",
258 | " train_ratio, val_ratio, test_ratio = ratios\n",
259 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)\n",
260 | " train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]\n",
261 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)\n",
262 | " test_list = [samples[x] for x in test_idx]\n",
263 | " train_list = [samples[x] for x in train_idx]\n",
264 | " train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]\n",
265 | " train_list = [samples[x] for x in train_idx]\n",
266 | " val_list = [samples[x] for x in val_idx]\n",
267 | " return train_list, val_list, test_list"
268 | ]
269 | },
270 | {
271 | "cell_type": "code",
272 | "execution_count": 14,
273 | "metadata": {},
274 | "outputs": [],
275 | "source": [
276 | "train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])"
277 | ]
278 | },
279 | {
280 | "cell_type": "code",
281 | "execution_count": 15,
282 | "metadata": {},
283 | "outputs": [
284 | {
285 | "data": {
286 | "text/plain": [
287 | "tensor([0.6000, 0.2000, 0.2000])"
288 | ]
289 | },
290 | "execution_count": 15,
291 | "metadata": {},
292 | "output_type": "execute_result"
293 | }
294 | ],
295 | "source": [
296 | "torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)"
297 | ]
298 | },
299 | {
300 | "cell_type": "code",
301 | "execution_count": 16,
302 | "metadata": {},
303 | "outputs": [],
304 | "source": [
305 | "def collate(samples):\n",
306 | " graphs, labels = [], []\n",
307 | " for graph, label in samples:\n",
308 | " graphs.append(graph)\n",
309 | " labels.append(label)\n",
310 | " batched_graph = dgl.batch(graphs)\n",
311 | " return batched_graph, torch.tensor(labels)"
312 | ]
313 | },
314 | {
315 | "cell_type": "code",
316 | "execution_count": 17,
317 | "metadata": {},
318 | "outputs": [],
319 | "source": [
320 | "train_dataset = MalwareDataset(processed_prefix , train_list, labels)\n",
321 | "val_dataset = MalwareDataset(processed_prefix , val_list, labels)\n",
322 | "test_dataset = MalwareDataset(processed_prefix , test_list, labels)"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 18,
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "train_data = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate, num_workers=8)\n",
332 | "val_data = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate , num_workers=40)\n",
333 | "test_data = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate, num_workers=4)"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": 19,
339 | "metadata": {},
340 | "outputs": [
341 | {
342 | "data": {
343 | "text/plain": [
344 | "0"
345 | ]
346 | },
347 | "execution_count": 19,
348 | "metadata": {},
349 | "output_type": "execute_result"
350 | }
351 | ],
352 | "source": [
353 | "len(test_dataset.cache)"
354 | ]
355 | },
356 | {
357 | "cell_type": "markdown",
358 | "metadata": {},
359 | "source": [
360 | "## Model"
361 | ]
362 | },
363 | {
364 | "cell_type": "code",
365 | "execution_count": 20,
366 | "metadata": {},
367 | "outputs": [],
368 | "source": [
369 | "class MalwareClassifier(pl.LightningModule):\n",
370 | " def __init__(self, in_dim, hidden_dim, n_classes):\n",
371 | " super().__init__()\n",
372 | " self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')\n",
373 | " self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')\n",
374 | " self.classify = nn.Linear(hidden_dim, n_classes)\n",
375 | " self.loss_func = nn.CrossEntropyLoss()\n",
376 | " \n",
377 | " \n",
378 | " def forward(self, g):\n",
379 | " h = g.ndata['sequence'].float()\n",
380 | " #h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)\n",
381 | " # h = g.in_degrees().view(-1,1).float()\n",
382 | " # Perform graph convolution and activation function.\n",
383 | " h = F.relu(self.conv1(g, h))\n",
384 | " h = F.relu(self.conv2(g, h))\n",
385 | " g.ndata['h'] = h\n",
386 | " # Calculate graph representation by averaging all the node representations.\n",
387 | " hg = dgl.mean_nodes(g, 'h')\n",
388 | " return self.classify(hg) \n",
389 | " \n",
390 | " def training_step(self, batch, batch_idx):\n",
391 | " bg, label = batch\n",
392 | " #print(\"Outer\", len(label))\n",
393 | " prediction = self.forward(bg)\n",
394 | " loss = self.loss_func(prediction, label)\n",
395 | " return loss\n",
396 | " \n",
397 | " def validation_step(self, batch, batch_idx):\n",
398 | " bg, label = batch\n",
399 | " prediction = self.forward(bg)\n",
400 | " loss = self.loss_func(prediction, label)\n",
401 | " self.log('val_loss', loss)\n",
402 | " \n",
403 | " def configure_optimizers(self):\n",
404 | " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
405 | " return optimizer"
406 | ]
407 | },
408 | {
409 | "cell_type": "code",
410 | "execution_count": 21,
411 | "metadata": {},
412 | "outputs": [],
413 | "source": [
414 | "callbacks = [\n",
415 | " EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),\n",
416 | "]"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 22,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": [
425 | "checkpointer = ModelCheckpoint(filepath='../models/3Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min')"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": 23,
431 | "metadata": {},
432 | "outputs": [
433 | {
434 | "name": "stderr",
435 | "output_type": "stream",
436 | "text": [
437 | "GPU available: True, used: True\n",
438 | "TPU available: False, using: 0 TPU cores\n",
439 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]\n"
440 | ]
441 | }
442 | ],
443 | "source": [
444 | "classifier= MalwareClassifier(**model_kwargs)\n",
445 | "trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, gpus=[2])"
446 | ]
447 | },
448 | {
449 | "cell_type": "code",
450 | "execution_count": 24,
451 | "metadata": {},
452 | "outputs": [
453 | {
454 | "data": {
455 | "text/plain": [
456 | "False"
457 | ]
458 | },
459 | "execution_count": 24,
460 | "metadata": {},
461 | "output_type": "execute_result"
462 | }
463 | ],
464 | "source": [
465 | "train"
466 | ]
467 | },
468 | {
469 | "cell_type": "code",
470 | "execution_count": 25,
471 | "metadata": {},
472 | "outputs": [],
473 | "source": [
474 | "if train:\n",
475 | " trainer.fit(classifier, train_data, val_data)"
476 | ]
477 | },
478 | {
479 | "cell_type": "markdown",
480 | "metadata": {},
481 | "source": [
482 | "## Testing "
483 | ]
484 | },
485 | {
486 | "cell_type": "code",
487 | "execution_count": 26,
488 | "metadata": {},
489 | "outputs": [],
490 | "source": [
491 | "classifier_saved = MalwareClassifier.load_from_checkpoint('../models/3Nov-epoch=36-val_loss=0.51.pt.ckpt', **model_kwargs)"
492 | ]
493 | },
494 | {
495 | "cell_type": "code",
496 | "execution_count": 33,
497 | "metadata": {},
498 | "outputs": [
499 | {
500 | "data": {
501 | "text/plain": [
502 | "tensor([[ -2.0630, 1.1638, -11.9895, 5.1457, -1.6285]])"
503 | ]
504 | },
505 | "execution_count": 33,
506 | "metadata": {},
507 | "output_type": "execute_result"
508 | }
509 | ],
510 | "source": [
511 | "classifier_saved(train_dataset[0][0])"
512 | ]
513 | },
514 | {
515 | "cell_type": "code",
516 | "execution_count": 31,
517 | "metadata": {},
518 | "outputs": [],
519 | "source": [
520 | "classifier_saved.freeze()"
521 | ]
522 | },
523 | {
524 | "cell_type": "code",
525 | "execution_count": 43,
526 | "metadata": {},
527 | "outputs": [
528 | {
529 | "data": {
530 | "text/plain": [
531 | "tensor([4, 2, 3, ..., 1, 3, 2])"
532 | ]
533 | },
534 | "execution_count": 43,
535 | "metadata": {},
536 | "output_type": "execute_result"
537 | }
538 | ],
539 | "source": [
540 | "predicted = torch.argmax(classifier_saved(dgl.batch([g for g,l in test_dataset])),dim=1)\n",
541 | "predicted"
542 | ]
543 | },
544 | {
545 | "cell_type": "code",
546 | "execution_count": 36,
547 | "metadata": {},
548 | "outputs": [
549 | {
550 | "data": {
551 | "text/plain": [
552 | "3302"
553 | ]
554 | },
555 | "execution_count": 36,
556 | "metadata": {},
557 | "output_type": "execute_result"
558 | }
559 | ],
560 | "source": [
561 | "len(test_dataset)"
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": 37,
567 | "metadata": {},
568 | "outputs": [
569 | {
570 | "data": {
571 | "text/plain": [
572 | "3302"
573 | ]
574 | },
575 | "execution_count": 37,
576 | "metadata": {},
577 | "output_type": "execute_result"
578 | }
579 | ],
580 | "source": [
581 | "len(test_dataset.cache)"
582 | ]
583 | },
584 | {
585 | "cell_type": "code",
586 | "execution_count": 44,
587 | "metadata": {},
588 | "outputs": [
589 | {
590 | "data": {
591 | "text/plain": [
592 | "tensor([4, 2, 3, ..., 3, 3, 2])"
593 | ]
594 | },
595 | "execution_count": 44,
596 | "metadata": {},
597 | "output_type": "execute_result"
598 | }
599 | ],
600 | "source": [
601 | "actual = torch.tensor([l for g,l in test_dataset])\n",
602 | "actual"
603 | ]
604 | },
605 | {
606 | "cell_type": "code",
607 | "execution_count": 45,
608 | "metadata": {},
609 | "outputs": [
610 | {
611 | "name": "stdout",
612 | "output_type": "stream",
613 | "text": [
614 | " precision recall f1-score support\n",
615 | "\n",
616 | " 0 0.8911 0.7318 0.8036 302\n",
617 | " 1 0.6159 0.8107 0.7000 449\n",
618 | " 2 0.9124 0.9282 0.9202 808\n",
619 | " 3 0.8524 0.7856 0.8176 779\n",
620 | " 4 0.9707 0.9295 0.9497 964\n",
621 | "\n",
622 | " accuracy 0.8610 3302\n",
623 | " macro avg 0.8485 0.8372 0.8382 3302\n",
624 | "weighted avg 0.8730 0.8610 0.8640 3302\n",
625 | "\n"
626 | ]
627 | }
628 | ],
629 | "source": [
630 | "print(M.classification_report(actual, predicted, digits=4))"
631 | ]
632 | },
633 | {
634 | "cell_type": "code",
635 | "execution_count": 46,
636 | "metadata": {},
637 | "outputs": [
638 | {
639 | "data": {
640 | "text/plain": [
641 | "array([[221, 38, 5, 37, 1],\n",
642 | " [ 5, 364, 19, 39, 22],\n",
643 | " [ 5, 29, 750, 23, 1],\n",
644 | " [ 10, 106, 48, 612, 3],\n",
645 | " [ 7, 54, 0, 7, 896]])"
646 | ]
647 | },
648 | "execution_count": 46,
649 | "metadata": {},
650 | "output_type": "execute_result"
651 | }
652 | ],
653 | "source": [
654 | "M.confusion_matrix(actual, predicted)"
655 | ]
656 | },
657 | {
658 | "cell_type": "code",
659 | "execution_count": null,
660 | "metadata": {},
661 | "outputs": [],
662 | "source": [
663 | "\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\""
664 | ]
665 | },
666 | {
667 | "cell_type": "markdown",
668 | "metadata": {},
669 | "source": [
670 | "## Results\n",
671 | "Accuracy - 86.10%,\n",
672 | "Precision - 0.8485,\n",
673 | "Recall - 0.8372,\n",
674 | "F1 - 0.8382"
675 | ]
676 | }
677 | ],
678 | "metadata": {
679 | "kernelspec": {
680 | "display_name": "Python 3",
681 | "language": "python",
682 | "name": "python3"
683 | },
684 | "language_info": {
685 | "codemirror_mode": {
686 | "name": "ipython",
687 | "version": 3
688 | },
689 | "file_extension": ".py",
690 | "mimetype": "text/x-python",
691 | "name": "python",
692 | "nbconvert_exporter": "python",
693 | "pygments_lexer": "ipython3",
694 | "version": "3.6.9"
695 | }
696 | },
697 | "nbformat": 4,
698 | "nbformat_minor": 4
699 | }
700 |
--------------------------------------------------------------------------------
/notebooks/3-APIFeatures.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {},
7 | "outputs": [
8 | {
9 | "name": "stderr",
10 | "output_type": "stream",
11 | "text": [
12 | "Using backend: pytorch\n"
13 | ]
14 | }
15 | ],
16 | "source": [
17 | "import re\n",
18 | "import dgl\n",
19 | "import torch\n",
20 | "from torch.utils.data import DataLoader\n",
21 | "\n",
22 | "import torch.nn as nn\n",
23 | "import torch.nn.functional as F\n",
24 | "import networkx as nx\n",
25 | "\n",
26 | "from pathlib import Path\n",
27 | "from androguard.misc import AnalyzeAPK\n",
28 | "import pickle\n",
29 | "import pytorch_lightning as pl\n",
30 | "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n",
31 | "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n",
32 | "import sklearn.metrics as M\n",
33 | "\n",
34 | "from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv\n",
35 | "from sklearn.model_selection import StratifiedShuffleSplit\n",
36 | "\n",
37 | "import joblib as J"
38 | ]
39 | },
40 | {
41 | "cell_type": "code",
42 | "execution_count": 2,
43 | "metadata": {},
44 | "outputs": [],
45 | "source": [
46 | "def get_api_list(file):\n",
47 | " apis = open(file).readlines()\n",
48 | " return {x.strip(): i for i, x in enumerate(apis)}"
49 | ]
50 | },
51 | {
52 | "cell_type": "code",
53 | "execution_count": 3,
54 | "metadata": {},
55 | "outputs": [],
56 | "source": [
57 | "api_list = get_api_list('api.list')"
58 | ]
59 | },
60 | {
61 | "cell_type": "code",
62 | "execution_count": 4,
63 | "metadata": {},
64 | "outputs": [
65 | {
66 | "data": {
67 | "text/plain": [
68 | "226"
69 | ]
70 | },
71 | "execution_count": 4,
72 | "metadata": {},
73 | "output_type": "execute_result"
74 | }
75 | ],
76 | "source": [
77 | "len(api_list)"
78 | ]
79 | },
80 | {
81 | "cell_type": "markdown",
82 | "metadata": {},
83 | "source": [
84 | "## Params"
85 | ]
86 | },
87 | {
88 | "cell_type": "code",
89 | "execution_count": 5,
90 | "metadata": {},
91 | "outputs": [],
92 | "source": [
93 | "model_kwargs = {'in_dim': len(api_list), 'hidden_dim': 64, 'n_classes': 5 }"
94 | ]
95 | },
96 | {
97 | "cell_type": "code",
98 | "execution_count": 6,
99 | "metadata": {},
100 | "outputs": [],
101 | "source": [
102 | "train = True"
103 | ]
104 | },
105 | {
106 | "cell_type": "code",
107 | "execution_count": 7,
108 | "metadata": {},
109 | "outputs": [],
110 | "source": [
111 | "extract = True"
112 | ]
113 | },
114 | {
115 | "cell_type": "markdown",
116 | "metadata": {},
117 | "source": [
118 | "## Dataset"
119 | ]
120 | },
121 | {
122 | "cell_type": "code",
123 | "execution_count": 8,
124 | "metadata": {},
125 | "outputs": [],
126 | "source": [
127 | "def get_samples(base_path):\n",
128 | " base_path = Path(base_path)\n",
129 | " labels_dict = {x:i for i,x in enumerate(sorted([\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"]))}\n",
130 | " if not base_path.exists():\n",
131 | " raise Exception(f'{base_path} does not exist')\n",
132 | " apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])\n",
133 | " samples = []\n",
134 | " labels = {}\n",
135 | " for apk in apk_list:\n",
136 | " samples.append(apk.name)\n",
137 | " labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]\n",
138 | " return samples, labels"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 9,
144 | "metadata": {},
145 | "outputs": [],
146 | "source": [
147 | "samples, labels = get_samples('../data/large/raw')"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 10,
153 | "metadata": {},
154 | "outputs": [],
155 | "source": [
156 | "raw_prefix = Path('../data/large/raw')\n",
157 | "processed_prefix = Path('../data/large/APIFeatures')"
158 | ]
159 | },
160 | {
161 | "cell_type": "code",
162 | "execution_count": 11,
163 | "metadata": {},
164 | "outputs": [],
165 | "source": [
166 | "def process(file):\n",
167 | " _, _, dx = AnalyzeAPK(raw_prefix/file)\n",
168 | " cg = dx.get_call_graph()\n",
169 | " mappings = {}\n",
170 | " #print(set(map(lambda x: x.full_name.split(';')[0][1:], filter(lambda x: x.is_external(), cg.nodes()))))\n",
171 | " #return\n",
172 | " for node in cg.nodes():\n",
173 | " mapping = {\"api_package\": None}\n",
174 | " if node.is_external():\n",
175 | " name = '.'.join(map(str, node.full_name.split(';')[0][1:].split('/')[:-2]))\n",
176 | " index = api_list.get(name, None)\n",
177 | " mapping[\"api_package\"] = index\n",
178 | " mappings[node] = mapping\n",
179 | " nx.set_node_attributes(cg, mappings)\n",
180 | " labels = {x: {'name': x.full_name} for x in cg.nodes()}\n",
181 | " nx.set_node_attributes(cg, labels)\n",
182 | " cg = nx.convert_node_labels_to_integers(cg)\n",
183 | " #return cg\n",
184 | " torch.save(cg, processed_prefix/ (file.split('.')[0]+'.graph'))"
185 | ]
186 | },
187 | {
188 | "cell_type": "code",
189 | "execution_count": null,
190 | "metadata": {},
191 | "outputs": [
192 | {
193 | "name": "stderr",
194 | "output_type": "stream",
195 | "text": [
196 | "/home/vinayak/.local/lib/python3.6/site-packages/joblib/externals/loky/process_executor.py:691: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n",
197 | " \"timeout or by a memory leak.\", UserWarning\n"
198 | ]
199 | }
200 | ],
201 | "source": [
202 | "if extract:\n",
203 | " J.Parallel(n_jobs=40)(J.delayed(process)(x) for x in samples);"
204 | ]
205 | },
206 | {
207 | "cell_type": "code",
208 | "execution_count": 15,
209 | "metadata": {},
210 | "outputs": [],
211 | "source": [
212 | "nx.get_node_attributes(torch.load('../data/large/APIFeatures/Benigh0000.graph'), \"api_package\");"
213 | ]
214 | },
215 | {
216 | "cell_type": "code",
217 | "execution_count": 24,
218 | "metadata": {},
219 | "outputs": [],
220 | "source": [
221 | "class MalwareDataset(torch.utils.data.Dataset):\n",
222 | " def __init__(self, save_dir, list_IDs, labels):\n",
223 | " self.save_dir = Path(save_dir)\n",
224 | " self.list_IDs = list_IDs\n",
225 | " self.labels = labels\n",
226 | " self.cache = {}\n",
227 | "\n",
228 | " def __len__(self):\n",
229 | " 'Denotes the total number of samples'\n",
230 | " return len(self.list_IDs)\n",
231 | " \n",
232 | " def get_node_vector(self, pos):\n",
233 | " vector = torch.zeros(len(api_list))\n",
234 | " if pos:\n",
235 | " vector[pos] = 1\n",
236 | " return vector\n",
237 | "\n",
238 | " def __getitem__(self, index):\n",
239 | " 'Generates one sample of data'\n",
240 | " # Select sample\n",
241 | " if index not in self.cache:\n",
242 | " ID = self.list_IDs[index]\n",
243 | " graph_path = self.save_dir / (ID.split('.')[0] + '.graph')\n",
244 | " cg = torch.load(graph_path)\n",
245 | " feature = {n: self.get_node_vector(pos) for n, pos in nx.get_node_attributes(cg, 'api_package').items()}\n",
246 | " nx.set_node_attributes(cg, feature, 'feature')\n",
247 | " dg = dgl.from_networkx(cg, node_attrs=['feature'])\n",
248 | " dg = dgl.add_self_loop(dg)\n",
249 | " self.cache[index] = (dg, self.labels[ID])\n",
250 | " return self.cache[index]"
251 | ]
252 | },
253 | {
254 | "cell_type": "markdown",
255 | "metadata": {},
256 | "source": [
257 | "## Data Loading"
258 | ]
259 | },
260 | {
261 | "cell_type": "code",
262 | "execution_count": 25,
263 | "metadata": {},
264 | "outputs": [],
265 | "source": [
266 | "def split_dataset(samples, labels, ratios):\n",
267 | " if sum(ratios) != 1:\n",
268 | " raise Exception(\"Invalid ratios provided\")\n",
269 | " train_ratio, val_ratio, test_ratio = ratios\n",
270 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)\n",
271 | " train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]\n",
272 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)\n",
273 | " test_list = [samples[x] for x in test_idx]\n",
274 | " train_list = [samples[x] for x in train_idx]\n",
275 | " train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]\n",
276 | " train_list = [samples[x] for x in train_idx]\n",
277 | " val_list = [samples[x] for x in val_idx]\n",
278 | " return train_list, val_list, test_list"
279 | ]
280 | },
281 | {
282 | "cell_type": "code",
283 | "execution_count": 26,
284 | "metadata": {},
285 | "outputs": [],
286 | "source": [
287 | "train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])"
288 | ]
289 | },
290 | {
291 | "cell_type": "code",
292 | "execution_count": 27,
293 | "metadata": {},
294 | "outputs": [
295 | {
296 | "data": {
297 | "text/plain": [
298 | "tensor([0.6000, 0.2000, 0.2000])"
299 | ]
300 | },
301 | "execution_count": 27,
302 | "metadata": {},
303 | "output_type": "execute_result"
304 | }
305 | ],
306 | "source": [
307 | "torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)"
308 | ]
309 | },
310 | {
311 | "cell_type": "code",
312 | "execution_count": 28,
313 | "metadata": {},
314 | "outputs": [],
315 | "source": [
316 | "def collate(samples):\n",
317 | " graphs, labels = [], []\n",
318 | " for graph, label in samples:\n",
319 | " graphs.append(graph)\n",
320 | " labels.append(label)\n",
321 | " batched_graph = dgl.batch(graphs)\n",
322 | " return batched_graph, torch.tensor(labels)"
323 | ]
324 | },
325 | {
326 | "cell_type": "code",
327 | "execution_count": 29,
328 | "metadata": {},
329 | "outputs": [],
330 | "source": [
331 | "train_dataset = MalwareDataset(processed_prefix , train_list, labels)\n",
332 | "val_dataset = MalwareDataset(processed_prefix , val_list, labels)\n",
333 | "test_dataset = MalwareDataset(processed_prefix , test_list, labels)"
334 | ]
335 | },
336 | {
337 | "cell_type": "code",
338 | "execution_count": 30,
339 | "metadata": {},
340 | "outputs": [
341 | {
342 | "data": {
343 | "text/plain": [
344 | "(0, 0, 0)"
345 | ]
346 | },
347 | "execution_count": 30,
348 | "metadata": {},
349 | "output_type": "execute_result"
350 | }
351 | ],
352 | "source": [
353 | "len(train_dataset.cache), len(val_dataset.cache), len(test_dataset.cache)"
354 | ]
355 | },
356 | {
357 | "cell_type": "code",
358 | "execution_count": 44,
359 | "metadata": {},
360 | "outputs": [],
361 | "source": [
362 | "train_data = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate, num_workers=8)\n",
363 | "val_data = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate , num_workers=40)\n",
364 | "test_data = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate, num_workers=4)"
365 | ]
366 | },
367 | {
368 | "cell_type": "markdown",
369 | "metadata": {},
370 | "source": [
371 | "## Model"
372 | ]
373 | },
374 | {
375 | "cell_type": "code",
376 | "execution_count": 37,
377 | "metadata": {},
378 | "outputs": [],
379 | "source": [
380 | "class MalwareClassifier(pl.LightningModule):\n",
381 | " def __init__(self, in_dim, hidden_dim, n_classes):\n",
382 | " super().__init__()\n",
383 | " self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')\n",
384 | " self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')\n",
385 | " self.classify = nn.Linear(hidden_dim, n_classes)\n",
386 | " self.loss_func = nn.CrossEntropyLoss()\n",
387 | " \n",
388 | " \n",
389 | " def forward(self, g):\n",
390 | " h = g.ndata['feature']\n",
391 | " #h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)\n",
392 | " # h = g.in_degrees().view(-1,1).float()\n",
393 | " # Perform graph convolution and activation function.\n",
394 | " h = F.relu(self.conv1(g, h))\n",
395 | " h = F.relu(self.conv2(g, h))\n",
396 | " g.ndata['h'] = h\n",
397 | " # Calculate graph representation by averaging all the node representations.\n",
398 | " hg = dgl.sum_nodes(g, 'h')\n",
399 | " return self.classify(hg) \n",
400 | " \n",
401 | " def training_step(self, batch, batch_idx):\n",
402 | " bg, label = batch\n",
403 | " #print(\"Outer\", len(label))\n",
404 | " prediction = self.forward(bg)\n",
405 | " loss = self.loss_func(prediction, label)\n",
406 | " return loss\n",
407 | " \n",
408 | " def validation_step(self, batch, batch_idx):\n",
409 | " bg, label = batch\n",
410 | " prediction = self.forward(bg)\n",
411 | " loss = self.loss_func(prediction, label)\n",
412 | " self.log('val_loss', loss)\n",
413 | " \n",
414 | " def configure_optimizers(self):\n",
415 | " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
416 | " return optimizer"
417 | ]
418 | },
419 | {
420 | "cell_type": "code",
421 | "execution_count": 38,
422 | "metadata": {},
423 | "outputs": [],
424 | "source": [
425 | "callbacks = [\n",
426 | " EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),\n",
427 | "]"
428 | ]
429 | },
430 | {
431 | "cell_type": "code",
432 | "execution_count": 39,
433 | "metadata": {},
434 | "outputs": [],
435 | "source": [
436 | "checkpointer = ModelCheckpoint(filepath='../models/10Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min', save_top_k=3)"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 40,
442 | "metadata": {},
443 | "outputs": [
444 | {
445 | "name": "stderr",
446 | "output_type": "stream",
447 | "text": [
448 | "GPU available: True, used: True\n",
449 | "TPU available: False, using: 0 TPU cores\n",
450 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]\n"
451 | ]
452 | }
453 | ],
454 | "source": [
455 | "classifier= MalwareClassifier(**model_kwargs)\n",
456 | "trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, gpus=[3])"
457 | ]
458 | },
459 | {
460 | "cell_type": "code",
461 | "execution_count": 53,
462 | "metadata": {},
463 | "outputs": [
464 | {
465 | "data": {
466 | "text/plain": [
467 | "True"
468 | ]
469 | },
470 | "execution_count": 53,
471 | "metadata": {},
472 | "output_type": "execute_result"
473 | }
474 | ],
475 | "source": [
476 | "train"
477 | ]
478 | },
479 | {
480 | "cell_type": "code",
481 | "execution_count": null,
482 | "metadata": {},
483 | "outputs": [],
484 | "source": [
485 | "iter(train_data).next()"
486 | ]
487 | },
488 | {
489 | "cell_type": "code",
490 | "execution_count": 59,
491 | "metadata": {},
492 | "outputs": [
493 | {
494 | "data": {
495 | "text/plain": [
496 | "8"
497 | ]
498 | },
499 | "execution_count": 59,
500 | "metadata": {},
501 | "output_type": "execute_result"
502 | }
503 | ],
504 | "source": [
505 | "len(train_dataset.cache)"
506 | ]
507 | },
508 | {
509 | "cell_type": "code",
510 | "execution_count": null,
511 | "metadata": {},
512 | "outputs": [
513 | {
514 | "name": "stderr",
515 | "output_type": "stream",
516 | "text": [
517 | "\n",
518 | " | Name | Type | Params\n",
519 | "-----------------------------------------------\n",
520 | "0 | conv1 | SAGEConv | 29 K \n",
521 | "1 | conv2 | SAGEConv | 8 K \n",
522 | "2 | classify | Linear | 325 \n",
523 | "3 | loss_func | CrossEntropyLoss | 0 \n"
524 | ]
525 | },
526 | {
527 | "data": {
528 | "application/vnd.jupyter.widget-view+json": {
529 | "model_id": "",
530 | "version_major": 2,
531 | "version_minor": 0
532 | },
533 | "text/plain": [
534 | "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…"
535 | ]
536 | },
537 | "metadata": {},
538 | "output_type": "display_data"
539 | },
540 | {
541 | "data": {
542 | "application/vnd.jupyter.widget-view+json": {
543 | "model_id": "c0db099f0d424271bd5198e1442156d7",
544 | "version_major": 2,
545 | "version_minor": 0
546 | },
547 | "text/plain": [
548 | "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…"
549 | ]
550 | },
551 | "metadata": {},
552 | "output_type": "display_data"
553 | },
554 | {
555 | "data": {
556 | "application/vnd.jupyter.widget-view+json": {
557 | "model_id": "dd1593bce733449095083d8ac4d4201b",
558 | "version_major": 2,
559 | "version_minor": 0
560 | },
561 | "text/plain": [
562 | "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
563 | ]
564 | },
565 | "metadata": {},
566 | "output_type": "display_data"
567 | },
568 | {
569 | "name": "stderr",
570 | "output_type": "stream",
571 | "text": [
572 | "IOPub message rate exceeded.\n",
573 | "The notebook server will temporarily stop sending output\n",
574 | "to the client in order to avoid crashing it.\n",
575 | "To change this limit, set the config variable\n",
576 | "`--NotebookApp.iopub_msg_rate_limit`.\n",
577 | "\n",
578 | "Current values:\n",
579 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n",
580 | "NotebookApp.rate_limit_window=3.0 (secs)\n",
581 | "\n"
582 | ]
583 | },
584 | {
585 | "data": {
586 | "application/vnd.jupyter.widget-view+json": {
587 | "model_id": "3feb31f9ddda4b1aab25aba9295c0a69",
588 | "version_major": 2,
589 | "version_minor": 0
590 | },
591 | "text/plain": [
592 | "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…"
593 | ]
594 | },
595 | "metadata": {},
596 | "output_type": "display_data"
597 | }
598 | ],
599 | "source": [
600 | "if train:\n",
601 | " trainer.fit(classifier, train_data, val_data)"
602 | ]
603 | },
604 | {
605 | "cell_type": "code",
606 | "execution_count": null,
607 | "metadata": {},
608 | "outputs": [],
609 | "source": []
610 | },
611 | {
612 | "cell_type": "markdown",
613 | "metadata": {},
614 | "source": [
615 | "## Testing "
616 | ]
617 | },
618 | {
619 | "cell_type": "code",
620 | "execution_count": 54,
621 | "metadata": {},
622 | "outputs": [],
623 | "source": [
624 | "classifier_saved = MalwareClassifier.load_from_checkpoint('../models/10Nov-epoch=15-val_loss=0.67.pt.ckpt', **model_kwargs)"
625 | ]
626 | },
627 | {
628 | "cell_type": "code",
629 | "execution_count": null,
630 | "metadata": {},
631 | "outputs": [],
632 | "source": [
633 | "predicted = torch.argmax(classifier(dgl.batch([g for g,l in test_dataset])),dim=1)\n",
634 | "predicted"
635 | ]
636 | },
637 | {
638 | "cell_type": "code",
639 | "execution_count": 60,
640 | "metadata": {},
641 | "outputs": [
642 | {
643 | "data": {
644 | "text/plain": [
645 | "tensor([4, 2, 3, ..., 1, 4, 2])"
646 | ]
647 | },
648 | "execution_count": 60,
649 | "metadata": {},
650 | "output_type": "execute_result"
651 | }
652 | ],
653 | "source": [
654 | "predicted"
655 | ]
656 | },
657 | {
658 | "cell_type": "code",
659 | "execution_count": 61,
660 | "metadata": {},
661 | "outputs": [
662 | {
663 | "data": {
664 | "text/plain": [
665 | "tensor([4, 2, 3, ..., 3, 3, 2])"
666 | ]
667 | },
668 | "execution_count": 61,
669 | "metadata": {},
670 | "output_type": "execute_result"
671 | }
672 | ],
673 | "source": [
674 | "actual = torch.tensor([l for g,l in test_dataset])\n",
675 | "actual"
676 | ]
677 | },
678 | {
679 | "cell_type": "code",
680 | "execution_count": 62,
681 | "metadata": {},
682 | "outputs": [
683 | {
684 | "name": "stdout",
685 | "output_type": "stream",
686 | "text": [
687 | " precision recall f1-score support\n",
688 | "\n",
689 | " 0 0.9139 0.8079 0.8576 302\n",
690 | " 1 0.7181 0.6526 0.6838 449\n",
691 | " 2 0.9272 0.8824 0.9042 808\n",
692 | " 3 0.7611 0.7728 0.7669 779\n",
693 | " 4 0.8641 0.9564 0.9079 964\n",
694 | "\n",
695 | " accuracy 0.8401 3302\n",
696 | " macro avg 0.8369 0.8144 0.8241 3302\n",
697 | "weighted avg 0.8399 0.8401 0.8387 3302\n",
698 | "\n"
699 | ]
700 | }
701 | ],
702 | "source": [
703 | "print(M.classification_report(actual, predicted, digits=4))"
704 | ]
705 | },
706 | {
707 | "cell_type": "code",
708 | "execution_count": 63,
709 | "metadata": {},
710 | "outputs": [
711 | {
712 | "data": {
713 | "text/plain": [
714 | "array([[244, 15, 2, 37, 4],\n",
715 | " [ 7, 293, 21, 70, 58],\n",
716 | " [ 5, 25, 713, 60, 5],\n",
717 | " [ 11, 55, 33, 602, 78],\n",
718 | " [ 0, 20, 0, 22, 922]])"
719 | ]
720 | },
721 | "execution_count": 63,
722 | "metadata": {},
723 | "output_type": "execute_result"
724 | }
725 | ],
726 | "source": [
727 | "M.confusion_matrix(actual, predicted)"
728 | ]
729 | },
730 | {
731 | "cell_type": "code",
732 | "execution_count": 64,
733 | "metadata": {},
734 | "outputs": [
735 | {
736 | "data": {
737 | "text/plain": [
738 | "['Adware', 'Banking', 'Benigh', 'Riskware', 'SMS']"
739 | ]
740 | },
741 | "execution_count": 64,
742 | "metadata": {},
743 | "output_type": "execute_result"
744 | }
745 | ],
746 | "source": [
747 | "sorted([\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"])"
748 | ]
749 | },
750 | {
751 | "cell_type": "code",
752 | "execution_count": 75,
753 | "metadata": {},
754 | "outputs": [
755 | {
756 | "data": {
757 | "text/plain": [
758 | "tensor(2)"
759 | ]
760 | },
761 | "execution_count": 75,
762 | "metadata": {},
763 | "output_type": "execute_result"
764 | }
765 | ],
766 | "source": [
767 | "predicted[15]"
768 | ]
769 | },
770 | {
771 | "cell_type": "code",
772 | "execution_count": 73,
773 | "metadata": {},
774 | "outputs": [
775 | {
776 | "data": {
777 | "text/plain": [
778 | "tensor([ 8, 15, 20, 25, 50, 56, 57, 58, 62, 67, 72, 73,\n",
779 | " 74, 75, 76, 96, 97, 99, 100, 114, 120, 121, 123, 125,\n",
780 | " 126, 131, 138, 140, 143, 158, 159, 186, 187, 211, 212, 221,\n",
781 | " 233, 239, 241, 263, 271, 281, 284, 297, 298, 305, 306, 309,\n",
782 | " 310, 313, 316, 318, 321, 329, 333, 335, 342, 355, 359, 360,\n",
783 | " 368, 370, 388, 390, 392, 396, 407, 411, 412, 426, 429, 434,\n",
784 | " 435, 442, 450, 453, 456, 459, 467, 469, 470, 471, 475, 476,\n",
785 | " 498, 505, 512, 522, 526, 528, 541, 546, 552, 553, 555, 557,\n",
786 | " 560, 567, 573, 577, 578, 579, 589, 602, 616, 618, 619, 625,\n",
787 | " 632, 642, 652, 653, 656, 658, 663, 664, 666, 672, 677, 682,\n",
788 | " 685, 690, 705, 708, 714, 728, 733, 734, 735, 744, 747, 775,\n",
789 | " 780, 781, 785, 798, 800, 802, 804, 811, 820, 825, 829, 830,\n",
790 | " 833, 838, 842, 846, 865, 870, 876, 881, 902, 903, 905, 909,\n",
791 | " 918, 928, 934, 949, 952, 960, 967, 971, 975, 976, 985, 1016,\n",
792 | " 1020, 1021, 1023, 1028, 1029, 1054, 1070, 1091, 1095, 1097, 1107, 1108,\n",
793 | " 1114, 1119, 1120, 1135, 1136, 1174, 1195, 1199, 1205, 1207, 1222, 1223,\n",
794 | " 1224, 1229, 1233, 1236, 1252, 1265, 1268, 1269, 1275, 1278, 1282, 1284,\n",
795 | " 1294, 1299, 1300, 1303, 1305, 1316, 1317, 1318, 1328, 1342, 1343, 1358,\n",
796 | " 1366, 1377, 1386, 1387, 1389, 1392, 1395, 1412, 1414, 1415, 1420, 1424,\n",
797 | " 1425, 1427, 1433, 1434, 1439, 1447, 1448, 1451, 1457, 1461, 1481, 1501,\n",
798 | " 1507, 1510, 1512, 1521, 1528, 1556, 1562, 1563, 1565, 1566, 1567, 1569,\n",
799 | " 1570, 1575, 1584, 1589, 1592, 1601, 1607, 1608, 1612, 1613, 1621, 1622,\n",
800 | " 1624, 1630, 1638, 1653, 1659, 1664, 1682, 1684, 1690, 1691, 1693, 1694,\n",
801 | " 1715, 1717, 1718, 1745, 1750, 1766, 1770, 1771, 1793, 1795, 1811, 1817,\n",
802 | " 1827, 1849, 1854, 1878, 1882, 1886, 1888, 1896, 1898, 1901, 1913, 1924,\n",
803 | " 1928, 1930, 1935, 1936, 1950, 1956, 1957, 1970, 1975, 1984, 1985, 1992,\n",
804 | " 1994, 2005, 2012, 2021, 2024, 2035, 2041, 2045, 2052, 2056, 2057, 2058,\n",
805 | " 2065, 2067, 2068, 2074, 2079, 2091, 2095, 2099, 2102, 2104, 2105, 2107,\n",
806 | " 2115, 2116, 2117, 2122, 2132, 2133, 2155, 2162, 2163, 2170, 2179, 2180,\n",
807 | " 2183, 2185, 2186, 2189, 2200, 2203, 2206, 2208, 2216, 2221, 2237, 2241,\n",
808 | " 2247, 2248, 2255, 2256, 2262, 2267, 2268, 2275, 2276, 2279, 2300, 2301,\n",
809 | " 2304, 2307, 2322, 2328, 2337, 2346, 2349, 2353, 2363, 2365, 2375, 2377,\n",
810 | " 2381, 2389, 2394, 2397, 2419, 2430, 2436, 2448, 2449, 2454, 2457, 2475,\n",
811 | " 2479, 2485, 2487, 2490, 2515, 2519, 2528, 2534, 2535, 2541, 2542, 2543,\n",
812 | " 2544, 2549, 2553, 2558, 2576, 2581, 2587, 2596, 2598, 2609, 2610, 2611,\n",
813 | " 2616, 2617, 2619, 2628, 2638, 2640, 2642, 2654, 2662, 2665, 2675, 2677,\n",
814 | " 2681, 2688, 2692, 2693, 2694, 2698, 2704, 2727, 2731, 2744, 2766, 2778,\n",
815 | " 2780, 2782, 2787, 2791, 2794, 2796, 2803, 2813, 2814, 2822, 2861, 2887,\n",
816 | " 2888, 2898, 2908, 2914, 2916, 2922, 2923, 2942, 2951, 2953, 2958, 2960,\n",
817 | " 2970, 2990, 2991, 2993, 2998, 3000, 3001, 3008, 3011, 3012, 3027, 3030,\n",
818 | " 3038, 3046, 3048, 3049, 3050, 3060, 3070, 3072, 3081, 3089, 3090, 3093,\n",
819 | " 3095, 3096, 3122, 3128, 3129, 3130, 3140, 3142, 3146, 3150, 3151, 3176,\n",
820 | " 3178, 3180, 3189, 3204, 3210, 3218, 3228, 3237, 3240, 3241, 3242, 3247,\n",
821 | " 3251, 3254, 3260, 3261, 3265, 3267, 3273, 3279, 3291, 3295, 3299, 3300])"
822 | ]
823 | },
824 | "execution_count": 73,
825 | "metadata": {},
826 | "output_type": "execute_result"
827 | }
828 | ],
829 | "source": [
830 | "torch.where(actual!=predicted)[0]"
831 | ]
832 | },
833 | {
834 | "cell_type": "code",
835 | "execution_count": 70,
836 | "metadata": {},
837 | "outputs": [],
838 | "source": [
839 | "import numpy as np"
840 | ]
841 | },
842 | {
843 | "cell_type": "code",
844 | "execution_count": 71,
845 | "metadata": {},
846 | "outputs": [],
847 | "source": [
848 | "test_list_np = np.array(test_list)"
849 | ]
850 | },
851 | {
852 | "cell_type": "code",
853 | "execution_count": 72,
854 | "metadata": {},
855 | "outputs": [
856 | {
857 | "data": {
858 | "text/plain": [
859 | "array(['Adware0767.apk', 'Banking1835.apk', 'Riskware0594.apk',\n",
860 | " 'Banking0851.apk', 'Riskware4216.apk', 'Riskware2248.apk',\n",
861 | " 'Banking0452.apk', 'Benigh0310.apk', 'Banking1730.apk',\n",
862 | " 'Banking0032.apk', 'Benigh0353.apk', 'Banking1341.apk',\n",
863 | " 'Banking0427.apk', 'Benigh2838.apk', 'Banking1769.apk',\n",
864 | " 'Banking1240.apk', 'Riskware4175.apk', 'Riskware1607.apk',\n",
865 | " 'Riskware0891.apk', 'Riskware2847.apk', 'Riskware4066.apk',\n",
866 | " 'Banking1974.apk', 'Benigh0736.apk', 'Riskware1756.apk',\n",
867 | " 'Adware0635.apk', 'Adware0093.apk', 'Banking0844.apk',\n",
868 | " 'Riskware3207.apk', 'SMS2836.apk', 'Benigh0740.apk',\n",
869 | " 'Banking1832.apk', 'Riskware2993.apk', 'Banking0169.apk',\n",
870 | " 'Riskware0301.apk', 'Banking0434.apk', 'Riskware1234.apk',\n",
871 | " 'Riskware3747.apk', 'Banking0577.apk', 'Banking1982.apk',\n",
872 | " 'Benigh1585.apk', 'Banking0099.apk', 'Banking0505.apk',\n",
873 | " 'Banking2505.apk', 'Banking0772.apk', 'Benigh3860.apk',\n",
874 | " 'Riskware2045.apk', 'Riskware3325.apk', 'Riskware1556.apk',\n",
875 | " 'Banking1570.apk', 'SMS3149.apk', 'Banking0966.apk',\n",
876 | " 'Riskware1269.apk', 'Benigh3483.apk', 'Banking2474.apk',\n",
877 | " 'Riskware1450.apk', 'Benigh2712.apk', 'Riskware0937.apk',\n",
878 | " 'SMS3697.apk', 'Banking2479.apk', 'Banking1733.apk',\n",
879 | " 'Adware1157.apk', 'SMS0266.apk', 'Adware0256.apk',\n",
880 | " 'Benigh1095.apk', 'Banking1418.apk', 'Riskware0442.apk',\n",
881 | " 'Riskware0396.apk', 'Riskware4114.apk', 'Riskware3205.apk',\n",
882 | " 'Adware1164.apk', 'Banking1098.apk', 'Banking2064.apk',\n",
883 | " 'Benigh0037.apk', 'Benigh3475.apk', 'Riskware3470.apk',\n",
884 | " 'Banking1285.apk', 'Banking2057.apk', 'Banking2073.apk',\n",
885 | " 'Banking1577.apk', 'Benigh1864.apk', 'Benigh0792.apk',\n",
886 | " 'Riskware1681.apk', 'Benigh1971.apk', 'Banking1405.apk',\n",
887 | " 'Riskware2764.apk', 'Riskware4039.apk', 'Banking0555.apk',\n",
888 | " 'Benigh2512.apk', 'Riskware0749.apk', 'SMS1263.apk',\n",
889 | " 'Riskware3435.apk', 'SMS2480.apk', 'Benigh0774.apk',\n",
890 | " 'Banking1152.apk', 'Riskware3737.apk', 'Benigh1511.apk',\n",
891 | " 'Banking2107.apk', 'Banking0744.apk', 'SMS1882.apk',\n",
892 | " 'Adware0851.apk', 'Benigh2390.apk', 'Riskware1853.apk',\n",
893 | " 'Benigh1048.apk', 'Benigh2096.apk', 'Riskware3455.apk',\n",
894 | " 'Benigh1203.apk', 'Banking2314.apk', 'SMS4618.apk',\n",
895 | " 'Banking0821.apk', 'Riskware3399.apk', 'Riskware2527.apk',\n",
896 | " 'Riskware3785.apk', 'Riskware2920.apk', 'Riskware3040.apk',\n",
897 | " 'Banking1558.apk', 'Riskware3602.apk', 'Benigh3157.apk',\n",
898 | " 'Benigh0566.apk', 'Riskware3317.apk', 'Adware0340.apk',\n",
899 | " 'Adware0703.apk', 'Banking2357.apk', 'Benigh2381.apk',\n",
900 | " 'Riskware3976.apk', 'SMS1072.apk', 'Banking0593.apk',\n",
901 | " 'Benigh1727.apk', 'Benigh1766.apk', 'Riskware0696.apk',\n",
902 | " 'Adware1217.apk', 'Benigh1855.apk', 'Benigh3369.apk',\n",
903 | " 'Riskware2471.apk', 'Benigh3685.apk', 'Banking0333.apk',\n",
904 | " 'Banking0718.apk', 'SMS0566.apk', 'Banking1504.apk',\n",
905 | " 'Riskware2216.apk', 'Banking0119.apk', 'Benigh3598.apk',\n",
906 | " 'Riskware1688.apk', 'Riskware3894.apk', 'Adware0449.apk',\n",
907 | " 'Banking0185.apk', 'Riskware2190.apk', 'Banking2200.apk',\n",
908 | " 'Banking1395.apk', 'Riskware3401.apk', 'Riskware3623.apk',\n",
909 | " 'Riskware4072.apk', 'Riskware0085.apk', 'Banking1550.apk',\n",
910 | " 'Adware0805.apk', 'Banking0309.apk', 'Riskware1574.apk',\n",
911 | " 'Banking1301.apk', 'Banking2225.apk', 'Banking1145.apk',\n",
912 | " 'Benigh3575.apk', 'Banking2331.apk', 'Benigh1259.apk',\n",
913 | " 'Banking1732.apk', 'Adware0431.apk', 'Adware0546.apk',\n",
914 | " 'Banking0839.apk', 'Banking1415.apk', 'Banking2355.apk',\n",
915 | " 'Adware0291.apk', 'Riskware3841.apk', 'Banking2252.apk',\n",
916 | " 'Banking1643.apk', 'Riskware4172.apk', 'Riskware0005.apk',\n",
917 | " 'Banking0339.apk', 'Riskware0950.apk', 'Riskware1184.apk',\n",
918 | " 'Banking2061.apk', 'Riskware2878.apk', 'Riskware3763.apk',\n",
919 | " 'Riskware2391.apk', 'Banking0042.apk', 'Banking1054.apk',\n",
920 | " 'Adware0526.apk', 'Benigh2876.apk', 'Banking1770.apk',\n",
921 | " 'Riskware0910.apk', 'SMS0311.apk', 'Banking2421.apk',\n",
922 | " 'Riskware0791.apk', 'Banking1325.apk', 'Adware0613.apk',\n",
923 | " 'Benigh0525.apk', 'Riskware0864.apk', 'Banking1290.apk',\n",
924 | " 'Banking1104.apk', 'Benigh2000.apk', 'Banking1738.apk',\n",
925 | " 'Riskware3268.apk', 'Benigh2831.apk', 'Riskware2095.apk',\n",
926 | " 'Benigh1315.apk', 'Benigh1070.apk', 'Riskware0519.apk',\n",
927 | " 'Banking1399.apk', 'Adware0509.apk', 'Banking2300.apk',\n",
928 | " 'Banking2376.apk', 'Riskware3438.apk', 'Riskware0450.apk',\n",
929 | " 'Riskware0347.apk', 'Riskware1758.apk', 'Riskware4079.apk',\n",
930 | " 'SMS1509.apk', 'Riskware2578.apk', 'SMS4525.apk', 'Adware1366.apk',\n",
931 | " 'Riskware2598.apk', 'Riskware3295.apk', 'Riskware2549.apk',\n",
932 | " 'Riskware0913.apk', 'SMS0928.apk', 'Banking1671.apk',\n",
933 | " 'Adware0861.apk', 'SMS2605.apk', 'Banking0365.apk',\n",
934 | " 'Adware0290.apk', 'Riskware1517.apk', 'SMS1008.apk',\n",
935 | " 'Banking0857.apk', 'SMS1574.apk', 'Benigh2745.apk',\n",
936 | " 'Adware0287.apk', 'Riskware0753.apk', 'Benigh2260.apk',\n",
937 | " 'Banking0488.apk', 'Riskware3329.apk', 'Banking1814.apk',\n",
938 | " 'Riskware4267.apk', 'Banking1529.apk', 'Riskware3121.apk',\n",
939 | " 'Adware0878.apk', 'Banking2173.apk', 'Riskware2561.apk',\n",
940 | " 'Riskware2245.apk', 'SMS0386.apk', 'Benigh1118.apk',\n",
941 | " 'Riskware2474.apk', 'Adware1108.apk', 'SMS4221.apk',\n",
942 | " 'Banking2434.apk', 'Adware1492.apk', 'Riskware3423.apk',\n",
943 | " 'Banking1538.apk', 'SMS1047.apk', 'Benigh3656.apk', 'SMS1980.apk',\n",
944 | " 'Benigh3011.apk', 'Benigh0128.apk', 'Riskware2808.apk',\n",
945 | " 'Banking0453.apk', 'Riskware2866.apk', 'Banking1628.apk',\n",
946 | " 'Benigh1507.apk', 'Riskware1915.apk', 'Banking0270.apk',\n",
947 | " 'Benigh1050.apk', 'Adware0837.apk', 'SMS2513.apk',\n",
948 | " 'Benigh2576.apk', 'Benigh1782.apk', 'Banking1030.apk',\n",
949 | " 'Benigh0981.apk', 'Banking1356.apk', 'Banking0020.apk',\n",
950 | " 'Adware0553.apk', 'Benigh0868.apk', 'Benigh1352.apk',\n",
951 | " 'Riskware2569.apk', 'Benigh1016.apk', 'SMS0983.apk',\n",
952 | " 'Riskware2127.apk', 'Benigh0087.apk', 'Benigh0073.apk',\n",
953 | " 'Riskware3855.apk', 'Benigh0844.apk', 'Riskware2247.apk',\n",
954 | " 'Adware1463.apk', 'Benigh0522.apk', 'Riskware0503.apk',\n",
955 | " 'Riskware2861.apk', 'Riskware1998.apk', 'SMS4184.apk',\n",
956 | " 'SMS2380.apk', 'Riskware1063.apk', 'Banking1420.apk',\n",
957 | " 'Benigh2884.apk', 'Banking2444.apk', 'Riskware0788.apk',\n",
958 | " 'Benigh0022.apk', 'SMS3378.apk', 'Banking1500.apk',\n",
959 | " 'Benigh3259.apk', 'Riskware0682.apk', 'Benigh2313.apk',\n",
960 | " 'Benigh1027.apk', 'Riskware3974.apk', 'Banking0336.apk',\n",
961 | " 'Banking0247.apk', 'Adware0625.apk', 'SMS3931.apk',\n",
962 | " 'Riskware4116.apk', 'Riskware2706.apk', 'Adware0244.apk',\n",
963 | " 'Banking1846.apk', 'Adware1393.apk', 'Benigh0922.apk',\n",
964 | " 'Banking2282.apk', 'Adware0975.apk', 'Benigh1472.apk',\n",
965 | " 'Adware1113.apk', 'Riskware2818.apk', 'Benigh3354.apk',\n",
966 | " 'Riskware3544.apk', 'Benigh2723.apk', 'Banking2102.apk',\n",
967 | " 'Riskware3272.apk', 'Banking1708.apk', 'Adware0329.apk',\n",
968 | " 'Riskware0964.apk', 'Riskware4263.apk', 'Banking2360.apk',\n",
969 | " 'Benigh0742.apk', 'Adware0111.apk', 'Banking1862.apk',\n",
970 | " 'Riskware2564.apk', 'Riskware1360.apk', 'Banking1074.apk',\n",
971 | " 'Adware0087.apk', 'Riskware0505.apk', 'SMS4480.apk',\n",
972 | " 'Adware0758.apk', 'Banking0322.apk', 'Riskware2606.apk',\n",
973 | " 'Banking2295.apk', 'Benigh3300.apk', 'Benigh2638.apk',\n",
974 | " 'Adware1297.apk', 'Benigh3951.apk', 'SMS2015.apk',\n",
975 | " 'Banking0682.apk', 'Riskware4104.apk', 'Banking1144.apk',\n",
976 | " 'Riskware3211.apk', 'Banking0210.apk', 'Banking2453.apk',\n",
977 | " 'Benigh3099.apk', 'Benigh1636.apk', 'Benigh1040.apk',\n",
978 | " 'SMS1414.apk', 'Banking1143.apk', 'Adware0146.apk',\n",
979 | " 'Banking1697.apk', 'Adware0099.apk', 'Adware1174.apk',\n",
980 | " 'Banking1860.apk', 'Banking0459.apk', 'Riskware4161.apk',\n",
981 | " 'Riskware0405.apk', 'Riskware1578.apk', 'Riskware2554.apk',\n",
982 | " 'Banking0222.apk', 'Benigh0823.apk', 'SMS3522.apk',\n",
983 | " 'Riskware3061.apk', 'Riskware0979.apk', 'Riskware2884.apk',\n",
984 | " 'Benigh1675.apk', 'SMS4760.apk', 'Riskware3814.apk',\n",
985 | " 'Riskware4173.apk', 'Riskware2855.apk', 'Riskware4120.apk',\n",
986 | " 'Banking0946.apk', 'Riskware1887.apk', 'Riskware1453.apk',\n",
987 | " 'Benigh0225.apk', 'Benigh0968.apk', 'SMS0143.apk',\n",
988 | " 'Riskware0628.apk', 'SMS0530.apk', 'Banking1407.apk',\n",
989 | " 'Riskware3744.apk', 'Riskware2538.apk', 'Adware0451.apk',\n",
990 | " 'Banking1162.apk', 'Benigh3747.apk', 'Benigh2595.apk',\n",
991 | " 'Banking1239.apk', 'Banking1887.apk', 'Banking0722.apk',\n",
992 | " 'Adware0445.apk', 'Benigh2217.apk', 'Banking1466.apk',\n",
993 | " 'Banking1844.apk', 'Riskware2210.apk', 'Riskware0597.apk',\n",
994 | " 'Riskware2742.apk', 'Riskware0990.apk', 'Riskware3062.apk',\n",
995 | " 'Riskware3016.apk', 'Banking1977.apk', 'Riskware1713.apk',\n",
996 | " 'Banking1397.apk', 'Banking1896.apk', 'Riskware1646.apk',\n",
997 | " 'Riskware1019.apk', 'Riskware0453.apk', 'Banking1322.apk',\n",
998 | " 'Adware0106.apk', 'Adware1510.apk', 'Banking0742.apk',\n",
999 | " 'Adware0086.apk', 'Benigh3999.apk', 'Riskware1353.apk',\n",
1000 | " 'Riskware0151.apk', 'Banking1995.apk', 'Benigh3291.apk',\n",
1001 | " 'Benigh3780.apk', 'Banking2036.apk', 'Adware0912.apk',\n",
1002 | " 'Benigh1539.apk', 'Riskware2241.apk', 'Banking2056.apk',\n",
1003 | " 'Benigh3888.apk', 'Riskware1761.apk', 'SMS2112.apk',\n",
1004 | " 'Banking1499.apk', 'Banking0940.apk', 'Banking1118.apk',\n",
1005 | " 'Riskware3994.apk', 'Riskware3917.apk', 'Riskware4197.apk',\n",
1006 | " 'Riskware3835.apk', 'Banking2390.apk', 'Banking1024.apk',\n",
1007 | " 'Riskware0613.apk', 'Adware0951.apk', 'Riskware0115.apk',\n",
1008 | " 'Banking1690.apk', 'SMS3593.apk', 'SMS2882.apk', 'Banking2457.apk',\n",
1009 | " 'Benigh0510.apk', 'Adware1361.apk', 'Riskware3432.apk',\n",
1010 | " 'Riskware3215.apk', 'Riskware3197.apk', 'Benigh0230.apk',\n",
1011 | " 'Riskware2911.apk', 'Banking0920.apk', 'SMS4304.apk',\n",
1012 | " 'Banking2394.apk', 'Benigh0646.apk', 'Banking2398.apk',\n",
1013 | " 'Banking1076.apk', 'Riskware0611.apk', 'Benigh3941.apk',\n",
1014 | " 'Banking1163.apk', 'Banking2348.apk', 'Riskware0204.apk',\n",
1015 | " 'Riskware0091.apk', 'Banking1518.apk', 'Riskware3844.apk',\n",
1016 | " 'Banking2256.apk', 'Banking0125.apk', 'Riskware2369.apk',\n",
1017 | " 'Riskware3876.apk', 'Riskware1473.apk', 'Banking0680.apk',\n",
1018 | " 'Banking1414.apk', 'SMS2469.apk', 'Riskware0935.apk',\n",
1019 | " 'Riskware1698.apk', 'Riskware0617.apk', 'Riskware2778.apk',\n",
1020 | " 'Riskware3130.apk', 'Riskware2655.apk', 'Adware0884.apk',\n",
1021 | " 'Banking2324.apk', 'Banking1277.apk', 'Benigh1729.apk',\n",
1022 | " 'Banking1882.apk', 'Riskware2733.apk', 'Adware1255.apk',\n",
1023 | " 'Adware0974.apk', 'Adware1073.apk', 'SMS1516.apk',\n",
1024 | " 'Banking1935.apk', 'Riskware1419.apk', 'Riskware0352.apk',\n",
1025 | " 'Banking2261.apk', 'Riskware0176.apk', 'Benigh2412.apk',\n",
1026 | " 'Adware0712.apk', 'Banking1516.apk', 'Banking1132.apk',\n",
1027 | " 'Riskware3141.apk', 'Riskware4091.apk', 'Banking2464.apk',\n",
1028 | " 'Riskware2850.apk', 'Benigh1448.apk', 'Adware0209.apk',\n",
1029 | " 'Adware0168.apk', 'Banking1439.apk', 'Banking1248.apk',\n",
1030 | " 'SMS3394.apk', 'Adware0383.apk', 'Riskware2744.apk',\n",
1031 | " 'Riskware0958.apk', 'Benigh2119.apk', 'Benigh2978.apk',\n",
1032 | " 'SMS3920.apk', 'Riskware0975.apk', 'Benigh1047.apk',\n",
1033 | " 'Adware0069.apk', 'Riskware1216.apk', 'Riskware0931.apk'],\n",
1034 | " dtype=')"
651 | ]
652 | },
653 | "execution_count": 20,
654 | "metadata": {},
655 | "output_type": "execute_result"
656 | }
657 | ],
658 | "source": [
659 | "predicted = classifier(dgl.batch([g for g,l in test_dataset]))\n",
660 | "predicted"
661 | ]
662 | },
663 | {
664 | "cell_type": "code",
665 | "execution_count": 21,
666 | "metadata": {},
667 | "outputs": [],
668 | "source": [
669 | "predicted_mod = predicted.detach()"
670 | ]
671 | },
672 | {
673 | "cell_type": "code",
674 | "execution_count": 23,
675 | "metadata": {},
676 | "outputs": [],
677 | "source": [
678 | "predicted_mod[predicted_mod>0.5] = 1\n",
679 | "predicted_mod[predicted_mod<0.5] = 0"
680 | ]
681 | },
682 | {
683 | "cell_type": "code",
684 | "execution_count": 29,
685 | "metadata": {},
686 | "outputs": [
687 | {
688 | "data": {
689 | "text/plain": [
690 | "tensor([0, 1, 0, ..., 1, 0, 1])"
691 | ]
692 | },
693 | "execution_count": 29,
694 | "metadata": {},
695 | "output_type": "execute_result"
696 | }
697 | ],
698 | "source": [
699 | "predicted_mod.long()"
700 | ]
701 | },
702 | {
703 | "cell_type": "code",
704 | "execution_count": 25,
705 | "metadata": {},
706 | "outputs": [
707 | {
708 | "data": {
709 | "text/plain": [
710 | "tensor([0, 1, 0, ..., 0, 0, 1])"
711 | ]
712 | },
713 | "execution_count": 25,
714 | "metadata": {},
715 | "output_type": "execute_result"
716 | }
717 | ],
718 | "source": [
719 | "actual = torch.tensor([l for g,l in test_dataset])\n",
720 | "actual[actual!=2]=0\n",
721 | "actual[actual==2]=1\n",
722 | "actual"
723 | ]
724 | },
725 | {
726 | "cell_type": "code",
727 | "execution_count": 47,
728 | "metadata": {},
729 | "outputs": [
730 | {
731 | "data": {
732 | "text/plain": [
733 | "(3302, 2494)"
734 | ]
735 | },
736 | "execution_count": 47,
737 | "metadata": {},
738 | "output_type": "execute_result"
739 | }
740 | ],
741 | "source": [
742 | "len(actual), len(torch.where(actual==0)[0])"
743 | ]
744 | },
745 | {
746 | "cell_type": "code",
747 | "execution_count": 48,
748 | "metadata": {},
749 | "outputs": [
750 | {
751 | "data": {
752 | "text/plain": [
753 | "808"
754 | ]
755 | },
756 | "execution_count": 48,
757 | "metadata": {},
758 | "output_type": "execute_result"
759 | }
760 | ],
761 | "source": [
762 | "_[0]-_[1]"
763 | ]
764 | },
765 | {
766 | "cell_type": "code",
767 | "execution_count": 31,
768 | "metadata": {},
769 | "outputs": [
770 | {
771 | "name": "stdout",
772 | "output_type": "stream",
773 | "text": [
774 | " precision recall f1-score support\n",
775 | "\n",
776 | " 0 0.9594 0.9379 0.9485 2494\n",
777 | " 1 0.8206 0.8775 0.8481 808\n",
778 | "\n",
779 | " accuracy 0.9231 3302\n",
780 | " macro avg 0.8900 0.9077 0.8983 3302\n",
781 | "weighted avg 0.9254 0.9231 0.9239 3302\n",
782 | "\n"
783 | ]
784 | }
785 | ],
786 | "source": [
787 | "print(M.classification_report(actual, predicted_mod.long(), digits=4))"
788 | ]
789 | },
790 | {
791 | "cell_type": "code",
792 | "execution_count": 49,
793 | "metadata": {},
794 | "outputs": [
795 | {
796 | "data": {
797 | "text/plain": [
798 | "array([[2339, 155],\n",
799 | " [ 99, 709]])"
800 | ]
801 | },
802 | "execution_count": 49,
803 | "metadata": {},
804 | "output_type": "execute_result"
805 | }
806 | ],
807 | "source": [
808 | "M.confusion_matrix(actual, predicted_mod.long())"
809 | ]
810 | },
811 | {
812 | "cell_type": "markdown",
813 | "metadata": {},
814 | "source": [
815 | "## Results\n",
816 | "Accuracy - 93.21%,\n",
817 | "Precision - 0.9254,\n",
818 | "Recall - 0.9231,\n",
819 | "F1 - 0.9239"
820 | ]
821 | },
822 | {
823 | "cell_type": "code",
824 | "execution_count": null,
825 | "metadata": {},
826 | "outputs": [],
827 | "source": []
828 | }
829 | ],
830 | "metadata": {
831 | "kernelspec": {
832 | "display_name": "Python 3",
833 | "language": "python",
834 | "name": "python3"
835 | },
836 | "language_info": {
837 | "codemirror_mode": {
838 | "name": "ipython",
839 | "version": 3
840 | },
841 | "file_extension": ".py",
842 | "mimetype": "text/x-python",
843 | "name": "python",
844 | "nbconvert_exporter": "python",
845 | "pygments_lexer": "ipython3",
846 | "version": "3.6.9"
847 | }
848 | },
849 | "nbformat": 4,
850 | "nbformat_minor": 4
851 | }
852 |
--------------------------------------------------------------------------------
/readme.md:
--------------------------------------------------------------------------------
1 | # AndMal-Detect
2 |
3 | Android Malware Detection using Function Call Graphs and Graph Convolutional Networks
4 |
5 | # What?
6 |
7 | A research work carried out by me ([Vinayaka K V](https://github.com/vinayakakv)) during MTech (Research) degree in Department of IT, NITK.
8 |
9 | The objectives of the research were:
10 |
11 | 1. To evaluate whether GCNs were effective in detecting Android Malware using FCGs, and which GCN algorithm is best for this task.
12 | 2. To enhance the FCGs by incorporating the callback information obtained from the framework code, and evaluate them against the normal FCGs
13 |
14 | # Code organization
15 |
16 | The code achieving first objective is present at `master` (current) branch, while the code achiving second objective is present at `experiment` branch.
17 |
18 | # Methodology
19 | 
20 |
21 | ## Datasets
22 |
23 | Stored in the [`/data`](/data) folder. Currently, it contains SHA256 of the APKs containing in training and testing splits.
24 |
25 |
26 | ## APK Size Balancer
27 |
28 | Obtains the histogram of APK sizes, adds APKs wherever there is a huge imbalance between the number of APKs between classes.
29 |
30 | > **Note:** *The provided dataset is already APK Size balanced* 🥳
31 |
32 | ## FCG Extractor
33 |
34 | Implemented in [`scripts/process_dataset.py`](scripts/process_dataset.py).
35 |
36 | The class `FeatureExtractors` provides two public methods:
37 |
38 | 1. `get_user_features()` - Returns 15-bit feature vector for *internal* methods
39 | 2. `get_api_features()` - Returns a one-hot feature vector for *external* methods
40 |
41 | The method `process` extracts the FCG and assignes node features.
42 |
43 | ## Node Count Balancer
44 |
45 | Balances the dataset so that the node count distribution of the APKs between the classes is exactly the same.
46 |
47 | Implemmented in [`scripts/split_dataset.py`](scripts/split_dataset.py).
48 |
49 | > **Note:** *The provided dataset is already node-count balanced to ensure **reproducibility*** 🤩
50 |
51 | ## GCN Classifier
52 |
53 | Multi-layer GCN with dense layer at the end.
54 |
55 | Implemented in [`core/model.py`](core/model.py)
56 |
57 | # The Execution Pipeline
58 |
59 | 1. Obtain the APKs ug
60 | 2. given SHA256 from [AndroZoo](https://androzoo.uni.lu/)
61 | 3. Build the container (either singularity or docker), and get into its shell
62 | 4. Run `scripts/process_dataset.py`[scripts/process_dataset.py] on the downloaded dataset
63 |
64 | python process_dataset.py \
65 | --source-dir \
66 | --dest-dir \
67 | --override # If you want to oveeride existing processed files \
68 | --dry # If you want to perform a dry run
69 |
70 | 4. Train the model! For configuration, refer to the section below.
71 |
72 | python train_model.py
73 |
74 | # Configuration
75 |
76 | The configuration is achieved using [Hydra](https://hydra.cc/). Look into [`config/conf.yaml`](config/conf.yaml) for available configuration options.
77 |
78 | Any configuration option can be overridden in the command line. As an example, to change the number of convolution layers to 2, invoke the program as
79 |
80 | python train_model.py model.convolution_count=2
81 |
82 | You can also perform a sweep, for example,
83 |
84 | python train_model.py \
85 | model.convolution_count=0,1,2,3 \
86 | model.convolution_algorithm=GraphConv, SAGEConv, TAGConv, SGConv, DotGatConv \
87 | features=degree, method_attributes, method_summary
88 |
89 | to train the model in all possible configurations! 🥳
90 |
91 | # Stack
92 |
93 | - [`androguard`](https://androguard.readthedocs.io/en/lates) - For FCG extraction and Feature assignment
94 | - [`pytorch`](https://pytorch.org/) - for Neural networks
95 | - [`dgl`](https://www.dgl.ai/) - for GCN modules
96 | - [`pytorch-lightning`](https://github.com/PyTorchLightning/pytorch-lightning) - for organization and pipeline 💖
97 | - [`hydra`](https://hydra.cc/) - for configuring experiments
98 | - [`wandb`](https://wandb.ai/) - for tracking experiments 🔥
99 |
100 | # Cite as
101 |
102 | The research paper corresponding to this work is available at [IEEE Xplore](https://ieeexplore.ieee.org/document/9478141). If you find this work helpful and use it, please cite it as
103 |
104 | @INPROCEEDINGS{9478141,
105 | author={V, Vinayaka K and D, Jaidhar C},
106 | booktitle={2021 2nd International Conference on Secure Cyber Computing and Communications (ICSCCC)},
107 | title={Android Malware Detection using Function Call Graph with Graph Convolutional Networks},
108 | year={2021},
109 | volume={},
110 | number={},
111 | pages={279-287},
112 | doi={10.1109/ICSCCC51823.2021.9478141}
113 | }
114 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pytorch-lightning==1.1.7
2 | wandb==0.10.18
3 | plotly==4.14.3
4 | scikit-learn==0.24.1
5 | joblib~=1.0.0
6 | hydra-core~=1.0.5
7 | pandas==1.2.1
8 | pygtrie==2.4.2
9 | seaborn==0.11.1
10 | pygraphviz==1.7
11 | dgl-cu110==0.6
--------------------------------------------------------------------------------
/scripts/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/vinayakakv/android-malware-detection/1aab288ec599a3958982866ce989311a96cbffd9/scripts/__init__.py
--------------------------------------------------------------------------------
/scripts/plot_callgraph.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from pathlib import Path
3 |
4 | import matplotlib.pyplot as plt
5 | import networkx as nx
6 | from androguard.misc import AnalyzeAPK
7 |
8 | plt.figure(figsize=(10, 5))
9 |
10 |
11 | def plot_call_graph(cg: nx.classes.multidigraph.MultiDiGraph):
12 | layout = nx.drawing.nx_agraph.graphviz_layout(cg, prog='dot')
13 | labels, cm = {}, []
14 | legend = ''
15 | node_list = []
16 | for i, node in enumerate(nx.topological_sort(cg)):
17 | node_list.append(node)
18 | labels[node] = i
19 | cm.append('yellow' if node.is_external() else 'blue')
20 | legend += '%d, \\texttt{%s %s}\n' % (i, node.class_name.replace('$', '\\$'), node.name)
21 | plt.axis('off')
22 | nx.draw_networkx(cg, pos=layout, nodelist=node_list, node_color=cm, labels=labels, alpha=0.6, node_size=500,
23 | font_family='serif')
24 | with open("cg.table", "w") as f:
25 | f.write(legend)
26 | plt.tight_layout()
27 | plt.savefig("cg.pdf", dpi=300, bbox_inches="tight")
28 | plt.show()
29 |
30 |
31 | if __name__ == '__main__':
32 | parser = argparse.ArgumentParser(description='Draw FCG of small APKs')
33 | parser.add_argument(
34 | '-s', '--source-file',
35 | help='The APK file to analyze and draw',
36 | required=True
37 | )
38 | args = parser.parse_args()
39 | if not Path(args.source_file).exists():
40 | raise FileNotFoundError(f"{args.source_file} doesn't exist")
41 | a, d, dx = AnalyzeAPK(args.source_file)
42 | plot_call_graph(dx.get_call_graph())
43 |
--------------------------------------------------------------------------------
/scripts/process_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import json
3 | import multiprocessing
4 | import os
5 | import sys
6 | import traceback
7 | from collections import defaultdict
8 | from pathlib import Path
9 | from typing import Dict, List, Union, Optional
10 |
11 | import dgl
12 | import joblib as J
13 | import networkx as nx
14 | import torch
15 | from androguard.core.analysis.analysis import MethodAnalysis
16 | from androguard.core.api_specific_resources import load_permission_mappings
17 | from androguard.misc import AnalyzeAPK
18 | from pygtrie import StringTrie
19 |
20 | ATTRIBUTES = ['external', 'entrypoint', 'native', 'public', 'static', 'codesize', 'api', 'user']
21 | package_directory = os.path.dirname(os.path.abspath(__file__))
22 |
23 | stats: Dict[str, int] = defaultdict(int)
24 |
25 |
26 | def memoize(function):
27 | """
28 | Alternative to @lru_cache which could not be pickled in ray
29 | :param function: Function to be cached
30 | :return: Wrapped function
31 | """
32 | memo = {}
33 |
34 | def wrapper(*args):
35 | if args in memo:
36 | return memo[args]
37 | else:
38 | rv = function(*args)
39 | memo[args] = rv
40 | return rv
41 |
42 | return wrapper
43 |
44 |
45 | class FeatureExtractors:
46 | NUM_PERMISSION_GROUPS = 20
47 | NUM_API_PACKAGES = 226
48 | NUM_OPCODE_MAPPINGS = 21
49 |
50 | @staticmethod
51 | def _get_opcode_mapping() -> Dict[str, int]:
52 | """
53 | Group opcodes and assign them an ID
54 | :return: Mapping from opcode group name to their ID
55 | """
56 | mapping = {x: i for i, x in enumerate(['nop', 'mov', 'return',
57 | 'const', 'monitor', 'check-cast', 'instanceof', 'new',
58 | 'fill', 'throw', 'goto/switch', 'cmp', 'if', 'unused',
59 | 'arrayop', 'instanceop', 'staticop', 'invoke',
60 | 'unaryop', 'binop', 'inline'])}
61 | mapping['invalid'] = -1
62 | return mapping
63 |
64 | @staticmethod
65 | @memoize
66 | def _get_instruction_type(op_value: int) -> str:
67 | """
68 | Get instruction group name from instruction
69 | :param op_value: Opcode value
70 | :return: String containing ID of :instr:
71 | """
72 | if 0x00 == op_value:
73 | return 'nop'
74 | elif 0x01 <= op_value <= 0x0D:
75 | return 'mov'
76 | elif 0x0E <= op_value <= 0x11:
77 | return 'return'
78 | elif 0x12 <= op_value <= 0x1C:
79 | return 'const'
80 | elif 0x1D <= op_value <= 0x1E:
81 | return 'monitor'
82 | elif 0x1F == op_value:
83 | return 'check-cast'
84 | elif 0x20 == op_value:
85 | return 'instanceof'
86 | elif 0x22 <= op_value <= 0x23:
87 | return 'new'
88 | elif 0x24 <= op_value <= 0x26:
89 | return 'fill'
90 | elif 0x27 == op_value:
91 | return 'throw'
92 | elif 0x28 <= op_value <= 0x2C:
93 | return 'goto/switch'
94 | elif 0x2D <= op_value <= 0x31:
95 | return 'cmp'
96 | elif 0x32 <= op_value <= 0x3D:
97 | return 'if'
98 | elif (0x3E <= op_value <= 0x43) or (op_value == 0x73) or (0x79 <= op_value <= 0x7A) or (
99 | 0xE3 <= op_value <= 0xED):
100 | return 'unused'
101 | elif (0x44 <= op_value <= 0x51) or (op_value == 0x21):
102 | return 'arrayop'
103 | elif (0x52 <= op_value <= 0x5F) or (0xF2 <= op_value <= 0xF7):
104 | return 'instanceop'
105 | elif 0x60 <= op_value <= 0x6D:
106 | return 'staticop'
107 | elif (0x6E <= op_value <= 0x72) or (0x74 <= op_value <= 0x78) or (0xF0 == op_value) or (
108 | 0xF8 <= op_value <= 0xFB):
109 | return 'invoke'
110 | elif 0x7B <= op_value <= 0x8F:
111 | return 'unaryop'
112 | elif 0x90 <= op_value <= 0xE2:
113 | return 'binop'
114 | elif 0xEE == op_value:
115 | return 'inline'
116 | else:
117 | return 'invalid'
118 |
119 | @staticmethod
120 | def _mapping_to_bitstring(mapping: List[int], max_len) -> torch.Tensor:
121 | """
122 | Convert opcode mappings to bitstring
123 | :param max_len:
124 | :param mapping: List of IDs of opcode groups (present in an method)
125 | :return: Binary tensor of length `len(opcode_mapping)` with value 1 at positions specified by :poram mapping:
126 | """
127 | size = torch.Size([1, max_len])
128 | if len(mapping) > 0:
129 | indices = torch.LongTensor([[0, x] for x in mapping]).t()
130 | values = torch.LongTensor([1] * len(mapping))
131 | tensor = torch.sparse.LongTensor(indices, values, size)
132 | else:
133 | tensor = torch.sparse.LongTensor(size)
134 | # Sparse tensor is normal tensor on CPU!
135 | return tensor.to_dense().squeeze()
136 |
137 | @staticmethod
138 | def _get_api_trie() -> StringTrie:
139 | apis = open(Path(package_directory).parent / "metadata" / "api.list").readlines()
140 | api_list = {x.strip(): i for i, x in enumerate(apis)}
141 | api_trie = StringTrie(separator='.')
142 | for k, v in api_list.items():
143 | api_trie[k] = v
144 | return api_trie
145 |
146 | @staticmethod
147 | @memoize
148 | def get_api_features(api: MethodAnalysis) -> Optional[torch.Tensor]:
149 | if not api.is_external():
150 | return None
151 | api_trie = FeatureExtractors._get_api_trie()
152 | name = str(api.class_name)[1:-1].replace('/', '.')
153 | _, index = api_trie.longest_prefix(name)
154 | if index is None:
155 | indices = []
156 | else:
157 | indices = [index]
158 | feature_vector = FeatureExtractors._mapping_to_bitstring(indices, FeatureExtractors.NUM_API_PACKAGES)
159 | return feature_vector
160 |
161 | @staticmethod
162 | @memoize
163 | def get_user_features(user: MethodAnalysis) -> Optional[torch.Tensor]:
164 | if user.is_external():
165 | return None
166 | opcode_mapping = FeatureExtractors._get_opcode_mapping()
167 | opcode_groups = set()
168 | for instr in user.get_method().get_instructions():
169 | instruction_type = FeatureExtractors._get_instruction_type(instr.get_op_value())
170 | instruction_id = opcode_mapping[instruction_type]
171 | if instruction_id >= 0:
172 | opcode_groups.add(instruction_id)
173 | # 1 subtraction for 'invalid' opcode group
174 | feature_vector = FeatureExtractors._mapping_to_bitstring(list(opcode_groups), len(opcode_mapping) - 1)
175 | return torch.LongTensor(feature_vector)
176 |
177 |
178 | def process(source_file: Path, dest_dir: Path):
179 | try:
180 | file_name = source_file.stem
181 | _, _, dx = AnalyzeAPK(source_file)
182 | cg = dx.get_call_graph()
183 | mappings = {}
184 | for node in cg.nodes():
185 | features = {
186 | "api": torch.zeros(FeatureExtractors.NUM_API_PACKAGES),
187 | "user": torch.zeros(FeatureExtractors.NUM_OPCODE_MAPPINGS)
188 | }
189 | if node.is_external():
190 | features["api"] = FeatureExtractors.get_api_features(node)
191 | else:
192 | features["user"] = FeatureExtractors.get_user_features(node)
193 | mappings[node] = features
194 | nx.set_node_attributes(cg, mappings)
195 | cg = nx.convert_node_labels_to_integers(cg)
196 | dg = dgl.from_networkx(cg, node_attrs=ATTRIBUTES)
197 | dest_dir = dest_dir / f'{file_name}.fcg'
198 | dgl.data.utils.save_graphs(str(dest_dir), [dg])
199 | print(f"Processed {source_file}")
200 | except:
201 | print(f"Error while processing {source_file}")
202 | traceback.print_exception(*sys.exc_info())
203 | return
204 |
205 |
206 | if __name__ == '__main__':
207 | parser = argparse.ArgumentParser(description='Preprocess APK Dataset into Graphs')
208 | parser.add_argument(
209 | '-s', '--source-dir',
210 | help='The directory containing apks',
211 | required=True
212 | )
213 | parser.add_argument(
214 | '-d', '--dest-dir',
215 | help='The directory to store processed graphs',
216 | required=True
217 | )
218 | parser.add_argument(
219 | '--override',
220 | help='Override existing processed files',
221 | action='store_true'
222 | )
223 | parser.add_argument(
224 | '--dry',
225 | help='Run without actual processing',
226 | action='store_true'
227 | )
228 | parser.add_argument(
229 | '--n-jobs',
230 | default=multiprocessing.cpu_count(),
231 | help='Number of jobs to be used for processing'
232 | )
233 | parser.add_argument(
234 | '--limit',
235 | help='Run for n apks',
236 | default=-1
237 | )
238 | args = parser.parse_args()
239 | source_dir = Path(args.source_dir)
240 | if not source_dir.exists():
241 | raise FileNotFoundError(f'{source_dir} not found')
242 | dest_dir = Path(args.dest_dir)
243 | if not dest_dir.exists():
244 | raise FileNotFoundError(f'{dest_dir} not found')
245 | n_jobs = args.n_jobs
246 | if n_jobs < 2:
247 | print(f"n_jobs={n_jobs} is too less. Switching to number of CPUs in this machine instead")
248 | n_jobs = multiprocessing.cpu_count()
249 | files = [x for x in source_dir.iterdir() if x.is_file()]
250 | source_files = set([x.stem for x in files])
251 | dest_files = set([x.name for x in dest_dir.iterdir() if x.is_file()])
252 | unprocessed = [source_dir / f'{x}.apk' for x in source_files - dest_files]
253 | print(f"Only {len(unprocessed)} out of {len(source_files)} remain to be processed")
254 | if args.override:
255 | print(f"--override specified. Ignoring {len(source_files) - len(unprocessed)} processed files")
256 | unprocessed = [source_dir / f'{x}.apk' for x in source_files]
257 | print(f"Starting dataset processing with {n_jobs} Jobs")
258 | limit = int(args.limit)
259 | if limit != -1:
260 | print(f"Limiting dataset processing to {limit} apks.")
261 | unprocessed = unprocessed[:limit]
262 | if not args.dry:
263 | J.Parallel(n_jobs=n_jobs)(J.delayed(process)(x, dest_dir) for x in unprocessed)
264 | print("DONE")
265 |
--------------------------------------------------------------------------------
/scripts/split_dataset.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import multiprocessing
3 | from pathlib import Path
4 |
5 | import dgl
6 | import joblib as J
7 | import numpy as np
8 | import pandas as pd
9 |
10 |
11 | def extract_stats(file: str):
12 | file = Path(file)
13 | if not file.exists():
14 | raise ValueError(f"{file} doesn't exist")
15 | result = {}
16 | graphs, labels = dgl.data.utils.load_graphs(str(file))
17 | graph: dgl.DGLGraph = graphs[0]
18 | result['label'] = 'Benign' if 'Benig' in file.stem else 'Malware'
19 | result['file_name'] = str(file)
20 | result['num_nodes'] = graph.num_nodes()
21 | result['num_edges'] = graph.num_edges()
22 | return result
23 |
24 |
25 | def save_list(dataframe, file_name):
26 | with open(file_name, 'a') as target:
27 | for file in dataframe['file_name']:
28 | target.writelines(f'{file.split(".")[0]}\n')
29 |
30 |
31 | def get_dataset(df: pd.DataFrame, test_ratio: float, log_dir: Path):
32 | assert 0 <= test_ratio < 1, "Ratio must be within 0 and 1"
33 | q1 = df['num_nodes'].quantile(0.25)
34 | q3 = df['num_nodes'].quantile(0.75)
35 | iqr = q3 - q1
36 | print(f"Initial range {df['num_nodes'].min(), df['num_nodes'].max()}")
37 | print(f"IQR num_nodes = {iqr}")
38 | df = df.query(f'{q1 - iqr} <= num_nodes <= {q3 + iqr}')
39 | print(f"Final range {df['num_nodes'].min(), df['num_nodes'].max()}")
40 | bins = np.arange(0, df['num_nodes'].max(), 500)
41 | ben_hist, _ = np.histogram(df.query('label == "Benign"')['num_nodes'], bins=bins)
42 | mal_hist, _ = np.histogram(df.query('label != "Benign"')['num_nodes'], bins=bins)
43 | combined = np.concatenate([ben_hist[:, np.newaxis], mal_hist[:, np.newaxis]], axis=1)
44 | np.savetxt(
45 | log_dir / 'histogram.list',
46 | combined
47 | )
48 | final_sizes = [(x, x) for x in np.min(combined, axis=1)]
49 | final_train = []
50 | final_test = []
51 | for i, (ben_size, mal_size) in enumerate(final_sizes):
52 | low, high = bins[i], bins[i + 1]
53 | benign_samples = df.query(f'label == "Benign" and {low} <= num_nodes < {high}')
54 | malware_samples = df.query(f'label == "Malware" and {low} <= num_nodes < {high}')
55 | assert len(benign_samples) >= ben_size and len(malware_samples) >= mal_size, "Mismatch"
56 | benign_samples = benign_samples.sample(ben_size)
57 | malware_samples = malware_samples.sample(mal_size)
58 | if test_ratio > 0:
59 | benign_samples, benign_test_samples = np.split(benign_samples,
60 | [round((1 - test_ratio) * len(benign_samples))])
61 | malware_samples, malware_test_samples = np.split(malware_samples,
62 | [round((1 - test_ratio) * len(malware_samples))])
63 | final_test.append(benign_test_samples)
64 | final_test.append(malware_test_samples)
65 | final_train.append(benign_samples)
66 | final_train.append(malware_samples)
67 | final_train = pd.concat(final_train)
68 | if final_test:
69 | final_test = pd.concat(final_test)
70 | return final_train, final_test
71 |
72 |
73 | if __name__ == '__main__':
74 | parser = argparse.ArgumentParser(
75 | description='Split the input dataset into train and test partitions (80%, 20%) based on bin equalization'
76 | )
77 | parser.add_argument(
78 | '-i', '--input-dirs',
79 | help="List of input paths",
80 | nargs='+',
81 | required=True
82 | )
83 | parser.add_argument(
84 | '-o', '--output-dir',
85 | help="The path to write the result lists to",
86 | required=True
87 | )
88 | parser.add_argument(
89 | '-s', '--strict',
90 | help="If set, program will terminate on error while in loop",
91 | action='store_true',
92 | default=False
93 | )
94 | args = parser.parse_args()
95 | output_dir = Path(args.output_dir)
96 | if not output_dir.exists():
97 | output_dir.mkdir(parents=True)
98 |
99 | input_stats = []
100 | for input_dir in args.input_dirs:
101 | input_dir = Path(input_dir)
102 | if not input_dir.exists():
103 | if args.strict:
104 | raise FileNotFoundError(f"{input_dir} does not exist. Halting")
105 | else:
106 | print(f"{input_dir} does not exist. Skipping...")
107 | continue
108 | stats = J.Parallel(n_jobs=multiprocessing.cpu_count())(
109 | J.delayed(extract_stats)(x) for x in input_dir.glob("*.fcg")
110 | )
111 | input_stats.append(pd.DataFrame.from_records(stats))
112 | input_stats = pd.concat(input_stats)
113 | zero_nodes = input_stats.query('num_nodes == 0')
114 | if len(zero_nodes) > 0:
115 | print(f"Warning: {len(zero_nodes)} APKs with num_nodes = 0 found. Writing their names to zero_nodes.list")
116 | save_list(zero_nodes, 'zero_nodes.list')
117 | input_stats = input_stats.query('num_nodes != 0')
118 | train_list, test_list = get_dataset(input_stats, 0.2, output_dir)
119 | save_list(train_list, 'train.list')
120 | save_list(test_list, 'test.list')
121 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 |
4 | import hydra
5 | import wandb
6 | import torch
7 | from omegaconf import DictConfig
8 | from pytorch_lightning import Trainer
9 | from pytorch_lightning.callbacks import ModelCheckpoint
10 | from pytorch_lightning.loggers import WandbLogger
11 |
12 | from core.callbacks import InputMonitor, BestModelTagger, MetricsLogger
13 | from core.data_module import MalwareDataModule
14 | from core.model import MalwareDetector
15 |
16 |
17 | @hydra.main(config_path="config", config_name="conf")
18 | def train_model(cfg: DictConfig) -> None:
19 | data_module = MalwareDataModule(**cfg['data'])
20 |
21 | model = MalwareDetector(**cfg['model'])
22 |
23 | callbacks = [ModelCheckpoint(
24 | dirpath=os.getcwd(),
25 | filename=str('{epoch:02d}-{val_loss:.2f}.pt'),
26 | monitor='val_loss',
27 | mode='min',
28 | save_last=True,
29 | save_top_k=-1
30 | )]
31 |
32 | trainer_kwargs = dict(cfg['trainer'])
33 | force_retrain = cfg.get('force_retrain', False)
34 | if Path('last.ckpt').exists() and not force_retrain:
35 | trainer_kwargs['resume_from_checkpoint'] = 'last.ckpt'
36 |
37 | if 'logger' in cfg:
38 | # We use WandB logger
39 | logger = WandbLogger(
40 | **cfg['logger']['args'],
41 | tags=[f'testing' if "testing" in cfg else "training"]
42 | )
43 | if "testing" in cfg:
44 | logger.experiment.summary["test_type"] = cfg["testing"]
45 | logger.watch(model)
46 | logger.log_hyperparams(cfg['logger']['hparams'])
47 | if logger:
48 | trainer_kwargs['logger'] = logger
49 | callbacks.append(InputMonitor())
50 | callbacks.append(BestModelTagger(monitor='val_loss', mode='min'))
51 | callbacks.append(MetricsLogger(stages='all'))
52 |
53 | trainer = Trainer(
54 | callbacks=callbacks,
55 | **trainer_kwargs
56 | )
57 | testing = cfg.get('testing', '')
58 | if not testing:
59 | trainer.fit(model, datamodule=data_module)
60 | else:
61 | if testing not in ['last', 'best'] and 'epoch' not in testing:
62 | raise ValueError(f"testing must be one of 'best' or 'last' or 'epoch=N'. It is {testing}")
63 | elif 'epoch' in testing:
64 | # epoch in testing
65 | epoch = testing.split('@')[1]
66 | checkpoints = list(Path(os.getcwd()).glob(f"epoch={epoch}*.ckpt"))
67 | if len(checkpoints) < 0:
68 | print(f"Checkpoint at epoch = {epoch} not found.")
69 | assert len(checkpoints) == 1, f"Multiple checkpoints corresponding to epoch = {epoch} found."
70 | ckpt_path = checkpoints[0]
71 | else:
72 | if not Path('last.ckpt').exists():
73 | raise FileNotFoundError("No last.ckpt exists. Could not do any testing.")
74 | if testing == 'last':
75 | ckpt_path = 'last.ckpt'
76 | else:
77 | # best
78 | last_checkpoint = torch.load('last.ckpt')
79 | ckpt_path = last_checkpoint['callbacks'][ModelCheckpoint]['best_model_path']
80 | print(f"Using checkpoint {ckpt_path} for testing.")
81 | model = MalwareDetector.load_from_checkpoint(ckpt_path, **cfg['model'])
82 | trainer.test(model, datamodule=data_module, verbose=True)
83 | wandb.finish()
84 |
85 |
86 | if __name__ == '__main__':
87 | train_model()
88 |
--------------------------------------------------------------------------------