├── .gitignore ├── ArtificialNeuralNetwork.ipynb ├── Churn_Modelling.csv ├── Evaluate_Improving_Tuning.ipynb ├── README.md ├── ann.py ├── evaluating_improving_tuning.py └── requirements.txt /.gitignore: -------------------------------------------------------------------------------- 1 | .ipynb_checkpoints/ 2 | -------------------------------------------------------------------------------- /ArtificialNeuralNetwork.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# REDES NEURAIS ARTIFICIAIS\n", 8 | "\n", 9 | "[Aula 3 de Deep Learning](http://bit.ly/dn-unb03) da Engenharia de Software da UnB" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "# Parte 1 - Pré-processamento dos Dados" 17 | ] 18 | }, 19 | { 20 | "cell_type": "markdown", 21 | "metadata": {}, 22 | "source": [ 23 | "### Importar as libs" 24 | ] 25 | }, 26 | { 27 | "cell_type": "code", 28 | "execution_count": 1, 29 | "metadata": { 30 | "colab": {}, 31 | "colab_type": "code", 32 | "id": "MxkJoQBkUIHC" 33 | }, 34 | "outputs": [], 35 | "source": [ 36 | "import numpy as np\n", 37 | "import pandas as pd\n", 38 | "import tensorflow as tf" 39 | ] 40 | }, 41 | { 42 | "cell_type": "code", 43 | "execution_count": 2, 44 | "metadata": { 45 | "colab": { 46 | "base_uri": "https://localhost:8080/", 47 | "height": 34 48 | }, 49 | "colab_type": "code", 50 | "id": "ZaTwK7ojXr2F", 51 | "outputId": "0b27a96d-d11a-43e8-ab4b-87c1f01896fe" 52 | }, 53 | "outputs": [ 54 | { 55 | "data": { 56 | "text/plain": [ 57 | "'2.4.1'" 58 | ] 59 | }, 60 | "execution_count": 2, 61 | "metadata": {}, 62 | "output_type": "execute_result" 63 | } 64 | ], 65 | "source": [ 66 | "tf.__version__" 67 | ] 68 | }, 69 | { 70 | "cell_type": "markdown", 71 | "metadata": {}, 72 | "source": [ 73 | "### Importar o dataset" 74 | ] 75 | }, 76 | { 77 | "cell_type": "code", 78 | "execution_count": 4, 79 | "metadata": {}, 80 | "outputs": [ 81 | { 82 | "data": { 83 | "text/html": [ 84 | "
\n", 85 | "\n", 98 | "\n", 99 | " \n", 100 | " \n", 101 | " \n", 102 | " \n", 103 | " \n", 104 | " \n", 105 | " \n", 106 | " \n", 107 | " \n", 108 | " \n", 109 | " \n", 110 | " \n", 111 | " \n", 112 | " \n", 113 | " \n", 114 | " \n", 115 | " \n", 116 | " \n", 117 | " \n", 118 | " \n", 119 | " \n", 120 | " \n", 121 | " \n", 122 | " \n", 123 | " \n", 124 | " \n", 125 | " \n", 126 | " \n", 127 | " \n", 128 | " \n", 129 | " \n", 130 | " \n", 131 | " \n", 132 | " \n", 133 | " \n", 134 | " \n", 135 | " \n", 136 | " \n", 137 | " \n", 138 | " \n", 139 | " \n", 140 | " \n", 141 | " \n", 142 | " \n", 143 | " \n", 144 | " \n", 145 | " \n", 146 | " \n", 147 | " \n", 148 | " \n", 149 | " \n", 150 | " \n", 151 | " \n", 152 | " \n", 153 | " \n", 154 | " \n", 155 | " \n", 156 | " \n", 157 | " \n", 158 | " \n", 159 | " \n", 160 | " \n", 161 | " \n", 162 | " \n", 163 | " \n", 164 | " \n", 165 | " \n", 166 | " \n", 167 | " \n", 168 | " \n", 169 | " \n", 170 | " \n", 171 | " \n", 172 | " \n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | " \n", 262 | " \n", 263 | " \n", 264 | " \n", 265 | " \n", 266 | " \n", 267 | " \n", 268 | " \n", 269 | " \n", 270 | " \n", 271 | " \n", 272 | " \n", 273 | " \n", 274 | " \n", 275 | " \n", 276 | " \n", 277 | " \n", 278 | " \n", 279 | " \n", 280 | " \n", 281 | " \n", 282 | " \n", 283 | " \n", 284 | " \n", 285 | " \n", 286 | " \n", 287 | " \n", 288 | " \n", 289 | " \n", 290 | " \n", 291 | " \n", 292 | " \n", 293 | " \n", 294 | " \n", 295 | " \n", 296 | " \n", 297 | " \n", 298 | " \n", 299 | " \n", 300 | " \n", 301 | " \n", 302 | " \n", 303 | " \n", 304 | " \n", 305 | " \n", 306 | " \n", 307 | " \n", 308 | " \n", 309 | " \n", 310 | " \n", 311 | " \n", 312 | " \n", 313 | " \n", 314 | " \n", 315 | " \n", 316 | " \n", 317 | " \n", 318 | " \n", 319 | " \n", 320 | " \n", 321 | " \n", 322 | " \n", 323 | " \n", 324 | " \n", 325 | " \n", 326 | " \n", 327 | " \n", 328 | " \n", 329 | " \n", 330 | " \n", 331 | " \n", 332 | " \n", 333 | " \n", 334 | " \n", 335 | " \n", 336 | " \n", 337 | " \n", 338 | " \n", 339 | " \n", 340 | " \n", 341 | " \n", 342 | " \n", 343 | " \n", 344 | " \n", 345 | " \n", 346 | " \n", 347 | " \n", 348 | " \n", 349 | " \n", 350 | " \n", 351 | " \n", 352 | " \n", 353 | " \n", 354 | " \n", 355 | " \n", 356 | " \n", 357 | " \n", 358 | " \n", 359 | " \n", 360 | " \n", 361 | " \n", 362 | " \n", 363 | " \n", 364 | " \n", 365 | " \n", 366 | " \n", 367 | " \n", 368 | " \n", 369 | " \n", 370 | " \n", 371 | " \n", 372 | " \n", 373 | " \n", 374 | " \n", 375 | " \n", 376 | " \n", 377 | " \n", 378 | " \n", 379 | " \n", 380 | " \n", 381 | " \n", 382 | " \n", 383 | " \n", 384 | " \n", 385 | " \n", 386 | " \n", 387 | " \n", 388 | " \n", 389 | " \n", 390 | " \n", 391 | " \n", 392 | " \n", 393 | " \n", 394 | " \n", 395 | " \n", 396 | " \n", 397 | " \n", 398 | " \n", 399 | " \n", 400 | " \n", 401 | " \n", 402 | " \n", 403 | " \n", 404 | " \n", 405 | " \n", 406 | " \n", 407 | " \n", 408 | " \n", 409 | " \n", 410 | " \n", 411 | " \n", 412 | " \n", 413 | " \n", 414 | " \n", 415 | " \n", 416 | " \n", 417 | " \n", 418 | " \n", 419 | " \n", 420 | " \n", 421 | " \n", 422 | " \n", 423 | " \n", 424 | " \n", 425 | " \n", 426 | " \n", 427 | " \n", 428 | " \n", 429 | " \n", 430 | " \n", 431 | " \n", 432 | " \n", 433 | " \n", 434 | " \n", 435 | " \n", 436 | " \n", 437 | " \n", 438 | " \n", 439 | " \n", 440 | " \n", 441 | " \n", 442 | " \n", 443 | " \n", 444 | " \n", 445 | " \n", 446 | " \n", 447 | " \n", 448 | " \n", 449 | " \n", 450 | " \n", 451 | " \n", 452 | " \n", 453 | " \n", 454 | " \n", 455 | " \n", 456 | " \n", 457 | " \n", 458 | " \n", 459 | " \n", 460 | " \n", 461 | " \n", 462 | " \n", 463 | " \n", 464 | " \n", 465 | " \n", 466 | " \n", 467 | " \n", 468 | " \n", 469 | " \n", 470 | " \n", 471 | " \n", 472 | " \n", 473 | " \n", 474 | " \n", 475 | " \n", 476 | " \n", 477 | " \n", 478 | " \n", 479 | " \n", 480 | " \n", 481 | " \n", 482 | " \n", 483 | " \n", 484 | " \n", 485 | " \n", 486 | " \n", 487 | " \n", 488 | " \n", 489 | " \n", 490 | " \n", 491 | " \n", 492 | " \n", 493 | " \n", 494 | " \n", 495 | " \n", 496 | " \n", 497 | " \n", 498 | " \n", 499 | " \n", 500 | " \n", 501 | " \n", 502 | " \n", 503 | " \n", 504 | " \n", 505 | " \n", 506 | " \n", 507 | " \n", 508 | " \n", 509 | " \n", 510 | " \n", 511 | " \n", 512 | " \n", 513 | " \n", 514 | " \n", 515 | " \n", 516 | " \n", 517 | " \n", 518 | " \n", 519 | " \n", 520 | " \n", 521 | " \n", 522 | " \n", 523 | " \n", 524 | " \n", 525 | " \n", 526 | " \n", 527 | " \n", 528 | " \n", 529 | " \n", 530 | " \n", 531 | " \n", 532 | " \n", 533 | " \n", 534 | " \n", 535 | " \n", 536 | " \n", 537 | " \n", 538 | " \n", 539 | " \n", 540 | " \n", 541 | " \n", 542 | " \n", 543 | " \n", 544 | " \n", 545 | " \n", 546 | " \n", 547 | " \n", 548 | " \n", 549 | " \n", 550 | " \n", 551 | " \n", 552 | " \n", 553 | " \n", 554 | " \n", 555 | " \n", 556 | " \n", 557 | " \n", 558 | " \n", 559 | " \n", 560 | " \n", 561 | " \n", 562 | " \n", 563 | " \n", 564 | " \n", 565 | " \n", 566 | " \n", 567 | " \n", 568 | " \n", 569 | " \n", 570 | " \n", 571 | " \n", 572 | " \n", 573 | " \n", 574 | " \n", 575 | " \n", 576 | " \n", 577 | " \n", 578 | " \n", 579 | " \n", 580 | " \n", 581 | " \n", 582 | " \n", 583 | " \n", 584 | " \n", 585 | " \n", 586 | " \n", 587 | " \n", 588 | " \n", 589 | " \n", 590 | " \n", 591 | " \n", 592 | " \n", 593 | " \n", 594 | " \n", 595 | " \n", 596 | " \n", 597 | " \n", 598 | " \n", 599 | " \n", 600 | " \n", 601 | " \n", 602 | " \n", 603 | " \n", 604 | " \n", 605 | " \n", 606 | " \n", 607 | " \n", 608 | " \n", 609 | " \n", 610 | " \n", 611 | " \n", 612 | " \n", 613 | " \n", 614 | " \n", 615 | " \n", 616 | " \n", 617 | " \n", 618 | " \n", 619 | " \n", 620 | " \n", 621 | " \n", 622 | " \n", 623 | " \n", 624 | " \n", 625 | " \n", 626 | " \n", 627 | " \n", 628 | " \n", 629 | " \n", 630 | " \n", 631 | " \n", 632 | " \n", 633 | " \n", 634 | " \n", 635 | " \n", 636 | " \n", 637 | " \n", 638 | " \n", 639 | " \n", 640 | " \n", 641 | " \n", 642 | " \n", 643 | " \n", 644 | " \n", 645 | " \n", 646 | " \n", 647 | " \n", 648 | " \n", 649 | " \n", 650 | " \n", 651 | " \n", 652 | " \n", 653 | " \n", 654 | " \n", 655 | " \n", 656 | " \n", 657 | " \n", 658 | " \n", 659 | " \n", 660 | " \n", 661 | " \n", 662 | " \n", 663 | " \n", 664 | " \n", 665 | " \n", 666 | " \n", 667 | " \n", 668 | " \n", 669 | " \n", 670 | " \n", 671 | " \n", 672 | " \n", 673 | " \n", 674 | " \n", 675 | " \n", 676 | " \n", 677 | " \n", 678 | " \n", 679 | " \n", 680 | " \n", 681 | " \n", 682 | " \n", 683 | " \n", 684 | " \n", 685 | " \n", 686 | " \n", 687 | " \n", 688 | " \n", 689 | " \n", 690 | " \n", 691 | " \n", 692 | " \n", 693 | " \n", 694 | " \n", 695 | " \n", 696 | " \n", 697 | " \n", 698 | " \n", 699 | " \n", 700 | " \n", 701 | " \n", 702 | " \n", 703 | " \n", 704 | " \n", 705 | " \n", 706 | " \n", 707 | " \n", 708 | " \n", 709 | " \n", 710 | " \n", 711 | " \n", 712 | " \n", 713 | " \n", 714 | " \n", 715 | " \n", 716 | " \n", 717 | " \n", 718 | " \n", 719 | " \n", 720 | " \n", 721 | " \n", 722 | " \n", 723 | " \n", 724 | " \n", 725 | " \n", 726 | " \n", 727 | " \n", 728 | " \n", 729 | " \n", 730 | " \n", 731 | " \n", 732 | " \n", 733 | " \n", 734 | " \n", 735 | " \n", 736 | " \n", 737 | " \n", 738 | " \n", 739 | " \n", 740 | " \n", 741 | " \n", 742 | " \n", 743 | " \n", 744 | " \n", 745 | " \n", 746 | " \n", 747 | " \n", 748 | " \n", 749 | " \n", 750 | " \n", 751 | " \n", 752 | " \n", 753 | " \n", 754 | " \n", 755 | " \n", 756 | " \n", 757 | " \n", 758 | " \n", 759 | " \n", 760 | " \n", 761 | " \n", 762 | " \n", 763 | " \n", 764 | " \n", 765 | " \n", 766 | " \n", 767 | " \n", 768 | " \n", 769 | " \n", 770 | " \n", 771 | " \n", 772 | " \n", 773 | " \n", 774 | " \n", 775 | " \n", 776 | " \n", 777 | " \n", 778 | " \n", 779 | " \n", 780 | " \n", 781 | " \n", 782 | " \n", 783 | " \n", 784 | " \n", 785 | " \n", 786 | " \n", 787 | " \n", 788 | " \n", 789 | " \n", 790 | " \n", 791 | " \n", 792 | " \n", 793 | " \n", 794 | " \n", 795 | " \n", 796 | " \n", 797 | " \n", 798 | " \n", 799 | " \n", 800 | " \n", 801 | " \n", 802 | " \n", 803 | " \n", 804 | " \n", 805 | " \n", 806 | " \n", 807 | " \n", 808 | " \n", 809 | " \n", 810 | " \n", 811 | " \n", 812 | " \n", 813 | " \n", 814 | " \n", 815 | " \n", 816 | " \n", 817 | " \n", 818 | " \n", 819 | " \n", 820 | " \n", 821 | " \n", 822 | " \n", 823 | " \n", 824 | " \n", 825 | " \n", 826 | " \n", 827 | " \n", 828 | " \n", 829 | " \n", 830 | " \n", 831 | " \n", 832 | " \n", 833 | " \n", 834 | " \n", 835 | " \n", 836 | " \n", 837 | " \n", 838 | " \n", 839 | " \n", 840 | " \n", 841 | " \n", 842 | " \n", 843 | " \n", 844 | " \n", 845 | " \n", 846 | " \n", 847 | " \n", 848 | " \n", 849 | " \n", 850 | " \n", 851 | " \n", 852 | " \n", 853 | " \n", 854 | " \n", 855 | " \n", 856 | " \n", 857 | " \n", 858 | " \n", 859 | " \n", 860 | " \n", 861 | " \n", 862 | " \n", 863 | " \n", 864 | " \n", 865 | " \n", 866 | " \n", 867 | " \n", 868 | " \n", 869 | " \n", 870 | " \n", 871 | " \n", 872 | " \n", 873 | " \n", 874 | " \n", 875 | " \n", 876 | " \n", 877 | " \n", 878 | " \n", 879 | " \n", 880 | " \n", 881 | " \n", 882 | " \n", 883 | " \n", 884 | " \n", 885 | " \n", 886 | " \n", 887 | " \n", 888 | " \n", 889 | " \n", 890 | " \n", 891 | " \n", 892 | " \n", 893 | " \n", 894 | " \n", 895 | " \n", 896 | " \n", 897 | " \n", 898 | " \n", 899 | " \n", 900 | " \n", 901 | " \n", 902 | " \n", 903 | " \n", 904 | " \n", 905 | " \n", 906 | " \n", 907 | " \n", 908 | " \n", 909 | " \n", 910 | " \n", 911 | " \n", 912 | " \n", 913 | " \n", 914 | " \n", 915 | " \n", 916 | " \n", 917 | " \n", 918 | " \n", 919 | " \n", 920 | " \n", 921 | " \n", 922 | " \n", 923 | " \n", 924 | " \n", 925 | " \n", 926 | " \n", 927 | " \n", 928 | " \n", 929 | " \n", 930 | " \n", 931 | " \n", 932 | " \n", 933 | " \n", 934 | " \n", 935 | " \n", 936 | " \n", 937 | " \n", 938 | " \n", 939 | " \n", 940 | " \n", 941 | " \n", 942 | " \n", 943 | " \n", 944 | " \n", 945 | " \n", 946 | " \n", 947 | " \n", 948 | " \n", 949 | " \n", 950 | " \n", 951 | " \n", 952 | " \n", 953 | " \n", 954 | " \n", 955 | " \n", 956 | " \n", 957 | " \n", 958 | " \n", 959 | " \n", 960 | " \n", 961 | " \n", 962 | " \n", 963 | " \n", 964 | " \n", 965 | " \n", 966 | " \n", 967 | " \n", 968 | " \n", 969 | " \n", 970 | "
RowNumberCustomerIdSurnameCreditScoreGeographyGenderAgeTenureBalanceNumOfProductsHasCrCardIsActiveMemberEstimatedSalaryExited
0115634602Hargrave619FranceFemale4220.00111101348.881
1215647311Hill608SpainFemale41183807.86101112542.580
2315619304Onio502FranceFemale428159660.80310113931.571
3415701354Boni699FranceFemale3910.0020093826.630
4515737888Mitchell850SpainFemale432125510.8211179084.100
5615574012Chu645SpainMale448113755.78210149756.711
6715592531Bartlett822FranceMale5070.0021110062.800
7815656148Obinna376GermanyFemale294115046.74410119346.881
8915792365He501FranceMale444142051.0720174940.500
91015592389H?684FranceMale272134603.8811171725.730
101115767821Bearce528FranceMale316102016.7220080181.120
111215737173Andrews497SpainMale2430.0021076390.010
121315632264Kay476FranceFemale34100.0021026260.980
131415691483Chin549FranceFemale2550.00200190857.790
141515600882Scott635SpainFemale3570.0021165951.650
151615643966Goforth616GermanyMale453143129.4120164327.260
161715737452Romeo653GermanyMale581132602.881105097.671
171815788218Henderson549SpainFemale2490.0021114406.410
181915661507Muldrow587SpainMale4560.00100158684.810
192015568982Hao726FranceFemale2460.0021154724.030
202115577657McDonald732FranceMale4180.00211170886.170
212215597945Dellucci636SpainFemale3280.00210138555.460
222315699309Gerasimov510SpainFemale3840.00110118913.531
232415725737Mosman669FranceMale4630.002018487.750
242515625047Yen846FranceFemale3850.00111187616.160
252615738191Maclean577FranceMale2530.00201124508.290
262715736816Young756GermanyMale362136815.64111170041.950
272815700772Nebechi571FranceMale4490.0020038433.350
282915728693McWilliams574GermanyFemale433141349.43111100187.430
293015656300Lucciano411FranceMale29059697.1721153483.210
303115589475Azikiwe591SpainFemale3930.00310140469.381
313215706552Odinakachukwu533FranceMale36785311.70101156731.910
323315750181Sanderson553GermanyMale419110112.5420081898.810
333415659428Maggard520SpainFemale4260.0021134410.550
343515732963Clements722SpainFemale2990.00211142033.070
353615794171Lombardo475FranceFemale450134264.0411027822.991
363715788448Watson490SpainMale313145260.23101114066.770
373815729599Lorenzo804SpainMale33776548.6010198453.450
383915717426Armstrong850FranceMale3670.0011140812.900
394015585768Cameron582GermanyMale41670349.48201178074.040
404115619360Hsiao472SpainMale4040.0011070154.220
414215738148Clarke465FranceFemale518122522.32100181297.651
424315687946Osborne556FranceFemale612117419.3511194153.830
434415755196Lavine834FranceFemale492131394.56100194365.761
444515684171Bianchi660SpainFemale615155931.11111158338.390
454615754849Tyler776GermanyFemale324109421.13211126517.460
464715602280Martin829GermanyFemale279112045.67111119708.211
474815771573Okagbue637GermanyFemale399137843.80111117622.801
484915766205Yin550GermanyMale382103391.3810190878.130
495015771873Buccho776GermanyFemale372103769.22210194099.120
\n", 971 | "
" 972 | ], 973 | "text/plain": [ 974 | " RowNumber CustomerId Surname CreditScore Geography Gender Age \\\n", 975 | "0 1 15634602 Hargrave 619 France Female 42 \n", 976 | "1 2 15647311 Hill 608 Spain Female 41 \n", 977 | "2 3 15619304 Onio 502 France Female 42 \n", 978 | "3 4 15701354 Boni 699 France Female 39 \n", 979 | "4 5 15737888 Mitchell 850 Spain Female 43 \n", 980 | "5 6 15574012 Chu 645 Spain Male 44 \n", 981 | "6 7 15592531 Bartlett 822 France Male 50 \n", 982 | "7 8 15656148 Obinna 376 Germany Female 29 \n", 983 | "8 9 15792365 He 501 France Male 44 \n", 984 | "9 10 15592389 H? 684 France Male 27 \n", 985 | "10 11 15767821 Bearce 528 France Male 31 \n", 986 | "11 12 15737173 Andrews 497 Spain Male 24 \n", 987 | "12 13 15632264 Kay 476 France Female 34 \n", 988 | "13 14 15691483 Chin 549 France Female 25 \n", 989 | "14 15 15600882 Scott 635 Spain Female 35 \n", 990 | "15 16 15643966 Goforth 616 Germany Male 45 \n", 991 | "16 17 15737452 Romeo 653 Germany Male 58 \n", 992 | "17 18 15788218 Henderson 549 Spain Female 24 \n", 993 | "18 19 15661507 Muldrow 587 Spain Male 45 \n", 994 | "19 20 15568982 Hao 726 France Female 24 \n", 995 | "20 21 15577657 McDonald 732 France Male 41 \n", 996 | "21 22 15597945 Dellucci 636 Spain Female 32 \n", 997 | "22 23 15699309 Gerasimov 510 Spain Female 38 \n", 998 | "23 24 15725737 Mosman 669 France Male 46 \n", 999 | "24 25 15625047 Yen 846 France Female 38 \n", 1000 | "25 26 15738191 Maclean 577 France Male 25 \n", 1001 | "26 27 15736816 Young 756 Germany Male 36 \n", 1002 | "27 28 15700772 Nebechi 571 France Male 44 \n", 1003 | "28 29 15728693 McWilliams 574 Germany Female 43 \n", 1004 | "29 30 15656300 Lucciano 411 France Male 29 \n", 1005 | "30 31 15589475 Azikiwe 591 Spain Female 39 \n", 1006 | "31 32 15706552 Odinakachukwu 533 France Male 36 \n", 1007 | "32 33 15750181 Sanderson 553 Germany Male 41 \n", 1008 | "33 34 15659428 Maggard 520 Spain Female 42 \n", 1009 | "34 35 15732963 Clements 722 Spain Female 29 \n", 1010 | "35 36 15794171 Lombardo 475 France Female 45 \n", 1011 | "36 37 15788448 Watson 490 Spain Male 31 \n", 1012 | "37 38 15729599 Lorenzo 804 Spain Male 33 \n", 1013 | "38 39 15717426 Armstrong 850 France Male 36 \n", 1014 | "39 40 15585768 Cameron 582 Germany Male 41 \n", 1015 | "40 41 15619360 Hsiao 472 Spain Male 40 \n", 1016 | "41 42 15738148 Clarke 465 France Female 51 \n", 1017 | "42 43 15687946 Osborne 556 France Female 61 \n", 1018 | "43 44 15755196 Lavine 834 France Female 49 \n", 1019 | "44 45 15684171 Bianchi 660 Spain Female 61 \n", 1020 | "45 46 15754849 Tyler 776 Germany Female 32 \n", 1021 | "46 47 15602280 Martin 829 Germany Female 27 \n", 1022 | "47 48 15771573 Okagbue 637 Germany Female 39 \n", 1023 | "48 49 15766205 Yin 550 Germany Male 38 \n", 1024 | "49 50 15771873 Buccho 776 Germany Female 37 \n", 1025 | "\n", 1026 | " Tenure Balance NumOfProducts HasCrCard IsActiveMember \\\n", 1027 | "0 2 0.00 1 1 1 \n", 1028 | "1 1 83807.86 1 0 1 \n", 1029 | "2 8 159660.80 3 1 0 \n", 1030 | "3 1 0.00 2 0 0 \n", 1031 | "4 2 125510.82 1 1 1 \n", 1032 | "5 8 113755.78 2 1 0 \n", 1033 | "6 7 0.00 2 1 1 \n", 1034 | "7 4 115046.74 4 1 0 \n", 1035 | "8 4 142051.07 2 0 1 \n", 1036 | "9 2 134603.88 1 1 1 \n", 1037 | "10 6 102016.72 2 0 0 \n", 1038 | "11 3 0.00 2 1 0 \n", 1039 | "12 10 0.00 2 1 0 \n", 1040 | "13 5 0.00 2 0 0 \n", 1041 | "14 7 0.00 2 1 1 \n", 1042 | "15 3 143129.41 2 0 1 \n", 1043 | "16 1 132602.88 1 1 0 \n", 1044 | "17 9 0.00 2 1 1 \n", 1045 | "18 6 0.00 1 0 0 \n", 1046 | "19 6 0.00 2 1 1 \n", 1047 | "20 8 0.00 2 1 1 \n", 1048 | "21 8 0.00 2 1 0 \n", 1049 | "22 4 0.00 1 1 0 \n", 1050 | "23 3 0.00 2 0 1 \n", 1051 | "24 5 0.00 1 1 1 \n", 1052 | "25 3 0.00 2 0 1 \n", 1053 | "26 2 136815.64 1 1 1 \n", 1054 | "27 9 0.00 2 0 0 \n", 1055 | "28 3 141349.43 1 1 1 \n", 1056 | "29 0 59697.17 2 1 1 \n", 1057 | "30 3 0.00 3 1 0 \n", 1058 | "31 7 85311.70 1 0 1 \n", 1059 | "32 9 110112.54 2 0 0 \n", 1060 | "33 6 0.00 2 1 1 \n", 1061 | "34 9 0.00 2 1 1 \n", 1062 | "35 0 134264.04 1 1 0 \n", 1063 | "36 3 145260.23 1 0 1 \n", 1064 | "37 7 76548.60 1 0 1 \n", 1065 | "38 7 0.00 1 1 1 \n", 1066 | "39 6 70349.48 2 0 1 \n", 1067 | "40 4 0.00 1 1 0 \n", 1068 | "41 8 122522.32 1 0 0 \n", 1069 | "42 2 117419.35 1 1 1 \n", 1070 | "43 2 131394.56 1 0 0 \n", 1071 | "44 5 155931.11 1 1 1 \n", 1072 | "45 4 109421.13 2 1 1 \n", 1073 | "46 9 112045.67 1 1 1 \n", 1074 | "47 9 137843.80 1 1 1 \n", 1075 | "48 2 103391.38 1 0 1 \n", 1076 | "49 2 103769.22 2 1 0 \n", 1077 | "\n", 1078 | " EstimatedSalary Exited \n", 1079 | "0 101348.88 1 \n", 1080 | "1 112542.58 0 \n", 1081 | "2 113931.57 1 \n", 1082 | "3 93826.63 0 \n", 1083 | "4 79084.10 0 \n", 1084 | "5 149756.71 1 \n", 1085 | "6 10062.80 0 \n", 1086 | "7 119346.88 1 \n", 1087 | "8 74940.50 0 \n", 1088 | "9 71725.73 0 \n", 1089 | "10 80181.12 0 \n", 1090 | "11 76390.01 0 \n", 1091 | "12 26260.98 0 \n", 1092 | "13 190857.79 0 \n", 1093 | "14 65951.65 0 \n", 1094 | "15 64327.26 0 \n", 1095 | "16 5097.67 1 \n", 1096 | "17 14406.41 0 \n", 1097 | "18 158684.81 0 \n", 1098 | "19 54724.03 0 \n", 1099 | "20 170886.17 0 \n", 1100 | "21 138555.46 0 \n", 1101 | "22 118913.53 1 \n", 1102 | "23 8487.75 0 \n", 1103 | "24 187616.16 0 \n", 1104 | "25 124508.29 0 \n", 1105 | "26 170041.95 0 \n", 1106 | "27 38433.35 0 \n", 1107 | "28 100187.43 0 \n", 1108 | "29 53483.21 0 \n", 1109 | "30 140469.38 1 \n", 1110 | "31 156731.91 0 \n", 1111 | "32 81898.81 0 \n", 1112 | "33 34410.55 0 \n", 1113 | "34 142033.07 0 \n", 1114 | "35 27822.99 1 \n", 1115 | "36 114066.77 0 \n", 1116 | "37 98453.45 0 \n", 1117 | "38 40812.90 0 \n", 1118 | "39 178074.04 0 \n", 1119 | "40 70154.22 0 \n", 1120 | "41 181297.65 1 \n", 1121 | "42 94153.83 0 \n", 1122 | "43 194365.76 1 \n", 1123 | "44 158338.39 0 \n", 1124 | "45 126517.46 0 \n", 1125 | "46 119708.21 1 \n", 1126 | "47 117622.80 1 \n", 1127 | "48 90878.13 0 \n", 1128 | "49 194099.12 0 " 1129 | ] 1130 | }, 1131 | "execution_count": 4, 1132 | "metadata": {}, 1133 | "output_type": "execute_result" 1134 | } 1135 | ], 1136 | "source": [ 1137 | "dataset = pd.read_csv('Churn_Modelling.csv')\n", 1138 | "dataset.head(50)" 1139 | ] 1140 | }, 1141 | { 1142 | "cell_type": "code", 1143 | "execution_count": 5, 1144 | "metadata": {}, 1145 | "outputs": [ 1146 | { 1147 | "data": { 1148 | "text/plain": [ 1149 | "France 5014\n", 1150 | "Germany 2509\n", 1151 | "Spain 2477\n", 1152 | "Name: Geography, dtype: int64" 1153 | ] 1154 | }, 1155 | "execution_count": 5, 1156 | "metadata": {}, 1157 | "output_type": "execute_result" 1158 | } 1159 | ], 1160 | "source": [ 1161 | "dataset.Geography.value_counts()" 1162 | ] 1163 | }, 1164 | { 1165 | "cell_type": "code", 1166 | "execution_count": 7, 1167 | "metadata": {}, 1168 | "outputs": [ 1169 | { 1170 | "name": "stdout", 1171 | "output_type": "stream", 1172 | "text": [ 1173 | "X >>\n", 1174 | " [[619 'France' 'Female' ... 1 1 101348.88]\n", 1175 | " [608 'Spain' 'Female' ... 0 1 112542.58]\n", 1176 | " [502 'France' 'Female' ... 1 0 113931.57]\n", 1177 | " ...\n", 1178 | " [709 'France' 'Female' ... 0 1 42085.58]\n", 1179 | " [772 'Germany' 'Male' ... 1 0 92888.52]\n", 1180 | " [792 'France' 'Female' ... 1 0 38190.78]]\n", 1181 | "y >>\n", 1182 | " [1 0 1 ... 1 1 0]\n" 1183 | ] 1184 | } 1185 | ], 1186 | "source": [ 1187 | "X = dataset.iloc[:, 3:13].values\n", 1188 | "y = dataset.iloc[:, 13].values\n", 1189 | "print(\"X >>\\n\",X)\n", 1190 | "print(\"y >>\\n\",y)" 1191 | ] 1192 | }, 1193 | { 1194 | "cell_type": "markdown", 1195 | "metadata": {}, 1196 | "source": [ 1197 | "### Transformando os dados categóricos\n" 1198 | ] 1199 | }, 1200 | { 1201 | "cell_type": "code", 1202 | "execution_count": 9, 1203 | "metadata": {}, 1204 | "outputs": [ 1205 | { 1206 | "name": "stdout", 1207 | "output_type": "stream", 1208 | "text": [ 1209 | "[0 0 0 ... 0 1 0]\n" 1210 | ] 1211 | } 1212 | ], 1213 | "source": [ 1214 | "# Label Encoding the \"Gender\" column\n", 1215 | "from sklearn.preprocessing import LabelEncoder\n", 1216 | "le = LabelEncoder()\n", 1217 | "X[:, 2] = le.fit_transform(X[:, 2])\n", 1218 | "print(X[:, 2])" 1219 | ] 1220 | }, 1221 | { 1222 | "cell_type": "code", 1223 | "execution_count": 10, 1224 | "metadata": {}, 1225 | "outputs": [ 1226 | { 1227 | "name": "stdout", 1228 | "output_type": "stream", 1229 | "text": [ 1230 | "[[1.0 0.0 0.0 ... 1 1 101348.88]\n", 1231 | " [0.0 0.0 1.0 ... 0 1 112542.58]\n", 1232 | " [1.0 0.0 0.0 ... 1 0 113931.57]\n", 1233 | " ...\n", 1234 | " [1.0 0.0 0.0 ... 0 1 42085.58]\n", 1235 | " [0.0 1.0 0.0 ... 1 0 92888.52]\n", 1236 | " [1.0 0.0 0.0 ... 1 0 38190.78]]\n" 1237 | ] 1238 | } 1239 | ], 1240 | "source": [ 1241 | "# One Hot Encoding the \"Geography\" column\n", 1242 | "from sklearn.compose import ColumnTransformer\n", 1243 | "from sklearn.preprocessing import OneHotEncoder\n", 1244 | "ct = ColumnTransformer(transformers=[('encoder', OneHotEncoder(), [1])], remainder='passthrough')\n", 1245 | "X = np.array(ct.fit_transform(X))\n", 1246 | "print(X)" 1247 | ] 1248 | }, 1249 | { 1250 | "cell_type": "markdown", 1251 | "metadata": {}, 1252 | "source": [ 1253 | "# Dividindo o dataset em conjunto de treinamento e conjunto de teste" 1254 | ] 1255 | }, 1256 | { 1257 | "cell_type": "code", 1258 | "execution_count": 11, 1259 | "metadata": {}, 1260 | "outputs": [], 1261 | "source": [ 1262 | "from sklearn.model_selection import train_test_split\n", 1263 | "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0)" 1264 | ] 1265 | }, 1266 | { 1267 | "cell_type": "markdown", 1268 | "metadata": {}, 1269 | "source": [ 1270 | "# Feature Scaling" 1271 | ] 1272 | }, 1273 | { 1274 | "cell_type": "code", 1275 | "execution_count": 12, 1276 | "metadata": {}, 1277 | "outputs": [], 1278 | "source": [ 1279 | "from sklearn.preprocessing import StandardScaler\n", 1280 | "sc = StandardScaler()\n", 1281 | "X_train = sc.fit_transform(X_train)\n", 1282 | "X_test = sc.transform(X_test)" 1283 | ] 1284 | }, 1285 | { 1286 | "cell_type": "code", 1287 | "execution_count": 13, 1288 | "metadata": {}, 1289 | "outputs": [ 1290 | { 1291 | "data": { 1292 | "text/plain": [ 1293 | "(8000, 12)" 1294 | ] 1295 | }, 1296 | "execution_count": 13, 1297 | "metadata": {}, 1298 | "output_type": "execute_result" 1299 | } 1300 | ], 1301 | "source": [ 1302 | "np.shape(X_train)" 1303 | ] 1304 | }, 1305 | { 1306 | "cell_type": "code", 1307 | "execution_count": 14, 1308 | "metadata": {}, 1309 | "outputs": [ 1310 | { 1311 | "data": { 1312 | "text/plain": [ 1313 | "array([[-1.01460667, -0.5698444 , 1.74309049, ..., 0.64259497,\n", 1314 | " -1.03227043, 1.10643166],\n", 1315 | " [-1.01460667, 1.75486502, -0.57369368, ..., 0.64259497,\n", 1316 | " 0.9687384 , -0.74866447],\n", 1317 | " [ 0.98560362, -0.5698444 , -0.57369368, ..., 0.64259497,\n", 1318 | " -1.03227043, 1.48533467],\n", 1319 | " ...,\n", 1320 | " [ 0.98560362, -0.5698444 , -0.57369368, ..., 0.64259497,\n", 1321 | " -1.03227043, 1.41231994],\n", 1322 | " [-1.01460667, -0.5698444 , 1.74309049, ..., 0.64259497,\n", 1323 | " 0.9687384 , 0.84432121],\n", 1324 | " [-1.01460667, 1.75486502, -0.57369368, ..., 0.64259497,\n", 1325 | " -1.03227043, 0.32472465]])" 1326 | ] 1327 | }, 1328 | "execution_count": 14, 1329 | "metadata": {}, 1330 | "output_type": "execute_result" 1331 | } 1332 | ], 1333 | "source": [ 1334 | "X_train" 1335 | ] 1336 | }, 1337 | { 1338 | "cell_type": "markdown", 1339 | "metadata": {}, 1340 | "source": [ 1341 | "---\n", 1342 | "# Parte 2 -Vamos construir uma ANN!\n" 1343 | ] 1344 | }, 1345 | { 1346 | "cell_type": "markdown", 1347 | "metadata": { 1348 | "colab_type": "text", 1349 | "id": "KvdeScabXtlB" 1350 | }, 1351 | "source": [ 1352 | "### Initializing the ANN" 1353 | ] 1354 | }, 1355 | { 1356 | "cell_type": "code", 1357 | "execution_count": 15, 1358 | "metadata": { 1359 | "colab": {}, 1360 | "colab_type": "code", 1361 | "id": "3dtrScHxXQox" 1362 | }, 1363 | "outputs": [], 1364 | "source": [ 1365 | "ann = tf.keras.models.Sequential()" 1366 | ] 1367 | }, 1368 | { 1369 | "cell_type": "markdown", 1370 | "metadata": { 1371 | "colab_type": "text", 1372 | "id": "rP6urV6SX7kS" 1373 | }, 1374 | "source": [ 1375 | "### Adding the input layer and the first hidden layer" 1376 | ] 1377 | }, 1378 | { 1379 | "cell_type": "code", 1380 | "execution_count": 16, 1381 | "metadata": { 1382 | "colab": {}, 1383 | "colab_type": "code", 1384 | "id": "bppGycBXYCQr" 1385 | }, 1386 | "outputs": [], 1387 | "source": [ 1388 | "ann.add(tf.keras.layers.Dense(units=6, activation='relu'))" 1389 | ] 1390 | }, 1391 | { 1392 | "cell_type": "markdown", 1393 | "metadata": { 1394 | "colab_type": "text", 1395 | "id": "BELWAc_8YJze" 1396 | }, 1397 | "source": [ 1398 | "### Adding the second hidden layer" 1399 | ] 1400 | }, 1401 | { 1402 | "cell_type": "code", 1403 | "execution_count": 17, 1404 | "metadata": { 1405 | "colab": {}, 1406 | "colab_type": "code", 1407 | "id": "JneR0u0sYRTd" 1408 | }, 1409 | "outputs": [], 1410 | "source": [ 1411 | "ann.add(tf.keras.layers.Dense(units=6, activation='relu'))" 1412 | ] 1413 | }, 1414 | { 1415 | "cell_type": "markdown", 1416 | "metadata": { 1417 | "colab_type": "text", 1418 | "id": "OyNEe6RXYcU4" 1419 | }, 1420 | "source": [ 1421 | "### Adding the output layer" 1422 | ] 1423 | }, 1424 | { 1425 | "cell_type": "code", 1426 | "execution_count": 18, 1427 | "metadata": { 1428 | "colab": {}, 1429 | "colab_type": "code", 1430 | "id": "Cn3x41RBYfvY" 1431 | }, 1432 | "outputs": [], 1433 | "source": [ 1434 | "ann.add(tf.keras.layers.Dense(units=1, activation='sigmoid'))" 1435 | ] 1436 | }, 1437 | { 1438 | "cell_type": "markdown", 1439 | "metadata": { 1440 | "colab_type": "text", 1441 | "id": "8GWlJChhY_ZI" 1442 | }, 1443 | "source": [ 1444 | "### Compiling the ANN" 1445 | ] 1446 | }, 1447 | { 1448 | "cell_type": "code", 1449 | "execution_count": 19, 1450 | "metadata": { 1451 | "colab": {}, 1452 | "colab_type": "code", 1453 | "id": "fG3RrwDXZEaS" 1454 | }, 1455 | "outputs": [], 1456 | "source": [ 1457 | "ann.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy'])" 1458 | ] 1459 | }, 1460 | { 1461 | "cell_type": "markdown", 1462 | "metadata": { 1463 | "colab_type": "text", 1464 | "id": "JT4u2S1_Y4WG" 1465 | }, 1466 | "source": [ 1467 | "## Part 3 - Training the ANN" 1468 | ] 1469 | }, 1470 | { 1471 | "cell_type": "markdown", 1472 | "metadata": { 1473 | "colab_type": "text", 1474 | "id": "0QR_G5u7ZLSM" 1475 | }, 1476 | "source": [ 1477 | "### Training the ANN on the Training set" 1478 | ] 1479 | }, 1480 | { 1481 | "cell_type": "code", 1482 | "execution_count": 20, 1483 | "metadata": { 1484 | "colab": { 1485 | "base_uri": "https://localhost:8080/", 1486 | "height": 1000 1487 | }, 1488 | "colab_type": "code", 1489 | "id": "nHZ-LKv_ZRb3", 1490 | "outputId": "718cc4b0-b5aa-40f0-9b20-d3d31730a531" 1491 | }, 1492 | "outputs": [ 1493 | { 1494 | "name": "stdout", 1495 | "output_type": "stream", 1496 | "text": [ 1497 | "Epoch 1/100\n", 1498 | "250/250 [==============================] - 2s 3ms/step - loss: 0.5819 - accuracy: 0.7604\n", 1499 | "Epoch 2/100\n", 1500 | "250/250 [==============================] - 1s 3ms/step - loss: 0.4819 - accuracy: 0.7966\n", 1501 | "Epoch 3/100\n", 1502 | "250/250 [==============================] - 1s 4ms/step - loss: 0.4565 - accuracy: 0.7976TA: 4\n", 1503 | "Epoch 4/100\n", 1504 | "250/250 [==============================] - 1s 4ms/step - loss: 0.4461 - accuracy: 0.8012\n", 1505 | "Epoch 5/100\n", 1506 | "250/250 [==============================] - 2s 6ms/step - loss: 0.4380 - accuracy: 0.8144\n", 1507 | "Epoch 6/100\n", 1508 | "250/250 [==============================] - 1s 5ms/step - loss: 0.4241 - accuracy: 0.8217\n", 1509 | "Epoch 7/100\n", 1510 | "250/250 [==============================] - 1s 5ms/step - loss: 0.4247 - accuracy: 0.8228\n", 1511 | "Epoch 8/100\n", 1512 | "250/250 [==============================] - 1s 6ms/step - loss: 0.4212 - accuracy: 0.8239\n", 1513 | "Epoch 9/100\n", 1514 | "250/250 [==============================] - 2s 6ms/step - loss: 0.4153 - accuracy: 0.8270\n", 1515 | "Epoch 10/100\n", 1516 | "250/250 [==============================] - 1s 6ms/step - loss: 0.4188 - accuracy: 0.8269\n", 1517 | "Epoch 11/100\n", 1518 | "250/250 [==============================] - 2s 6ms/step - loss: 0.4013 - accuracy: 0.8392\n", 1519 | "Epoch 12/100\n", 1520 | "250/250 [==============================] - 1s 6ms/step - loss: 0.4047 - accuracy: 0.8339\n", 1521 | "Epoch 13/100\n", 1522 | "250/250 [==============================] - 1s 5ms/step - loss: 0.4108 - accuracy: 0.8234\n", 1523 | "Epoch 14/100\n", 1524 | "250/250 [==============================] - 1s 5ms/step - loss: 0.4088 - accuracy: 0.8298\n", 1525 | "Epoch 15/100\n", 1526 | "250/250 [==============================] - 1s 6ms/step - loss: 0.4024 - accuracy: 0.8309\n", 1527 | "Epoch 16/100\n", 1528 | "250/250 [==============================] - 2s 6ms/step - loss: 0.4004 - accuracy: 0.8277\n", 1529 | "Epoch 17/100\n", 1530 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3997 - accuracy: 0.8324\n", 1531 | "Epoch 18/100\n", 1532 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3812 - accuracy: 0.8430: 1s - loss: - ETA: 0s - los\n", 1533 | "Epoch 19/100\n", 1534 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3860 - accuracy: 0.8360\n", 1535 | "Epoch 20/100\n", 1536 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3943 - accuracy: 0.8300\n", 1537 | "Epoch 21/100\n", 1538 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3885 - accuracy: 0.8331\n", 1539 | "Epoch 22/100\n", 1540 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3755 - accuracy: 0.8342\n", 1541 | "Epoch 23/100\n", 1542 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3746 - accuracy: 0.8395\n", 1543 | "Epoch 24/100\n", 1544 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3769 - accuracy: 0.8359\n", 1545 | "Epoch 25/100\n", 1546 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3792 - accuracy: 0.8390: 0s - l\n", 1547 | "Epoch 26/100\n", 1548 | "250/250 [==============================] - 2s 9ms/step - loss: 0.3752 - accuracy: 0.8437\n", 1549 | "Epoch 27/100\n", 1550 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3629 - accuracy: 0.8575\n", 1551 | "Epoch 28/100\n", 1552 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3619 - accuracy: 0.8469: \n", 1553 | "Epoch 29/100\n", 1554 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3646 - accuracy: 0.8520\n", 1555 | "Epoch 30/100\n", 1556 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3559 - accuracy: 0.8568\n", 1557 | "Epoch 31/100\n", 1558 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3490 - accuracy: 0.8637\n", 1559 | "Epoch 32/100\n", 1560 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3608 - accuracy: 0.8485\n", 1561 | "Epoch 33/100\n", 1562 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3533 - accuracy: 0.8581\n", 1563 | "Epoch 34/100\n", 1564 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3498 - accuracy: 0.8586\n", 1565 | "Epoch 35/100\n", 1566 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3616 - accuracy: 0.8514: 0s -\n", 1567 | "Epoch 36/100\n", 1568 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3465 - accuracy: 0.8561\n", 1569 | "Epoch 37/100\n", 1570 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3532 - accuracy: 0.8522: 1s - loss: 0.3\n", 1571 | "Epoch 38/100\n", 1572 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3506 - accuracy: 0.8518\n", 1573 | "Epoch 39/100\n", 1574 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3329 - accuracy: 0.8668\n", 1575 | "Epoch 40/100\n", 1576 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3594 - accuracy: 0.8522: 0s - loss:\n", 1577 | "Epoch 41/100\n", 1578 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3466 - accuracy: 0.8533\n", 1579 | "Epoch 42/100\n", 1580 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3600 - accuracy: 0.8502\n", 1581 | "Epoch 43/100\n", 1582 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3510 - accuracy: 0.8531\n", 1583 | "Epoch 44/100\n", 1584 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3483 - accuracy: 0.8568\n", 1585 | "Epoch 45/100\n", 1586 | "250/250 [==============================] - 2s 9ms/step - loss: 0.3460 - accuracy: 0.8601\n", 1587 | "Epoch 46/100\n", 1588 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3346 - accuracy: 0.8593\n", 1589 | "Epoch 47/100\n", 1590 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3413 - accuracy: 0.8607: 0s - loss: 0.3411 - accuracy: \n", 1591 | "Epoch 48/100\n", 1592 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3569 - accuracy: 0.8516\n", 1593 | "Epoch 49/100\n", 1594 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3346 - accuracy: 0.8662\n", 1595 | "Epoch 50/100\n", 1596 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3404 - accuracy: 0.8630\n", 1597 | "Epoch 51/100\n", 1598 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3303 - accuracy: 0.8675\n", 1599 | "Epoch 52/100\n", 1600 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3423 - accuracy: 0.8595\n", 1601 | "Epoch 53/100\n", 1602 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3317 - accuracy: 0.8637\n", 1603 | "Epoch 54/100\n", 1604 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3361 - accuracy: 0.8626\n", 1605 | "Epoch 55/100\n", 1606 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3498 - accuracy: 0.8522\n", 1607 | "Epoch 56/100\n", 1608 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3446 - accuracy: 0.8582\n", 1609 | "Epoch 57/100\n", 1610 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3295 - accuracy: 0.8666: 0s - loss: 0.3282 - accura\n", 1611 | "Epoch 58/100\n", 1612 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3392 - accuracy: 0.8621\n", 1613 | "Epoch 59/100\n", 1614 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3335 - accuracy: 0.8620\n", 1615 | "Epoch 60/100\n", 1616 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3355 - accuracy: 0.8670\n", 1617 | "Epoch 61/100\n", 1618 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3264 - accuracy: 0.8667\n", 1619 | "Epoch 62/100\n", 1620 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3400 - accuracy: 0.8600: 0s - loss:\n", 1621 | "Epoch 63/100\n", 1622 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3512 - accuracy: 0.8541: 0s - loss: 0.3544 - ac\n", 1623 | "Epoch 64/100\n", 1624 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3483 - accuracy: 0.8578\n", 1625 | "Epoch 65/100\n", 1626 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3294 - accuracy: 0.8683\n", 1627 | "Epoch 66/100\n", 1628 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3318 - accuracy: 0.8681\n", 1629 | "Epoch 67/100\n", 1630 | "250/250 [==============================] - 1s 4ms/step - loss: 0.3367 - accuracy: 0.8607\n", 1631 | "Epoch 68/100\n", 1632 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3358 - accuracy: 0.8667\n", 1633 | "Epoch 69/100\n", 1634 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3324 - accuracy: 0.8619\n", 1635 | "Epoch 70/100\n", 1636 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3338 - accuracy: 0.8634: 0s - loss:\n", 1637 | "Epoch 71/100\n", 1638 | "250/250 [==============================] - 3s 10ms/step - loss: 0.3324 - accuracy: 0.8639\n", 1639 | "Epoch 72/100\n", 1640 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3336 - accuracy: 0.8595\n", 1641 | "Epoch 73/100\n", 1642 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3276 - accuracy: 0.8686: 0s - loss: 0.3245 - \n", 1643 | "Epoch 74/100\n", 1644 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3334 - accuracy: 0.8647\n", 1645 | "Epoch 75/100\n", 1646 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3256 - accuracy: 0.8665\n", 1647 | "Epoch 76/100\n", 1648 | "250/250 [==============================] - 2s 6ms/step - loss: 0.3226 - accuracy: 0.8694\n", 1649 | "Epoch 77/100\n", 1650 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3385 - accuracy: 0.8571\n", 1651 | "Epoch 78/100\n", 1652 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3427 - accuracy: 0.8521: 0s - loss: 0.3433 - accuracy\n", 1653 | "Epoch 79/100\n", 1654 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3366 - accuracy: 0.8609: 0s\n", 1655 | "Epoch 80/100\n", 1656 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3314 - accuracy: 0.8625\n", 1657 | "Epoch 81/100\n", 1658 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3452 - accuracy: 0.8597\n", 1659 | "Epoch 82/100\n", 1660 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3302 - accuracy: 0.8609: 0s - loss: 0.3288 - \n", 1661 | "Epoch 83/100\n", 1662 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3370 - accuracy: 0.8608\n", 1663 | "Epoch 84/100\n", 1664 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3350 - accuracy: 0.8646\n", 1665 | "Epoch 85/100\n", 1666 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3524 - accuracy: 0.8557: 0s - loss: 0.3586 - accuracy: 0.85 - ETA: 0s - loss: 0.3\n", 1667 | "Epoch 86/100\n", 1668 | "250/250 [==============================] - 2s 9ms/step - loss: 0.3319 - accuracy: 0.8623\n", 1669 | "Epoch 87/100\n", 1670 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3357 - accuracy: 0.8626\n", 1671 | "Epoch 88/100\n", 1672 | "250/250 [==============================] - 2s 7ms/step - loss: 0.3367 - accuracy: 0.8612\n", 1673 | "Epoch 89/100\n", 1674 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3339 - accuracy: 0.8631\n", 1675 | "Epoch 90/100\n", 1676 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3390 - accuracy: 0.8546\n", 1677 | "Epoch 91/100\n", 1678 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3307 - accuracy: 0.8648: 1s - loss: 0.3297 - accu\n", 1679 | "Epoch 92/100\n", 1680 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3218 - accuracy: 0.8735\n", 1681 | "Epoch 93/100\n", 1682 | "250/250 [==============================] - 1s 2ms/step - loss: 0.3458 - accuracy: 0.8532\n", 1683 | "Epoch 94/100\n", 1684 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3362 - accuracy: 0.8622\n", 1685 | "Epoch 95/100\n", 1686 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3364 - accuracy: 0.8617\n", 1687 | "Epoch 96/100\n", 1688 | "250/250 [==============================] - 2s 8ms/step - loss: 0.3329 - accuracy: 0.8617\n", 1689 | "Epoch 97/100\n", 1690 | "250/250 [==============================] - 1s 5ms/step - loss: 0.3260 - accuracy: 0.8705\n", 1691 | "Epoch 98/100\n", 1692 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3452 - accuracy: 0.8588: 0s - loss: 0.3490 - ac\n", 1693 | "Epoch 99/100\n", 1694 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3383 - accuracy: 0.8617\n", 1695 | "Epoch 100/100\n", 1696 | "250/250 [==============================] - 1s 6ms/step - loss: 0.3389 - accuracy: 0.8612\n" 1697 | ] 1698 | }, 1699 | { 1700 | "data": { 1701 | "text/plain": [ 1702 | "" 1703 | ] 1704 | }, 1705 | "execution_count": 20, 1706 | "metadata": {}, 1707 | "output_type": "execute_result" 1708 | } 1709 | ], 1710 | "source": [ 1711 | "ann.fit(X_train, y_train, batch_size = 32, epochs = 100)" 1712 | ] 1713 | }, 1714 | { 1715 | "cell_type": "markdown", 1716 | "metadata": {}, 1717 | "source": [ 1718 | "# Parte 3 - Fazendo predições e avaliando o modelo\n", 1719 | "\n", 1720 | "## Prevendo os resultados com o conjunto de testes" 1721 | ] 1722 | }, 1723 | { 1724 | "cell_type": "code", 1725 | "execution_count": 21, 1726 | "metadata": {}, 1727 | "outputs": [ 1728 | { 1729 | "data": { 1730 | "text/plain": [ 1731 | "array([[0.3427144 ],\n", 1732 | " [0.3022073 ],\n", 1733 | " [0.1530391 ],\n", 1734 | " ...,\n", 1735 | " [0.19075722],\n", 1736 | " [0.138769 ],\n", 1737 | " [0.2750219 ]], dtype=float32)" 1738 | ] 1739 | }, 1740 | "execution_count": 21, 1741 | "metadata": {}, 1742 | "output_type": "execute_result" 1743 | } 1744 | ], 1745 | "source": [ 1746 | "y_pred = ann.predict(X_test)\n", 1747 | "y_pred" 1748 | ] 1749 | }, 1750 | { 1751 | "cell_type": "code", 1752 | "execution_count": 22, 1753 | "metadata": {}, 1754 | "outputs": [ 1755 | { 1756 | "data": { 1757 | "text/plain": [ 1758 | "array([[False],\n", 1759 | " [False],\n", 1760 | " [False],\n", 1761 | " ...,\n", 1762 | " [False],\n", 1763 | " [False],\n", 1764 | " [False]])" 1765 | ] 1766 | }, 1767 | "execution_count": 22, 1768 | "metadata": {}, 1769 | "output_type": "execute_result" 1770 | } 1771 | ], 1772 | "source": [ 1773 | "y_pred = (y_pred > 0.5)\n", 1774 | "y_pred" 1775 | ] 1776 | }, 1777 | { 1778 | "cell_type": "markdown", 1779 | "metadata": {}, 1780 | "source": [ 1781 | "# Criando uma Confusion Matrix" 1782 | ] 1783 | }, 1784 | { 1785 | "cell_type": "code", 1786 | "execution_count": 18, 1787 | "metadata": {}, 1788 | "outputs": [ 1789 | { 1790 | "name": "stdout", 1791 | "output_type": "stream", 1792 | "text": [ 1793 | "[[1519 76]\n", 1794 | " [ 200 205]]\n" 1795 | ] 1796 | } 1797 | ], 1798 | "source": [ 1799 | "from sklearn.metrics import confusion_matrix\n", 1800 | "cm = confusion_matrix(y_test, y_pred)\n", 1801 | "print(cm)" 1802 | ] 1803 | }, 1804 | { 1805 | "cell_type": "markdown", 1806 | "metadata": {}, 1807 | "source": [ 1808 | "# FIM" 1809 | ] 1810 | }, 1811 | { 1812 | "cell_type": "markdown", 1813 | "metadata": {}, 1814 | "source": [ 1815 | "> Professor Diego Dorgam \n", 1816 | "> [@diegodorgam](https://twitter.com/diegodorgam)" 1817 | ] 1818 | } 1819 | ], 1820 | "metadata": { 1821 | "kernelspec": { 1822 | "display_name": "Python 3", 1823 | "language": "python", 1824 | "name": "python3" 1825 | }, 1826 | "language_info": { 1827 | "codemirror_mode": { 1828 | "name": "ipython", 1829 | "version": 3 1830 | }, 1831 | "file_extension": ".py", 1832 | "mimetype": "text/x-python", 1833 | "name": "python", 1834 | "nbconvert_exporter": "python", 1835 | "pygments_lexer": "ipython3", 1836 | "version": "3.8.6" 1837 | } 1838 | }, 1839 | "nbformat": 4, 1840 | "nbformat_minor": 4 1841 | } 1842 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Construindo uma Rede Neural Artificial 2 | 3 | Construindo sua primeira rede neural usando Keras, Theano e Tensorflow 4 | Parte da matéria de DeepLearning da escola de Engenharia de Software da UnB 5 | 6 | ## Usage 7 | 8 | ### jupyter-notebook 9 | 10 | Open a `jupyter-notebook` in your desired `env` and run the command: 11 | 12 | ```sh 13 | jupyter-notebook ArtificialNeuralNetwork.ipynb 14 | ``` 15 | ### Python module 16 | 17 | Install the requirements in the desired `env`: 18 | 19 | ```sh 20 | pip install -r requirements.txt 21 | ``` 22 | 23 | Run the python module and make changes in a file editor: 24 | 25 | ```sh 26 | python ann.py 27 | ``` 28 | 29 | or 30 | 31 | ```sh 32 | python evaluating_improving_tuning.py 33 | ``` 34 | 35 | -------------------------------------------------------------------------------- /ann.py: -------------------------------------------------------------------------------- 1 | # Artificial Neural Network 2 | 3 | # Installing Theano 4 | # pip install --upgrade --no-deps git+git://github.com/Theano/Theano.git 5 | 6 | # Installing Tensorflow 7 | # pip install tensorflow 8 | 9 | # Installing Keras 10 | # pip install --upgrade keras 11 | 12 | # Part 1 - Data Preprocessing 13 | 14 | # Importing the libraries 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import pandas as pd 18 | 19 | # Importing the dataset 20 | dataset = pd.read_csv('Churn_Modelling.csv') 21 | X = dataset.iloc[:, 3:13].values 22 | y = dataset.iloc[:, 13].values 23 | 24 | # Encoding categorical data 25 | from sklearn.preprocessing import LabelEncoder, OneHotEncoder 26 | labelencoder_X_1 = LabelEncoder() 27 | X[:, 1] = labelencoder_X_1.fit_transform(X[:, 1]) 28 | labelencoder_X_2 = LabelEncoder() 29 | X[:, 2] = labelencoder_X_2.fit_transform(X[:, 2]) 30 | onehotencoder = OneHotEncoder(categorical_features = [1]) 31 | X = onehotencoder.fit_transform(X).toarray() 32 | X = X[:, 1:] 33 | 34 | # Splitting the dataset into the Training set and Test set 35 | from sklearn.model_selection import train_test_split 36 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0) 37 | 38 | # Feature Scaling 39 | from sklearn.preprocessing import StandardScaler 40 | sc = StandardScaler() 41 | X_train = sc.fit_transform(X_train) 42 | X_test = sc.transform(X_test) 43 | 44 | # Part 2 - Now let's make the ANN! 45 | 46 | # Importing the Keras libraries and packages 47 | import keras 48 | from keras.models import Sequential 49 | from keras.layers import Dense 50 | 51 | # Initialising the ANN 52 | classifier = Sequential() 53 | 54 | # Adding the input layer and the first hidden layer 55 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu', input_dim = 11)) 56 | 57 | # Adding the second hidden layer 58 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu')) 59 | 60 | # Adding the output layer 61 | classifier.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid')) 62 | 63 | # Compiling the ANN 64 | classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy']) 65 | 66 | # Fitting the ANN to the Training set 67 | classifier.fit(X_train, y_train, batch_size = 10, epochs = 30) 68 | 69 | # Part 3 - Making predictions and evaluating the model 70 | 71 | # Predicting the Test set results 72 | y_pred = classifier.predict(X_test) 73 | y_pred = (y_pred > 0.5) 74 | 75 | # Making the Confusion Matrix 76 | from sklearn.metrics import confusion_matrix 77 | cm = confusion_matrix(y_test, y_pred) -------------------------------------------------------------------------------- /evaluating_improving_tuning.py: -------------------------------------------------------------------------------- 1 | # Artificial Neural Network 2 | 3 | # Installing Theano 4 | # pip install --upgrade --no-deps git+git://github.com/Theano/Theano.git 5 | 6 | # Installing Tensorflow 7 | # pip install tensorflow 8 | 9 | # Installing Keras 10 | # pip install --upgrade keras 11 | 12 | # Part 1 - Data Preprocessing 13 | 14 | # Importing the libraries 15 | import numpy as np 16 | import matplotlib.pyplot as plt 17 | import pandas as pd 18 | 19 | # Importing the dataset 20 | dataset = pd.read_csv('Churn_Modelling.csv') 21 | X = dataset.iloc[:, 3:13].values 22 | y = dataset.iloc[:, 13].values 23 | 24 | # Encoding categorical data 25 | from sklearn.preprocessing import LabelEncoder, OneHotEncoder 26 | labelencoder_X_1 = LabelEncoder() 27 | X[:, 1] = labelencoder_X_1.fit_transform(X[:, 1]) 28 | labelencoder_X_2 = LabelEncoder() 29 | X[:, 2] = labelencoder_X_2.fit_transform(X[:, 2]) 30 | onehotencoder = OneHotEncoder(categorical_features = [1]) 31 | X = onehotencoder.fit_transform(X).toarray() 32 | X = X[:, 1:] 33 | 34 | # Splitting the dataset into the Training set and Test set 35 | from sklearn.model_selection import train_test_split 36 | X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = 0.2, random_state = 0) 37 | 38 | # Feature Scaling 39 | from sklearn.preprocessing import StandardScaler 40 | sc = StandardScaler() 41 | X_train = sc.fit_transform(X_train) 42 | X_test = sc.transform(X_test) 43 | 44 | # Part 2 - Now let's make the ANN! 45 | 46 | # Importing the Keras libraries and packages 47 | import keras 48 | from keras.models import Sequential 49 | from keras.layers import Dense 50 | from keras.layers import Dropout 51 | 52 | # Initialising the ANN 53 | classifier = Sequential() 54 | 55 | # Adding the input layer and the first hidden layer 56 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu', input_dim = 11)) 57 | # classifier.add(Dropout(rate = 0.1)) 58 | 59 | # Adding the second hidden layer 60 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu')) 61 | # classifier.add(Dropout(rate = 0.1)) 62 | 63 | # Adding the output layer 64 | classifier.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid')) 65 | 66 | # Compiling the ANN 67 | classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy']) 68 | 69 | # Fitting the ANN to the Training set 70 | classifier.fit(X_train, y_train, batch_size = 10, epochs = 100) 71 | 72 | # Part 3 - Making predictions and evaluating the model 73 | 74 | # Predicting the Test set results 75 | y_pred = classifier.predict(X_test) 76 | y_pred = (y_pred > 0.5) 77 | 78 | # Predicting a single new observation 79 | """Predict if the customer with the following informations will leave the bank: 80 | Geography: France 81 | Credit Score: 600 82 | Gender: Male 83 | Age: 40 84 | Tenure: 3 85 | Balance: 60000 86 | Number of Products: 2 87 | Has Credit Card: Yes 88 | Is Active Member: Yes 89 | Estimated Salary: 50000""" 90 | new_prediction = classifier.predict(sc.transform(np.array([[0.0, 0, 600, 1, 40, 3, 60000, 2, 1, 1, 50000]]))) 91 | new_prediction = (new_prediction > 0.5) 92 | 93 | # Making the Confusion Matrix 94 | from sklearn.metrics import confusion_matrix 95 | cm = confusion_matrix(y_test, y_pred) 96 | 97 | # Part 4 - Evaluating, Improving and Tuning the ANN 98 | 99 | # Evaluating the ANN 100 | from keras.wrappers.scikit_learn import KerasClassifier 101 | from sklearn.model_selection import cross_val_score 102 | from keras.models import Sequential 103 | from keras.layers import Dense 104 | #from keras.layers import Dropout 105 | def build_classifier(): 106 | classifier = Sequential() 107 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu', input_dim = 11)) 108 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu')) 109 | classifier.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid')) 110 | classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy']) 111 | return classifier 112 | 113 | classifier = KerasClassifier(build_fn = build_classifier, batch_size = 10, epochs = 10) 114 | accuracies = cross_val_score(estimator = classifier, X = X_train, y = y_train, cv = 10, n_jobs = -1) 115 | mean = accuracies.mean() 116 | variance = accuracies.std() 117 | 118 | # Improving the ANN 119 | # Dropout Regularization to reduce overfitting if needed 120 | from keras.layers import Dropout 121 | 122 | def build_classifier(): 123 | classifier = Sequential() 124 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu', input_dim = 11)) 125 | classifier.add(Dropout(rate = 0.1)) 126 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu')) 127 | classifier.add(Dropout(rate = 0.1)) 128 | classifier.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid')) 129 | classifier.compile(optimizer = 'adam', loss = 'binary_crossentropy', metrics = ['accuracy']) 130 | return classifier 131 | 132 | classifier = KerasClassifier(build_fn = build_classifier, batch_size = 10, epochs = 10) 133 | accuracies = cross_val_score(estimator = classifier, X = X_train, y = y_train, cv = 10, n_jobs = -1) 134 | mean = accuracies.mean() 135 | variance = accuracies.std() 136 | 137 | # Tuning the ANN 138 | from keras.wrappers.scikit_learn import KerasClassifier 139 | from sklearn.model_selection import GridSearchCV 140 | from keras.models import Sequential 141 | from keras.layers import Dense 142 | 143 | def build_classifier(optimizer): 144 | classifier = Sequential() 145 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu', input_dim = 11)) 146 | classifier.add(Dropout(rate = 0.1)) 147 | classifier.add(Dense(units = 6, kernel_initializer = 'uniform', activation = 'relu')) 148 | classifier.add(Dropout(rate = 0.1)) 149 | classifier.add(Dense(units = 1, kernel_initializer = 'uniform', activation = 'sigmoid')) 150 | classifier.compile(optimizer = optimizer, loss = 'binary_crossentropy', metrics = ['accuracy']) 151 | return classifier 152 | 153 | classifier = KerasClassifier(build_fn = build_classifier) 154 | 155 | parameters = {'batch_size': [10, 25, 32], 156 | 'epochs': [100, 500], 157 | 'optimizer': ['adam', 'rmsprop']} 158 | 159 | grid_search = GridSearchCV(estimator = classifier, 160 | scoring = 'accuracy', 161 | param_grid = parameters, 162 | cv = 10) 163 | 164 | grid_search = grid_search.fit(X_train, y_train) 165 | 166 | best_parameters = grid_search.best_params_ 167 | best_accuracy = grid_search.best_score_ -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | numpy==1.16.2 2 | pandas==0.24.2 3 | matplotlib==3.0.3 4 | sklearn==0.0 5 | tensorflow==2.3.1 6 | --------------------------------------------------------------------------------