├── .dockerignore ├── .gitignore ├── Dockerfile ├── LICENSE ├── Plots.ipynb ├── assets └── architecture.svg ├── config ├── conf.yaml ├── features │ ├── degree.yaml │ ├── method_attributes.yaml │ └── method_summary.yaml └── logger │ └── wandb.yaml ├── core ├── __init__.py ├── callbacks.py ├── data_module.py ├── dataset.py ├── model.py └── utils.py ├── data ├── README.md ├── test_new.sha256 ├── test_old.sha256 ├── train_new.sha256 └── train_old.sha256 ├── malware-learning.def ├── metadata └── api.list ├── notebooks ├── 0-Preliminary Analysis.ipynb ├── 1-AGfeatures.ipynb ├── 2-GFeatures.ipynb ├── 3-APIFeatures.ipynb ├── 4-APIFeatures-Binary.ipynb └── 5-AGfeatures-Binary.ipynb ├── readme.md ├── requirements.txt ├── scripts ├── __init__.py ├── plot_callgraph.py ├── process_dataset.py └── split_dataset.py └── train_model.py /.dockerignore: -------------------------------------------------------------------------------- 1 | data/ 2 | temp/ 3 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | data/ 2 | models/ 3 | notebooks/lightning_logs 4 | .idea 5 | output*/ 6 | metadata/ 7 | # Dataset list 8 | *.list 9 | # Pickle 10 | *.pkl 11 | # Java files 12 | *.jar 13 | # Singularity Image 14 | *.sif 15 | # Nano saved 16 | *.save 17 | # Byte-compiled / optimized / DLL files 18 | __pycache__/ 19 | *.py[cod] 20 | *$py.class 21 | 22 | # C extensions 23 | *.so 24 | 25 | # Distribution / packaging 26 | .Python 27 | build/ 28 | develop-eggs/ 29 | dist/ 30 | downloads/ 31 | eggs/ 32 | .eggs/ 33 | lib/ 34 | lib64/ 35 | parts/ 36 | sdist/ 37 | var/ 38 | wheels/ 39 | pip-wheel-metadata/ 40 | share/python-wheels/ 41 | *.egg-info/ 42 | .installed.cfg 43 | *.egg 44 | MANIFEST 45 | 46 | # PyInstaller 47 | # Usually these files are written by a python script from a template 48 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 49 | *.manifest 50 | *.spec 51 | 52 | # Installer logs 53 | pip-log.txt 54 | pip-delete-this-directory.txt 55 | 56 | # Unit test / coverage reports 57 | htmlcov/ 58 | .tox/ 59 | .nox/ 60 | .coverage 61 | .coverage.* 62 | .cache 63 | nosetests.xml 64 | coverage.xml 65 | *.cover 66 | *.py,cover 67 | .hypothesis/ 68 | .pytest_cache/ 69 | 70 | # Translations 71 | *.mo 72 | *.pot 73 | 74 | # Django stuff: 75 | *.log 76 | local_settings.py 77 | db.sqlite3 78 | db.sqlite3-journal 79 | 80 | # Flask stuff: 81 | instance/ 82 | .webassets-cache 83 | 84 | # Scrapy stuff: 85 | .scrapy 86 | 87 | # Sphinx documentation 88 | docs/_build/ 89 | 90 | # PyBuilder 91 | target/ 92 | 93 | # Jupyter Notebook 94 | .ipynb_checkpoints 95 | *ipynb 96 | 97 | # IPython 98 | profile_default/ 99 | ipython_config.py 100 | 101 | # pyenv 102 | .python-version 103 | 104 | # pipenv 105 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 106 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 107 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 108 | # install all needed dependencies. 109 | #Pipfile.lock 110 | 111 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 112 | __pypackages__/ 113 | 114 | # Celery stuff 115 | celerybeat-schedule 116 | celerybeat.pid 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | .dmypy.json 143 | dmypy.json 144 | 145 | # Pyre type checker 146 | .pyre/ -------------------------------------------------------------------------------- /Dockerfile: -------------------------------------------------------------------------------- 1 | FROM pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel 2 | COPY requirements.txt /mnt/ 3 | RUN apt-get update && apt-get install -y git graphviz graphviz-dev && rm -rf /var/lib/apt/lists/* 4 | RUN pip install -r /mnt/requirements.txt 5 | RUN git clone https://github.com/androguard/androguard.git && cd androguard && python setup.py install 6 | VOLUME /model 7 | WORKDIR /model -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | GNU GENERAL PUBLIC LICENSE 2 | Version 3, 29 June 2007 3 | 4 | Copyright (C) 2007 Free Software Foundation, Inc. 5 | Everyone is permitted to copy and distribute verbatim copies 6 | of this license document, but changing it is not allowed. 7 | 8 | Preamble 9 | 10 | The GNU General Public License is a free, copyleft license for 11 | software and other kinds of works. 12 | 13 | The licenses for most software and other practical works are designed 14 | to take away your freedom to share and change the works. By contrast, 15 | the GNU General Public License is intended to guarantee your freedom to 16 | share and change all versions of a program--to make sure it remains free 17 | software for all its users. We, the Free Software Foundation, use the 18 | GNU General Public License for most of our software; it applies also to 19 | any other work released this way by its authors. You can apply it to 20 | your programs, too. 21 | 22 | When we speak of free software, we are referring to freedom, not 23 | price. Our General Public Licenses are designed to make sure that you 24 | have the freedom to distribute copies of free software (and charge for 25 | them if you wish), that you receive source code or can get it if you 26 | want it, that you can change the software or use pieces of it in new 27 | free programs, and that you know you can do these things. 28 | 29 | To protect your rights, we need to prevent others from denying you 30 | these rights or asking you to surrender the rights. Therefore, you have 31 | certain responsibilities if you distribute copies of the software, or if 32 | you modify it: responsibilities to respect the freedom of others. 33 | 34 | For example, if you distribute copies of such a program, whether 35 | gratis or for a fee, you must pass on to the recipients the same 36 | freedoms that you received. You must make sure that they, too, receive 37 | or can get the source code. And you must show them these terms so they 38 | know their rights. 39 | 40 | Developers that use the GNU GPL protect your rights with two steps: 41 | (1) assert copyright on the software, and (2) offer you this License 42 | giving you legal permission to copy, distribute and/or modify it. 43 | 44 | For the developers' and authors' protection, the GPL clearly explains 45 | that there is no warranty for this free software. For both users' and 46 | authors' sake, the GPL requires that modified versions be marked as 47 | changed, so that their problems will not be attributed erroneously to 48 | authors of previous versions. 49 | 50 | Some devices are designed to deny users access to install or run 51 | modified versions of the software inside them, although the manufacturer 52 | can do so. This is fundamentally incompatible with the aim of 53 | protecting users' freedom to change the software. The systematic 54 | pattern of such abuse occurs in the area of products for individuals to 55 | use, which is precisely where it is most unacceptable. Therefore, we 56 | have designed this version of the GPL to prohibit the practice for those 57 | products. If such problems arise substantially in other domains, we 58 | stand ready to extend this provision to those domains in future versions 59 | of the GPL, as needed to protect the freedom of users. 60 | 61 | Finally, every program is threatened constantly by software patents. 62 | States should not allow patents to restrict development and use of 63 | software on general-purpose computers, but in those that do, we wish to 64 | avoid the special danger that patents applied to a free program could 65 | make it effectively proprietary. To prevent this, the GPL assures that 66 | patents cannot be used to render the program non-free. 67 | 68 | The precise terms and conditions for copying, distribution and 69 | modification follow. 70 | 71 | TERMS AND CONDITIONS 72 | 73 | 0. Definitions. 74 | 75 | "This License" refers to version 3 of the GNU General Public License. 76 | 77 | "Copyright" also means copyright-like laws that apply to other kinds of 78 | works, such as semiconductor masks. 79 | 80 | "The Program" refers to any copyrightable work licensed under this 81 | License. Each licensee is addressed as "you". "Licensees" and 82 | "recipients" may be individuals or organizations. 83 | 84 | To "modify" a work means to copy from or adapt all or part of the work 85 | in a fashion requiring copyright permission, other than the making of an 86 | exact copy. The resulting work is called a "modified version" of the 87 | earlier work or a work "based on" the earlier work. 88 | 89 | A "covered work" means either the unmodified Program or a work based 90 | on the Program. 91 | 92 | To "propagate" a work means to do anything with it that, without 93 | permission, would make you directly or secondarily liable for 94 | infringement under applicable copyright law, except executing it on a 95 | computer or modifying a private copy. Propagation includes copying, 96 | distribution (with or without modification), making available to the 97 | public, and in some countries other activities as well. 98 | 99 | To "convey" a work means any kind of propagation that enables other 100 | parties to make or receive copies. Mere interaction with a user through 101 | a computer network, with no transfer of a copy, is not conveying. 102 | 103 | An interactive user interface displays "Appropriate Legal Notices" 104 | to the extent that it includes a convenient and prominently visible 105 | feature that (1) displays an appropriate copyright notice, and (2) 106 | tells the user that there is no warranty for the work (except to the 107 | extent that warranties are provided), that licensees may convey the 108 | work under this License, and how to view a copy of this License. If 109 | the interface presents a list of user commands or options, such as a 110 | menu, a prominent item in the list meets this criterion. 111 | 112 | 1. Source Code. 113 | 114 | The "source code" for a work means the preferred form of the work 115 | for making modifications to it. "Object code" means any non-source 116 | form of a work. 117 | 118 | A "Standard Interface" means an interface that either is an official 119 | standard defined by a recognized standards body, or, in the case of 120 | interfaces specified for a particular programming language, one that 121 | is widely used among developers working in that language. 122 | 123 | The "System Libraries" of an executable work include anything, other 124 | than the work as a whole, that (a) is included in the normal form of 125 | packaging a Major Component, but which is not part of that Major 126 | Component, and (b) serves only to enable use of the work with that 127 | Major Component, or to implement a Standard Interface for which an 128 | implementation is available to the public in source code form. A 129 | "Major Component", in this context, means a major essential component 130 | (kernel, window system, and so on) of the specific operating system 131 | (if any) on which the executable work runs, or a compiler used to 132 | produce the work, or an object code interpreter used to run it. 133 | 134 | The "Corresponding Source" for a work in object code form means all 135 | the source code needed to generate, install, and (for an executable 136 | work) run the object code and to modify the work, including scripts to 137 | control those activities. However, it does not include the work's 138 | System Libraries, or general-purpose tools or generally available free 139 | programs which are used unmodified in performing those activities but 140 | which are not part of the work. For example, Corresponding Source 141 | includes interface definition files associated with source files for 142 | the work, and the source code for shared libraries and dynamically 143 | linked subprograms that the work is specifically designed to require, 144 | such as by intimate data communication or control flow between those 145 | subprograms and other parts of the work. 146 | 147 | The Corresponding Source need not include anything that users 148 | can regenerate automatically from other parts of the Corresponding 149 | Source. 150 | 151 | The Corresponding Source for a work in source code form is that 152 | same work. 153 | 154 | 2. Basic Permissions. 155 | 156 | All rights granted under this License are granted for the term of 157 | copyright on the Program, and are irrevocable provided the stated 158 | conditions are met. This License explicitly affirms your unlimited 159 | permission to run the unmodified Program. The output from running a 160 | covered work is covered by this License only if the output, given its 161 | content, constitutes a covered work. This License acknowledges your 162 | rights of fair use or other equivalent, as provided by copyright law. 163 | 164 | You may make, run and propagate covered works that you do not 165 | convey, without conditions so long as your license otherwise remains 166 | in force. You may convey covered works to others for the sole purpose 167 | of having them make modifications exclusively for you, or provide you 168 | with facilities for running those works, provided that you comply with 169 | the terms of this License in conveying all material for which you do 170 | not control copyright. Those thus making or running the covered works 171 | for you must do so exclusively on your behalf, under your direction 172 | and control, on terms that prohibit them from making any copies of 173 | your copyrighted material outside their relationship with you. 174 | 175 | Conveying under any other circumstances is permitted solely under 176 | the conditions stated below. Sublicensing is not allowed; section 10 177 | makes it unnecessary. 178 | 179 | 3. Protecting Users' Legal Rights From Anti-Circumvention Law. 180 | 181 | No covered work shall be deemed part of an effective technological 182 | measure under any applicable law fulfilling obligations under article 183 | 11 of the WIPO copyright treaty adopted on 20 December 1996, or 184 | similar laws prohibiting or restricting circumvention of such 185 | measures. 186 | 187 | When you convey a covered work, you waive any legal power to forbid 188 | circumvention of technological measures to the extent such circumvention 189 | is effected by exercising rights under this License with respect to 190 | the covered work, and you disclaim any intention to limit operation or 191 | modification of the work as a means of enforcing, against the work's 192 | users, your or third parties' legal rights to forbid circumvention of 193 | technological measures. 194 | 195 | 4. Conveying Verbatim Copies. 196 | 197 | You may convey verbatim copies of the Program's source code as you 198 | receive it, in any medium, provided that you conspicuously and 199 | appropriately publish on each copy an appropriate copyright notice; 200 | keep intact all notices stating that this License and any 201 | non-permissive terms added in accord with section 7 apply to the code; 202 | keep intact all notices of the absence of any warranty; and give all 203 | recipients a copy of this License along with the Program. 204 | 205 | You may charge any price or no price for each copy that you convey, 206 | and you may offer support or warranty protection for a fee. 207 | 208 | 5. Conveying Modified Source Versions. 209 | 210 | You may convey a work based on the Program, or the modifications to 211 | produce it from the Program, in the form of source code under the 212 | terms of section 4, provided that you also meet all of these conditions: 213 | 214 | a) The work must carry prominent notices stating that you modified 215 | it, and giving a relevant date. 216 | 217 | b) The work must carry prominent notices stating that it is 218 | released under this License and any conditions added under section 219 | 7. This requirement modifies the requirement in section 4 to 220 | "keep intact all notices". 221 | 222 | c) You must license the entire work, as a whole, under this 223 | License to anyone who comes into possession of a copy. This 224 | License will therefore apply, along with any applicable section 7 225 | additional terms, to the whole of the work, and all its parts, 226 | regardless of how they are packaged. This License gives no 227 | permission to license the work in any other way, but it does not 228 | invalidate such permission if you have separately received it. 229 | 230 | d) If the work has interactive user interfaces, each must display 231 | Appropriate Legal Notices; however, if the Program has interactive 232 | interfaces that do not display Appropriate Legal Notices, your 233 | work need not make them do so. 234 | 235 | A compilation of a covered work with other separate and independent 236 | works, which are not by their nature extensions of the covered work, 237 | and which are not combined with it such as to form a larger program, 238 | in or on a volume of a storage or distribution medium, is called an 239 | "aggregate" if the compilation and its resulting copyright are not 240 | used to limit the access or legal rights of the compilation's users 241 | beyond what the individual works permit. Inclusion of a covered work 242 | in an aggregate does not cause this License to apply to the other 243 | parts of the aggregate. 244 | 245 | 6. Conveying Non-Source Forms. 246 | 247 | You may convey a covered work in object code form under the terms 248 | of sections 4 and 5, provided that you also convey the 249 | machine-readable Corresponding Source under the terms of this License, 250 | in one of these ways: 251 | 252 | a) Convey the object code in, or embodied in, a physical product 253 | (including a physical distribution medium), accompanied by the 254 | Corresponding Source fixed on a durable physical medium 255 | customarily used for software interchange. 256 | 257 | b) Convey the object code in, or embodied in, a physical product 258 | (including a physical distribution medium), accompanied by a 259 | written offer, valid for at least three years and valid for as 260 | long as you offer spare parts or customer support for that product 261 | model, to give anyone who possesses the object code either (1) a 262 | copy of the Corresponding Source for all the software in the 263 | product that is covered by this License, on a durable physical 264 | medium customarily used for software interchange, for a price no 265 | more than your reasonable cost of physically performing this 266 | conveying of source, or (2) access to copy the 267 | Corresponding Source from a network server at no charge. 268 | 269 | c) Convey individual copies of the object code with a copy of the 270 | written offer to provide the Corresponding Source. This 271 | alternative is allowed only occasionally and noncommercially, and 272 | only if you received the object code with such an offer, in accord 273 | with subsection 6b. 274 | 275 | d) Convey the object code by offering access from a designated 276 | place (gratis or for a charge), and offer equivalent access to the 277 | Corresponding Source in the same way through the same place at no 278 | further charge. You need not require recipients to copy the 279 | Corresponding Source along with the object code. If the place to 280 | copy the object code is a network server, the Corresponding Source 281 | may be on a different server (operated by you or a third party) 282 | that supports equivalent copying facilities, provided you maintain 283 | clear directions next to the object code saying where to find the 284 | Corresponding Source. Regardless of what server hosts the 285 | Corresponding Source, you remain obligated to ensure that it is 286 | available for as long as needed to satisfy these requirements. 287 | 288 | e) Convey the object code using peer-to-peer transmission, provided 289 | you inform other peers where the object code and Corresponding 290 | Source of the work are being offered to the general public at no 291 | charge under subsection 6d. 292 | 293 | A separable portion of the object code, whose source code is excluded 294 | from the Corresponding Source as a System Library, need not be 295 | included in conveying the object code work. 296 | 297 | A "User Product" is either (1) a "consumer product", which means any 298 | tangible personal property which is normally used for personal, family, 299 | or household purposes, or (2) anything designed or sold for incorporation 300 | into a dwelling. In determining whether a product is a consumer product, 301 | doubtful cases shall be resolved in favor of coverage. For a particular 302 | product received by a particular user, "normally used" refers to a 303 | typical or common use of that class of product, regardless of the status 304 | of the particular user or of the way in which the particular user 305 | actually uses, or expects or is expected to use, the product. A product 306 | is a consumer product regardless of whether the product has substantial 307 | commercial, industrial or non-consumer uses, unless such uses represent 308 | the only significant mode of use of the product. 309 | 310 | "Installation Information" for a User Product means any methods, 311 | procedures, authorization keys, or other information required to install 312 | and execute modified versions of a covered work in that User Product from 313 | a modified version of its Corresponding Source. The information must 314 | suffice to ensure that the continued functioning of the modified object 315 | code is in no case prevented or interfered with solely because 316 | modification has been made. 317 | 318 | If you convey an object code work under this section in, or with, or 319 | specifically for use in, a User Product, and the conveying occurs as 320 | part of a transaction in which the right of possession and use of the 321 | User Product is transferred to the recipient in perpetuity or for a 322 | fixed term (regardless of how the transaction is characterized), the 323 | Corresponding Source conveyed under this section must be accompanied 324 | by the Installation Information. But this requirement does not apply 325 | if neither you nor any third party retains the ability to install 326 | modified object code on the User Product (for example, the work has 327 | been installed in ROM). 328 | 329 | The requirement to provide Installation Information does not include a 330 | requirement to continue to provide support service, warranty, or updates 331 | for a work that has been modified or installed by the recipient, or for 332 | the User Product in which it has been modified or installed. Access to a 333 | network may be denied when the modification itself materially and 334 | adversely affects the operation of the network or violates the rules and 335 | protocols for communication across the network. 336 | 337 | Corresponding Source conveyed, and Installation Information provided, 338 | in accord with this section must be in a format that is publicly 339 | documented (and with an implementation available to the public in 340 | source code form), and must require no special password or key for 341 | unpacking, reading or copying. 342 | 343 | 7. Additional Terms. 344 | 345 | "Additional permissions" are terms that supplement the terms of this 346 | License by making exceptions from one or more of its conditions. 347 | Additional permissions that are applicable to the entire Program shall 348 | be treated as though they were included in this License, to the extent 349 | that they are valid under applicable law. If additional permissions 350 | apply only to part of the Program, that part may be used separately 351 | under those permissions, but the entire Program remains governed by 352 | this License without regard to the additional permissions. 353 | 354 | When you convey a copy of a covered work, you may at your option 355 | remove any additional permissions from that copy, or from any part of 356 | it. (Additional permissions may be written to require their own 357 | removal in certain cases when you modify the work.) You may place 358 | additional permissions on material, added by you to a covered work, 359 | for which you have or can give appropriate copyright permission. 360 | 361 | Notwithstanding any other provision of this License, for material you 362 | add to a covered work, you may (if authorized by the copyright holders of 363 | that material) supplement the terms of this License with terms: 364 | 365 | a) Disclaiming warranty or limiting liability differently from the 366 | terms of sections 15 and 16 of this License; or 367 | 368 | b) Requiring preservation of specified reasonable legal notices or 369 | author attributions in that material or in the Appropriate Legal 370 | Notices displayed by works containing it; or 371 | 372 | c) Prohibiting misrepresentation of the origin of that material, or 373 | requiring that modified versions of such material be marked in 374 | reasonable ways as different from the original version; or 375 | 376 | d) Limiting the use for publicity purposes of names of licensors or 377 | authors of the material; or 378 | 379 | e) Declining to grant rights under trademark law for use of some 380 | trade names, trademarks, or service marks; or 381 | 382 | f) Requiring indemnification of licensors and authors of that 383 | material by anyone who conveys the material (or modified versions of 384 | it) with contractual assumptions of liability to the recipient, for 385 | any liability that these contractual assumptions directly impose on 386 | those licensors and authors. 387 | 388 | All other non-permissive additional terms are considered "further 389 | restrictions" within the meaning of section 10. If the Program as you 390 | received it, or any part of it, contains a notice stating that it is 391 | governed by this License along with a term that is a further 392 | restriction, you may remove that term. If a license document contains 393 | a further restriction but permits relicensing or conveying under this 394 | License, you may add to a covered work material governed by the terms 395 | of that license document, provided that the further restriction does 396 | not survive such relicensing or conveying. 397 | 398 | If you add terms to a covered work in accord with this section, you 399 | must place, in the relevant source files, a statement of the 400 | additional terms that apply to those files, or a notice indicating 401 | where to find the applicable terms. 402 | 403 | Additional terms, permissive or non-permissive, may be stated in the 404 | form of a separately written license, or stated as exceptions; 405 | the above requirements apply either way. 406 | 407 | 8. Termination. 408 | 409 | You may not propagate or modify a covered work except as expressly 410 | provided under this License. Any attempt otherwise to propagate or 411 | modify it is void, and will automatically terminate your rights under 412 | this License (including any patent licenses granted under the third 413 | paragraph of section 11). 414 | 415 | However, if you cease all violation of this License, then your 416 | license from a particular copyright holder is reinstated (a) 417 | provisionally, unless and until the copyright holder explicitly and 418 | finally terminates your license, and (b) permanently, if the copyright 419 | holder fails to notify you of the violation by some reasonable means 420 | prior to 60 days after the cessation. 421 | 422 | Moreover, your license from a particular copyright holder is 423 | reinstated permanently if the copyright holder notifies you of the 424 | violation by some reasonable means, this is the first time you have 425 | received notice of violation of this License (for any work) from that 426 | copyright holder, and you cure the violation prior to 30 days after 427 | your receipt of the notice. 428 | 429 | Termination of your rights under this section does not terminate the 430 | licenses of parties who have received copies or rights from you under 431 | this License. If your rights have been terminated and not permanently 432 | reinstated, you do not qualify to receive new licenses for the same 433 | material under section 10. 434 | 435 | 9. Acceptance Not Required for Having Copies. 436 | 437 | You are not required to accept this License in order to receive or 438 | run a copy of the Program. Ancillary propagation of a covered work 439 | occurring solely as a consequence of using peer-to-peer transmission 440 | to receive a copy likewise does not require acceptance. However, 441 | nothing other than this License grants you permission to propagate or 442 | modify any covered work. These actions infringe copyright if you do 443 | not accept this License. Therefore, by modifying or propagating a 444 | covered work, you indicate your acceptance of this License to do so. 445 | 446 | 10. Automatic Licensing of Downstream Recipients. 447 | 448 | Each time you convey a covered work, the recipient automatically 449 | receives a license from the original licensors, to run, modify and 450 | propagate that work, subject to this License. You are not responsible 451 | for enforcing compliance by third parties with this License. 452 | 453 | An "entity transaction" is a transaction transferring control of an 454 | organization, or substantially all assets of one, or subdividing an 455 | organization, or merging organizations. If propagation of a covered 456 | work results from an entity transaction, each party to that 457 | transaction who receives a copy of the work also receives whatever 458 | licenses to the work the party's predecessor in interest had or could 459 | give under the previous paragraph, plus a right to possession of the 460 | Corresponding Source of the work from the predecessor in interest, if 461 | the predecessor has it or can get it with reasonable efforts. 462 | 463 | You may not impose any further restrictions on the exercise of the 464 | rights granted or affirmed under this License. For example, you may 465 | not impose a license fee, royalty, or other charge for exercise of 466 | rights granted under this License, and you may not initiate litigation 467 | (including a cross-claim or counterclaim in a lawsuit) alleging that 468 | any patent claim is infringed by making, using, selling, offering for 469 | sale, or importing the Program or any portion of it. 470 | 471 | 11. Patents. 472 | 473 | A "contributor" is a copyright holder who authorizes use under this 474 | License of the Program or a work on which the Program is based. The 475 | work thus licensed is called the contributor's "contributor version". 476 | 477 | A contributor's "essential patent claims" are all patent claims 478 | owned or controlled by the contributor, whether already acquired or 479 | hereafter acquired, that would be infringed by some manner, permitted 480 | by this License, of making, using, or selling its contributor version, 481 | but do not include claims that would be infringed only as a 482 | consequence of further modification of the contributor version. For 483 | purposes of this definition, "control" includes the right to grant 484 | patent sublicenses in a manner consistent with the requirements of 485 | this License. 486 | 487 | Each contributor grants you a non-exclusive, worldwide, royalty-free 488 | patent license under the contributor's essential patent claims, to 489 | make, use, sell, offer for sale, import and otherwise run, modify and 490 | propagate the contents of its contributor version. 491 | 492 | In the following three paragraphs, a "patent license" is any express 493 | agreement or commitment, however denominated, not to enforce a patent 494 | (such as an express permission to practice a patent or covenant not to 495 | sue for patent infringement). To "grant" such a patent license to a 496 | party means to make such an agreement or commitment not to enforce a 497 | patent against the party. 498 | 499 | If you convey a covered work, knowingly relying on a patent license, 500 | and the Corresponding Source of the work is not available for anyone 501 | to copy, free of charge and under the terms of this License, through a 502 | publicly available network server or other readily accessible means, 503 | then you must either (1) cause the Corresponding Source to be so 504 | available, or (2) arrange to deprive yourself of the benefit of the 505 | patent license for this particular work, or (3) arrange, in a manner 506 | consistent with the requirements of this License, to extend the patent 507 | license to downstream recipients. "Knowingly relying" means you have 508 | actual knowledge that, but for the patent license, your conveying the 509 | covered work in a country, or your recipient's use of the covered work 510 | in a country, would infringe one or more identifiable patents in that 511 | country that you have reason to believe are valid. 512 | 513 | If, pursuant to or in connection with a single transaction or 514 | arrangement, you convey, or propagate by procuring conveyance of, a 515 | covered work, and grant a patent license to some of the parties 516 | receiving the covered work authorizing them to use, propagate, modify 517 | or convey a specific copy of the covered work, then the patent license 518 | you grant is automatically extended to all recipients of the covered 519 | work and works based on it. 520 | 521 | A patent license is "discriminatory" if it does not include within 522 | the scope of its coverage, prohibits the exercise of, or is 523 | conditioned on the non-exercise of one or more of the rights that are 524 | specifically granted under this License. You may not convey a covered 525 | work if you are a party to an arrangement with a third party that is 526 | in the business of distributing software, under which you make payment 527 | to the third party based on the extent of your activity of conveying 528 | the work, and under which the third party grants, to any of the 529 | parties who would receive the covered work from you, a discriminatory 530 | patent license (a) in connection with copies of the covered work 531 | conveyed by you (or copies made from those copies), or (b) primarily 532 | for and in connection with specific products or compilations that 533 | contain the covered work, unless you entered into that arrangement, 534 | or that patent license was granted, prior to 28 March 2007. 535 | 536 | Nothing in this License shall be construed as excluding or limiting 537 | any implied license or other defenses to infringement that may 538 | otherwise be available to you under applicable patent law. 539 | 540 | 12. No Surrender of Others' Freedom. 541 | 542 | If conditions are imposed on you (whether by court order, agreement or 543 | otherwise) that contradict the conditions of this License, they do not 544 | excuse you from the conditions of this License. If you cannot convey a 545 | covered work so as to satisfy simultaneously your obligations under this 546 | License and any other pertinent obligations, then as a consequence you may 547 | not convey it at all. For example, if you agree to terms that obligate you 548 | to collect a royalty for further conveying from those to whom you convey 549 | the Program, the only way you could satisfy both those terms and this 550 | License would be to refrain entirely from conveying the Program. 551 | 552 | 13. Use with the GNU Affero General Public License. 553 | 554 | Notwithstanding any other provision of this License, you have 555 | permission to link or combine any covered work with a work licensed 556 | under version 3 of the GNU Affero General Public License into a single 557 | combined work, and to convey the resulting work. The terms of this 558 | License will continue to apply to the part which is the covered work, 559 | but the special requirements of the GNU Affero General Public License, 560 | section 13, concerning interaction through a network will apply to the 561 | combination as such. 562 | 563 | 14. Revised Versions of this License. 564 | 565 | The Free Software Foundation may publish revised and/or new versions of 566 | the GNU General Public License from time to time. Such new versions will 567 | be similar in spirit to the present version, but may differ in detail to 568 | address new problems or concerns. 569 | 570 | Each version is given a distinguishing version number. If the 571 | Program specifies that a certain numbered version of the GNU General 572 | Public License "or any later version" applies to it, you have the 573 | option of following the terms and conditions either of that numbered 574 | version or of any later version published by the Free Software 575 | Foundation. If the Program does not specify a version number of the 576 | GNU General Public License, you may choose any version ever published 577 | by the Free Software Foundation. 578 | 579 | If the Program specifies that a proxy can decide which future 580 | versions of the GNU General Public License can be used, that proxy's 581 | public statement of acceptance of a version permanently authorizes you 582 | to choose that version for the Program. 583 | 584 | Later license versions may give you additional or different 585 | permissions. However, no additional obligations are imposed on any 586 | author or copyright holder as a result of your choosing to follow a 587 | later version. 588 | 589 | 15. Disclaimer of Warranty. 590 | 591 | THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY 592 | APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT 593 | HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY 594 | OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, 595 | THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR 596 | PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM 597 | IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF 598 | ALL NECESSARY SERVICING, REPAIR OR CORRECTION. 599 | 600 | 16. Limitation of Liability. 601 | 602 | IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING 603 | WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS 604 | THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY 605 | GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE 606 | USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF 607 | DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD 608 | PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), 609 | EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF 610 | SUCH DAMAGES. 611 | 612 | 17. Interpretation of Sections 15 and 16. 613 | 614 | If the disclaimer of warranty and limitation of liability provided 615 | above cannot be given local legal effect according to their terms, 616 | reviewing courts shall apply local law that most closely approximates 617 | an absolute waiver of all civil liability in connection with the 618 | Program, unless a warranty or assumption of liability accompanies a 619 | copy of the Program in return for a fee. 620 | 621 | END OF TERMS AND CONDITIONS 622 | 623 | How to Apply These Terms to Your New Programs 624 | 625 | If you develop a new program, and you want it to be of the greatest 626 | possible use to the public, the best way to achieve this is to make it 627 | free software which everyone can redistribute and change under these terms. 628 | 629 | To do so, attach the following notices to the program. It is safest 630 | to attach them to the start of each source file to most effectively 631 | state the exclusion of warranty; and each file should have at least 632 | the "copyright" line and a pointer to where the full notice is found. 633 | 634 | 635 | Copyright (C) 636 | 637 | This program is free software: you can redistribute it and/or modify 638 | it under the terms of the GNU General Public License as published by 639 | the Free Software Foundation, either version 3 of the License, or 640 | (at your option) any later version. 641 | 642 | This program is distributed in the hope that it will be useful, 643 | but WITHOUT ANY WARRANTY; without even the implied warranty of 644 | MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 645 | GNU General Public License for more details. 646 | 647 | You should have received a copy of the GNU General Public License 648 | along with this program. If not, see . 649 | 650 | Also add information on how to contact you by electronic and paper mail. 651 | 652 | If the program does terminal interaction, make it output a short 653 | notice like this when it starts in an interactive mode: 654 | 655 | Copyright (C) 656 | This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. 657 | This is free software, and you are welcome to redistribute it 658 | under certain conditions; type `show c' for details. 659 | 660 | The hypothetical commands `show w' and `show c' should show the appropriate 661 | parts of the General Public License. Of course, your program's commands 662 | might be different; for a GUI interface, you would use an "about box". 663 | 664 | You should also get your employer (if you work as a programmer) or school, 665 | if any, to sign a "copyright disclaimer" for the program, if necessary. 666 | For more information on this, and how to apply and follow the GNU GPL, see 667 | . 668 | 669 | The GNU General Public License does not permit incorporating your program 670 | into proprietary programs. If your program is a subroutine library, you 671 | may consider it more useful to permit linking proprietary applications with 672 | the library. If this is what you want to do, use the GNU Lesser General 673 | Public License instead of this License. But first, please read 674 | . 675 | -------------------------------------------------------------------------------- /config/conf.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | train_dir: ${env:PWD}/data/train 3 | test_dir: ${env:PWD}/data/test 4 | batch_size: 16 5 | pin_memory: false 6 | num_workers: 6 7 | split_train_val: true 8 | split_ratios: [0.75, 0.25] 9 | consider_features: ${features.attributes} 10 | 11 | model: 12 | convolution_count: 0 13 | convolution_algorithm: GraphConv # Can be one of GraphConv, SAGEConv, TAGConv, SGConv, DotGatConv 14 | input_dimension: ${features.size} 15 | 16 | trainer: 17 | max_epochs: 100 18 | gpus: null 19 | 20 | hydra: 21 | run: 22 | dir: output/${model.convolution_algorithm}/${features.name}-conv_count=${model.convolution_count} 23 | sweep: 24 | dir: output/${model.convolution_algorithm} 25 | subdir: ${features.name}-conv_count=${model.convolution_count} 26 | 27 | defaults: 28 | - features: degree 29 | - logger: wandb 30 | -------------------------------------------------------------------------------- /config/features/degree.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: degree 3 | attributes: [] 4 | size: 1 -------------------------------------------------------------------------------- /config/features/method_attributes.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: method 3 | attributes: 4 | - external 5 | - native 6 | - public 7 | - static 8 | - codesize 9 | size: 5 -------------------------------------------------------------------------------- /config/features/method_summary.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: hybrid 3 | attributes: 4 | - api 5 | - user 6 | size: 247 7 | -------------------------------------------------------------------------------- /config/logger/wandb.yaml: -------------------------------------------------------------------------------- 1 | # @package _group_ 2 | name: wandb 3 | args: 4 | project: malware_homo 5 | log_model: true 6 | hparams: 7 | convolution_algorithm: ${model.convolution_algorithm} 8 | features: ${features.name} 9 | convolution_count: ${model.convolution_count} -------------------------------------------------------------------------------- /core/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinayakakv/android-malware-detection/1aab288ec599a3958982866ce989311a96cbffd9/core/__init__.py -------------------------------------------------------------------------------- /core/callbacks.py: -------------------------------------------------------------------------------- 1 | import numbers 2 | from typing import Tuple, List, Union 3 | 4 | import dgl 5 | import matplotlib.pyplot as plt 6 | import numpy as np 7 | import pytorch_lightning as pl 8 | import torch.nn 9 | import wandb 10 | from pytorch_lightning.callbacks import Callback 11 | from pytorch_lightning.metrics.metric import Metric 12 | 13 | from core.model import MalwareDetector 14 | from core.utils import plot_confusion_matrix, plot_curve 15 | 16 | 17 | class InputMonitor(Callback): 18 | """ 19 | Plots the histogram of input labels 20 | """ 21 | 22 | def __init__(self): 23 | pass 24 | 25 | def on_train_batch_start( 26 | self, 27 | trainer: pl.Trainer, 28 | pl_module: pl.LightningModule, 29 | batch: Tuple[dgl.DGLHeteroGraph, torch.Tensor], 30 | batch_idx: int, 31 | dataloader_idx: int 32 | ): 33 | samples, labels = batch 34 | trainer.logger.experiment.log({ 35 | 'train_data_histogram': wandb.Histogram(labels.detach().cpu().numpy()) 36 | }, commit=False) 37 | 38 | 39 | class BestModelTagger(Callback): 40 | """ 41 | Logs the "best_epoch" and the metric value corresponding to that to the logger 42 | Inspired from https://github.com/PyTorchLightning/pytorch-lightning/blob/master/pytorch_lightning/callbacks/early_stopping.py 43 | """ 44 | 45 | def __init__(self, monitor: str = 'val_loss', mode: str = 'min'): 46 | self.monitor = monitor 47 | if mode not in ['min', 'max']: 48 | raise ValueError(f"Invalid mode {mode}. Must be one of 'min' or 'max'") 49 | self.mode = mode 50 | self.monitor_op = torch.lt if mode == 'min' else torch.gt 51 | torch_inf = torch.tensor(np.Inf) 52 | self.best_score = torch_inf if self.monitor_op == torch.lt else -torch_inf 53 | 54 | def on_validation_end(self, trainer: pl.Trainer, pl_module: pl.LightningModule): 55 | logs = trainer.callback_metrics 56 | monitor_val = logs.get(self.monitor) 57 | if monitor_val is None: 58 | raise RuntimeError(f"{self.monitor} was supposed to be logged from model. Could not find that") 59 | if monitor_val is not None: 60 | if isinstance(monitor_val, Metric): 61 | monitor_val = monitor_val.compute() 62 | elif isinstance(monitor_val, numbers.Number): 63 | monitor_val = torch.tensor(monitor_val, device=pl_module.device, dtype=torch.float) 64 | if self.monitor_op(monitor_val, self.best_score): 65 | self.best_score = monitor_val 66 | trainer.logger.experiment.log({ 67 | f'{self.mode}_{self.monitor}': monitor_val.cpu().numpy(), 68 | 'best_epoch': trainer.current_epoch 69 | }, commit=False) 70 | 71 | 72 | class MetricsLogger(Callback): 73 | 74 | def __init__(self, stages: Union[List[str], str]): 75 | valid_stages = {'train', 'val', 'test'} 76 | if stages == 'all': 77 | self.stages = valid_stages 78 | else: 79 | for stage in stages: 80 | if stage not in valid_stages: 81 | raise ValueError(f"Stage {stage} is not valid. Must be one of {valid_stages}") 82 | self.stages = set(stages) & valid_stages 83 | 84 | @staticmethod 85 | def _plot_metrics(trainer: pl.Trainer, pl_module: MalwareDetector, stage: str): 86 | confusion_matrix = pl_module.test_outputs['confusion_matrix'].compute().cpu().numpy() 87 | plot_confusion_matrix( 88 | confusion_matrix, 89 | group_names=['TN', 'FP', 'FN', 'TP'], 90 | categories=['Benign', 'Malware'], 91 | cmap='binary' 92 | ) 93 | trainer.logger.experiment.log({ 94 | f'{stage}_confusion_matrix': wandb.Image(plt) 95 | }, commit=False) 96 | if stage != 'test': 97 | return 98 | roc = pl_module.test_outputs['roc'].compute() 99 | figure = plot_curve(roc[0].cpu(), roc[1].cpu(), 'roc') 100 | trainer.logger.experiment.log({ 101 | f'ROC': figure 102 | }, commit=False) 103 | prc = pl_module.test_outputs['prc'].compute() 104 | figure = plot_curve(prc[1].cpu(), prc[0].cpu(), 'prc') 105 | trainer.logger.experiment.log({ 106 | f'PRC': figure 107 | }, commit=False) 108 | 109 | @staticmethod 110 | def compute_metrics(pl_module: MalwareDetector, stage: str): 111 | metrics = {} 112 | if stage == 'train': 113 | metric_dict = pl_module.train_metrics 114 | elif stage == 'val': 115 | metric_dict = pl_module.val_metrics 116 | elif stage == 'test': 117 | metric_dict = pl_module.test_metrics 118 | else: 119 | raise ValueError(f"Invalid stage: {stage}") 120 | for metric_name, metric in metric_dict.items(): 121 | metrics[metric_name] = metric.compute() 122 | return metrics 123 | 124 | def on_train_epoch_end(self, trainer: pl.Trainer, pl_module: MalwareDetector, outputs): 125 | if 'train' not in self.stages: 126 | return 127 | trainer.logger.experiment.log(self.compute_metrics(pl_module, 'train'), commit=False) 128 | 129 | def on_validation_end(self, trainer: pl.Trainer, pl_module: MalwareDetector): 130 | if 'val' not in self.stages or trainer.running_sanity_check: 131 | return 132 | trainer.logger.experiment.log(self.compute_metrics(pl_module, 'val'), commit=False) 133 | 134 | def on_test_end(self, trainer: pl.Trainer, pl_module: MalwareDetector): 135 | if 'test' not in self.stages: 136 | return 137 | trainer.logger.experiment.log(self.compute_metrics(pl_module, 'test'), commit=False) 138 | self._plot_metrics(trainer, pl_module, 'test') 139 | -------------------------------------------------------------------------------- /core/data_module.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Dict, Tuple, Union 3 | 4 | import dgl 5 | import pytorch_lightning as pl 6 | import torch 7 | from sklearn.model_selection import StratifiedShuffleSplit 8 | from torch.utils.data import DataLoader 9 | 10 | from core.dataset import MalwareDataset 11 | 12 | 13 | def stratified_split_dataset(samples: List[str], 14 | labels: Dict[str, int], 15 | ratios: Tuple[float, float]) -> Tuple[List[str], List[str]]: 16 | """ 17 | Split the dataset into train and validation datasets based on the given ratio 18 | :param samples: List of file names 19 | :param labels: Mapping from file name to label 20 | :param ratios: Training ratio, validation ratio 21 | :return: List of file names in training and validation split 22 | """ 23 | if sum(ratios) != 1: 24 | raise Exception("Invalid ratios provided") 25 | train_ratio, val_ratio = ratios 26 | sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio, random_state=0) 27 | train_idx, val_idx = list(sss.split(samples, [labels[x] for x in samples]))[0] 28 | train_list = [samples[x] for x in train_idx] 29 | val_list = [samples[x] for x in val_idx] 30 | return train_list, val_list 31 | 32 | 33 | @torch.no_grad() 34 | def collate(samples: List[Tuple[dgl.DGLGraph, int]]) -> (dgl.DGLGraph, torch.Tensor): 35 | """ 36 | Batches several graphs into one 37 | :param samples: Tuple containing graph and its label 38 | :return: Batched graph, and labels concatenated into a tensor 39 | """ 40 | graphs, labels = map(list, zip(*samples)) 41 | batched_graph = dgl.batch(graphs) 42 | labels = torch.tensor(labels) 43 | return batched_graph, labels.float() 44 | 45 | 46 | class MalwareDataModule(pl.LightningDataModule): 47 | """ 48 | Handler class for data loading, splitting and initializing datasets and dataloaders. 49 | """ 50 | 51 | def __init__( 52 | self, 53 | train_dir: Union[str, Path], 54 | test_dir: Union[str, Path], 55 | batch_size: int, 56 | split_ratios: Tuple[float, float], 57 | consider_features: List[str], 58 | num_workers: int, 59 | pin_memory: bool, 60 | split_train_val: bool, 61 | ): 62 | """ 63 | Creates the MalwareDataModule 64 | :param train_dir: The directory containing the training samples 65 | :param test_dir: The directory containing the testing samples 66 | :param batch_size: Number of graphs in a batch 67 | :param split_ratios: Tuple containing training and validation split 68 | :param consider_features: Features types to consider 69 | :param num_workers: Number of processes to 70 | :param pin_memory: If True, said to be speeding up GPU data transfer 71 | :param split_train_val: If true, split the train dataset into train and validation, 72 | else use test dataset for validation 73 | """ 74 | super().__init__() 75 | self.train_dir = Path(train_dir) 76 | if not self.train_dir.exists(): 77 | raise FileNotFoundError(f"Train directory {train_dir} does not exist. Could not read from it.") 78 | self.test_dir = Path(test_dir) 79 | if not self.test_dir.exists(): 80 | raise FileNotFoundError(f"Test directory {test_dir} does not exist. Could not read from it.") 81 | self.dataloader_kwargs = { 82 | 'num_workers': num_workers, 83 | 'batch_size': batch_size, 84 | 'pin_memory': pin_memory, 85 | 'collate_fn': collate, 86 | 'drop_last': True 87 | } 88 | self.split_ratios = split_ratios 89 | self.split = split_train_val 90 | self.splitter = stratified_split_dataset 91 | self.consider_features = consider_features 92 | 93 | @staticmethod 94 | def get_samples(path: Union[str, Path]) -> Tuple[List[str], Dict[str, int]]: 95 | """ 96 | Get samples and labels from the given path 97 | :param path: The directory containing graphs 98 | :return: The file list, and their label mapping 99 | """ 100 | base_path = Path(path) 101 | if not base_path.exists(): 102 | raise FileNotFoundError(f'{base_path} does not exist') 103 | apk_list = sorted([x for x in base_path.iterdir()]) 104 | samples = [] 105 | labels = {} 106 | for apk in apk_list: 107 | samples.append(apk.name) 108 | labels[apk.name] = int("Benig" not in apk.name) 109 | return samples, labels 110 | 111 | def setup(self, stage=None): 112 | samples, labels = self.get_samples(self.train_dir) 113 | test_samples, test_labels = self.get_samples(self.test_dir) 114 | if self.split: 115 | train_samples, val_samples = self.splitter(samples, labels, self.split_ratios) 116 | val_dir = self.train_dir 117 | val_labels = labels 118 | else: 119 | train_samples = samples 120 | val_dir = self.test_dir 121 | val_samples, val_labels = test_samples, test_labels 122 | self.train_dataset = MalwareDataset( 123 | source_dir=self.train_dir, 124 | samples=train_samples, 125 | labels=labels, 126 | consider_features=self.consider_features 127 | ) 128 | self.val_dataset = MalwareDataset( 129 | source_dir=val_dir, 130 | samples=val_samples, 131 | labels=val_labels, 132 | consider_features=self.consider_features 133 | ) 134 | self.test_dataset = MalwareDataset( 135 | source_dir=self.test_dir, 136 | samples=test_samples, 137 | labels=test_labels, 138 | consider_features=self.consider_features 139 | ) 140 | 141 | def train_dataloader(self): 142 | return DataLoader(self.train_dataset, shuffle=True, **self.dataloader_kwargs) 143 | 144 | def val_dataloader(self): 145 | return DataLoader(self.val_dataset, **self.dataloader_kwargs) 146 | 147 | def test_dataloader(self): 148 | return DataLoader(self.test_dataset, **self.dataloader_kwargs) 149 | -------------------------------------------------------------------------------- /core/dataset.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import List, Dict, Tuple, Union 3 | 4 | import dgl 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | attributes = {'external', 'entrypoint', 'native', 'public', 'static', 'codesize'} 9 | 10 | 11 | class MalwareDataset(Dataset): 12 | def __init__( 13 | self, 14 | source_dir: Union[str, Path], 15 | samples: List[str], 16 | labels: Dict[str, int], 17 | consider_features: List[str], 18 | ): 19 | self.source_dir = Path(source_dir) 20 | self.samples = samples 21 | self.labels = labels 22 | self.consider_features = consider_features 23 | 24 | def __len__(self) -> int: 25 | """Denotes the total number of samples""" 26 | return len(self.samples) 27 | 28 | @staticmethod 29 | def _process_node_attributes(g: dgl.DGLGraph): 30 | for attribute in attributes & set(g.ndata.keys()): 31 | g.ndata[attribute] = g.ndata[attribute].view(-1, 1) 32 | return g 33 | 34 | def __getitem__(self, index: int) -> Tuple[dgl.DGLGraph, int]: 35 | """Generates one sample of data""" 36 | name = self.samples[index] 37 | graphs, _ = dgl.data.utils.load_graphs(str(self.source_dir / name)) 38 | graph: dgl.DGLGraph = dgl.add_self_loop(graphs[0]) 39 | g = self._process_node_attributes(graph) 40 | if len(g.ndata.keys()) > 0: 41 | features = torch.cat([g.ndata[x] for x in self.consider_features], dim=1).float() 42 | else: 43 | features = (g.in_degrees() + g.out_degrees()).view(-1, 1).float() 44 | g.ndata.clear() 45 | g.ndata['features'] = features 46 | return g, self.labels[name] 47 | -------------------------------------------------------------------------------- /core/model.py: -------------------------------------------------------------------------------- 1 | from typing import Mapping 2 | from typing import Tuple, Optional, Dict 3 | 4 | import dgl 5 | import dgl.nn.pytorch as graph_nn 6 | import pytorch_lightning as pl 7 | import pytorch_lightning.metrics as metrics 8 | import torch 9 | import torch.nn.functional as F 10 | from dgl.nn import Sequential 11 | from pytorch_lightning.metrics import Metric 12 | from torch import nn 13 | 14 | 15 | class MalwareDetector(pl.LightningModule): 16 | def __init__( 17 | self, 18 | input_dimension: int, 19 | convolution_algorithm: str, 20 | convolution_count: int, 21 | ): 22 | super().__init__() 23 | supported_algorithms = ['GraphConv', 'SAGEConv', 'TAGConv', 'DotGatConv'] 24 | if convolution_algorithm not in supported_algorithms: 25 | raise ValueError( 26 | f"{convolution_algorithm} is not supported. Supported algorithms are {supported_algorithms}") 27 | self.save_hyperparameters() 28 | self.convolution_layers = [] 29 | convolution_dimensions = [64, 32, 16] 30 | for dimension in convolution_dimensions[:convolution_count]: 31 | self.convolution_layers.append(self._get_convolution_layer( 32 | name=convolution_algorithm, 33 | input_dimension=input_dimension, 34 | output_dimension=dimension 35 | )) 36 | input_dimension = dimension 37 | self.convolution_layers = Sequential(*self.convolution_layers) 38 | self.last_dimension = input_dimension 39 | self.classify = nn.Linear(input_dimension, 1) 40 | # Metrics 41 | self.loss_func = nn.BCEWithLogitsLoss() 42 | self.train_metrics = self._get_metric_dict('train') 43 | self.val_metrics = self._get_metric_dict('val') 44 | self.test_metrics = self._get_metric_dict('test') 45 | self.test_outputs = nn.ModuleDict({ 46 | 'confusion_matrix': metrics.ConfusionMatrix(num_classes=2), 47 | 'prc': metrics.PrecisionRecallCurve(compute_on_step=False), 48 | 'roc': metrics.ROC(compute_on_step=False) 49 | }) 50 | 51 | @staticmethod 52 | def _get_convolution_layer( 53 | name: str, 54 | input_dimension: int, 55 | output_dimension: int 56 | ) -> Optional[nn.Module]: 57 | return { 58 | "GraphConv": graph_nn.GraphConv( 59 | input_dimension, 60 | output_dimension, 61 | activation=F.relu 62 | ), 63 | "SAGEConv": graph_nn.SAGEConv( 64 | input_dimension, 65 | output_dimension, 66 | activation=F.relu, 67 | aggregator_type='mean', 68 | norm=F.normalize 69 | ), 70 | "DotGatConv": graph_nn.DotGatConv( 71 | input_dimension, 72 | output_dimension, 73 | num_heads=1 74 | ), 75 | "TAGConv": graph_nn.TAGConv( 76 | input_dimension, 77 | output_dimension, 78 | k=4 79 | ) 80 | }.get(name, None) 81 | 82 | @staticmethod 83 | def _get_metric_dict(stage: str) -> Mapping[str, Metric]: 84 | return nn.ModuleDict({ 85 | f'{stage}_accuracy': metrics.Accuracy(), 86 | f'{stage}_precision': metrics.Precision(num_classes=1), 87 | f'{stage}_recall': metrics.Recall(num_classes=1), 88 | f'{stage}_f1': metrics.FBeta(num_classes=1) 89 | }) 90 | 91 | def forward(self, g: dgl.DGLGraph) -> torch.Tensor: 92 | with g.local_scope(): 93 | h = g.ndata['features'] 94 | h = self.convolution_layers(g, h) 95 | g.ndata['h'] = h if len(self.convolution_layers) > 0 else h[0] 96 | # Calculate graph representation by averaging all the node representations. 97 | hg = dgl.mean_nodes(g, 'h') 98 | return self.classify(hg).squeeze() 99 | 100 | def training_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int) -> torch.Tensor: 101 | bg, label = batch 102 | logits = self.forward(bg) 103 | loss = self.loss_func(logits, label) 104 | prediction = torch.sigmoid(logits) 105 | for metric_name, metric in self.train_metrics.items(): 106 | metric.update(prediction, label) 107 | self.log('train_loss', loss, on_step=True, on_epoch=True) 108 | return loss 109 | 110 | def validation_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int): 111 | bg, label = batch 112 | logits = self.forward(bg) 113 | loss = self.loss_func(logits, label) 114 | prediction = torch.sigmoid(logits) 115 | for metric_name, metric in self.val_metrics.items(): 116 | metric.update(prediction, label) 117 | self.log('val_loss', loss, on_step=False, on_epoch=True) 118 | return loss 119 | 120 | def test_step(self, batch: Tuple[dgl.DGLGraph, torch.Tensor], batch_idx: int): 121 | bg, label = batch 122 | logits = self.forward(bg) 123 | prediction = torch.sigmoid(logits) 124 | loss = self.loss_func(logits, label) 125 | for metric_name, metric in self.test_metrics.items(): 126 | metric.update(prediction, label) 127 | for metric_name, metric in self.test_outputs.items(): 128 | metric.update(prediction, label) 129 | self.log('test_loss', loss, on_step=False, on_epoch=True) 130 | return loss 131 | 132 | def configure_optimizers(self) -> torch.optim.Adam: 133 | optimizer = torch.optim.Adam(self.parameters(), lr=1e-3) 134 | return optimizer 135 | -------------------------------------------------------------------------------- /core/utils.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | import numpy as np 3 | import plotly.express as px 4 | import pytorch_lightning.metrics.functional as M 5 | import seaborn as sns 6 | 7 | 8 | def plot_curve(x, y, curve_type): 9 | """ 10 | Plots ROC or PRC 11 | Inspired from https://plotly.com/python/roc-and-pr-curves/ 12 | :param x: The x co-ordinates 13 | :param y: The y co-ordinates 14 | :param curve_type: one of 'roc' or 'prc' 15 | :return: Plotly figure 16 | """ 17 | auc = M.classification.auc(x, y) 18 | x, y = x.numpy(), y.numpy() 19 | if curve_type == 'roc': 20 | title = f"ROC, AUC = {auc}" 21 | labels = dict(x='FPR', y='TPR') 22 | elif curve_type == 'prc': 23 | title = f"PRC, mAP = {auc}" 24 | labels = dict(x='Recall', y='Precision') 25 | else: 26 | raise ValueError(f"Invalid curve type - {curve_type}. Must be one of 'roc' or 'prc'.") 27 | fig = px.area(x=x, y=y, labels=labels, title=title) 28 | fig.update_yaxes(scaleanchor="x", scaleratio=1) 29 | fig.update_xaxes(constrain='domain') 30 | return fig 31 | 32 | 33 | def plot_confusion_matrix(cf, 34 | group_names=None, 35 | categories='auto', 36 | count=True, 37 | percent=True, 38 | cbar=True, 39 | xyticks=True, 40 | xyplotlabels=True, 41 | sum_stats=True, 42 | fig_size=None, 43 | cmap='Blues', 44 | title=None): 45 | ''' 46 | From https://github.com/DTrimarchi10/confusion_matrix/blob/master/cf_matrix.py 47 | Blog https://medium.com/@dtuk81/confusion-matrix-visualization-fc31e3f30fea 48 | This function will make a pretty plot of an sklearn Confusion Matrix cm using a Seaborn heatmap visualization. 49 | Arguments 50 | --------- 51 | cf: confusion matrix to be passed in 52 | group_names: List of strings that represent the labels row by row to be shown in each square. 53 | categories: List of strings containing the categories to be displayed on the x,y axis. Default is 'auto' 54 | count: If True, show the raw number in the confusion matrix. Default is True. 55 | normalize: If True, show the proportions for each category. Default is True. 56 | cbar: If True, show the color bar. The cbar values are based off the values in the confusion matrix. 57 | Default is True. 58 | xyticks: If True, show x and y ticks. Default is True. 59 | xyplotlabels: If True, show 'True Label' and 'Predicted Label' on the figure. Default is True. 60 | sum_stats: If True, display summary statistics below the figure. Default is True. 61 | fig_size: Tuple representing the figure size. Default will be the matplotlib rcParams value. 62 | cmap: Colormap of the values displayed from matplotlib.pyplot.cm. Default is 'Blues' 63 | See http://matplotlib.org/examples/color/colormaps_reference.html 64 | title: Title for the heatmap. Default is None. 65 | ''' 66 | plt.clf() 67 | # CODE TO GENERATE TEXT INSIDE EACH SQUARE 68 | blanks = ['' for i in range(cf.size)] 69 | 70 | if group_names and len(group_names) == cf.size: 71 | group_labels = ["{}\n".format(value) for value in group_names] 72 | else: 73 | group_labels = blanks 74 | 75 | if count: 76 | group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()] 77 | else: 78 | group_counts = blanks 79 | 80 | if percent: 81 | group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)] 82 | else: 83 | group_percentages = blanks 84 | 85 | box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)] 86 | box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1]) 87 | 88 | # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS 89 | if sum_stats: 90 | # Accuracy is sum of diagonal divided by total observations 91 | accuracy = np.trace(cf) / float(np.sum(cf)) 92 | 93 | # if it is a binary confusion matrix, show some more stats 94 | if len(cf) == 2: 95 | # Metrics for Binary Confusion Matrices 96 | precision = cf[1, 1] / sum(cf[:, 1]) 97 | recall = cf[1, 1] / sum(cf[1, :]) 98 | f1_score = 2 * precision * recall / (precision + recall) 99 | stats_text = "\n\nAccuracy={:0.4f}\nPrecision={:0.4f}\nRecall={:0.4f}\nF1 Score={:0.4f}".format( 100 | accuracy, precision, recall, f1_score) 101 | else: 102 | stats_text = "\n\nAccuracy={:0.3f}".format(accuracy) 103 | else: 104 | stats_text = "" 105 | 106 | # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS 107 | if fig_size is None: 108 | # Get default figure size if not set 109 | fig_size = plt.rcParams.get('figure.figsize') 110 | 111 | if not xyticks: 112 | # Do not show categories if xyticks is False 113 | categories = False 114 | 115 | # MAKE THE HEATMAP VISUALIZATION 116 | plt.figure(figsize=fig_size) 117 | sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, xticklabels=categories, yticklabels=categories) 118 | 119 | if xyplotlabels: 120 | plt.ylabel('True label') 121 | plt.xlabel('Predicted label' + stats_text) 122 | else: 123 | plt.xlabel(stats_text) 124 | 125 | if title: 126 | plt.title(title) 127 | 128 | plt.tight_layout() 129 | plt.savefig("CM.png") 130 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Datasets 2 | 3 | To conduct the experiments, 2 datasets were used. 4 | 5 | 1. [`MalDroid2020`](https://www.unb.ca/cic/datasets/maldroid-2020.html) 6 | 2. [`AndroZoo`](https://androzoo.uni.lu/) 7 | 8 | While `MalDroid2020` was used as a base, `AndroZoo` was used to collect new benign APKs. 9 | 10 | ## Format 11 | This directory contains the hashes of APKs in `*.sha256` files. 12 | Each line in the file consists of `sha256 name` pairs. 13 | There are 4 such files 14 | 15 | 1. `train_old.sha256` - The training samples from `MalDroid2020` 16 | 2. `train_new.sha256` - The training samples from `AndroZoo` 17 | 3. `test_old.sha256` - The testing samples from `MalDroid2020` 18 | 4. `test_new.sha256` - The testing samples from `AndroZoo` -------------------------------------------------------------------------------- /malware-learning.def: -------------------------------------------------------------------------------- 1 | Bootstrap: docker 2 | From: pytorch/pytorch:1.7.0-cuda11.0-cudnn8-devel 3 | 4 | %files 5 | requirements.txt /mnt/requirements.txt 6 | dgl-0.6-cp38-cp38-linux_x86_64.whl /mnt/dgl-0.6-cp38-cp38-linux_x86_64.whl 7 | 8 | %post 9 | apt-get update && apt-get install -y git graphviz graphviz-dev && rm -rf /var/lib/apt/lists/* 10 | pip install -r /mnt/requirements.txt 11 | pip install /mnt/dgl-0.6-cp38-cp38-linux_x86_64.whl 12 | cd /mnt 13 | git clone --depth 1 https://github.com/androguard/androguard.git && cd androguard && python setup.py install -------------------------------------------------------------------------------- /metadata/api.list: -------------------------------------------------------------------------------- 1 | android 2 | android.accessibilityservice 3 | android.accounts 4 | android.animation 5 | android.annotation 6 | android.app 7 | android.app.admin 8 | android.app.assist 9 | android.app.backup 10 | android.app.blob 11 | android.app.job 12 | android.app.role 13 | android.app.slice 14 | android.app.usage 15 | android.appwidget 16 | android.bluetooth 17 | android.bluetooth.le 18 | android.companion 19 | android.content 20 | android.content.pm 21 | android.content.res 22 | android.content.res.loader 23 | android.database 24 | android.database.sqlite 25 | android.drm 26 | android.gesture 27 | android.graphics 28 | android.graphics.drawable 29 | android.graphics.drawable.shapes 30 | android.graphics.fonts 31 | android.graphics.pdf 32 | android.graphics.text 33 | android.hardware 34 | android.hardware.biometrics 35 | android.hardware.camera2 36 | android.hardware.camera2.params 37 | android.hardware.display 38 | android.hardware.fingerprint 39 | android.hardware.input 40 | android.hardware.usb 41 | android.icu.lang 42 | android.icu.math 43 | android.icu.number 44 | android.icu.text 45 | android.icu.util 46 | android.inputmethodservice 47 | android.location 48 | android.media 49 | android.media.audiofx 50 | android.media.browse 51 | android.media.effect 52 | android.media.midi 53 | android.media.projection 54 | android.media.session 55 | android.media.tv 56 | android.mtp 57 | android.net 58 | android.net.http 59 | android.net.nsd 60 | android.net.rtp 61 | android.net.sip 62 | android.net.ssl 63 | android.net.wifi 64 | android.net.wifi.aware 65 | android.net.wifi.hotspot2 66 | android.net.wifi.hotspot2.omadm 67 | android.net.wifi.hotspot2.pps 68 | android.net.wifi.p2p 69 | android.net.wifi.p2p.nsd 70 | android.net.wifi.rtt 71 | android.nfc 72 | android.nfc.cardemulation 73 | android.nfc.tech 74 | android.opengl 75 | android.os 76 | android.os.health 77 | android.os.storage 78 | android.os.strictmode 79 | android.preference 80 | android.print 81 | android.print.pdf 82 | android.printservice 83 | android.provider 84 | android.renderscript 85 | android.sax 86 | android.se.omapi 87 | android.security 88 | android.security.identity 89 | android.security.keystore 90 | android.service.autofill 91 | android.service.carrier 92 | android.service.chooser 93 | android.service.controls 94 | android.service.controls.actions 95 | android.service.controls.templates 96 | android.service.dreams 97 | android.service.media 98 | android.service.notification 99 | android.service.quickaccesswallet 100 | android.service.quicksettings 101 | android.service.restrictions 102 | android.service.textservice 103 | android.service.voice 104 | android.service.vr 105 | android.service.wallpaper 106 | android.speech 107 | android.speech.tts 108 | android.system 109 | android.telecom 110 | android.telephony 111 | android.telephony.cdma 112 | android.telephony.data 113 | android.telephony.emergency 114 | android.telephony.euicc 115 | android.telephony.gsm 116 | android.telephony.ims 117 | android.telephony.ims.feature 118 | android.telephony.mbms 119 | android.test 120 | android.test.mock 121 | android.test.suitebuilder 122 | android.test.suitebuilder.annotation 123 | android.text 124 | android.text.format 125 | android.text.method 126 | android.text.style 127 | android.text.util 128 | android.transition 129 | android.util 130 | android.util.proto 131 | android.view 132 | android.view.accessibility 133 | android.view.animation 134 | android.view.autofill 135 | android.view.contentcapture 136 | android.view.inputmethod 137 | android.view.inspector 138 | android.view.textclassifier 139 | android.view.textservice 140 | android.webkit 141 | android.widget 142 | android.widget.inline 143 | com.google.android.collect 144 | com.google.android.gles_jni 145 | com.google.android.util 146 | dalvik.annotation 147 | dalvik.bytecode 148 | dalvik.system 149 | java.awt.font 150 | java.beans 151 | java.io 152 | java.lang 153 | java.lang.annotation 154 | java.lang.invoke 155 | java.lang.ref 156 | java.lang.reflect 157 | java.math 158 | java.net 159 | java.nio 160 | java.nio.channels 161 | java.nio.channels.spi 162 | java.nio.charset 163 | java.nio.charset.spi 164 | java.nio.file 165 | java.nio.file.attribute 166 | java.nio.file.spi 167 | java.security 168 | java.security.acl 169 | java.security.cert 170 | java.security.interfaces 171 | java.security.spec 172 | java.sql 173 | java.text 174 | java.time 175 | java.time.chrono 176 | java.time.format 177 | java.time.temporal 178 | java.time.zone 179 | java.util 180 | java.util.concurrent 181 | java.util.concurrent.atomic 182 | java.util.concurrent.locks 183 | java.util.function 184 | java.util.jar 185 | java.util.logging 186 | java.util.prefs 187 | java.util.regex 188 | java.util.stream 189 | java.util.zip 190 | javax.crypto 191 | javax.crypto.interfaces 192 | javax.crypto.spec 193 | javax.microedition.khronos.egl 194 | javax.microedition.khronos.opengles 195 | javax.net 196 | javax.net.ssl 197 | javax.security.auth 198 | javax.security.auth.callback 199 | javax.security.auth.login 200 | javax.security.auth.x500 201 | javax.security.cert 202 | javax.sql 203 | javax.xml 204 | javax.xml.datatype 205 | javax.xml.namespace 206 | javax.xml.parsers 207 | javax.xml.transform 208 | javax.xml.transform.dom 209 | javax.xml.transform.sax 210 | javax.xml.transform.stream 211 | javax.xml.validation 212 | javax.xml.xpath 213 | junit.framework 214 | junit.runner 215 | org.apache.http.conn 216 | org.apache.http.conn.scheme 217 | org.apache.http.conn.ssl 218 | org.apache.http.params 219 | org.json 220 | org.w3c.dom 221 | org.w3c.dom.ls 222 | org.xml.sax 223 | org.xml.sax.ext 224 | org.xml.sax.helpers 225 | org.xmlpull.v1 226 | org.xmlpull.v1.sax2 -------------------------------------------------------------------------------- /notebooks/2-GFeatures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using backend: pytorch\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import re\n", 18 | "import dgl\n", 19 | "import torch\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "\n", 22 | "import torch.nn as nn\n", 23 | "import torch.nn.functional as F\n", 24 | "import networkx as nx\n", 25 | "\n", 26 | "from pathlib import Path\n", 27 | "from androguard.misc import AnalyzeAPK\n", 28 | "import pickle\n", 29 | "import pytorch_lightning as pl\n", 30 | "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", 31 | "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", 32 | "import sklearn.metrics as M\n", 33 | "\n", 34 | "from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv\n", 35 | "from sklearn.model_selection import StratifiedShuffleSplit\n", 36 | "\n", 37 | "import joblib as J" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "#%xmode verbose" 47 | ] 48 | }, 49 | { 50 | "cell_type": "markdown", 51 | "metadata": {}, 52 | "source": [ 53 | "## Params" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": 3, 59 | "metadata": {}, 60 | "outputs": [], 61 | "source": [ 62 | "model_kwargs = {'in_dim': 15, 'hidden_dim': 30, 'n_classes': 5 }" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 4, 68 | "metadata": {}, 69 | "outputs": [], 70 | "source": [ 71 | "train = False" 72 | ] 73 | }, 74 | { 75 | "cell_type": "code", 76 | "execution_count": 5, 77 | "metadata": {}, 78 | "outputs": [], 79 | "source": [ 80 | "extract = False" 81 | ] 82 | }, 83 | { 84 | "cell_type": "markdown", 85 | "metadata": {}, 86 | "source": [ 87 | "## Dataset" 88 | ] 89 | }, 90 | { 91 | "cell_type": "code", 92 | "execution_count": 6, 93 | "metadata": {}, 94 | "outputs": [], 95 | "source": [ 96 | "def get_samples(base_path):\n", 97 | " base_path = Path(base_path)\n", 98 | " labels_dict = {x:i for i,x in enumerate(sorted([\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"]))}\n", 99 | " if not base_path.exists():\n", 100 | " raise Exception(f'{base_path} does not exist')\n", 101 | " apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])\n", 102 | " samples = []\n", 103 | " labels = {}\n", 104 | " for apk in apk_list:\n", 105 | " samples.append(apk.name)\n", 106 | " labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]\n", 107 | " return samples, labels" 108 | ] 109 | }, 110 | { 111 | "cell_type": "code", 112 | "execution_count": 7, 113 | "metadata": {}, 114 | "outputs": [], 115 | "source": [ 116 | "samples, labels = get_samples('../data/large/raw')" 117 | ] 118 | }, 119 | { 120 | "cell_type": "code", 121 | "execution_count": 8, 122 | "metadata": {}, 123 | "outputs": [ 124 | { 125 | "data": { 126 | "text/plain": [ 127 | "'Adware0000.apk'" 128 | ] 129 | }, 130 | "execution_count": 8, 131 | "metadata": {}, 132 | "output_type": "execute_result" 133 | } 134 | ], 135 | "source": [ 136 | "samples[0]" 137 | ] 138 | }, 139 | { 140 | "cell_type": "code", 141 | "execution_count": 9, 142 | "metadata": {}, 143 | "outputs": [], 144 | "source": [ 145 | "raw_prefix = Path('../data/large/raw')\n", 146 | "processed_prefix = Path('../data/large/G-feat')" 147 | ] 148 | }, 149 | { 150 | "cell_type": "code", 151 | "execution_count": 10, 152 | "metadata": {}, 153 | "outputs": [], 154 | "source": [ 155 | "def process(file):\n", 156 | " _, _, dx = AnalyzeAPK(raw_prefix/file)\n", 157 | " cg = dx.get_call_graph()\n", 158 | " opcodes = {}\n", 159 | " for node in cg.nodes():\n", 160 | " sequence = [0] * 15\n", 161 | " if not node.is_external():\n", 162 | " for instr in node.get_method().get_instructions():\n", 163 | " value = instr.get_op_value()\n", 164 | " if value == 0x00: # nop\n", 165 | " sequence[0] = 1\n", 166 | " elif value >= 0x01 and value <= 0x0D: # mov\n", 167 | " sequence[1] = 1\n", 168 | " elif value >= 0x0E and value <= 0x11: # return\n", 169 | " sequence[2] = 1\n", 170 | " elif value == 0x1D or value == 0x1E: # monitor\n", 171 | " sequence[3] = 1\n", 172 | " elif value >= 0x32 and value <= 0x3D: # if\n", 173 | " sequence[4] = 1\n", 174 | " elif value == 0x27: # throw\n", 175 | " sequence[5] = 1\n", 176 | " elif value == 0x28 or value == 0x29: #goto\n", 177 | " sequence[6] = 1\n", 178 | " elif value >= 0x2F and value <= 0x31: # compare\n", 179 | " sequence[7] = 1\n", 180 | " elif value >= 0x7F and value <= 0x8F: # unop\n", 181 | " sequence[8] = 1\n", 182 | " elif value >=90 and value <= 0xE2: # binop\n", 183 | " sequence[9] = 1\n", 184 | " elif value == 0x21 or (value >= 0x23 and value <= 0x26) or (value >= 0x44 and value <= 0x51): # aop\n", 185 | " sequence[10] = 1\n", 186 | " elif (value >= 0x52 and value <= 0x5F) or (value >= 0xF2 and value <= 0xF7): # instanceop\n", 187 | " sequence[11] = 1\n", 188 | " elif (value >= 0x60 and value <= 0x6D): # staticop\n", 189 | " sequence[12] = 1\n", 190 | " elif (value >= 0x6E and value <= 0x72) and (value >= 0x74 and value <= 0x78) and (value >= 0xF9 and value <= 0xFB):\n", 191 | " sequence[13] = 1\n", 192 | " elif (value >= 0x22 and value <= 0x25):\n", 193 | " sequence[14] = 1\n", 194 | " opcodes[node] = {'sequence': sequence}\n", 195 | " nx.set_node_attributes(cg, opcodes)\n", 196 | " labels = {x: {'name': x.full_name} for x in cg.nodes()}\n", 197 | " nx.set_node_attributes(cg, labels)\n", 198 | " cg = nx.convert_node_labels_to_integers(cg)\n", 199 | " torch.save(cg, processed_prefix/ (file.split('.')[0]+'.graph'))" 200 | ] 201 | }, 202 | { 203 | "cell_type": "code", 204 | "execution_count": 11, 205 | "metadata": {}, 206 | "outputs": [], 207 | "source": [ 208 | "if extract:\n", 209 | " J.Parallel(n_jobs=40)(J.delayed(process)(x) for x in samples);" 210 | ] 211 | }, 212 | { 213 | "cell_type": "code", 214 | "execution_count": 12, 215 | "metadata": {}, 216 | "outputs": [], 217 | "source": [ 218 | "class MalwareDataset(torch.utils.data.Dataset):\n", 219 | " def __init__(self, save_dir, list_IDs, labels):\n", 220 | " self.save_dir = Path(save_dir)\n", 221 | " self.list_IDs = list_IDs\n", 222 | " self.labels = labels\n", 223 | " self.cache = {}\n", 224 | "\n", 225 | " def __len__(self):\n", 226 | " 'Denotes the total number of samples'\n", 227 | " return len(self.list_IDs)\n", 228 | "\n", 229 | " def __getitem__(self, index):\n", 230 | " 'Generates one sample of data'\n", 231 | " # Select sample\n", 232 | " if index not in self.cache:\n", 233 | " ID = self.list_IDs[index]\n", 234 | " graph_path = self.save_dir / (ID.split('.')[0] + '.graph')\n", 235 | " cg = torch.load(graph_path)\n", 236 | " dg = dgl.from_networkx(cg, node_attrs=['sequence'], edge_attrs=['offset'])\n", 237 | " dg = dgl.add_self_loop(dg)\n", 238 | " self.cache[index] = (dg, self.labels[ID])\n", 239 | " return self.cache[index]" 240 | ] 241 | }, 242 | { 243 | "cell_type": "markdown", 244 | "metadata": {}, 245 | "source": [ 246 | "## Data Loading" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": 13, 252 | "metadata": {}, 253 | "outputs": [], 254 | "source": [ 255 | "def split_dataset(samples, labels, ratios):\n", 256 | " if sum(ratios) != 1:\n", 257 | " raise Exception(\"Invalid ratios provided\")\n", 258 | " train_ratio, val_ratio, test_ratio = ratios\n", 259 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)\n", 260 | " train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]\n", 261 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)\n", 262 | " test_list = [samples[x] for x in test_idx]\n", 263 | " train_list = [samples[x] for x in train_idx]\n", 264 | " train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]\n", 265 | " train_list = [samples[x] for x in train_idx]\n", 266 | " val_list = [samples[x] for x in val_idx]\n", 267 | " return train_list, val_list, test_list" 268 | ] 269 | }, 270 | { 271 | "cell_type": "code", 272 | "execution_count": 14, 273 | "metadata": {}, 274 | "outputs": [], 275 | "source": [ 276 | "train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])" 277 | ] 278 | }, 279 | { 280 | "cell_type": "code", 281 | "execution_count": 15, 282 | "metadata": {}, 283 | "outputs": [ 284 | { 285 | "data": { 286 | "text/plain": [ 287 | "tensor([0.6000, 0.2000, 0.2000])" 288 | ] 289 | }, 290 | "execution_count": 15, 291 | "metadata": {}, 292 | "output_type": "execute_result" 293 | } 294 | ], 295 | "source": [ 296 | "torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": 16, 302 | "metadata": {}, 303 | "outputs": [], 304 | "source": [ 305 | "def collate(samples):\n", 306 | " graphs, labels = [], []\n", 307 | " for graph, label in samples:\n", 308 | " graphs.append(graph)\n", 309 | " labels.append(label)\n", 310 | " batched_graph = dgl.batch(graphs)\n", 311 | " return batched_graph, torch.tensor(labels)" 312 | ] 313 | }, 314 | { 315 | "cell_type": "code", 316 | "execution_count": 17, 317 | "metadata": {}, 318 | "outputs": [], 319 | "source": [ 320 | "train_dataset = MalwareDataset(processed_prefix , train_list, labels)\n", 321 | "val_dataset = MalwareDataset(processed_prefix , val_list, labels)\n", 322 | "test_dataset = MalwareDataset(processed_prefix , test_list, labels)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 18, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "train_data = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate, num_workers=8)\n", 332 | "val_data = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate , num_workers=40)\n", 333 | "test_data = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate, num_workers=4)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 19, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "data": { 343 | "text/plain": [ 344 | "0" 345 | ] 346 | }, 347 | "execution_count": 19, 348 | "metadata": {}, 349 | "output_type": "execute_result" 350 | } 351 | ], 352 | "source": [ 353 | "len(test_dataset.cache)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "markdown", 358 | "metadata": {}, 359 | "source": [ 360 | "## Model" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": 20, 366 | "metadata": {}, 367 | "outputs": [], 368 | "source": [ 369 | "class MalwareClassifier(pl.LightningModule):\n", 370 | " def __init__(self, in_dim, hidden_dim, n_classes):\n", 371 | " super().__init__()\n", 372 | " self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')\n", 373 | " self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')\n", 374 | " self.classify = nn.Linear(hidden_dim, n_classes)\n", 375 | " self.loss_func = nn.CrossEntropyLoss()\n", 376 | " \n", 377 | " \n", 378 | " def forward(self, g):\n", 379 | " h = g.ndata['sequence'].float()\n", 380 | " #h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)\n", 381 | " # h = g.in_degrees().view(-1,1).float()\n", 382 | " # Perform graph convolution and activation function.\n", 383 | " h = F.relu(self.conv1(g, h))\n", 384 | " h = F.relu(self.conv2(g, h))\n", 385 | " g.ndata['h'] = h\n", 386 | " # Calculate graph representation by averaging all the node representations.\n", 387 | " hg = dgl.mean_nodes(g, 'h')\n", 388 | " return self.classify(hg) \n", 389 | " \n", 390 | " def training_step(self, batch, batch_idx):\n", 391 | " bg, label = batch\n", 392 | " #print(\"Outer\", len(label))\n", 393 | " prediction = self.forward(bg)\n", 394 | " loss = self.loss_func(prediction, label)\n", 395 | " return loss\n", 396 | " \n", 397 | " def validation_step(self, batch, batch_idx):\n", 398 | " bg, label = batch\n", 399 | " prediction = self.forward(bg)\n", 400 | " loss = self.loss_func(prediction, label)\n", 401 | " self.log('val_loss', loss)\n", 402 | " \n", 403 | " def configure_optimizers(self):\n", 404 | " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", 405 | " return optimizer" 406 | ] 407 | }, 408 | { 409 | "cell_type": "code", 410 | "execution_count": 21, 411 | "metadata": {}, 412 | "outputs": [], 413 | "source": [ 414 | "callbacks = [\n", 415 | " EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),\n", 416 | "]" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 22, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "checkpointer = ModelCheckpoint(filepath='../models/3Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min')" 426 | ] 427 | }, 428 | { 429 | "cell_type": "code", 430 | "execution_count": 23, 431 | "metadata": {}, 432 | "outputs": [ 433 | { 434 | "name": "stderr", 435 | "output_type": "stream", 436 | "text": [ 437 | "GPU available: True, used: True\n", 438 | "TPU available: False, using: 0 TPU cores\n", 439 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [2]\n" 440 | ] 441 | } 442 | ], 443 | "source": [ 444 | "classifier= MalwareClassifier(**model_kwargs)\n", 445 | "trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, gpus=[2])" 446 | ] 447 | }, 448 | { 449 | "cell_type": "code", 450 | "execution_count": 24, 451 | "metadata": {}, 452 | "outputs": [ 453 | { 454 | "data": { 455 | "text/plain": [ 456 | "False" 457 | ] 458 | }, 459 | "execution_count": 24, 460 | "metadata": {}, 461 | "output_type": "execute_result" 462 | } 463 | ], 464 | "source": [ 465 | "train" 466 | ] 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 25, 471 | "metadata": {}, 472 | "outputs": [], 473 | "source": [ 474 | "if train:\n", 475 | " trainer.fit(classifier, train_data, val_data)" 476 | ] 477 | }, 478 | { 479 | "cell_type": "markdown", 480 | "metadata": {}, 481 | "source": [ 482 | "## Testing " 483 | ] 484 | }, 485 | { 486 | "cell_type": "code", 487 | "execution_count": 26, 488 | "metadata": {}, 489 | "outputs": [], 490 | "source": [ 491 | "classifier_saved = MalwareClassifier.load_from_checkpoint('../models/3Nov-epoch=36-val_loss=0.51.pt.ckpt', **model_kwargs)" 492 | ] 493 | }, 494 | { 495 | "cell_type": "code", 496 | "execution_count": 33, 497 | "metadata": {}, 498 | "outputs": [ 499 | { 500 | "data": { 501 | "text/plain": [ 502 | "tensor([[ -2.0630, 1.1638, -11.9895, 5.1457, -1.6285]])" 503 | ] 504 | }, 505 | "execution_count": 33, 506 | "metadata": {}, 507 | "output_type": "execute_result" 508 | } 509 | ], 510 | "source": [ 511 | "classifier_saved(train_dataset[0][0])" 512 | ] 513 | }, 514 | { 515 | "cell_type": "code", 516 | "execution_count": 31, 517 | "metadata": {}, 518 | "outputs": [], 519 | "source": [ 520 | "classifier_saved.freeze()" 521 | ] 522 | }, 523 | { 524 | "cell_type": "code", 525 | "execution_count": 43, 526 | "metadata": {}, 527 | "outputs": [ 528 | { 529 | "data": { 530 | "text/plain": [ 531 | "tensor([4, 2, 3, ..., 1, 3, 2])" 532 | ] 533 | }, 534 | "execution_count": 43, 535 | "metadata": {}, 536 | "output_type": "execute_result" 537 | } 538 | ], 539 | "source": [ 540 | "predicted = torch.argmax(classifier_saved(dgl.batch([g for g,l in test_dataset])),dim=1)\n", 541 | "predicted" 542 | ] 543 | }, 544 | { 545 | "cell_type": "code", 546 | "execution_count": 36, 547 | "metadata": {}, 548 | "outputs": [ 549 | { 550 | "data": { 551 | "text/plain": [ 552 | "3302" 553 | ] 554 | }, 555 | "execution_count": 36, 556 | "metadata": {}, 557 | "output_type": "execute_result" 558 | } 559 | ], 560 | "source": [ 561 | "len(test_dataset)" 562 | ] 563 | }, 564 | { 565 | "cell_type": "code", 566 | "execution_count": 37, 567 | "metadata": {}, 568 | "outputs": [ 569 | { 570 | "data": { 571 | "text/plain": [ 572 | "3302" 573 | ] 574 | }, 575 | "execution_count": 37, 576 | "metadata": {}, 577 | "output_type": "execute_result" 578 | } 579 | ], 580 | "source": [ 581 | "len(test_dataset.cache)" 582 | ] 583 | }, 584 | { 585 | "cell_type": "code", 586 | "execution_count": 44, 587 | "metadata": {}, 588 | "outputs": [ 589 | { 590 | "data": { 591 | "text/plain": [ 592 | "tensor([4, 2, 3, ..., 3, 3, 2])" 593 | ] 594 | }, 595 | "execution_count": 44, 596 | "metadata": {}, 597 | "output_type": "execute_result" 598 | } 599 | ], 600 | "source": [ 601 | "actual = torch.tensor([l for g,l in test_dataset])\n", 602 | "actual" 603 | ] 604 | }, 605 | { 606 | "cell_type": "code", 607 | "execution_count": 45, 608 | "metadata": {}, 609 | "outputs": [ 610 | { 611 | "name": "stdout", 612 | "output_type": "stream", 613 | "text": [ 614 | " precision recall f1-score support\n", 615 | "\n", 616 | " 0 0.8911 0.7318 0.8036 302\n", 617 | " 1 0.6159 0.8107 0.7000 449\n", 618 | " 2 0.9124 0.9282 0.9202 808\n", 619 | " 3 0.8524 0.7856 0.8176 779\n", 620 | " 4 0.9707 0.9295 0.9497 964\n", 621 | "\n", 622 | " accuracy 0.8610 3302\n", 623 | " macro avg 0.8485 0.8372 0.8382 3302\n", 624 | "weighted avg 0.8730 0.8610 0.8640 3302\n", 625 | "\n" 626 | ] 627 | } 628 | ], 629 | "source": [ 630 | "print(M.classification_report(actual, predicted, digits=4))" 631 | ] 632 | }, 633 | { 634 | "cell_type": "code", 635 | "execution_count": 46, 636 | "metadata": {}, 637 | "outputs": [ 638 | { 639 | "data": { 640 | "text/plain": [ 641 | "array([[221, 38, 5, 37, 1],\n", 642 | " [ 5, 364, 19, 39, 22],\n", 643 | " [ 5, 29, 750, 23, 1],\n", 644 | " [ 10, 106, 48, 612, 3],\n", 645 | " [ 7, 54, 0, 7, 896]])" 646 | ] 647 | }, 648 | "execution_count": 46, 649 | "metadata": {}, 650 | "output_type": "execute_result" 651 | } 652 | ], 653 | "source": [ 654 | "M.confusion_matrix(actual, predicted)" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": null, 660 | "metadata": {}, 661 | "outputs": [], 662 | "source": [ 663 | "\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"" 664 | ] 665 | }, 666 | { 667 | "cell_type": "markdown", 668 | "metadata": {}, 669 | "source": [ 670 | "## Results\n", 671 | "Accuracy - 86.10%,\n", 672 | "Precision - 0.8485,\n", 673 | "Recall - 0.8372,\n", 674 | "F1 - 0.8382" 675 | ] 676 | } 677 | ], 678 | "metadata": { 679 | "kernelspec": { 680 | "display_name": "Python 3", 681 | "language": "python", 682 | "name": "python3" 683 | }, 684 | "language_info": { 685 | "codemirror_mode": { 686 | "name": "ipython", 687 | "version": 3 688 | }, 689 | "file_extension": ".py", 690 | "mimetype": "text/x-python", 691 | "name": "python", 692 | "nbconvert_exporter": "python", 693 | "pygments_lexer": "ipython3", 694 | "version": "3.6.9" 695 | } 696 | }, 697 | "nbformat": 4, 698 | "nbformat_minor": 4 699 | } 700 | -------------------------------------------------------------------------------- /notebooks/3-APIFeatures.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": {}, 7 | "outputs": [ 8 | { 9 | "name": "stderr", 10 | "output_type": "stream", 11 | "text": [ 12 | "Using backend: pytorch\n" 13 | ] 14 | } 15 | ], 16 | "source": [ 17 | "import re\n", 18 | "import dgl\n", 19 | "import torch\n", 20 | "from torch.utils.data import DataLoader\n", 21 | "\n", 22 | "import torch.nn as nn\n", 23 | "import torch.nn.functional as F\n", 24 | "import networkx as nx\n", 25 | "\n", 26 | "from pathlib import Path\n", 27 | "from androguard.misc import AnalyzeAPK\n", 28 | "import pickle\n", 29 | "import pytorch_lightning as pl\n", 30 | "from pytorch_lightning.callbacks.early_stopping import EarlyStopping\n", 31 | "from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint\n", 32 | "import sklearn.metrics as M\n", 33 | "\n", 34 | "from dgl.nn.pytorch import GraphConv, GATConv, SAGEConv\n", 35 | "from sklearn.model_selection import StratifiedShuffleSplit\n", 36 | "\n", 37 | "import joblib as J" 38 | ] 39 | }, 40 | { 41 | "cell_type": "code", 42 | "execution_count": 2, 43 | "metadata": {}, 44 | "outputs": [], 45 | "source": [ 46 | "def get_api_list(file):\n", 47 | " apis = open(file).readlines()\n", 48 | " return {x.strip(): i for i, x in enumerate(apis)}" 49 | ] 50 | }, 51 | { 52 | "cell_type": "code", 53 | "execution_count": 3, 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "api_list = get_api_list('api.list')" 58 | ] 59 | }, 60 | { 61 | "cell_type": "code", 62 | "execution_count": 4, 63 | "metadata": {}, 64 | "outputs": [ 65 | { 66 | "data": { 67 | "text/plain": [ 68 | "226" 69 | ] 70 | }, 71 | "execution_count": 4, 72 | "metadata": {}, 73 | "output_type": "execute_result" 74 | } 75 | ], 76 | "source": [ 77 | "len(api_list)" 78 | ] 79 | }, 80 | { 81 | "cell_type": "markdown", 82 | "metadata": {}, 83 | "source": [ 84 | "## Params" 85 | ] 86 | }, 87 | { 88 | "cell_type": "code", 89 | "execution_count": 5, 90 | "metadata": {}, 91 | "outputs": [], 92 | "source": [ 93 | "model_kwargs = {'in_dim': len(api_list), 'hidden_dim': 64, 'n_classes': 5 }" 94 | ] 95 | }, 96 | { 97 | "cell_type": "code", 98 | "execution_count": 6, 99 | "metadata": {}, 100 | "outputs": [], 101 | "source": [ 102 | "train = True" 103 | ] 104 | }, 105 | { 106 | "cell_type": "code", 107 | "execution_count": 7, 108 | "metadata": {}, 109 | "outputs": [], 110 | "source": [ 111 | "extract = True" 112 | ] 113 | }, 114 | { 115 | "cell_type": "markdown", 116 | "metadata": {}, 117 | "source": [ 118 | "## Dataset" 119 | ] 120 | }, 121 | { 122 | "cell_type": "code", 123 | "execution_count": 8, 124 | "metadata": {}, 125 | "outputs": [], 126 | "source": [ 127 | "def get_samples(base_path):\n", 128 | " base_path = Path(base_path)\n", 129 | " labels_dict = {x:i for i,x in enumerate(sorted([\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"]))}\n", 130 | " if not base_path.exists():\n", 131 | " raise Exception(f'{base_path} does not exist')\n", 132 | " apk_list = sorted([x for x in base_path.iterdir() if not x.is_dir()])\n", 133 | " samples = []\n", 134 | " labels = {}\n", 135 | " for apk in apk_list:\n", 136 | " samples.append(apk.name)\n", 137 | " labels[apk.name] = labels_dict[re.findall(r'[A-Z](?:[a-z]|[A-Z])+',apk.name)[0]]\n", 138 | " return samples, labels" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 9, 144 | "metadata": {}, 145 | "outputs": [], 146 | "source": [ 147 | "samples, labels = get_samples('../data/large/raw')" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 10, 153 | "metadata": {}, 154 | "outputs": [], 155 | "source": [ 156 | "raw_prefix = Path('../data/large/raw')\n", 157 | "processed_prefix = Path('../data/large/APIFeatures')" 158 | ] 159 | }, 160 | { 161 | "cell_type": "code", 162 | "execution_count": 11, 163 | "metadata": {}, 164 | "outputs": [], 165 | "source": [ 166 | "def process(file):\n", 167 | " _, _, dx = AnalyzeAPK(raw_prefix/file)\n", 168 | " cg = dx.get_call_graph()\n", 169 | " mappings = {}\n", 170 | " #print(set(map(lambda x: x.full_name.split(';')[0][1:], filter(lambda x: x.is_external(), cg.nodes()))))\n", 171 | " #return\n", 172 | " for node in cg.nodes():\n", 173 | " mapping = {\"api_package\": None}\n", 174 | " if node.is_external():\n", 175 | " name = '.'.join(map(str, node.full_name.split(';')[0][1:].split('/')[:-2]))\n", 176 | " index = api_list.get(name, None)\n", 177 | " mapping[\"api_package\"] = index\n", 178 | " mappings[node] = mapping\n", 179 | " nx.set_node_attributes(cg, mappings)\n", 180 | " labels = {x: {'name': x.full_name} for x in cg.nodes()}\n", 181 | " nx.set_node_attributes(cg, labels)\n", 182 | " cg = nx.convert_node_labels_to_integers(cg)\n", 183 | " #return cg\n", 184 | " torch.save(cg, processed_prefix/ (file.split('.')[0]+'.graph'))" 185 | ] 186 | }, 187 | { 188 | "cell_type": "code", 189 | "execution_count": null, 190 | "metadata": {}, 191 | "outputs": [ 192 | { 193 | "name": "stderr", 194 | "output_type": "stream", 195 | "text": [ 196 | "/home/vinayak/.local/lib/python3.6/site-packages/joblib/externals/loky/process_executor.py:691: UserWarning: A worker stopped while some jobs were given to the executor. This can be caused by a too short worker timeout or by a memory leak.\n", 197 | " \"timeout or by a memory leak.\", UserWarning\n" 198 | ] 199 | } 200 | ], 201 | "source": [ 202 | "if extract:\n", 203 | " J.Parallel(n_jobs=40)(J.delayed(process)(x) for x in samples);" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": 15, 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "nx.get_node_attributes(torch.load('../data/large/APIFeatures/Benigh0000.graph'), \"api_package\");" 213 | ] 214 | }, 215 | { 216 | "cell_type": "code", 217 | "execution_count": 24, 218 | "metadata": {}, 219 | "outputs": [], 220 | "source": [ 221 | "class MalwareDataset(torch.utils.data.Dataset):\n", 222 | " def __init__(self, save_dir, list_IDs, labels):\n", 223 | " self.save_dir = Path(save_dir)\n", 224 | " self.list_IDs = list_IDs\n", 225 | " self.labels = labels\n", 226 | " self.cache = {}\n", 227 | "\n", 228 | " def __len__(self):\n", 229 | " 'Denotes the total number of samples'\n", 230 | " return len(self.list_IDs)\n", 231 | " \n", 232 | " def get_node_vector(self, pos):\n", 233 | " vector = torch.zeros(len(api_list))\n", 234 | " if pos:\n", 235 | " vector[pos] = 1\n", 236 | " return vector\n", 237 | "\n", 238 | " def __getitem__(self, index):\n", 239 | " 'Generates one sample of data'\n", 240 | " # Select sample\n", 241 | " if index not in self.cache:\n", 242 | " ID = self.list_IDs[index]\n", 243 | " graph_path = self.save_dir / (ID.split('.')[0] + '.graph')\n", 244 | " cg = torch.load(graph_path)\n", 245 | " feature = {n: self.get_node_vector(pos) for n, pos in nx.get_node_attributes(cg, 'api_package').items()}\n", 246 | " nx.set_node_attributes(cg, feature, 'feature')\n", 247 | " dg = dgl.from_networkx(cg, node_attrs=['feature'])\n", 248 | " dg = dgl.add_self_loop(dg)\n", 249 | " self.cache[index] = (dg, self.labels[ID])\n", 250 | " return self.cache[index]" 251 | ] 252 | }, 253 | { 254 | "cell_type": "markdown", 255 | "metadata": {}, 256 | "source": [ 257 | "## Data Loading" 258 | ] 259 | }, 260 | { 261 | "cell_type": "code", 262 | "execution_count": 25, 263 | "metadata": {}, 264 | "outputs": [], 265 | "source": [ 266 | "def split_dataset(samples, labels, ratios):\n", 267 | " if sum(ratios) != 1:\n", 268 | " raise Exception(\"Invalid ratios provided\")\n", 269 | " train_ratio, val_ratio, test_ratio = ratios\n", 270 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=test_ratio, random_state=0)\n", 271 | " train_idx, test_idx = list(sss.split(samples, [labels[x] for x in samples]))[0]\n", 272 | " sss = StratifiedShuffleSplit(n_splits=1, test_size=val_ratio/(1-test_ratio), random_state=0)\n", 273 | " test_list = [samples[x] for x in test_idx]\n", 274 | " train_list = [samples[x] for x in train_idx]\n", 275 | " train_idx, val_idx = list(sss.split(train_list, [labels[x] for x in train_list]))[0]\n", 276 | " train_list = [samples[x] for x in train_idx]\n", 277 | " val_list = [samples[x] for x in val_idx]\n", 278 | " return train_list, val_list, test_list" 279 | ] 280 | }, 281 | { 282 | "cell_type": "code", 283 | "execution_count": 26, 284 | "metadata": {}, 285 | "outputs": [], 286 | "source": [ 287 | "train_list, val_list, test_list = split_dataset(samples, labels, [0.6, 0.2, 0.2])" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": 27, 293 | "metadata": {}, 294 | "outputs": [ 295 | { 296 | "data": { 297 | "text/plain": [ 298 | "tensor([0.6000, 0.2000, 0.2000])" 299 | ] 300 | }, 301 | "execution_count": 27, 302 | "metadata": {}, 303 | "output_type": "execute_result" 304 | } 305 | ], 306 | "source": [ 307 | "torch.tensor([len(train_list), len(val_list), len(test_list)]).float()/len(samples)" 308 | ] 309 | }, 310 | { 311 | "cell_type": "code", 312 | "execution_count": 28, 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [ 316 | "def collate(samples):\n", 317 | " graphs, labels = [], []\n", 318 | " for graph, label in samples:\n", 319 | " graphs.append(graph)\n", 320 | " labels.append(label)\n", 321 | " batched_graph = dgl.batch(graphs)\n", 322 | " return batched_graph, torch.tensor(labels)" 323 | ] 324 | }, 325 | { 326 | "cell_type": "code", 327 | "execution_count": 29, 328 | "metadata": {}, 329 | "outputs": [], 330 | "source": [ 331 | "train_dataset = MalwareDataset(processed_prefix , train_list, labels)\n", 332 | "val_dataset = MalwareDataset(processed_prefix , val_list, labels)\n", 333 | "test_dataset = MalwareDataset(processed_prefix , test_list, labels)" 334 | ] 335 | }, 336 | { 337 | "cell_type": "code", 338 | "execution_count": 30, 339 | "metadata": {}, 340 | "outputs": [ 341 | { 342 | "data": { 343 | "text/plain": [ 344 | "(0, 0, 0)" 345 | ] 346 | }, 347 | "execution_count": 30, 348 | "metadata": {}, 349 | "output_type": "execute_result" 350 | } 351 | ], 352 | "source": [ 353 | "len(train_dataset.cache), len(val_dataset.cache), len(test_dataset.cache)" 354 | ] 355 | }, 356 | { 357 | "cell_type": "code", 358 | "execution_count": 44, 359 | "metadata": {}, 360 | "outputs": [], 361 | "source": [ 362 | "train_data = DataLoader(train_dataset, batch_size=8, shuffle=True, collate_fn=collate, num_workers=8)\n", 363 | "val_data = DataLoader(val_dataset, batch_size=8, shuffle=False, collate_fn=collate , num_workers=40)\n", 364 | "test_data = DataLoader(test_dataset, batch_size=8, shuffle=False, collate_fn=collate, num_workers=4)" 365 | ] 366 | }, 367 | { 368 | "cell_type": "markdown", 369 | "metadata": {}, 370 | "source": [ 371 | "## Model" 372 | ] 373 | }, 374 | { 375 | "cell_type": "code", 376 | "execution_count": 37, 377 | "metadata": {}, 378 | "outputs": [], 379 | "source": [ 380 | "class MalwareClassifier(pl.LightningModule):\n", 381 | " def __init__(self, in_dim, hidden_dim, n_classes):\n", 382 | " super().__init__()\n", 383 | " self.conv1 = SAGEConv(in_dim, hidden_dim, aggregator_type='mean')\n", 384 | " self.conv2 = SAGEConv(hidden_dim, hidden_dim, aggregator_type='mean')\n", 385 | " self.classify = nn.Linear(hidden_dim, n_classes)\n", 386 | " self.loss_func = nn.CrossEntropyLoss()\n", 387 | " \n", 388 | " \n", 389 | " def forward(self, g):\n", 390 | " h = g.ndata['feature']\n", 391 | " #h = torch.cat([g.ndata[x].view(-1,1).float() for x in ['public', 'entrypoint', 'external', 'native', 'codesize' ]], dim=1)\n", 392 | " # h = g.in_degrees().view(-1,1).float()\n", 393 | " # Perform graph convolution and activation function.\n", 394 | " h = F.relu(self.conv1(g, h))\n", 395 | " h = F.relu(self.conv2(g, h))\n", 396 | " g.ndata['h'] = h\n", 397 | " # Calculate graph representation by averaging all the node representations.\n", 398 | " hg = dgl.sum_nodes(g, 'h')\n", 399 | " return self.classify(hg) \n", 400 | " \n", 401 | " def training_step(self, batch, batch_idx):\n", 402 | " bg, label = batch\n", 403 | " #print(\"Outer\", len(label))\n", 404 | " prediction = self.forward(bg)\n", 405 | " loss = self.loss_func(prediction, label)\n", 406 | " return loss\n", 407 | " \n", 408 | " def validation_step(self, batch, batch_idx):\n", 409 | " bg, label = batch\n", 410 | " prediction = self.forward(bg)\n", 411 | " loss = self.loss_func(prediction, label)\n", 412 | " self.log('val_loss', loss)\n", 413 | " \n", 414 | " def configure_optimizers(self):\n", 415 | " optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n", 416 | " return optimizer" 417 | ] 418 | }, 419 | { 420 | "cell_type": "code", 421 | "execution_count": 38, 422 | "metadata": {}, 423 | "outputs": [], 424 | "source": [ 425 | "callbacks = [\n", 426 | " EarlyStopping(monitor='val_loss', patience=5, min_delta=0.01),\n", 427 | "]" 428 | ] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": 39, 433 | "metadata": {}, 434 | "outputs": [], 435 | "source": [ 436 | "checkpointer = ModelCheckpoint(filepath='../models/10Nov-{epoch:02d}-{val_loss:.2f}.pt', monitor='val_loss', mode='min', save_top_k=3)" 437 | ] 438 | }, 439 | { 440 | "cell_type": "code", 441 | "execution_count": 40, 442 | "metadata": {}, 443 | "outputs": [ 444 | { 445 | "name": "stderr", 446 | "output_type": "stream", 447 | "text": [ 448 | "GPU available: True, used: True\n", 449 | "TPU available: False, using: 0 TPU cores\n", 450 | "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [3]\n" 451 | ] 452 | } 453 | ], 454 | "source": [ 455 | "classifier= MalwareClassifier(**model_kwargs)\n", 456 | "trainer = pl.Trainer(callbacks=callbacks, checkpoint_callback=checkpointer, gpus=[3])" 457 | ] 458 | }, 459 | { 460 | "cell_type": "code", 461 | "execution_count": 53, 462 | "metadata": {}, 463 | "outputs": [ 464 | { 465 | "data": { 466 | "text/plain": [ 467 | "True" 468 | ] 469 | }, 470 | "execution_count": 53, 471 | "metadata": {}, 472 | "output_type": "execute_result" 473 | } 474 | ], 475 | "source": [ 476 | "train" 477 | ] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "metadata": {}, 483 | "outputs": [], 484 | "source": [ 485 | "iter(train_data).next()" 486 | ] 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 59, 491 | "metadata": {}, 492 | "outputs": [ 493 | { 494 | "data": { 495 | "text/plain": [ 496 | "8" 497 | ] 498 | }, 499 | "execution_count": 59, 500 | "metadata": {}, 501 | "output_type": "execute_result" 502 | } 503 | ], 504 | "source": [ 505 | "len(train_dataset.cache)" 506 | ] 507 | }, 508 | { 509 | "cell_type": "code", 510 | "execution_count": null, 511 | "metadata": {}, 512 | "outputs": [ 513 | { 514 | "name": "stderr", 515 | "output_type": "stream", 516 | "text": [ 517 | "\n", 518 | " | Name | Type | Params\n", 519 | "-----------------------------------------------\n", 520 | "0 | conv1 | SAGEConv | 29 K \n", 521 | "1 | conv2 | SAGEConv | 8 K \n", 522 | "2 | classify | Linear | 325 \n", 523 | "3 | loss_func | CrossEntropyLoss | 0 \n" 524 | ] 525 | }, 526 | { 527 | "data": { 528 | "application/vnd.jupyter.widget-view+json": { 529 | "model_id": "", 530 | "version_major": 2, 531 | "version_minor": 0 532 | }, 533 | "text/plain": [ 534 | "HBox(children=(HTML(value='Validation sanity check'), FloatProgress(value=1.0, bar_style='info', layout=Layout…" 535 | ] 536 | }, 537 | "metadata": {}, 538 | "output_type": "display_data" 539 | }, 540 | { 541 | "data": { 542 | "application/vnd.jupyter.widget-view+json": { 543 | "model_id": "c0db099f0d424271bd5198e1442156d7", 544 | "version_major": 2, 545 | "version_minor": 0 546 | }, 547 | "text/plain": [ 548 | "HBox(children=(HTML(value='Training'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max…" 549 | ] 550 | }, 551 | "metadata": {}, 552 | "output_type": "display_data" 553 | }, 554 | { 555 | "data": { 556 | "application/vnd.jupyter.widget-view+json": { 557 | "model_id": "dd1593bce733449095083d8ac4d4201b", 558 | "version_major": 2, 559 | "version_minor": 0 560 | }, 561 | "text/plain": [ 562 | "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" 563 | ] 564 | }, 565 | "metadata": {}, 566 | "output_type": "display_data" 567 | }, 568 | { 569 | "name": "stderr", 570 | "output_type": "stream", 571 | "text": [ 572 | "IOPub message rate exceeded.\n", 573 | "The notebook server will temporarily stop sending output\n", 574 | "to the client in order to avoid crashing it.\n", 575 | "To change this limit, set the config variable\n", 576 | "`--NotebookApp.iopub_msg_rate_limit`.\n", 577 | "\n", 578 | "Current values:\n", 579 | "NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)\n", 580 | "NotebookApp.rate_limit_window=3.0 (secs)\n", 581 | "\n" 582 | ] 583 | }, 584 | { 585 | "data": { 586 | "application/vnd.jupyter.widget-view+json": { 587 | "model_id": "3feb31f9ddda4b1aab25aba9295c0a69", 588 | "version_major": 2, 589 | "version_minor": 0 590 | }, 591 | "text/plain": [ 592 | "HBox(children=(HTML(value='Validating'), FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), m…" 593 | ] 594 | }, 595 | "metadata": {}, 596 | "output_type": "display_data" 597 | } 598 | ], 599 | "source": [ 600 | "if train:\n", 601 | " trainer.fit(classifier, train_data, val_data)" 602 | ] 603 | }, 604 | { 605 | "cell_type": "code", 606 | "execution_count": null, 607 | "metadata": {}, 608 | "outputs": [], 609 | "source": [] 610 | }, 611 | { 612 | "cell_type": "markdown", 613 | "metadata": {}, 614 | "source": [ 615 | "## Testing " 616 | ] 617 | }, 618 | { 619 | "cell_type": "code", 620 | "execution_count": 54, 621 | "metadata": {}, 622 | "outputs": [], 623 | "source": [ 624 | "classifier_saved = MalwareClassifier.load_from_checkpoint('../models/10Nov-epoch=15-val_loss=0.67.pt.ckpt', **model_kwargs)" 625 | ] 626 | }, 627 | { 628 | "cell_type": "code", 629 | "execution_count": null, 630 | "metadata": {}, 631 | "outputs": [], 632 | "source": [ 633 | "predicted = torch.argmax(classifier(dgl.batch([g for g,l in test_dataset])),dim=1)\n", 634 | "predicted" 635 | ] 636 | }, 637 | { 638 | "cell_type": "code", 639 | "execution_count": 60, 640 | "metadata": {}, 641 | "outputs": [ 642 | { 643 | "data": { 644 | "text/plain": [ 645 | "tensor([4, 2, 3, ..., 1, 4, 2])" 646 | ] 647 | }, 648 | "execution_count": 60, 649 | "metadata": {}, 650 | "output_type": "execute_result" 651 | } 652 | ], 653 | "source": [ 654 | "predicted" 655 | ] 656 | }, 657 | { 658 | "cell_type": "code", 659 | "execution_count": 61, 660 | "metadata": {}, 661 | "outputs": [ 662 | { 663 | "data": { 664 | "text/plain": [ 665 | "tensor([4, 2, 3, ..., 3, 3, 2])" 666 | ] 667 | }, 668 | "execution_count": 61, 669 | "metadata": {}, 670 | "output_type": "execute_result" 671 | } 672 | ], 673 | "source": [ 674 | "actual = torch.tensor([l for g,l in test_dataset])\n", 675 | "actual" 676 | ] 677 | }, 678 | { 679 | "cell_type": "code", 680 | "execution_count": 62, 681 | "metadata": {}, 682 | "outputs": [ 683 | { 684 | "name": "stdout", 685 | "output_type": "stream", 686 | "text": [ 687 | " precision recall f1-score support\n", 688 | "\n", 689 | " 0 0.9139 0.8079 0.8576 302\n", 690 | " 1 0.7181 0.6526 0.6838 449\n", 691 | " 2 0.9272 0.8824 0.9042 808\n", 692 | " 3 0.7611 0.7728 0.7669 779\n", 693 | " 4 0.8641 0.9564 0.9079 964\n", 694 | "\n", 695 | " accuracy 0.8401 3302\n", 696 | " macro avg 0.8369 0.8144 0.8241 3302\n", 697 | "weighted avg 0.8399 0.8401 0.8387 3302\n", 698 | "\n" 699 | ] 700 | } 701 | ], 702 | "source": [ 703 | "print(M.classification_report(actual, predicted, digits=4))" 704 | ] 705 | }, 706 | { 707 | "cell_type": "code", 708 | "execution_count": 63, 709 | "metadata": {}, 710 | "outputs": [ 711 | { 712 | "data": { 713 | "text/plain": [ 714 | "array([[244, 15, 2, 37, 4],\n", 715 | " [ 7, 293, 21, 70, 58],\n", 716 | " [ 5, 25, 713, 60, 5],\n", 717 | " [ 11, 55, 33, 602, 78],\n", 718 | " [ 0, 20, 0, 22, 922]])" 719 | ] 720 | }, 721 | "execution_count": 63, 722 | "metadata": {}, 723 | "output_type": "execute_result" 724 | } 725 | ], 726 | "source": [ 727 | "M.confusion_matrix(actual, predicted)" 728 | ] 729 | }, 730 | { 731 | "cell_type": "code", 732 | "execution_count": 64, 733 | "metadata": {}, 734 | "outputs": [ 735 | { 736 | "data": { 737 | "text/plain": [ 738 | "['Adware', 'Banking', 'Benigh', 'Riskware', 'SMS']" 739 | ] 740 | }, 741 | "execution_count": 64, 742 | "metadata": {}, 743 | "output_type": "execute_result" 744 | } 745 | ], 746 | "source": [ 747 | "sorted([\"Adware\", \"Benigh\", \"Banking\", \"SMS\", \"Riskware\"])" 748 | ] 749 | }, 750 | { 751 | "cell_type": "code", 752 | "execution_count": 75, 753 | "metadata": {}, 754 | "outputs": [ 755 | { 756 | "data": { 757 | "text/plain": [ 758 | "tensor(2)" 759 | ] 760 | }, 761 | "execution_count": 75, 762 | "metadata": {}, 763 | "output_type": "execute_result" 764 | } 765 | ], 766 | "source": [ 767 | "predicted[15]" 768 | ] 769 | }, 770 | { 771 | "cell_type": "code", 772 | "execution_count": 73, 773 | "metadata": {}, 774 | "outputs": [ 775 | { 776 | "data": { 777 | "text/plain": [ 778 | "tensor([ 8, 15, 20, 25, 50, 56, 57, 58, 62, 67, 72, 73,\n", 779 | " 74, 75, 76, 96, 97, 99, 100, 114, 120, 121, 123, 125,\n", 780 | " 126, 131, 138, 140, 143, 158, 159, 186, 187, 211, 212, 221,\n", 781 | " 233, 239, 241, 263, 271, 281, 284, 297, 298, 305, 306, 309,\n", 782 | " 310, 313, 316, 318, 321, 329, 333, 335, 342, 355, 359, 360,\n", 783 | " 368, 370, 388, 390, 392, 396, 407, 411, 412, 426, 429, 434,\n", 784 | " 435, 442, 450, 453, 456, 459, 467, 469, 470, 471, 475, 476,\n", 785 | " 498, 505, 512, 522, 526, 528, 541, 546, 552, 553, 555, 557,\n", 786 | " 560, 567, 573, 577, 578, 579, 589, 602, 616, 618, 619, 625,\n", 787 | " 632, 642, 652, 653, 656, 658, 663, 664, 666, 672, 677, 682,\n", 788 | " 685, 690, 705, 708, 714, 728, 733, 734, 735, 744, 747, 775,\n", 789 | " 780, 781, 785, 798, 800, 802, 804, 811, 820, 825, 829, 830,\n", 790 | " 833, 838, 842, 846, 865, 870, 876, 881, 902, 903, 905, 909,\n", 791 | " 918, 928, 934, 949, 952, 960, 967, 971, 975, 976, 985, 1016,\n", 792 | " 1020, 1021, 1023, 1028, 1029, 1054, 1070, 1091, 1095, 1097, 1107, 1108,\n", 793 | " 1114, 1119, 1120, 1135, 1136, 1174, 1195, 1199, 1205, 1207, 1222, 1223,\n", 794 | " 1224, 1229, 1233, 1236, 1252, 1265, 1268, 1269, 1275, 1278, 1282, 1284,\n", 795 | " 1294, 1299, 1300, 1303, 1305, 1316, 1317, 1318, 1328, 1342, 1343, 1358,\n", 796 | " 1366, 1377, 1386, 1387, 1389, 1392, 1395, 1412, 1414, 1415, 1420, 1424,\n", 797 | " 1425, 1427, 1433, 1434, 1439, 1447, 1448, 1451, 1457, 1461, 1481, 1501,\n", 798 | " 1507, 1510, 1512, 1521, 1528, 1556, 1562, 1563, 1565, 1566, 1567, 1569,\n", 799 | " 1570, 1575, 1584, 1589, 1592, 1601, 1607, 1608, 1612, 1613, 1621, 1622,\n", 800 | " 1624, 1630, 1638, 1653, 1659, 1664, 1682, 1684, 1690, 1691, 1693, 1694,\n", 801 | " 1715, 1717, 1718, 1745, 1750, 1766, 1770, 1771, 1793, 1795, 1811, 1817,\n", 802 | " 1827, 1849, 1854, 1878, 1882, 1886, 1888, 1896, 1898, 1901, 1913, 1924,\n", 803 | " 1928, 1930, 1935, 1936, 1950, 1956, 1957, 1970, 1975, 1984, 1985, 1992,\n", 804 | " 1994, 2005, 2012, 2021, 2024, 2035, 2041, 2045, 2052, 2056, 2057, 2058,\n", 805 | " 2065, 2067, 2068, 2074, 2079, 2091, 2095, 2099, 2102, 2104, 2105, 2107,\n", 806 | " 2115, 2116, 2117, 2122, 2132, 2133, 2155, 2162, 2163, 2170, 2179, 2180,\n", 807 | " 2183, 2185, 2186, 2189, 2200, 2203, 2206, 2208, 2216, 2221, 2237, 2241,\n", 808 | " 2247, 2248, 2255, 2256, 2262, 2267, 2268, 2275, 2276, 2279, 2300, 2301,\n", 809 | " 2304, 2307, 2322, 2328, 2337, 2346, 2349, 2353, 2363, 2365, 2375, 2377,\n", 810 | " 2381, 2389, 2394, 2397, 2419, 2430, 2436, 2448, 2449, 2454, 2457, 2475,\n", 811 | " 2479, 2485, 2487, 2490, 2515, 2519, 2528, 2534, 2535, 2541, 2542, 2543,\n", 812 | " 2544, 2549, 2553, 2558, 2576, 2581, 2587, 2596, 2598, 2609, 2610, 2611,\n", 813 | " 2616, 2617, 2619, 2628, 2638, 2640, 2642, 2654, 2662, 2665, 2675, 2677,\n", 814 | " 2681, 2688, 2692, 2693, 2694, 2698, 2704, 2727, 2731, 2744, 2766, 2778,\n", 815 | " 2780, 2782, 2787, 2791, 2794, 2796, 2803, 2813, 2814, 2822, 2861, 2887,\n", 816 | " 2888, 2898, 2908, 2914, 2916, 2922, 2923, 2942, 2951, 2953, 2958, 2960,\n", 817 | " 2970, 2990, 2991, 2993, 2998, 3000, 3001, 3008, 3011, 3012, 3027, 3030,\n", 818 | " 3038, 3046, 3048, 3049, 3050, 3060, 3070, 3072, 3081, 3089, 3090, 3093,\n", 819 | " 3095, 3096, 3122, 3128, 3129, 3130, 3140, 3142, 3146, 3150, 3151, 3176,\n", 820 | " 3178, 3180, 3189, 3204, 3210, 3218, 3228, 3237, 3240, 3241, 3242, 3247,\n", 821 | " 3251, 3254, 3260, 3261, 3265, 3267, 3273, 3279, 3291, 3295, 3299, 3300])" 822 | ] 823 | }, 824 | "execution_count": 73, 825 | "metadata": {}, 826 | "output_type": "execute_result" 827 | } 828 | ], 829 | "source": [ 830 | "torch.where(actual!=predicted)[0]" 831 | ] 832 | }, 833 | { 834 | "cell_type": "code", 835 | "execution_count": 70, 836 | "metadata": {}, 837 | "outputs": [], 838 | "source": [ 839 | "import numpy as np" 840 | ] 841 | }, 842 | { 843 | "cell_type": "code", 844 | "execution_count": 71, 845 | "metadata": {}, 846 | "outputs": [], 847 | "source": [ 848 | "test_list_np = np.array(test_list)" 849 | ] 850 | }, 851 | { 852 | "cell_type": "code", 853 | "execution_count": 72, 854 | "metadata": {}, 855 | "outputs": [ 856 | { 857 | "data": { 858 | "text/plain": [ 859 | "array(['Adware0767.apk', 'Banking1835.apk', 'Riskware0594.apk',\n", 860 | " 'Banking0851.apk', 'Riskware4216.apk', 'Riskware2248.apk',\n", 861 | " 'Banking0452.apk', 'Benigh0310.apk', 'Banking1730.apk',\n", 862 | " 'Banking0032.apk', 'Benigh0353.apk', 'Banking1341.apk',\n", 863 | " 'Banking0427.apk', 'Benigh2838.apk', 'Banking1769.apk',\n", 864 | " 'Banking1240.apk', 'Riskware4175.apk', 'Riskware1607.apk',\n", 865 | " 'Riskware0891.apk', 'Riskware2847.apk', 'Riskware4066.apk',\n", 866 | " 'Banking1974.apk', 'Benigh0736.apk', 'Riskware1756.apk',\n", 867 | " 'Adware0635.apk', 'Adware0093.apk', 'Banking0844.apk',\n", 868 | " 'Riskware3207.apk', 'SMS2836.apk', 'Benigh0740.apk',\n", 869 | " 'Banking1832.apk', 'Riskware2993.apk', 'Banking0169.apk',\n", 870 | " 'Riskware0301.apk', 'Banking0434.apk', 'Riskware1234.apk',\n", 871 | " 'Riskware3747.apk', 'Banking0577.apk', 'Banking1982.apk',\n", 872 | " 'Benigh1585.apk', 'Banking0099.apk', 'Banking0505.apk',\n", 873 | " 'Banking2505.apk', 'Banking0772.apk', 'Benigh3860.apk',\n", 874 | " 'Riskware2045.apk', 'Riskware3325.apk', 'Riskware1556.apk',\n", 875 | " 'Banking1570.apk', 'SMS3149.apk', 'Banking0966.apk',\n", 876 | " 'Riskware1269.apk', 'Benigh3483.apk', 'Banking2474.apk',\n", 877 | " 'Riskware1450.apk', 'Benigh2712.apk', 'Riskware0937.apk',\n", 878 | " 'SMS3697.apk', 'Banking2479.apk', 'Banking1733.apk',\n", 879 | " 'Adware1157.apk', 'SMS0266.apk', 'Adware0256.apk',\n", 880 | " 'Benigh1095.apk', 'Banking1418.apk', 'Riskware0442.apk',\n", 881 | " 'Riskware0396.apk', 'Riskware4114.apk', 'Riskware3205.apk',\n", 882 | " 'Adware1164.apk', 'Banking1098.apk', 'Banking2064.apk',\n", 883 | " 'Benigh0037.apk', 'Benigh3475.apk', 'Riskware3470.apk',\n", 884 | " 'Banking1285.apk', 'Banking2057.apk', 'Banking2073.apk',\n", 885 | " 'Banking1577.apk', 'Benigh1864.apk', 'Benigh0792.apk',\n", 886 | " 'Riskware1681.apk', 'Benigh1971.apk', 'Banking1405.apk',\n", 887 | " 'Riskware2764.apk', 'Riskware4039.apk', 'Banking0555.apk',\n", 888 | " 'Benigh2512.apk', 'Riskware0749.apk', 'SMS1263.apk',\n", 889 | " 'Riskware3435.apk', 'SMS2480.apk', 'Benigh0774.apk',\n", 890 | " 'Banking1152.apk', 'Riskware3737.apk', 'Benigh1511.apk',\n", 891 | " 'Banking2107.apk', 'Banking0744.apk', 'SMS1882.apk',\n", 892 | " 'Adware0851.apk', 'Benigh2390.apk', 'Riskware1853.apk',\n", 893 | " 'Benigh1048.apk', 'Benigh2096.apk', 'Riskware3455.apk',\n", 894 | " 'Benigh1203.apk', 'Banking2314.apk', 'SMS4618.apk',\n", 895 | " 'Banking0821.apk', 'Riskware3399.apk', 'Riskware2527.apk',\n", 896 | " 'Riskware3785.apk', 'Riskware2920.apk', 'Riskware3040.apk',\n", 897 | " 'Banking1558.apk', 'Riskware3602.apk', 'Benigh3157.apk',\n", 898 | " 'Benigh0566.apk', 'Riskware3317.apk', 'Adware0340.apk',\n", 899 | " 'Adware0703.apk', 'Banking2357.apk', 'Benigh2381.apk',\n", 900 | " 'Riskware3976.apk', 'SMS1072.apk', 'Banking0593.apk',\n", 901 | " 'Benigh1727.apk', 'Benigh1766.apk', 'Riskware0696.apk',\n", 902 | " 'Adware1217.apk', 'Benigh1855.apk', 'Benigh3369.apk',\n", 903 | " 'Riskware2471.apk', 'Benigh3685.apk', 'Banking0333.apk',\n", 904 | " 'Banking0718.apk', 'SMS0566.apk', 'Banking1504.apk',\n", 905 | " 'Riskware2216.apk', 'Banking0119.apk', 'Benigh3598.apk',\n", 906 | " 'Riskware1688.apk', 'Riskware3894.apk', 'Adware0449.apk',\n", 907 | " 'Banking0185.apk', 'Riskware2190.apk', 'Banking2200.apk',\n", 908 | " 'Banking1395.apk', 'Riskware3401.apk', 'Riskware3623.apk',\n", 909 | " 'Riskware4072.apk', 'Riskware0085.apk', 'Banking1550.apk',\n", 910 | " 'Adware0805.apk', 'Banking0309.apk', 'Riskware1574.apk',\n", 911 | " 'Banking1301.apk', 'Banking2225.apk', 'Banking1145.apk',\n", 912 | " 'Benigh3575.apk', 'Banking2331.apk', 'Benigh1259.apk',\n", 913 | " 'Banking1732.apk', 'Adware0431.apk', 'Adware0546.apk',\n", 914 | " 'Banking0839.apk', 'Banking1415.apk', 'Banking2355.apk',\n", 915 | " 'Adware0291.apk', 'Riskware3841.apk', 'Banking2252.apk',\n", 916 | " 'Banking1643.apk', 'Riskware4172.apk', 'Riskware0005.apk',\n", 917 | " 'Banking0339.apk', 'Riskware0950.apk', 'Riskware1184.apk',\n", 918 | " 'Banking2061.apk', 'Riskware2878.apk', 'Riskware3763.apk',\n", 919 | " 'Riskware2391.apk', 'Banking0042.apk', 'Banking1054.apk',\n", 920 | " 'Adware0526.apk', 'Benigh2876.apk', 'Banking1770.apk',\n", 921 | " 'Riskware0910.apk', 'SMS0311.apk', 'Banking2421.apk',\n", 922 | " 'Riskware0791.apk', 'Banking1325.apk', 'Adware0613.apk',\n", 923 | " 'Benigh0525.apk', 'Riskware0864.apk', 'Banking1290.apk',\n", 924 | " 'Banking1104.apk', 'Benigh2000.apk', 'Banking1738.apk',\n", 925 | " 'Riskware3268.apk', 'Benigh2831.apk', 'Riskware2095.apk',\n", 926 | " 'Benigh1315.apk', 'Benigh1070.apk', 'Riskware0519.apk',\n", 927 | " 'Banking1399.apk', 'Adware0509.apk', 'Banking2300.apk',\n", 928 | " 'Banking2376.apk', 'Riskware3438.apk', 'Riskware0450.apk',\n", 929 | " 'Riskware0347.apk', 'Riskware1758.apk', 'Riskware4079.apk',\n", 930 | " 'SMS1509.apk', 'Riskware2578.apk', 'SMS4525.apk', 'Adware1366.apk',\n", 931 | " 'Riskware2598.apk', 'Riskware3295.apk', 'Riskware2549.apk',\n", 932 | " 'Riskware0913.apk', 'SMS0928.apk', 'Banking1671.apk',\n", 933 | " 'Adware0861.apk', 'SMS2605.apk', 'Banking0365.apk',\n", 934 | " 'Adware0290.apk', 'Riskware1517.apk', 'SMS1008.apk',\n", 935 | " 'Banking0857.apk', 'SMS1574.apk', 'Benigh2745.apk',\n", 936 | " 'Adware0287.apk', 'Riskware0753.apk', 'Benigh2260.apk',\n", 937 | " 'Banking0488.apk', 'Riskware3329.apk', 'Banking1814.apk',\n", 938 | " 'Riskware4267.apk', 'Banking1529.apk', 'Riskware3121.apk',\n", 939 | " 'Adware0878.apk', 'Banking2173.apk', 'Riskware2561.apk',\n", 940 | " 'Riskware2245.apk', 'SMS0386.apk', 'Benigh1118.apk',\n", 941 | " 'Riskware2474.apk', 'Adware1108.apk', 'SMS4221.apk',\n", 942 | " 'Banking2434.apk', 'Adware1492.apk', 'Riskware3423.apk',\n", 943 | " 'Banking1538.apk', 'SMS1047.apk', 'Benigh3656.apk', 'SMS1980.apk',\n", 944 | " 'Benigh3011.apk', 'Benigh0128.apk', 'Riskware2808.apk',\n", 945 | " 'Banking0453.apk', 'Riskware2866.apk', 'Banking1628.apk',\n", 946 | " 'Benigh1507.apk', 'Riskware1915.apk', 'Banking0270.apk',\n", 947 | " 'Benigh1050.apk', 'Adware0837.apk', 'SMS2513.apk',\n", 948 | " 'Benigh2576.apk', 'Benigh1782.apk', 'Banking1030.apk',\n", 949 | " 'Benigh0981.apk', 'Banking1356.apk', 'Banking0020.apk',\n", 950 | " 'Adware0553.apk', 'Benigh0868.apk', 'Benigh1352.apk',\n", 951 | " 'Riskware2569.apk', 'Benigh1016.apk', 'SMS0983.apk',\n", 952 | " 'Riskware2127.apk', 'Benigh0087.apk', 'Benigh0073.apk',\n", 953 | " 'Riskware3855.apk', 'Benigh0844.apk', 'Riskware2247.apk',\n", 954 | " 'Adware1463.apk', 'Benigh0522.apk', 'Riskware0503.apk',\n", 955 | " 'Riskware2861.apk', 'Riskware1998.apk', 'SMS4184.apk',\n", 956 | " 'SMS2380.apk', 'Riskware1063.apk', 'Banking1420.apk',\n", 957 | " 'Benigh2884.apk', 'Banking2444.apk', 'Riskware0788.apk',\n", 958 | " 'Benigh0022.apk', 'SMS3378.apk', 'Banking1500.apk',\n", 959 | " 'Benigh3259.apk', 'Riskware0682.apk', 'Benigh2313.apk',\n", 960 | " 'Benigh1027.apk', 'Riskware3974.apk', 'Banking0336.apk',\n", 961 | " 'Banking0247.apk', 'Adware0625.apk', 'SMS3931.apk',\n", 962 | " 'Riskware4116.apk', 'Riskware2706.apk', 'Adware0244.apk',\n", 963 | " 'Banking1846.apk', 'Adware1393.apk', 'Benigh0922.apk',\n", 964 | " 'Banking2282.apk', 'Adware0975.apk', 'Benigh1472.apk',\n", 965 | " 'Adware1113.apk', 'Riskware2818.apk', 'Benigh3354.apk',\n", 966 | " 'Riskware3544.apk', 'Benigh2723.apk', 'Banking2102.apk',\n", 967 | " 'Riskware3272.apk', 'Banking1708.apk', 'Adware0329.apk',\n", 968 | " 'Riskware0964.apk', 'Riskware4263.apk', 'Banking2360.apk',\n", 969 | " 'Benigh0742.apk', 'Adware0111.apk', 'Banking1862.apk',\n", 970 | " 'Riskware2564.apk', 'Riskware1360.apk', 'Banking1074.apk',\n", 971 | " 'Adware0087.apk', 'Riskware0505.apk', 'SMS4480.apk',\n", 972 | " 'Adware0758.apk', 'Banking0322.apk', 'Riskware2606.apk',\n", 973 | " 'Banking2295.apk', 'Benigh3300.apk', 'Benigh2638.apk',\n", 974 | " 'Adware1297.apk', 'Benigh3951.apk', 'SMS2015.apk',\n", 975 | " 'Banking0682.apk', 'Riskware4104.apk', 'Banking1144.apk',\n", 976 | " 'Riskware3211.apk', 'Banking0210.apk', 'Banking2453.apk',\n", 977 | " 'Benigh3099.apk', 'Benigh1636.apk', 'Benigh1040.apk',\n", 978 | " 'SMS1414.apk', 'Banking1143.apk', 'Adware0146.apk',\n", 979 | " 'Banking1697.apk', 'Adware0099.apk', 'Adware1174.apk',\n", 980 | " 'Banking1860.apk', 'Banking0459.apk', 'Riskware4161.apk',\n", 981 | " 'Riskware0405.apk', 'Riskware1578.apk', 'Riskware2554.apk',\n", 982 | " 'Banking0222.apk', 'Benigh0823.apk', 'SMS3522.apk',\n", 983 | " 'Riskware3061.apk', 'Riskware0979.apk', 'Riskware2884.apk',\n", 984 | " 'Benigh1675.apk', 'SMS4760.apk', 'Riskware3814.apk',\n", 985 | " 'Riskware4173.apk', 'Riskware2855.apk', 'Riskware4120.apk',\n", 986 | " 'Banking0946.apk', 'Riskware1887.apk', 'Riskware1453.apk',\n", 987 | " 'Benigh0225.apk', 'Benigh0968.apk', 'SMS0143.apk',\n", 988 | " 'Riskware0628.apk', 'SMS0530.apk', 'Banking1407.apk',\n", 989 | " 'Riskware3744.apk', 'Riskware2538.apk', 'Adware0451.apk',\n", 990 | " 'Banking1162.apk', 'Benigh3747.apk', 'Benigh2595.apk',\n", 991 | " 'Banking1239.apk', 'Banking1887.apk', 'Banking0722.apk',\n", 992 | " 'Adware0445.apk', 'Benigh2217.apk', 'Banking1466.apk',\n", 993 | " 'Banking1844.apk', 'Riskware2210.apk', 'Riskware0597.apk',\n", 994 | " 'Riskware2742.apk', 'Riskware0990.apk', 'Riskware3062.apk',\n", 995 | " 'Riskware3016.apk', 'Banking1977.apk', 'Riskware1713.apk',\n", 996 | " 'Banking1397.apk', 'Banking1896.apk', 'Riskware1646.apk',\n", 997 | " 'Riskware1019.apk', 'Riskware0453.apk', 'Banking1322.apk',\n", 998 | " 'Adware0106.apk', 'Adware1510.apk', 'Banking0742.apk',\n", 999 | " 'Adware0086.apk', 'Benigh3999.apk', 'Riskware1353.apk',\n", 1000 | " 'Riskware0151.apk', 'Banking1995.apk', 'Benigh3291.apk',\n", 1001 | " 'Benigh3780.apk', 'Banking2036.apk', 'Adware0912.apk',\n", 1002 | " 'Benigh1539.apk', 'Riskware2241.apk', 'Banking2056.apk',\n", 1003 | " 'Benigh3888.apk', 'Riskware1761.apk', 'SMS2112.apk',\n", 1004 | " 'Banking1499.apk', 'Banking0940.apk', 'Banking1118.apk',\n", 1005 | " 'Riskware3994.apk', 'Riskware3917.apk', 'Riskware4197.apk',\n", 1006 | " 'Riskware3835.apk', 'Banking2390.apk', 'Banking1024.apk',\n", 1007 | " 'Riskware0613.apk', 'Adware0951.apk', 'Riskware0115.apk',\n", 1008 | " 'Banking1690.apk', 'SMS3593.apk', 'SMS2882.apk', 'Banking2457.apk',\n", 1009 | " 'Benigh0510.apk', 'Adware1361.apk', 'Riskware3432.apk',\n", 1010 | " 'Riskware3215.apk', 'Riskware3197.apk', 'Benigh0230.apk',\n", 1011 | " 'Riskware2911.apk', 'Banking0920.apk', 'SMS4304.apk',\n", 1012 | " 'Banking2394.apk', 'Benigh0646.apk', 'Banking2398.apk',\n", 1013 | " 'Banking1076.apk', 'Riskware0611.apk', 'Benigh3941.apk',\n", 1014 | " 'Banking1163.apk', 'Banking2348.apk', 'Riskware0204.apk',\n", 1015 | " 'Riskware0091.apk', 'Banking1518.apk', 'Riskware3844.apk',\n", 1016 | " 'Banking2256.apk', 'Banking0125.apk', 'Riskware2369.apk',\n", 1017 | " 'Riskware3876.apk', 'Riskware1473.apk', 'Banking0680.apk',\n", 1018 | " 'Banking1414.apk', 'SMS2469.apk', 'Riskware0935.apk',\n", 1019 | " 'Riskware1698.apk', 'Riskware0617.apk', 'Riskware2778.apk',\n", 1020 | " 'Riskware3130.apk', 'Riskware2655.apk', 'Adware0884.apk',\n", 1021 | " 'Banking2324.apk', 'Banking1277.apk', 'Benigh1729.apk',\n", 1022 | " 'Banking1882.apk', 'Riskware2733.apk', 'Adware1255.apk',\n", 1023 | " 'Adware0974.apk', 'Adware1073.apk', 'SMS1516.apk',\n", 1024 | " 'Banking1935.apk', 'Riskware1419.apk', 'Riskware0352.apk',\n", 1025 | " 'Banking2261.apk', 'Riskware0176.apk', 'Benigh2412.apk',\n", 1026 | " 'Adware0712.apk', 'Banking1516.apk', 'Banking1132.apk',\n", 1027 | " 'Riskware3141.apk', 'Riskware4091.apk', 'Banking2464.apk',\n", 1028 | " 'Riskware2850.apk', 'Benigh1448.apk', 'Adware0209.apk',\n", 1029 | " 'Adware0168.apk', 'Banking1439.apk', 'Banking1248.apk',\n", 1030 | " 'SMS3394.apk', 'Adware0383.apk', 'Riskware2744.apk',\n", 1031 | " 'Riskware0958.apk', 'Benigh2119.apk', 'Benigh2978.apk',\n", 1032 | " 'SMS3920.apk', 'Riskware0975.apk', 'Benigh1047.apk',\n", 1033 | " 'Adware0069.apk', 'Riskware1216.apk', 'Riskware0931.apk'],\n", 1034 | " dtype=')" 651 | ] 652 | }, 653 | "execution_count": 20, 654 | "metadata": {}, 655 | "output_type": "execute_result" 656 | } 657 | ], 658 | "source": [ 659 | "predicted = classifier(dgl.batch([g for g,l in test_dataset]))\n", 660 | "predicted" 661 | ] 662 | }, 663 | { 664 | "cell_type": "code", 665 | "execution_count": 21, 666 | "metadata": {}, 667 | "outputs": [], 668 | "source": [ 669 | "predicted_mod = predicted.detach()" 670 | ] 671 | }, 672 | { 673 | "cell_type": "code", 674 | "execution_count": 23, 675 | "metadata": {}, 676 | "outputs": [], 677 | "source": [ 678 | "predicted_mod[predicted_mod>0.5] = 1\n", 679 | "predicted_mod[predicted_mod<0.5] = 0" 680 | ] 681 | }, 682 | { 683 | "cell_type": "code", 684 | "execution_count": 29, 685 | "metadata": {}, 686 | "outputs": [ 687 | { 688 | "data": { 689 | "text/plain": [ 690 | "tensor([0, 1, 0, ..., 1, 0, 1])" 691 | ] 692 | }, 693 | "execution_count": 29, 694 | "metadata": {}, 695 | "output_type": "execute_result" 696 | } 697 | ], 698 | "source": [ 699 | "predicted_mod.long()" 700 | ] 701 | }, 702 | { 703 | "cell_type": "code", 704 | "execution_count": 25, 705 | "metadata": {}, 706 | "outputs": [ 707 | { 708 | "data": { 709 | "text/plain": [ 710 | "tensor([0, 1, 0, ..., 0, 0, 1])" 711 | ] 712 | }, 713 | "execution_count": 25, 714 | "metadata": {}, 715 | "output_type": "execute_result" 716 | } 717 | ], 718 | "source": [ 719 | "actual = torch.tensor([l for g,l in test_dataset])\n", 720 | "actual[actual!=2]=0\n", 721 | "actual[actual==2]=1\n", 722 | "actual" 723 | ] 724 | }, 725 | { 726 | "cell_type": "code", 727 | "execution_count": 47, 728 | "metadata": {}, 729 | "outputs": [ 730 | { 731 | "data": { 732 | "text/plain": [ 733 | "(3302, 2494)" 734 | ] 735 | }, 736 | "execution_count": 47, 737 | "metadata": {}, 738 | "output_type": "execute_result" 739 | } 740 | ], 741 | "source": [ 742 | "len(actual), len(torch.where(actual==0)[0])" 743 | ] 744 | }, 745 | { 746 | "cell_type": "code", 747 | "execution_count": 48, 748 | "metadata": {}, 749 | "outputs": [ 750 | { 751 | "data": { 752 | "text/plain": [ 753 | "808" 754 | ] 755 | }, 756 | "execution_count": 48, 757 | "metadata": {}, 758 | "output_type": "execute_result" 759 | } 760 | ], 761 | "source": [ 762 | "_[0]-_[1]" 763 | ] 764 | }, 765 | { 766 | "cell_type": "code", 767 | "execution_count": 31, 768 | "metadata": {}, 769 | "outputs": [ 770 | { 771 | "name": "stdout", 772 | "output_type": "stream", 773 | "text": [ 774 | " precision recall f1-score support\n", 775 | "\n", 776 | " 0 0.9594 0.9379 0.9485 2494\n", 777 | " 1 0.8206 0.8775 0.8481 808\n", 778 | "\n", 779 | " accuracy 0.9231 3302\n", 780 | " macro avg 0.8900 0.9077 0.8983 3302\n", 781 | "weighted avg 0.9254 0.9231 0.9239 3302\n", 782 | "\n" 783 | ] 784 | } 785 | ], 786 | "source": [ 787 | "print(M.classification_report(actual, predicted_mod.long(), digits=4))" 788 | ] 789 | }, 790 | { 791 | "cell_type": "code", 792 | "execution_count": 49, 793 | "metadata": {}, 794 | "outputs": [ 795 | { 796 | "data": { 797 | "text/plain": [ 798 | "array([[2339, 155],\n", 799 | " [ 99, 709]])" 800 | ] 801 | }, 802 | "execution_count": 49, 803 | "metadata": {}, 804 | "output_type": "execute_result" 805 | } 806 | ], 807 | "source": [ 808 | "M.confusion_matrix(actual, predicted_mod.long())" 809 | ] 810 | }, 811 | { 812 | "cell_type": "markdown", 813 | "metadata": {}, 814 | "source": [ 815 | "## Results\n", 816 | "Accuracy - 93.21%,\n", 817 | "Precision - 0.9254,\n", 818 | "Recall - 0.9231,\n", 819 | "F1 - 0.9239" 820 | ] 821 | }, 822 | { 823 | "cell_type": "code", 824 | "execution_count": null, 825 | "metadata": {}, 826 | "outputs": [], 827 | "source": [] 828 | } 829 | ], 830 | "metadata": { 831 | "kernelspec": { 832 | "display_name": "Python 3", 833 | "language": "python", 834 | "name": "python3" 835 | }, 836 | "language_info": { 837 | "codemirror_mode": { 838 | "name": "ipython", 839 | "version": 3 840 | }, 841 | "file_extension": ".py", 842 | "mimetype": "text/x-python", 843 | "name": "python", 844 | "nbconvert_exporter": "python", 845 | "pygments_lexer": "ipython3", 846 | "version": "3.6.9" 847 | } 848 | }, 849 | "nbformat": 4, 850 | "nbformat_minor": 4 851 | } 852 | -------------------------------------------------------------------------------- /readme.md: -------------------------------------------------------------------------------- 1 | # AndMal-Detect 2 | 3 | Android Malware Detection using Function Call Graphs and Graph Convolutional Networks 4 | 5 | # What? 6 | 7 | A research work carried out by me ([Vinayaka K V](https://github.com/vinayakakv)) during MTech (Research) degree in Department of IT, NITK. 8 | 9 | The objectives of the research were: 10 | 11 | 1. To evaluate whether GCNs were effective in detecting Android Malware using FCGs, and which GCN algorithm is best for this task. 12 | 2. To enhance the FCGs by incorporating the callback information obtained from the framework code, and evaluate them against the normal FCGs 13 | 14 | # Code organization 15 | 16 | The code achieving first objective is present at `master` (current) branch, while the code achiving second objective is present at `experiment` branch. 17 | 18 | # Methodology 19 | ![Architecture](assets/architecture.svg) 20 | 21 | ## Datasets 22 | 23 | Stored in the [`/data`](/data) folder. Currently, it contains SHA256 of the APKs containing in training and testing splits. 24 | 25 | 26 | ## APK Size Balancer 27 | 28 | Obtains the histogram of APK sizes, adds APKs wherever there is a huge imbalance between the number of APKs between classes. 29 | 30 | > **Note:** *The provided dataset is already APK Size balanced* 🥳 31 | 32 | ## FCG Extractor 33 | 34 | Implemented in [`scripts/process_dataset.py`](scripts/process_dataset.py). 35 | 36 | The class `FeatureExtractors` provides two public methods: 37 | 38 | 1. `get_user_features()` - Returns 15-bit feature vector for *internal* methods 39 | 2. `get_api_features()` - Returns a one-hot feature vector for *external* methods 40 | 41 | The method `process` extracts the FCG and assignes node features. 42 | 43 | ## Node Count Balancer 44 | 45 | Balances the dataset so that the node count distribution of the APKs between the classes is exactly the same. 46 | 47 | Implemmented in [`scripts/split_dataset.py`](scripts/split_dataset.py). 48 | 49 | > **Note:** *The provided dataset is already node-count balanced to ensure **reproducibility*** 🤩 50 | 51 | ## GCN Classifier 52 | 53 | Multi-layer GCN with dense layer at the end. 54 | 55 | Implemented in [`core/model.py`](core/model.py) 56 | 57 | # The Execution Pipeline 58 | 59 | 1. Obtain the APKs ug 60 | 2. given SHA256 from [AndroZoo](https://androzoo.uni.lu/) 61 | 3. Build the container (either singularity or docker), and get into its shell 62 | 4. Run `scripts/process_dataset.py`[scripts/process_dataset.py] on the downloaded dataset 63 | 64 | python process_dataset.py \ 65 | --source-dir \ 66 | --dest-dir \ 67 | --override # If you want to oveeride existing processed files \ 68 | --dry # If you want to perform a dry run 69 | 70 | 4. Train the model! For configuration, refer to the section below. 71 | 72 | python train_model.py 73 | 74 | # Configuration 75 | 76 | The configuration is achieved using [Hydra](https://hydra.cc/). Look into [`config/conf.yaml`](config/conf.yaml) for available configuration options. 77 | 78 | Any configuration option can be overridden in the command line. As an example, to change the number of convolution layers to 2, invoke the program as 79 | 80 | python train_model.py model.convolution_count=2 81 | 82 | You can also perform a sweep, for example, 83 | 84 | python train_model.py \ 85 | model.convolution_count=0,1,2,3 \ 86 | model.convolution_algorithm=GraphConv, SAGEConv, TAGConv, SGConv, DotGatConv \ 87 | features=degree, method_attributes, method_summary 88 | 89 | to train the model in all possible configurations! 🥳 90 | 91 | # Stack 92 | 93 | - [`androguard`](https://androguard.readthedocs.io/en/lates) - For FCG extraction and Feature assignment 94 | - [`pytorch`](https://pytorch.org/) - for Neural networks 95 | - [`dgl`](https://www.dgl.ai/) - for GCN modules 96 | - [`pytorch-lightning`](https://github.com/PyTorchLightning/pytorch-lightning) - for organization and pipeline 💖 97 | - [`hydra`](https://hydra.cc/) - for configuring experiments 98 | - [`wandb`](https://wandb.ai/) - for tracking experiments 🔥 99 | 100 | # Cite as 101 | 102 | The research paper corresponding to this work is available at [IEEE Xplore](https://ieeexplore.ieee.org/document/9478141). If you find this work helpful and use it, please cite it as 103 | 104 | @INPROCEEDINGS{9478141, 105 | author={V, Vinayaka K and D, Jaidhar C}, 106 | booktitle={2021 2nd International Conference on Secure Cyber Computing and Communications (ICSCCC)}, 107 | title={Android Malware Detection using Function Call Graph with Graph Convolutional Networks}, 108 | year={2021}, 109 | volume={}, 110 | number={}, 111 | pages={279-287}, 112 | doi={10.1109/ICSCCC51823.2021.9478141} 113 | } 114 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pytorch-lightning==1.1.7 2 | wandb==0.10.18 3 | plotly==4.14.3 4 | scikit-learn==0.24.1 5 | joblib~=1.0.0 6 | hydra-core~=1.0.5 7 | pandas==1.2.1 8 | pygtrie==2.4.2 9 | seaborn==0.11.1 10 | pygraphviz==1.7 11 | dgl-cu110==0.6 -------------------------------------------------------------------------------- /scripts/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/vinayakakv/android-malware-detection/1aab288ec599a3958982866ce989311a96cbffd9/scripts/__init__.py -------------------------------------------------------------------------------- /scripts/plot_callgraph.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from pathlib import Path 3 | 4 | import matplotlib.pyplot as plt 5 | import networkx as nx 6 | from androguard.misc import AnalyzeAPK 7 | 8 | plt.figure(figsize=(10, 5)) 9 | 10 | 11 | def plot_call_graph(cg: nx.classes.multidigraph.MultiDiGraph): 12 | layout = nx.drawing.nx_agraph.graphviz_layout(cg, prog='dot') 13 | labels, cm = {}, [] 14 | legend = '' 15 | node_list = [] 16 | for i, node in enumerate(nx.topological_sort(cg)): 17 | node_list.append(node) 18 | labels[node] = i 19 | cm.append('yellow' if node.is_external() else 'blue') 20 | legend += '%d, \\texttt{%s %s}\n' % (i, node.class_name.replace('$', '\\$'), node.name) 21 | plt.axis('off') 22 | nx.draw_networkx(cg, pos=layout, nodelist=node_list, node_color=cm, labels=labels, alpha=0.6, node_size=500, 23 | font_family='serif') 24 | with open("cg.table", "w") as f: 25 | f.write(legend) 26 | plt.tight_layout() 27 | plt.savefig("cg.pdf", dpi=300, bbox_inches="tight") 28 | plt.show() 29 | 30 | 31 | if __name__ == '__main__': 32 | parser = argparse.ArgumentParser(description='Draw FCG of small APKs') 33 | parser.add_argument( 34 | '-s', '--source-file', 35 | help='The APK file to analyze and draw', 36 | required=True 37 | ) 38 | args = parser.parse_args() 39 | if not Path(args.source_file).exists(): 40 | raise FileNotFoundError(f"{args.source_file} doesn't exist") 41 | a, d, dx = AnalyzeAPK(args.source_file) 42 | plot_call_graph(dx.get_call_graph()) 43 | -------------------------------------------------------------------------------- /scripts/process_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import json 3 | import multiprocessing 4 | import os 5 | import sys 6 | import traceback 7 | from collections import defaultdict 8 | from pathlib import Path 9 | from typing import Dict, List, Union, Optional 10 | 11 | import dgl 12 | import joblib as J 13 | import networkx as nx 14 | import torch 15 | from androguard.core.analysis.analysis import MethodAnalysis 16 | from androguard.core.api_specific_resources import load_permission_mappings 17 | from androguard.misc import AnalyzeAPK 18 | from pygtrie import StringTrie 19 | 20 | ATTRIBUTES = ['external', 'entrypoint', 'native', 'public', 'static', 'codesize', 'api', 'user'] 21 | package_directory = os.path.dirname(os.path.abspath(__file__)) 22 | 23 | stats: Dict[str, int] = defaultdict(int) 24 | 25 | 26 | def memoize(function): 27 | """ 28 | Alternative to @lru_cache which could not be pickled in ray 29 | :param function: Function to be cached 30 | :return: Wrapped function 31 | """ 32 | memo = {} 33 | 34 | def wrapper(*args): 35 | if args in memo: 36 | return memo[args] 37 | else: 38 | rv = function(*args) 39 | memo[args] = rv 40 | return rv 41 | 42 | return wrapper 43 | 44 | 45 | class FeatureExtractors: 46 | NUM_PERMISSION_GROUPS = 20 47 | NUM_API_PACKAGES = 226 48 | NUM_OPCODE_MAPPINGS = 21 49 | 50 | @staticmethod 51 | def _get_opcode_mapping() -> Dict[str, int]: 52 | """ 53 | Group opcodes and assign them an ID 54 | :return: Mapping from opcode group name to their ID 55 | """ 56 | mapping = {x: i for i, x in enumerate(['nop', 'mov', 'return', 57 | 'const', 'monitor', 'check-cast', 'instanceof', 'new', 58 | 'fill', 'throw', 'goto/switch', 'cmp', 'if', 'unused', 59 | 'arrayop', 'instanceop', 'staticop', 'invoke', 60 | 'unaryop', 'binop', 'inline'])} 61 | mapping['invalid'] = -1 62 | return mapping 63 | 64 | @staticmethod 65 | @memoize 66 | def _get_instruction_type(op_value: int) -> str: 67 | """ 68 | Get instruction group name from instruction 69 | :param op_value: Opcode value 70 | :return: String containing ID of :instr: 71 | """ 72 | if 0x00 == op_value: 73 | return 'nop' 74 | elif 0x01 <= op_value <= 0x0D: 75 | return 'mov' 76 | elif 0x0E <= op_value <= 0x11: 77 | return 'return' 78 | elif 0x12 <= op_value <= 0x1C: 79 | return 'const' 80 | elif 0x1D <= op_value <= 0x1E: 81 | return 'monitor' 82 | elif 0x1F == op_value: 83 | return 'check-cast' 84 | elif 0x20 == op_value: 85 | return 'instanceof' 86 | elif 0x22 <= op_value <= 0x23: 87 | return 'new' 88 | elif 0x24 <= op_value <= 0x26: 89 | return 'fill' 90 | elif 0x27 == op_value: 91 | return 'throw' 92 | elif 0x28 <= op_value <= 0x2C: 93 | return 'goto/switch' 94 | elif 0x2D <= op_value <= 0x31: 95 | return 'cmp' 96 | elif 0x32 <= op_value <= 0x3D: 97 | return 'if' 98 | elif (0x3E <= op_value <= 0x43) or (op_value == 0x73) or (0x79 <= op_value <= 0x7A) or ( 99 | 0xE3 <= op_value <= 0xED): 100 | return 'unused' 101 | elif (0x44 <= op_value <= 0x51) or (op_value == 0x21): 102 | return 'arrayop' 103 | elif (0x52 <= op_value <= 0x5F) or (0xF2 <= op_value <= 0xF7): 104 | return 'instanceop' 105 | elif 0x60 <= op_value <= 0x6D: 106 | return 'staticop' 107 | elif (0x6E <= op_value <= 0x72) or (0x74 <= op_value <= 0x78) or (0xF0 == op_value) or ( 108 | 0xF8 <= op_value <= 0xFB): 109 | return 'invoke' 110 | elif 0x7B <= op_value <= 0x8F: 111 | return 'unaryop' 112 | elif 0x90 <= op_value <= 0xE2: 113 | return 'binop' 114 | elif 0xEE == op_value: 115 | return 'inline' 116 | else: 117 | return 'invalid' 118 | 119 | @staticmethod 120 | def _mapping_to_bitstring(mapping: List[int], max_len) -> torch.Tensor: 121 | """ 122 | Convert opcode mappings to bitstring 123 | :param max_len: 124 | :param mapping: List of IDs of opcode groups (present in an method) 125 | :return: Binary tensor of length `len(opcode_mapping)` with value 1 at positions specified by :poram mapping: 126 | """ 127 | size = torch.Size([1, max_len]) 128 | if len(mapping) > 0: 129 | indices = torch.LongTensor([[0, x] for x in mapping]).t() 130 | values = torch.LongTensor([1] * len(mapping)) 131 | tensor = torch.sparse.LongTensor(indices, values, size) 132 | else: 133 | tensor = torch.sparse.LongTensor(size) 134 | # Sparse tensor is normal tensor on CPU! 135 | return tensor.to_dense().squeeze() 136 | 137 | @staticmethod 138 | def _get_api_trie() -> StringTrie: 139 | apis = open(Path(package_directory).parent / "metadata" / "api.list").readlines() 140 | api_list = {x.strip(): i for i, x in enumerate(apis)} 141 | api_trie = StringTrie(separator='.') 142 | for k, v in api_list.items(): 143 | api_trie[k] = v 144 | return api_trie 145 | 146 | @staticmethod 147 | @memoize 148 | def get_api_features(api: MethodAnalysis) -> Optional[torch.Tensor]: 149 | if not api.is_external(): 150 | return None 151 | api_trie = FeatureExtractors._get_api_trie() 152 | name = str(api.class_name)[1:-1].replace('/', '.') 153 | _, index = api_trie.longest_prefix(name) 154 | if index is None: 155 | indices = [] 156 | else: 157 | indices = [index] 158 | feature_vector = FeatureExtractors._mapping_to_bitstring(indices, FeatureExtractors.NUM_API_PACKAGES) 159 | return feature_vector 160 | 161 | @staticmethod 162 | @memoize 163 | def get_user_features(user: MethodAnalysis) -> Optional[torch.Tensor]: 164 | if user.is_external(): 165 | return None 166 | opcode_mapping = FeatureExtractors._get_opcode_mapping() 167 | opcode_groups = set() 168 | for instr in user.get_method().get_instructions(): 169 | instruction_type = FeatureExtractors._get_instruction_type(instr.get_op_value()) 170 | instruction_id = opcode_mapping[instruction_type] 171 | if instruction_id >= 0: 172 | opcode_groups.add(instruction_id) 173 | # 1 subtraction for 'invalid' opcode group 174 | feature_vector = FeatureExtractors._mapping_to_bitstring(list(opcode_groups), len(opcode_mapping) - 1) 175 | return torch.LongTensor(feature_vector) 176 | 177 | 178 | def process(source_file: Path, dest_dir: Path): 179 | try: 180 | file_name = source_file.stem 181 | _, _, dx = AnalyzeAPK(source_file) 182 | cg = dx.get_call_graph() 183 | mappings = {} 184 | for node in cg.nodes(): 185 | features = { 186 | "api": torch.zeros(FeatureExtractors.NUM_API_PACKAGES), 187 | "user": torch.zeros(FeatureExtractors.NUM_OPCODE_MAPPINGS) 188 | } 189 | if node.is_external(): 190 | features["api"] = FeatureExtractors.get_api_features(node) 191 | else: 192 | features["user"] = FeatureExtractors.get_user_features(node) 193 | mappings[node] = features 194 | nx.set_node_attributes(cg, mappings) 195 | cg = nx.convert_node_labels_to_integers(cg) 196 | dg = dgl.from_networkx(cg, node_attrs=ATTRIBUTES) 197 | dest_dir = dest_dir / f'{file_name}.fcg' 198 | dgl.data.utils.save_graphs(str(dest_dir), [dg]) 199 | print(f"Processed {source_file}") 200 | except: 201 | print(f"Error while processing {source_file}") 202 | traceback.print_exception(*sys.exc_info()) 203 | return 204 | 205 | 206 | if __name__ == '__main__': 207 | parser = argparse.ArgumentParser(description='Preprocess APK Dataset into Graphs') 208 | parser.add_argument( 209 | '-s', '--source-dir', 210 | help='The directory containing apks', 211 | required=True 212 | ) 213 | parser.add_argument( 214 | '-d', '--dest-dir', 215 | help='The directory to store processed graphs', 216 | required=True 217 | ) 218 | parser.add_argument( 219 | '--override', 220 | help='Override existing processed files', 221 | action='store_true' 222 | ) 223 | parser.add_argument( 224 | '--dry', 225 | help='Run without actual processing', 226 | action='store_true' 227 | ) 228 | parser.add_argument( 229 | '--n-jobs', 230 | default=multiprocessing.cpu_count(), 231 | help='Number of jobs to be used for processing' 232 | ) 233 | parser.add_argument( 234 | '--limit', 235 | help='Run for n apks', 236 | default=-1 237 | ) 238 | args = parser.parse_args() 239 | source_dir = Path(args.source_dir) 240 | if not source_dir.exists(): 241 | raise FileNotFoundError(f'{source_dir} not found') 242 | dest_dir = Path(args.dest_dir) 243 | if not dest_dir.exists(): 244 | raise FileNotFoundError(f'{dest_dir} not found') 245 | n_jobs = args.n_jobs 246 | if n_jobs < 2: 247 | print(f"n_jobs={n_jobs} is too less. Switching to number of CPUs in this machine instead") 248 | n_jobs = multiprocessing.cpu_count() 249 | files = [x for x in source_dir.iterdir() if x.is_file()] 250 | source_files = set([x.stem for x in files]) 251 | dest_files = set([x.name for x in dest_dir.iterdir() if x.is_file()]) 252 | unprocessed = [source_dir / f'{x}.apk' for x in source_files - dest_files] 253 | print(f"Only {len(unprocessed)} out of {len(source_files)} remain to be processed") 254 | if args.override: 255 | print(f"--override specified. Ignoring {len(source_files) - len(unprocessed)} processed files") 256 | unprocessed = [source_dir / f'{x}.apk' for x in source_files] 257 | print(f"Starting dataset processing with {n_jobs} Jobs") 258 | limit = int(args.limit) 259 | if limit != -1: 260 | print(f"Limiting dataset processing to {limit} apks.") 261 | unprocessed = unprocessed[:limit] 262 | if not args.dry: 263 | J.Parallel(n_jobs=n_jobs)(J.delayed(process)(x, dest_dir) for x in unprocessed) 264 | print("DONE") 265 | -------------------------------------------------------------------------------- /scripts/split_dataset.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import multiprocessing 3 | from pathlib import Path 4 | 5 | import dgl 6 | import joblib as J 7 | import numpy as np 8 | import pandas as pd 9 | 10 | 11 | def extract_stats(file: str): 12 | file = Path(file) 13 | if not file.exists(): 14 | raise ValueError(f"{file} doesn't exist") 15 | result = {} 16 | graphs, labels = dgl.data.utils.load_graphs(str(file)) 17 | graph: dgl.DGLGraph = graphs[0] 18 | result['label'] = 'Benign' if 'Benig' in file.stem else 'Malware' 19 | result['file_name'] = str(file) 20 | result['num_nodes'] = graph.num_nodes() 21 | result['num_edges'] = graph.num_edges() 22 | return result 23 | 24 | 25 | def save_list(dataframe, file_name): 26 | with open(file_name, 'a') as target: 27 | for file in dataframe['file_name']: 28 | target.writelines(f'{file.split(".")[0]}\n') 29 | 30 | 31 | def get_dataset(df: pd.DataFrame, test_ratio: float, log_dir: Path): 32 | assert 0 <= test_ratio < 1, "Ratio must be within 0 and 1" 33 | q1 = df['num_nodes'].quantile(0.25) 34 | q3 = df['num_nodes'].quantile(0.75) 35 | iqr = q3 - q1 36 | print(f"Initial range {df['num_nodes'].min(), df['num_nodes'].max()}") 37 | print(f"IQR num_nodes = {iqr}") 38 | df = df.query(f'{q1 - iqr} <= num_nodes <= {q3 + iqr}') 39 | print(f"Final range {df['num_nodes'].min(), df['num_nodes'].max()}") 40 | bins = np.arange(0, df['num_nodes'].max(), 500) 41 | ben_hist, _ = np.histogram(df.query('label == "Benign"')['num_nodes'], bins=bins) 42 | mal_hist, _ = np.histogram(df.query('label != "Benign"')['num_nodes'], bins=bins) 43 | combined = np.concatenate([ben_hist[:, np.newaxis], mal_hist[:, np.newaxis]], axis=1) 44 | np.savetxt( 45 | log_dir / 'histogram.list', 46 | combined 47 | ) 48 | final_sizes = [(x, x) for x in np.min(combined, axis=1)] 49 | final_train = [] 50 | final_test = [] 51 | for i, (ben_size, mal_size) in enumerate(final_sizes): 52 | low, high = bins[i], bins[i + 1] 53 | benign_samples = df.query(f'label == "Benign" and {low} <= num_nodes < {high}') 54 | malware_samples = df.query(f'label == "Malware" and {low} <= num_nodes < {high}') 55 | assert len(benign_samples) >= ben_size and len(malware_samples) >= mal_size, "Mismatch" 56 | benign_samples = benign_samples.sample(ben_size) 57 | malware_samples = malware_samples.sample(mal_size) 58 | if test_ratio > 0: 59 | benign_samples, benign_test_samples = np.split(benign_samples, 60 | [round((1 - test_ratio) * len(benign_samples))]) 61 | malware_samples, malware_test_samples = np.split(malware_samples, 62 | [round((1 - test_ratio) * len(malware_samples))]) 63 | final_test.append(benign_test_samples) 64 | final_test.append(malware_test_samples) 65 | final_train.append(benign_samples) 66 | final_train.append(malware_samples) 67 | final_train = pd.concat(final_train) 68 | if final_test: 69 | final_test = pd.concat(final_test) 70 | return final_train, final_test 71 | 72 | 73 | if __name__ == '__main__': 74 | parser = argparse.ArgumentParser( 75 | description='Split the input dataset into train and test partitions (80%, 20%) based on bin equalization' 76 | ) 77 | parser.add_argument( 78 | '-i', '--input-dirs', 79 | help="List of input paths", 80 | nargs='+', 81 | required=True 82 | ) 83 | parser.add_argument( 84 | '-o', '--output-dir', 85 | help="The path to write the result lists to", 86 | required=True 87 | ) 88 | parser.add_argument( 89 | '-s', '--strict', 90 | help="If set, program will terminate on error while in loop", 91 | action='store_true', 92 | default=False 93 | ) 94 | args = parser.parse_args() 95 | output_dir = Path(args.output_dir) 96 | if not output_dir.exists(): 97 | output_dir.mkdir(parents=True) 98 | 99 | input_stats = [] 100 | for input_dir in args.input_dirs: 101 | input_dir = Path(input_dir) 102 | if not input_dir.exists(): 103 | if args.strict: 104 | raise FileNotFoundError(f"{input_dir} does not exist. Halting") 105 | else: 106 | print(f"{input_dir} does not exist. Skipping...") 107 | continue 108 | stats = J.Parallel(n_jobs=multiprocessing.cpu_count())( 109 | J.delayed(extract_stats)(x) for x in input_dir.glob("*.fcg") 110 | ) 111 | input_stats.append(pd.DataFrame.from_records(stats)) 112 | input_stats = pd.concat(input_stats) 113 | zero_nodes = input_stats.query('num_nodes == 0') 114 | if len(zero_nodes) > 0: 115 | print(f"Warning: {len(zero_nodes)} APKs with num_nodes = 0 found. Writing their names to zero_nodes.list") 116 | save_list(zero_nodes, 'zero_nodes.list') 117 | input_stats = input_stats.query('num_nodes != 0') 118 | train_list, test_list = get_dataset(input_stats, 0.2, output_dir) 119 | save_list(train_list, 'train.list') 120 | save_list(test_list, 'test.list') 121 | -------------------------------------------------------------------------------- /train_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | 4 | import hydra 5 | import wandb 6 | import torch 7 | from omegaconf import DictConfig 8 | from pytorch_lightning import Trainer 9 | from pytorch_lightning.callbacks import ModelCheckpoint 10 | from pytorch_lightning.loggers import WandbLogger 11 | 12 | from core.callbacks import InputMonitor, BestModelTagger, MetricsLogger 13 | from core.data_module import MalwareDataModule 14 | from core.model import MalwareDetector 15 | 16 | 17 | @hydra.main(config_path="config", config_name="conf") 18 | def train_model(cfg: DictConfig) -> None: 19 | data_module = MalwareDataModule(**cfg['data']) 20 | 21 | model = MalwareDetector(**cfg['model']) 22 | 23 | callbacks = [ModelCheckpoint( 24 | dirpath=os.getcwd(), 25 | filename=str('{epoch:02d}-{val_loss:.2f}.pt'), 26 | monitor='val_loss', 27 | mode='min', 28 | save_last=True, 29 | save_top_k=-1 30 | )] 31 | 32 | trainer_kwargs = dict(cfg['trainer']) 33 | force_retrain = cfg.get('force_retrain', False) 34 | if Path('last.ckpt').exists() and not force_retrain: 35 | trainer_kwargs['resume_from_checkpoint'] = 'last.ckpt' 36 | 37 | if 'logger' in cfg: 38 | # We use WandB logger 39 | logger = WandbLogger( 40 | **cfg['logger']['args'], 41 | tags=[f'testing' if "testing" in cfg else "training"] 42 | ) 43 | if "testing" in cfg: 44 | logger.experiment.summary["test_type"] = cfg["testing"] 45 | logger.watch(model) 46 | logger.log_hyperparams(cfg['logger']['hparams']) 47 | if logger: 48 | trainer_kwargs['logger'] = logger 49 | callbacks.append(InputMonitor()) 50 | callbacks.append(BestModelTagger(monitor='val_loss', mode='min')) 51 | callbacks.append(MetricsLogger(stages='all')) 52 | 53 | trainer = Trainer( 54 | callbacks=callbacks, 55 | **trainer_kwargs 56 | ) 57 | testing = cfg.get('testing', '') 58 | if not testing: 59 | trainer.fit(model, datamodule=data_module) 60 | else: 61 | if testing not in ['last', 'best'] and 'epoch' not in testing: 62 | raise ValueError(f"testing must be one of 'best' or 'last' or 'epoch=N'. It is {testing}") 63 | elif 'epoch' in testing: 64 | # epoch in testing 65 | epoch = testing.split('@')[1] 66 | checkpoints = list(Path(os.getcwd()).glob(f"epoch={epoch}*.ckpt")) 67 | if len(checkpoints) < 0: 68 | print(f"Checkpoint at epoch = {epoch} not found.") 69 | assert len(checkpoints) == 1, f"Multiple checkpoints corresponding to epoch = {epoch} found." 70 | ckpt_path = checkpoints[0] 71 | else: 72 | if not Path('last.ckpt').exists(): 73 | raise FileNotFoundError("No last.ckpt exists. Could not do any testing.") 74 | if testing == 'last': 75 | ckpt_path = 'last.ckpt' 76 | else: 77 | # best 78 | last_checkpoint = torch.load('last.ckpt') 79 | ckpt_path = last_checkpoint['callbacks'][ModelCheckpoint]['best_model_path'] 80 | print(f"Using checkpoint {ckpt_path} for testing.") 81 | model = MalwareDetector.load_from_checkpoint(ckpt_path, **cfg['model']) 82 | trainer.test(model, datamodule=data_module, verbose=True) 83 | wandb.finish() 84 | 85 | 86 | if __name__ == '__main__': 87 | train_model() 88 | --------------------------------------------------------------------------------