├── .DS_Store ├── PointsCoordinates.csv ├── README.md ├── Section-7-word2vec_model.bin ├── Section_2_1_unsupervised.ipynb ├── Section_2_2_DBSCANvsk-Means.ipynb ├── Section_3_1-Supervised-Learning.py ├── Section_3_2-Gradient-Descent-Small-Step.py ├── Section_3_3-Gradient-Descent-Big-Step.py ├── Section_3_4_SVM.ipynb ├── Section_6-TicTactoe ├── .DS_Store ├── Game.jpg ├── Game.py ├── Play Dumb Agent.py ├── Play Q-Learning.py ├── QLearning.py ├── README.md ├── Readme.txt ├── Train.py ├── player1states └── player2states ├── Section_7-Word-Embedding-3D.py ├── Section_7-Word-Embeddings.ipynb └── Section_7-Word2Vec.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/.DS_Store -------------------------------------------------------------------------------- /PointsCoordinates.csv: -------------------------------------------------------------------------------- 1 | Shape Name,Center X,Center Y 2 | Oval 4,194.5258,297.6739 3 | Oval 5,221.3741,406.0962 4 | Oval 6,256.5606,248.1152 5 | Oval 7,267.4167,253.2322 6 | Oval 8,271.9066,248.1152 7 | Oval 9,278.7228,245.95 8 | Oval 10,239.1838,393.42 9 | Oval 11,246.8677,365.775 10 | Oval 12,249.194,353.9753 11 | Oval 13,267.9001,334.8935 12 | Oval 14,274.3933,329.4432 13 | Oval 15,281.0491,335.1558 14 | Oval 16,261.6967,352.8928 15 | Oval 17,256.8146,350.8774 16 | Oval 18,272.3218,294.0628 17 | Oval 19,272.7131,302.5286 18 | Oval 20,270.2218,365.726 19 | Oval 21,273.7727,367.2584 20 | Oval 22,276.8045,368.7908 21 | Oval 23,281.9763,371.2182 22 | Oval 24,283.8144,349.0555 23 | Oval 25,292.0991,358.83 24 | Oval 26,286.3121,364.6177 25 | Oval 27,267.4167,395.487 26 | Oval 28,281.9763,398.2527 27 | Oval 29,279.65,403.4066 28 | Oval 30,274.4782,412.9363 29 | Oval 31,299.9237,403.4066 30 | Oval 32,296.7516,406.0962 31 | Oval 33,292.8745,401.1091 32 | Oval 34,294.4254,410.9509 33 | Oval 35,304.5764,386.0987 34 | Oval 36,305.0198,391.3665 35 | Oval 37,316.9721,395.8841 36 | Oval 38,315.3231,388.526 37 | Oval 39,315.4881,381.711 38 | Oval 40,297.5974,373.1844 39 | Oval 41,303.3864,367.6555 40 | Oval 42,292.0991,337.5194 41 | Oval 43,301.3309,349.7333 42 | Oval 44,316.5211,347.3818 43 | Oval 45,305.2245,353.4259 44 | Oval 46,317.6859,364.6177 45 | Oval 47,325.7565,362.2403 46 | Oval 48,332.4338,368.3297 47 | Oval 49,327.7812,375.6118 48 | Oval 50,216.6979,453.3619 49 | Oval 51,258.8633,445.3986 50 | Oval 52,303.6572,454.4445 51 | Oval 53,306.0122,431.1167 52 | Oval 54,301.9053,446.7493 53 | Oval 55,301.7006,441.2393 54 | Oval 56,315.6225,410.4041 55 | Oval 57,311.8685,419.7925 56 | Oval 58,317.6859,422.7141 57 | Oval 59,313.7828,442.9712 58 | Oval 60,334.7602,444.2314 59 | Oval 61,325.6603,444.2314 60 | Oval 62,325.455,461.7011 61 | Oval 63,341.6745,466.2038 62 | Oval 64,326.1121,420.2719 63 | Oval 65,329.061,426.262 64 | Oval 66,333.6398,423.9317 65 | Oval 67,347.7356,449.912 66 | Oval 68,360.5498,446.6587 67 | Oval 69,353.9947,449.8829 68 | Oval 70,351.7995,436.3845 69 | Oval 71,405.3248,468.6311 70 | Oval 72,381.4155,481.8091 71 | Oval 73,386.0681,479.704 72 | Oval 74,374.5059,440.1919 73 | Oval 75,380.7457,451.2568 74 | Oval 76,373.8128,452.6452 75 | Oval 77,372.1796,459.2737 76 | Oval 78,340.7199,414.1104 77 | Oval 79,333.2341,414.1104 78 | Oval 80,340.3558,419.077 79 | Oval 81,326.3224,401.2421 80 | Oval 82,332.9103,404.0165 81 | Oval 83,340.3558,409.1437 82 | Oval 84,330.1075,384.1013 83 | Oval 85,328.1628,389.4101 84 | Oval 86,333.4267,392.5907 85 | Oval 87,337.7887,391.8586 86 | Oval 88,344.3384,395.257 87 | Oval 89,358.7986,407.9767 88 | Oval 90,369.8533,423.8347 89 | Oval 91,394.7828,437.7645 90 | Oval 92,425.8002,479.704 91 | Oval 93,441.309,474.6173 92 | Oval 94,441.309,486.3119 93 | Oval 95,458.2393,470.4786 94 | Oval 96,466.362,477.499 95 | Oval 97,428.5517,458.2166 96 | Oval 98,448.1469,464.2456 97 | Oval 99,305.5764,328.8101 98 | Oval 100,297.7513,323.1573 99 | Oval 101,302.2788,336.6659 100 | Oval 102,388.3945,422.7141 101 | Oval 103,391.385,426.262 102 | Oval 104,292.9033,288.4395 103 | Oval 105,294.5792,279.7856 104 | Oval 106,304.2317,318.0323 105 | Oval 107,306.558,321.5575 106 | Oval 108,316.5211,319.8595 107 | Oval 109,327.5685,323.1573 108 | Oval 110,328.5147,330.049 109 | Oval 111,340.8542,384.9154 110 | Oval 112,299.4298,275.6598 111 | Oval 113,299.4298,300.7245 112 | Oval 114,308.3385,311.0182 113 | Oval 115,326.1121,311.7204 114 | Oval 116,337.9106,330.0388 115 | Oval 117,342.8602,331.8705 116 | Oval 118,334.0376,348.8384 117 | Oval 119,360.8294,405.5493 118 | Oval 120,335.5604,354.7435 119 | Oval 121,349.7103,371.6595 120 | Oval 122,305.5068,293.477 121 | Oval 123,341.3839,358.2212 122 | Oval 124,344.7142,355.4899 123 | Oval 125,350.1722,351.1623 124 | Oval 126,344.3384,376.2097 125 | Oval 127,347.8459,374.0868 126 | Oval 128,305.7387,299.2159 127 | Oval 129,314.7997,300.3108 128 | Oval 130,315.3231,293.8366 129 | Oval 131,319.2984,291.6354 130 | Oval 132,321.9035,300.1012 131 | Oval 133,325.455,296.4435 132 | Oval 134,295.2008,264.2908 133 | Oval 135,299.0779,267.2104 134 | Oval 136,304.2317,273.8878 135 | Oval 137,306.3301,279.6981 136 | Oval 138,375.789,317.4322 137 | Oval 139,360.397,362.5589 138 | Oval 140,358.2235,388.065 139 | Oval 141,356.5516,393.0892 140 | Oval 142,362.0768,374.0868 141 | Oval 143,375.9795,403.1075 142 | Oval 144,371.0161,405.5493 143 | Oval 145,363.1953,401.2835 144 | Oval 146,366.6409,391.8586 145 | Oval 147,368.4493,407.0655 146 | Oval 148,364.1124,370.7973 147 | Oval 149,365.8961,381.244 148 | Oval 150,368.5362,374.8761 149 | Oval 151,373.6423,381.244 150 | Oval 152,377.685,375.6118 151 | Oval 153,382.8852,399.8819 152 | Oval 154,340.9798,322.6322 153 | Oval 155,345.1698,325.8021 154 | Oval 156,347.0405,331.1712 155 | Oval 157,304.4185,238.1819 156 | Oval 158,307.2849,242.7688 157 | Oval 159,304.0825,250.8544 158 | Oval 160,328.9872,304.9153 159 | Oval 161,332.1635,302.2241 160 | Oval 162,338.0646,315.5327 161 | Oval 163,342.3879,313.4187 162 | Oval 164,303.3864,250.2765 163 | Oval 165,315.3231,270.8051 164 | Oval 166,320.6495,282.6422 165 | Oval 167,336.7195,303.8926 166 | Oval 168,339.2332,306.8657 167 | Oval 169,358.2235,328.8101 168 | Oval 170,352.3367,318.6414 169 | Oval 171,351.4855,315.0048 170 | Oval 172,331.3134,281.1574 171 | Oval 173,334.4038,282.6414 172 | Oval 174,344.7142,305.6429 173 | Oval 175,352.9679,308.5511 174 | Oval 176,358.6474,308.5908 175 | Oval 177,365.6869,310.6795 176 | Oval 178,316.5211,251.562 177 | Oval 179,320.8024,260.5217 178 | Oval 180,318.4354,266.6178 179 | Oval 181,322.8583,267.4882 180 | Oval 182,325.1949,271.987 181 | Oval 183,331.0691,270.8538 182 | Oval 184,342.3879,283.7414 183 | Oval 185,349.2234,294.1823 184 | Oval 186,347.384,287.497 185 | Oval 187,353.9947,287.3157 186 | Oval 188,359.6216,293.2942 187 | Oval 189,362.8761,299.2615 188 | Oval 190,327.1781,242.2493 189 | Oval 191,301.3309,211.1485 190 | Oval 192,316.8396,227.331 191 | Oval 193,302.25,115.632 192 | Oval 194,295.0612,112.3598 193 | Oval 195,333.2341,186.1181 194 | Oval 196,333.2674,147.9969 195 | Oval 197,348.4026,170.2478 196 | Oval 198,356.9097,182.2281 197 | Oval 199,360.337,194.0278 198 | Oval 200,341.8996,221.6453 199 | Oval 201,337.7887,232.7874 200 | Oval 202,337.8607,238.4267 201 | Oval 203,375.1995,159.6126 202 | Oval 204,372.1796,130.2819 203 | Oval 205,505.9598,38.37374 204 | Oval 206,441.309,45.66122 205 | Oval 207,463.2277,52.06681 206 | Oval 208,471.0146,53.41531 207 | Oval 209,499.4665,87.12878 208 | Oval 210,516.2676,132.6419 209 | Oval 211,535.0073,209.1713 210 | Oval 212,541.4693,181.2634 211 | Oval 213,544.5389,195.0117 212 | Oval 214,438.4977,54.42673 213 | Oval 215,423.4739,60.15799 214 | Oval 216,385.3983,99.70508 215 | Oval 217,396.0377,101.5069 216 | Oval 218,399.4354,97.1387 217 | Oval 219,409.9774,95.46822 218 | Oval 220,423.5706,84.48531 219 | Oval 221,452.7995,78.16035 220 | Oval 222,471.0146,70.53783 221 | Oval 223,473.4778,75.73295 222 | Oval 224,478.9705,94.40807 223 | Oval 225,490.602,112.4806 224 | Oval 226,505.7876,142.1614 225 | Oval 227,508.0493,147.8797 226 | Oval 228,503.5259,151.9253 227 | Oval 229,517.4191,164.0622 228 | Oval 230,476.1389,116.7966 229 | Oval 231,480.0161,122.1907 230 | Oval 232,476.462,105.6523 231 | Oval 233,480.3391,131.6304 232 | Oval 234,443.6673,100.9883 233 | Oval 235,454.5763,126.2363 234 | Oval 236,463.7446,123.5393 235 | Oval 237,455.913,119.4937 236 | Oval 238,443.6353,126.9106 237 | Oval 239,455.913,108.0797 238 | Oval 240,416.7672,116.1223 239 | Oval 241,436.6563,126.2363 240 | Oval 242,425.8002,134.6647 241 | Oval 243,399.4354,130.2819 242 | Oval 244,413.5315,136.6875 243 | Oval 245,406.9167,139.734 244 | Oval 246,448.1469,132.979 245 | Oval 247,443.6353,135.6761 246 | Oval 248,494.9568,144.5887 247 | Oval 249,447.7356,139.734 248 | Oval 250,471.0146,144.5887 249 | Oval 251,492.372,144.5887 250 | Oval 252,440.4453,139.734 251 | Oval 253,436.245,144.5887 252 | Oval 254,443.4943,149.1615 253 | Oval 255,455.913,147.8129 254 | Oval 256,482.162,149.4986 255 | Oval 257,471.0146,151.1842 256 | Oval 258,471.0146,157.2526 257 | Oval 259,482.162,154.2185 258 | Oval 260,499.4802,161.9726 259 | Oval 261,490.1104,167.7038 260 | Oval 262,466.362,159.2755 261 | Oval 263,461.7094,162.3096 262 | Oval 264,431.5923,148.4872 263 | Oval 265,428.5517,151.8586 264 | Oval 266,423.4739,153.207 265 | Oval 267,445.9616,153.207 266 | Oval 268,450.4732,151.1842 267 | Oval 269,450.4732,161.6354 268 | Oval 270,453.5867,158.6012 269 | Oval 271,443.6353,161.6354 270 | Oval 272,433.2043,165.681 271 | Oval 273,458.2393,173.098 272 | Oval 274,471.0146,172.4237 273 | Oval 275,504.0035,197.3716 274 | Oval 276,416.444,156.9156 275 | Oval 277,497.5417,186.1181 276 | Oval 278,502.0649,193.6631 277 | Oval 279,496.8954,192.9889 278 | Oval 280,506.9114,181.2634 279 | Oval 281,466.362,177.4808 280 | Oval 282,471.8957,181.2634 281 | Oval 283,462.8919,183.6908 282 | Oval 284,474.9247,183.6908 283 | Oval 285,479.7268,186.1181 284 | Oval 286,477.5095,191.3032 285 | Oval 287,479.8357,197.3716 286 | Oval 288,516.9275,214.2283 287 | Oval 289,486.8794,200.4059 288 | Oval 290,491.4027,206.8114 289 | Oval 291,506.2653,204.4515 290 | Oval 292,497.8647,201.0801 291 | Oval 293,490.1104,198.7202 292 | Oval 294,452.7995,183.6908 293 | Oval 295,471.0146,188.6061 294 | Oval 296,443.6353,181.2634 295 | Oval 297,468.6884,191.9774 296 | Oval 298,475.571,196.0231 297 | Oval 299,471.932,201.4172 298 | Oval 300,460.5656,196.3602 299 | Oval 301,455.913,194.0003 300 | Oval 302,460.5656,191.6404 301 | Oval 303,484.4883,207.1485 302 | Oval 304,494.3106,210.1827 303 | Oval 305,504.3267,213.8912 304 | Oval 306,386.0681,181.2634 305 | Oval 307,382.3376,186.1181 306 | Oval 308,388.3945,190.9661 307 | Oval 309,445.9616,197.3716 308 | Oval 310,464.0357,203.1029 309 | Oval 311,468.6884,201.0801 310 | Oval 312,470.9308,206.1371 311 | Oval 313,477.1865,212.5426 312 | Oval 314,404.4611,204.1143 313 | Oval 315,397.6671,190.6289 314 | Oval 316,433.2043,181.2634 315 | Oval 317,501.0957,224.6795 316 | Oval 318,462.783,209.8456 317 | Oval 319,487.8487,211.8684 318 | Oval 320,508.5269,219.9597 319 | Oval 321,519.8353,241.0953 320 | Oval 322,417.0902,171.4123 321 | Oval 323,420.7363,173.7722 322 | Oval 324,428.1265,178.155 323 | Oval 325,468.6884,214.2283 324 | Oval 326,479.0159,229.0623 325 | Oval 327,492.372,232.4336 326 | Oval 328,525.9742,250.5425 327 | Oval 329,520.4816,252.9699 328 | Oval 330,377.2787,203.7772 329 | Oval 331,400.6722,166.3553 330 | Oval 332,418.7057,178.8293 331 | Oval 333,448.1469,204.1143 332 | Oval 334,486.8794,241.0953 333 | Oval 335,466.362,220.6339 334 | Oval 336,523.3895,273.2324 335 | Oval 337,529.5283,273.2324 336 | Oval 338,518.2199,291.0496 337 | Oval 339,546.8652,295.1073 338 | Oval 340,369.1602,215.2397 339 | Oval 341,397.1091,181.2634 340 | Oval 342,418.7057,178.8293 341 | Oval 343,468.6884,227.0394 342 | Oval 344,489.7872,245.95 343 | Oval 345,465.2318,227.3766 344 | Oval 346,484.4883,246.6769 345 | Oval 347,496.2492,259.7415 346 | Oval 348,399.4354,176.8064 347 | Oval 349,405.3248,173.7722 348 | Oval 350,411.2461,174.1093 349 | Oval 351,409.9774,183.6908 350 | Oval 352,409.9774,190.6289 351 | Oval 353,416.7388,194.6745 352 | Oval 354,442.4161,208.1599 353 | Oval 355,421.1476,197.3716 354 | Oval 356,425.8002,205.8 355 | Oval 357,428.1265,210.1827 356 | Oval 358,432.4492,207.4856 357 | Oval 359,460.5656,233.1078 358 | Oval 360,436.6563,213.5541 359 | Oval 361,443.4943,214.5655 360 | Oval 362,451.7213,214.2283 361 | Oval 363,458.2393,222.6567 362 | Oval 364,438.9826,219.6225 363 | Oval 365,445.8206,220.6339 364 | Oval 366,449.072,222.9938 365 | Oval 367,468.3133,232.4336 366 | Oval 368,421.1476,209.8456 367 | Oval 369,450.4732,231.4222 368 | Oval 370,452.3676,237.1534 369 | Oval 371,474.2502,248.1152 370 | Oval 372,388.3945,209.8456 371 | Oval 373,433.2043,221.9824 372 | Oval 374,460.9543,245.6878 373 | Oval 375,463.9838,250.5425 374 | Oval 376,474.5734,252.9699 375 | Oval 377,482.162,252.9699 376 | Oval 378,484.3794,251.4642 377 | Oval 379,362.9003,229.0623 378 | Oval 380,491.6975,266.147 379 | Oval 381,489.1128,261.4271 380 | Oval 382,484.1132,258.0869 381 | Oval 383,492.6668,258.0869 382 | Oval 384,482.162,306.7016 383 | Oval 385,393.7113,222.6567 384 | Oval 386,391.385,232.4336 385 | Oval 387,379.9906,228.0508 386 | Oval 388,383.7418,231.0231 387 | Oval 389,354.8248,243.2604 388 | Oval 390,371.4865,233.445 389 | Oval 391,377.3226,236.4792 390 | Oval 392,380.7457,238.1649 391 | Oval 393,373.8128,241.0953 392 | Oval 394,372.9456,245.95 393 | Oval 395,333.1361,252.9699 394 | Oval 396,337.7887,253.5386 395 | Oval 397,339.4128,260.9238 396 | Oval 398,338.6902,266.2489 397 | Oval 399,350.7742,256.3189 398 | Oval 400,357.9507,241.0953 399 | Oval 401,365.2267,258.0869 400 | Oval 402,350.1722,265.94 401 | Oval 403,354.9198,276.2645 402 | Oval 404,360.5498,274.543 403 | Oval 405,362.9003,267.8326 404 | Oval 406,379.1585,260.7529 405 | Oval 407,376.1391,267.6129 406 | Oval 408,373.097,264.1891 407 | Oval 409,367.553,268.1698 408 | Oval 410,369.8533,275.6598 409 | Oval 411,513.022,317.3914 410 | Oval 412,390.1302,251.4729 411 | Oval 413,517.8685,264.7985 412 | Oval 414,397.3879,243.5227 413 | Oval 415,511.0835,264.4613 414 | Oval 416,402.9985,233.445 415 | Oval 417,496.8672,278.0872 416 | Oval 418,503.6522,274.9781 417 | Oval 419,507.2063,266.8213 418 | Oval 420,397.1091,258.0869 419 | Oval 421,402.9985,258.0869 420 | Oval 422,398.6804,254.5658 421 | Oval 423,403.8082,248.1152 422 | Oval 424,478.5618,290.8668 423 | Oval 425,479.8357,283.5848 424 | Oval 426,480.0498,280.5145 425 | Oval 427,484.4883,277.3582 426 | Oval 428,466.362,298.9175 427 | Oval 429,452.7995,295.9043 428 | Oval 430,468.6884,281.8482 429 | Oval 431,468.6884,288.4395 430 | Oval 432,376.8322,273.2324 431 | Oval 433,383.7418,279.7856 432 | Oval 434,378.4194,281.1193 433 | Oval 435,377.1035,286.0121 434 | Oval 436,374.5059,292.8192 435 | Oval 437,376.8322,297.6739 436 | Oval 438,377.3412,302.5286 437 | Oval 439,380.7457,297.882 438 | Oval 440,383.0721,294.0628 439 | Oval 441,400.6722,266.8213 440 | Oval 442,408.8992,267.4955 441 | Oval 443,393.7113,273.2324 442 | Oval 444,416.0927,270.8051 443 | Oval 445,413.5079,275.6598 444 | Oval 446,401.9203,282.6422 445 | Oval 447,402.6235,290.8668 446 | Oval 448,405.3248,227.7137 447 | Oval 449,417.385,266.8213 448 | Oval 450,455.913,281.7597 449 | Oval 451,445.7687,291.6354 450 | Oval 452,407.6511,245.6878 451 | Oval 453,430.878,278.0872 452 | Oval 454,440.5539,274.916 453 | Oval 455,438.9826,278.9744 454 | Oval 456,425.8002,273.2324 455 | Oval 457,423.4739,264.8851 456 | Oval 458,430.878,262.4578 457 | Oval 459,434.1217,270.8051 458 | Oval 460,407.6511,223.6681 459 | Oval 461,414.6643,255.6595 460 | Oval 462,421.3618,258.0869 461 | Oval 463,443.4424,265.4727 462 | Oval 464,416.9093,250.5425 463 | Oval 465,448.1469,273.2324 464 | Oval 466,446.681,259.0834 465 | Oval 467,456.895,270.8051 466 | Oval 468,409.9774,228.7251 467 | Oval 469,415.1234,226.028 468 | Oval 470,428.4428,251.1271 469 | Oval 471,449.1289,263.4499 470 | Oval 472,420.3173,243.5227 471 | Oval 473,421.1476,236.1421 472 | Oval 474,438.9826,250.5425 473 | Oval 475,469.9934,270.8051 474 | Oval 476,425.8002,243.5227 475 | Oval 477,453.6931,253.2322 476 | Oval 478,455.5573,245.95 477 | Oval 479,443.6353,243.5227 478 | Oval 480,445.9616,250.8048 479 | Oval 481,450.9872,252.6616 480 | Oval 482,409.9774,215.914 481 | Oval 483,425.8002,234.4564 482 | Oval 484,423.4739,223.331 483 | Oval 485,430.4528,228.7251 484 | Oval 486,413.1848,214.2283 485 | Oval 487,433.2043,234.1193 486 | Oval 488,419.9698,220.2967 487 | Oval 489,413.8309,222.3196 488 | Oval 490,427.4076,292.8192 489 | Oval 491,486.8164,319.3231 490 | Oval 492,497.8018,319.9121 491 | Oval 493,504.2638,322.212 492 | Oval 494,514.3251,380.7104 493 | Oval 495,407.9427,300.7824 494 | Oval 496,421.1476,298.3971 495 | Oval 497,425.8002,300.1012 496 | Oval 498,438.9826,304.956 497 | Oval 499,443.4943,317.4847 498 | Oval 500,407.6511,304.956 499 | Oval 501,416.3812,309.293 500 | Oval 502,422.4311,313.1776 501 | Oval 503,462.8919,323.0824 502 | Oval 504,484.3598,334.2386 503 | Oval 505,402.6439,309.3 504 | Oval 506,435.1252,308.1094 505 | Oval 507,430.878,311.0182 506 | Oval 508,433.0718,317.4322 507 | Oval 509,476.8005,337.9972 508 | Oval 510,468.6626,339.2794 509 | Oval 511,471.0146,353.0625 510 | Oval 512,480.1351,360.8712 511 | Oval 513,494.2477,372.7725 512 | Oval 514,432.0403,330.7199 513 | Oval 515,402.9985,315.6049 514 | Oval 516,397.1091,315.6049 515 | Oval 517,402.1438,321.6648 516 | Oval 518,489.0782,384.3818 517 | Oval 519,400.2347,333.6717 518 | Oval 520,423.4739,339.9468 519 | Oval 521,427.4428,339.9468 520 | Oval 522,445.8206,339.0933 521 | Oval 523,450.4732,339.9468 522 | Oval 524,453.5867,343.3163 523 | Oval 525,413.1307,342.5586 524 | Oval 526,421.1476,344.5266 525 | Oval 527,460.5656,347.489 526 | Oval 528,431.7942,350.3412 527 | Oval 529,460.5656,359.813 528 | Oval 530,458.2393,353.9753 529 | Oval 531,472.4017,378.9415 530 | Oval 532,473.3946,374.0868 531 | Oval 533,471.0683,370.7571 532 | Oval 534,382.9138,311.7204 533 | Oval 535,387.5664,313.4455 534 | Oval 536,384.3408,317.9084 535 | Oval 537,442.6364,352.4808 536 | Oval 538,450.4732,346.4111 537 | Oval 539,365.3305,319.047 538 | Oval 540,449.3107,351.0181 539 | Oval 541,449.8861,355.9416 540 | Oval 542,450.7413,360.8712 541 | Oval 543,452.7995,364.6177 542 | Oval 544,371.6614,321.5575 543 | Oval 545,427.9918,347.3059 544 | Oval 546,437.4726,359.9926 545 | Oval 547,434.0197,364.7249 546 | Oval 548,464.0357,365.726 547 | Oval 549,385.9608,325.5281 548 | Oval 550,415.3539,358.3487 549 | Oval 551,461.7094,369.525 550 | Oval 552,464.0357,371.6595 551 | Oval 553,438.1692,366.2711 552 | Oval 554,389.3503,328.817 553 | Oval 555,447.1729,365.7574 554 | Oval 556,438.9826,371.6595 555 | Oval 557,453.2139,371.4296 556 | Oval 558,479.8357,383.6713 557 | Oval 559,373.6539,328.012 558 | Oval 560,420.477,363.269 559 | Oval 561,428.4261,367.1945 560 | Oval 562,455.8111,374.0868 561 | Oval 563,460.5656,374.5647 562 | Oval 564,457.5872,378.9415 563 | Oval 565,464.2241,378.0392 564 | Oval 566,468.6884,383.6713 565 | Oval 567,475.8312,387.736 566 | Oval 568,482.162,391.8586 567 | Oval 569,484.4883,399.1618 568 | Oval 570,484.4883,395.8474 569 | Oval 571,488.7551,397.6844 570 | Oval 572,492.3091,400.2287 571 | Oval 573,500.7601,411.2865 572 | Oval 574,460.6936,382.3921 573 | Oval 575,450.3204,381.4369 574 | Oval 576,446.1455,378.0392 575 | Oval 577,450.0425,375.3819 576 | Oval 578,445.8206,370.0829 577 | Oval 579,479.8357,405.5493 578 | Oval 580,460.5656,386.1302 579 | Oval 581,420.8468,368.2059 580 | Oval 582,426.9776,372.5102 581 | Oval 583,466.362,388.1741 582 | Oval 584,353.9632,342.7182 583 | Oval 585,380.7457,393.3808 584 | Oval 586,385.7754,403.6689 585 | Oval 587,391.385,419.7925 586 | Oval 588,497.9268,410.5448 587 | Oval 589,389.8866,393.0677 588 | Oval 590,392.9124,401.302 589 | Oval 591,384.6208,394.286 590 | Oval 592,389.8866,398.2527 591 | Oval 593,386.0681,387.8766 592 | Oval 594,392.7481,386.9828 593 | Oval 595,396.0377,415.2589 594 | Oval 596,395.7912,394.7334 595 | Oval 597,398.3083,411.5711 596 | Oval 598,400.8546,405.5493 597 | Oval 599,401.4142,425.1414 598 | Oval 600,406.0668,419.077 599 | Oval 601,471.0146,392.2272 600 | Oval 602,476.4773,400.8116 601 | Oval 603,474.5387,395.257 602 | Oval 604,356.6473,347.3059 603 | Oval 605,413.8605,441.2393 604 | Oval 606,411.0684,431.1167 605 | Oval 607,410.1963,435.4541 606 | Oval 608,409.2055,440.3326 607 | Oval 609,475.6627,411.3527 608 | Oval 610,468.5535,398.6388 609 | Oval 611,464.7209,394.2649 610 | Oval 612,460.5656,390.9534 611 | Oval 613,456.7149,389.7998 612 | Oval 614,487.1741,417.791 613 | Oval 615,480.3153,413.5327 614 | Oval 616,445.6968,455.7962 615 | Oval 617,439.6518,458.2166 616 | Oval 618,433.4772,454.4445 617 | Oval 619,419.3633,440.1919 618 | Oval 620,424.0151,445.8976 619 | Oval 621,380.7457,332.6716 620 | Oval 622,469.9182,460.644 621 | Oval 623,471.0146,406.4439 622 | Oval 624,407.3476,410.4041 623 | Oval 625,460.234,455.7892 624 | Oval 626,423.32,437.7645 625 | Oval 627,430.4528,441.2393 626 | Oval 628,432.5336,445.9701 627 | Oval 629,435.7313,446.4021 628 | Oval 630,432.3687,449.8829 629 | Oval 631,493.9247,431.1167 630 | Oval 632,482.162,427.5689 631 | Oval 633,479.653,419.7925 632 | Oval 634,486.9441,432.9098 633 | Oval 635,489.5808,428.493 634 | Oval 636,450.8868,385.4313 635 | Oval 637,446.368,393.5909 636 | Oval 638,433.2043,381.711 637 | Oval 639,429.8019,376.4393 638 | Oval 640,438.5349,384.1383 639 | Oval 641,364.7587,348.4501 640 | Oval 642,366.3069,356.017 641 | Oval 643,370.1953,354.1218 642 | Oval 644,413.4909,407.9767 643 | Oval 645,414.8528,413.2212 644 | Oval 646,473.8926,456.8463 645 | Oval 647,477.7774,461.349 646 | Oval 648,456.698,445.0467 647 | Oval 649,455.913,448.316 648 | Oval 650,453.5673,452.4949 649 | Oval 651,476.021,436.3845 650 | Oval 652,476.021,440.1919 651 | Oval 653,464.0357,441.1154 652 | Oval 654,467.088,446.0941 653 | Oval 655,470.6361,448.8294 654 | Oval 656,473.9521,451.2568 655 | Oval 657,460.5656,396.7133 656 | Oval 658,474.0107,433.9572 657 | Oval 659,460.5656,400.4341 658 | Oval 660,455.913,395.8474 659 | Oval 661,466.362,406.3643 660 | Oval 662,477.4505,423.8999 661 | Oval 663,466.1669,417.3651 662 | Oval 664,452.3495,446.4021 663 | Oval 665,448.8066,449.0861 664 | Oval 666,468.2866,431.1167 665 | Oval 667,473.553,423.8347 666 | Oval 668,471.9086,419.077 667 | Oval 669,369.9813,359.813 668 | Oval 670,374.6747,363.2986 669 | Oval 671,376.8322,360.3446 670 | Oval 672,373.8128,354.7435 671 | Oval 673,378.5852,367.6555 672 | Oval 674,381.9387,364.0844 673 | Oval 675,451.4075,441.5833 674 | Oval 676,428.5517,436.054 675 | Oval 677,455.7442,439.0252 676 | Oval 678,443.6353,442.7362 677 | Oval 679,447.493,446.4042 678 | Oval 680,470.9045,426.269 679 | Oval 681,463.9009,421.5359 680 | Oval 682,465.9603,426.1382 681 | Oval 683,468.2866,421.6213 682 | Oval 684,461.7094,434.0097 683 | Oval 685,391.385,334.2386 684 | Oval 686,445.8206,438.5063 685 | Oval 687,431.6357,435.3371 686 | Oval 688,436.6563,437.7645 687 | Oval 689,437.7661,442.6193 688 | Oval 690,460.5656,415.94 689 | Oval 691,455.913,418.4837 690 | Oval 692,461.6857,430.4963 691 | Oval 693,456.1006,428.6894 692 | Oval 694,459.9491,427.0352 693 | Oval 695,422.946,428.5871 694 | Oval 696,450.3385,432.0572 695 | Oval 697,445.8206,433.4961 696 | Oval 698,440.3016,433.9572 697 | Oval 699,428.1265,432.4383 698 | Oval 700,450.6481,435.9715 699 | Oval 701,402.9985,354.7435 700 | Oval 702,454.0879,400.1118 701 | Oval 703,460.3553,406.8236 702 | Oval 704,464.0357,401.5891 703 | Oval 705,405.3248,390.9534 704 | Oval 706,391.385,378.0392 705 | Oval 707,398.701,381.711 706 | Oval 708,413.0181,397.5623 707 | Oval 709,419.5837,419.5079 708 | Oval 710,423.8991,424.6472 709 | Oval 711,442.7134,428.7209 710 | Oval 712,437.6021,431.8436 711 | Oval 713,432.4104,431.165 712 | Oval 714,428.5517,426.6518 713 | Oval 715,455.9942,405.5493 714 | Oval 716,448.0122,426.262 715 | Oval 717,452.0795,421.7588 716 | Oval 718,451.682,413.1985 717 | Oval 719,453.5673,409.2844 718 | Oval 720,447.9518,421.5044 719 | Oval 721,441.309,424.6472 720 | Oval 722,429.851,422.433 721 | Oval 723,424.7883,419.077 722 | Oval 724,421.1476,416.6497 723 | Oval 725,450.4417,401.2589 724 | Oval 726,422.3864,376.7071 725 | Oval 727,442.5777,420.5193 726 | Oval 728,447.3356,416.8 727 | Oval 729,449.8661,407.7803 728 | Oval 730,417.9831,406.7163 729 | Oval 731,435.9444,390.8865 730 | Oval 732,440.3847,394.2649 731 | Oval 733,437.7349,419.077 732 | Oval 734,433.2043,419.7925 733 | Oval 735,423.4739,411.3527 734 | Oval 736,433.2043,386.0987 735 | Oval 737,429.064,381.244 736 | Oval 738,429.4787,387.736 737 | Oval 739,448.1469,396.517 738 | Oval 740,445.8206,406.4439 739 | Oval 741,435.4262,415.3637 740 | Oval 742,442.6857,400.8386 741 | Oval 743,437.1157,398.2527 742 | Oval 744,439.9503,413.077 743 | Oval 745,443.6353,410.4041 744 | Oval 746,438.2927,404.0165 745 | Oval 747,441.309,406.4439 746 | Oval 748,430.3462,412.4195 747 | Oval 749,435.767,409.7223 748 | Oval 750,417.5889,399.4445 749 | Oval 751,432.1407,407.1671 750 | Oval 752,427.1809,408.3664 751 | Oval 753,422.2709,404.0165 752 | Oval 754,428.5517,392.2604 753 | Oval 755,425.7588,397.3012 754 | Oval 756,427.2105,401.6417 755 | Oval 757,432.4406,399.8229 756 | Oval 758,410.5805,390.1633 757 | Oval 759,421.2946,398.2527 758 | Oval 760,417.0167,392.9803 759 | Oval 761,400.3167,348.4047 760 | Oval 762,414.3116,388.0447 761 | Oval 763,423.4325,392.5907 762 | Oval 764,397.1091,344.8786 763 | Oval 765,391.0964,343.3496 764 | Oval 766,415.7991,371.6595 765 | Oval 767,370.5507,334.2349 766 | Oval 768,386.2746,338.4949 767 | Oval 769,397.4843,337.0556 768 | Oval 770,390.5417,352.1607 769 | Oval 771,404.0828,363.2986 770 | Oval 772,415.4295,384.8281 771 | Oval 773,394.5869,373.4195 772 | Oval 774,412.7788,382.0067 773 | Oval 775,417.8685,379.6684 774 | Oval 776,404.4823,373.1844 775 | Oval 777,402.9985,380.0607 776 | Oval 778,411.655,375.1508 777 | Oval 779,366.2212,336.6623 778 | Oval 780,399.4302,364.1364 779 | Oval 781,396.2202,361.1339 780 | Oval 782,385.3983,343.9837 781 | Oval 783,369.1602,346.0227 782 | Oval 784,369.1602,341.0393 783 | Oval 785,374.5059,343.3134 784 | Oval 786,380.4241,343.9837 785 | Oval 787,379.5444,348.7349 786 | Oval 788,379.7998,353.7826 787 | Oval 789,475.1696,470.4786 788 | Oval 790,486.3164,464.1284 789 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AIML-Algorithms-Training 2 | Training scripts for the AIML Algorithms Course offered through O'Rielly / Pearson 3 | 4 | Added the PointsCoordinates.csv file 5 | -------------------------------------------------------------------------------- /Section-7-word2vec_model.bin: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/Section-7-word2vec_model.bin -------------------------------------------------------------------------------- /Section_2_1_unsupervised.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "8f4fb93f-ebc2-4160-a446-920a0b6cf9e8", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "#Plotting all the points on an animated graph\n", 11 | "%matplotlib inline\n", 12 | "import pandas as pd\n", 13 | "from IPython.display import display, HTML\n", 14 | "import matplotlib.pyplot as plt\n", 15 | "import matplotlib.animation as animation\n", 16 | "\n", 17 | "# Read the CSV file\n", 18 | "df = pd.read_csv('PointsCoordinates.csv')\n", 19 | "\n", 20 | "# Extract the second and third columns for x and y coordinates\n", 21 | "x = df.iloc[:, 1].tolist()\n", 22 | "y = df.iloc[:, 2].tolist()\n", 23 | "\n", 24 | "# Number of points\n", 25 | "n = len(x)\n", 26 | "\n", 27 | "fig, ax = plt.subplots(figsize=(10, 14))\n", 28 | "sc = ax.scatter([], [], s=5, color='blue')\n", 29 | "\n", 30 | "def init():\n", 31 | " ax.set_xlim(min(x), max(x))\n", 32 | " ax.set_ylim(min(y), max(y))\n", 33 | " return sc,\n", 34 | "\n", 35 | "x_data, y_data = [], []\n", 36 | "\n", 37 | "def update(frame):\n", 38 | " # Plot two points at a time\n", 39 | " x_data.extend([x[2*frame], x[2*frame + 1]])\n", 40 | " y_data.extend([y[2*frame], y[2*frame + 1]])\n", 41 | " sc.set_offsets(list(zip(x_data, y_data)))\n", 42 | " return sc,\n", 43 | "\n", 44 | "global ani\n", 45 | "ani = animation.FuncAnimation(fig, update, frames=n//2, init_func=init, blit=True, repeat=False, interval=2)\n", 46 | "\n", 47 | "plt.title('Points from CSV on a 2D Graph')\n", 48 | "plt.xlabel('X-axis')\n", 49 | "plt.ylabel('Y-axis')\n", 50 | "plt.close(fig) # This will prevent the static plot from displaying\n", 51 | "display(HTML(ani.to_jshtml()))\n" 52 | ] 53 | }, 54 | { 55 | "cell_type": "code", 56 | "execution_count": null, 57 | "id": "751d9f8c-2c04-4805-a30b-ea790434b866", 58 | "metadata": {}, 59 | "outputs": [], 60 | "source": [ 61 | "# plotting the progression of 3 K means clustering on the same data\n", 62 | "import pandas as pd\n", 63 | "import numpy as np\n", 64 | "import matplotlib.pyplot as plt\n", 65 | "import matplotlib.animation as animation\n", 66 | "from sklearn.cluster import KMeans\n", 67 | "from IPython.display import display, HTML\n", 68 | "\n", 69 | "# Read the CSV file\n", 70 | "df = pd.read_csv('PointsCoordinates.csv')\n", 71 | "\n", 72 | "# Extract the second and third columns for x and y coordinates\n", 73 | "data = df.iloc[:, 1:3].values\n", 74 | "\n", 75 | "fig, ax = plt.subplots(figsize=(10, 14))\n", 76 | "centroid_paths = [[], [], []]\n", 77 | "\n", 78 | "def animate(i):\n", 79 | " ax.clear()\n", 80 | " \n", 81 | " # For the initial frame, plot all points in black\n", 82 | " if i == 0:\n", 83 | " ax.scatter(data[:, 0], data[:, 1], s=5, c='black')\n", 84 | " ax.set_title('Initial State')\n", 85 | " return\n", 86 | "\n", 87 | " # Fit KMeans with an increasing number of iterations and 'random' initialization\n", 88 | " kmeans = KMeans(n_clusters=3, init='random', n_init=1, max_iter=i, random_state=42)\n", 89 | " kmeans.fit(data)\n", 90 | " labels = kmeans.labels_\n", 91 | " \n", 92 | " # Plot points based on their cluster labels\n", 93 | " ax.scatter(data[labels == 0][:, 0], data[labels == 0][:, 1], s=5, c='green', label='Cluster 1')\n", 94 | " ax.scatter(data[labels == 1][:, 0], data[labels == 1][:, 1], s=5, c='red', label='Cluster 2')\n", 95 | " ax.scatter(data[labels == 2][:, 0], data[labels == 2][:, 1], s=5, c='blue', label='Cluster 3')\n", 96 | " \n", 97 | " # Plot cluster centers and their movement\n", 98 | " centers = kmeans.cluster_centers_\n", 99 | " for j, center in enumerate(centers):\n", 100 | " centroid_paths[j].append(center)\n", 101 | " path = np.array(centroid_paths[j])\n", 102 | " ax.plot(path[:, 0], path[:, 1], 'w--', linewidth=1)\n", 103 | " ax.scatter(center[0], center[1], c='black', s=100, marker='X')\n", 104 | " \n", 105 | " ax.set_title(f'Iteration: {i}')\n", 106 | " ax.legend()\n", 107 | "\n", 108 | "# Animate for 11 frames (1 initial + 10 iterations)\n", 109 | "ani = animation.FuncAnimation(fig, animate, frames=11, repeat=False, interval=500)\n", 110 | "\n", 111 | "\n", 112 | "plt.close(fig) # This will prevent the static plot from displaying\n", 113 | "display(HTML(ani.to_jshtml()))\n" 114 | ] 115 | }, 116 | { 117 | "cell_type": "code", 118 | "execution_count": null, 119 | "id": "5abcb880-8785-4b48-be2a-cfc1d5248dfd", 120 | "metadata": {}, 121 | "outputs": [], 122 | "source": [] 123 | }, 124 | { 125 | "cell_type": "code", 126 | "execution_count": null, 127 | "id": "85ce44ea-bbf3-475f-b738-a9cf359628ae", 128 | "metadata": {}, 129 | "outputs": [], 130 | "source": [] 131 | } 132 | ], 133 | "metadata": { 134 | "kernelspec": { 135 | "display_name": "Python 3 (ipykernel)", 136 | "language": "python", 137 | "name": "python3" 138 | }, 139 | "language_info": { 140 | "codemirror_mode": { 141 | "name": "ipython", 142 | "version": 3 143 | }, 144 | "file_extension": ".py", 145 | "mimetype": "text/x-python", 146 | "name": "python", 147 | "nbconvert_exporter": "python", 148 | "pygments_lexer": "ipython3", 149 | "version": "3.9.6" 150 | } 151 | }, 152 | "nbformat": 4, 153 | "nbformat_minor": 5 154 | } 155 | -------------------------------------------------------------------------------- /Section_3_1-Supervised-Learning.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.animation as animation 7 | 8 | # Generate training data with 1000 samples 9 | np.random.seed(0) 10 | torch.manual_seed(0) 11 | x_train = torch.linspace(-4, 4, 1000).unsqueeze(1) 12 | y_train = torch.sin(x_train) - 0.1 * x_train**2 13 | 14 | # Define the neural network architecture 15 | class NeuralNetwork(nn.Module): 16 | def __init__(self): 17 | super(NeuralNetwork, self).__init__() 18 | self.layer1 = nn.Linear(1, 4) # Input to hidden layer 19 | self.layer2 = nn.Linear(4, 1) # Hidden to output layer 20 | 21 | def forward(self, x): 22 | x = torch.sigmoid(self.layer1(x)) 23 | x = self.layer2(x) 24 | return x 25 | 26 | # Implement the training loop with real-time plot 27 | def train_model_with_real_time_plot(model, x_train, y_train, num_epochs=10000, learning_rate=0.15): 28 | criterion = nn.MSELoss() 29 | optimizer = optim.SGD(model.parameters(), lr=learning_rate) 30 | 31 | fig, ax = plt.subplots(figsize=(8, 6)) 32 | ax.scatter(x_train, y_train, label='Training data', color='green', s=10) 33 | ax.plot(x_train, y_train, label='True function', color='blue') 34 | line, = ax.plot([], [], label='Trained Model', color='red') 35 | ax.set_xlabel('x') 36 | ax.set_ylabel('y') 37 | ax.set_title(f'Supervised Learning Fit: sin(x) - 0.1x^2') 38 | ax.legend() 39 | 40 | epoch_text = ax.text(-4, 0, f'Epoch 0', fontsize=12, ha='left') 41 | 42 | def update(frame, model, x_train, line, epoch_text): 43 | if frame == 0: 44 | line.set_data([], []) 45 | return line, epoch_text 46 | 47 | y_pred = model(x_train) 48 | loss = criterion(y_pred, y_train) 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | if frame % 1 == 0: # Update the plot every 1 epochs 54 | line.set_data(x_train, y_pred.detach().numpy()[:, 0]) 55 | epoch = frame * 10 56 | epoch_text.set_text(f'Epoch {epoch}') 57 | epoch_text.set_position((-4, 0)) 58 | 59 | return line, epoch_text 60 | 61 | ani = animation.FuncAnimation(fig, update, fargs=(model, x_train, line, epoch_text), frames=num_epochs // 10 + 1, blit=True, interval=100, repeat=False) 62 | plt.show() 63 | 64 | # Create the model and train it with real-time plot 65 | model = NeuralNetwork() 66 | train_model_with_real_time_plot(model, x_train, y_train) 67 | -------------------------------------------------------------------------------- /Section_3_2-Gradient-Descent-Small-Step.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.animation as animation 7 | 8 | # Generate training data with 10000 samples 9 | np.random.seed(0) 10 | torch.manual_seed(0) 11 | x_train = torch.linspace(-4, 4, 10000).unsqueeze(1) 12 | y_train = torch.sin(x_train) - 0.1 * x_train**2 13 | 14 | # Define the neural network architecture 15 | class NeuralNetwork(nn.Module): 16 | def __init__(self): 17 | super(NeuralNetwork, self).__init__() 18 | self.layer1 = nn.Linear(1, 4) # Input to hidden layer 19 | self.layer2 = nn.Linear(4, 1) # Hidden to output layer 20 | 21 | def forward(self, x): 22 | x = torch.sigmoid(self.layer1(x)) 23 | x = self.layer2(x) 24 | return x 25 | 26 | # Implement the training loop with real-time plot 27 | def train_model_with_real_time_plot(model, x_train, y_train, num_epochs=10000, learning_rate=0.02): 28 | criterion = nn.MSELoss() 29 | optimizer = optim.SGD(model.parameters(), lr=learning_rate) 30 | 31 | fig, ax = plt.subplots(figsize=(8, 6)) 32 | ax.scatter(x_train, y_train, label='Training data', color='green', s=10) 33 | ax.plot(x_train, y_train, label='True function', color='blue') 34 | line, = ax.plot([], [], label='Trained Model', color='red') 35 | ax.set_xlabel('x') 36 | ax.set_ylabel('y') 37 | ax.set_title(f'Supervised Learning Fit: sin(x) - 0.1x^2') 38 | ax.legend() 39 | 40 | epoch_text = ax.text(-5, 0.75, f'Epoch 0', fontsize=12, ha='left') 41 | 42 | def update(frame, model, x_train, line, epoch_text): 43 | if frame == 0: 44 | line.set_data([], []) 45 | return line, epoch_text 46 | 47 | y_pred = model(x_train) 48 | loss = criterion(y_pred, y_train) 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | if frame % 1 == 0: # Update the plot every 1 epochs 54 | line.set_data(x_train, y_pred.detach().numpy()) 55 | epoch = frame * 10 56 | epoch_text.set_text(f'Epoch {epoch}') 57 | epoch_text.set_position((-5, 0.75)) 58 | 59 | return line, epoch_text 60 | 61 | ani = animation.FuncAnimation(fig, update, fargs=(model, x_train, line, epoch_text), frames=num_epochs // 10 + 1, blit=True, interval=100, repeat=False) 62 | plt.show() 63 | 64 | # Create the model and train it with real-time plot 65 | model = NeuralNetwork() 66 | train_model_with_real_time_plot(model, x_train, y_train) 67 | -------------------------------------------------------------------------------- /Section_3_3-Gradient-Descent-Big-Step.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.optim as optim 4 | import numpy as np 5 | import matplotlib.pyplot as plt 6 | import matplotlib.animation as animation 7 | 8 | # Generate training data with 10000 samples 9 | np.random.seed(0) 10 | torch.manual_seed(0) 11 | x_train = torch.linspace(-4, 4, 10000).unsqueeze(1) 12 | y_train = torch.sin(x_train) - 0.1 * x_train**2 13 | 14 | # Define the neural network architecture 15 | class NeuralNetwork(nn.Module): 16 | def __init__(self): 17 | super(NeuralNetwork, self).__init__() 18 | self.layer1 = nn.Linear(1, 4) # Input to hidden layer 19 | self.layer2 = nn.Linear(4, 1) # Hidden to output layer 20 | 21 | def forward(self, x): 22 | x = torch.sigmoid(self.layer1(x)) 23 | x = self.layer2(x) 24 | return x 25 | 26 | # Implement the training loop with real-time plot 27 | def train_model_with_real_time_plot(model, x_train, y_train, num_epochs=10000, learning_rate=0.4): 28 | criterion = nn.MSELoss() 29 | optimizer = optim.SGD(model.parameters(), lr=learning_rate) 30 | 31 | fig, ax = plt.subplots(figsize=(8, 6)) 32 | ax.scatter(x_train, y_train, label='Training data', color='green', s=10) 33 | ax.plot(x_train, y_train, label='True function', color='blue') 34 | line, = ax.plot([], [], label='Trained Model', color='red') 35 | ax.set_xlabel('x') 36 | ax.set_ylabel('y') 37 | ax.set_title(f'Supervised Learning Fit: sin(x) - 0.1x^2') 38 | ax.legend() 39 | 40 | epoch_text = ax.text(-5, 0.75, f'Epoch 0', fontsize=12, ha='left') 41 | 42 | def update(frame, model, x_train, line, epoch_text): 43 | if frame == 0: 44 | line.set_data([], []) 45 | return line, epoch_text 46 | 47 | y_pred = model(x_train) 48 | loss = criterion(y_pred, y_train) 49 | optimizer.zero_grad() 50 | loss.backward() 51 | optimizer.step() 52 | 53 | if frame % 1 == 0: # Update the plot every 1 epochs 54 | line.set_data(x_train, y_pred.detach().numpy()) 55 | epoch = frame * 10 56 | epoch_text.set_text(f'Epoch {epoch}') 57 | epoch_text.set_position((-5, 0.75)) 58 | 59 | return line, epoch_text 60 | 61 | ani = animation.FuncAnimation(fig, update, fargs=(model, x_train, line, epoch_text), frames=num_epochs // 10 + 1, blit=True, interval=100, repeat=False) 62 | plt.show() 63 | 64 | # Create the model and train it with real-time plot 65 | model = NeuralNetwork() 66 | train_model_with_real_time_plot(model, x_train, y_train) 67 | -------------------------------------------------------------------------------- /Section_6-TicTactoe/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/Section_6-TicTactoe/.DS_Store -------------------------------------------------------------------------------- /Section_6-TicTactoe/Game.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/Section_6-TicTactoe/Game.jpg -------------------------------------------------------------------------------- /Section_6-TicTactoe/Game.py: -------------------------------------------------------------------------------- 1 | import pygame 2 | import random 3 | import time 4 | from QLearning import Qlearning 5 | 6 | #humman player 7 | class Humanplayer: 8 | pass 9 | 10 | #randomplayer player 11 | class Randomplayer: 12 | def __init__(self): 13 | pass 14 | def move(self,possiblemoves): 15 | return random.choice(possiblemoves) 16 | 17 | class TicTacToe: 18 | def __init__(self,traning=False): 19 | self.board = [' ']*9 20 | 21 | self.done = False 22 | self.humman=None 23 | self.computer=None 24 | self.humanTurn=None 25 | self.training=traning 26 | self.player1 = None 27 | self.player2 = None 28 | self.aiplayer=None 29 | self.isAI=False 30 | # if not training display 31 | if(not self.training): 32 | pygame.init() 33 | self.ttt = pygame.display.set_mode((225,250)) 34 | pygame.display.set_caption('Tic-Tac-Toe') 35 | 36 | #reset the game 37 | def reset(self): 38 | if(self.training): 39 | self.board = [' '] * 9 40 | return 41 | 42 | self.board = [' '] * 9 43 | self.humanTurn=random.choice([True,False]) 44 | 45 | self.surface = pygame.Surface(self.ttt.get_size()) 46 | self.surface = self.surface.convert() 47 | self.surface.fill((250, 250, 250)) 48 | #horizontal line 49 | pygame.draw.line(self.surface, (0, 0, 0), (75, 0), (75, 225), 2) 50 | pygame.draw.line(self.surface, (0, 0, 0), (150, 0), (150, 225), 2) 51 | # veritical line 52 | pygame.draw.line(self.surface, (0, 0, 0), (0,75), (225, 75), 2) 53 | pygame.draw.line(self.surface, (0, 0, 0), (0,150), (225, 150), 2) 54 | 55 | #evaluate function 56 | def evaluate(self, ch): 57 | # "rows checking" 58 | for i in range(3): 59 | if (ch == self.board[i * 3] == self.board[i * 3 + 1] and self.board[i * 3 + 1] == self.board[i * 3 + 2]): 60 | return 1.0, True 61 | # "col checking" 62 | for i in range(3): 63 | if (ch == self.board[i + 0] == self.board[i + 3] and self.board[i + 3] == self.board[i + 6]): 64 | return 1.0, True 65 | # diagonal checking 66 | if (ch == self.board[0] == self.board[4] and self.board[4] == self.board[8]): 67 | return 1.0, True 68 | 69 | if (ch == self.board[2] == self.board[4] and self.board[4] == self.board[6]): 70 | return 1.0, True 71 | # "if filled draw" 72 | if not any(c == ' ' for c in self.board): 73 | return 0.5, True 74 | 75 | return 0.0, False 76 | 77 | #return remaining possible moves 78 | def possible_moves(self): 79 | return [moves + 1 for moves, v in enumerate(self.board) if v == ' '] 80 | 81 | #take next step and return reward 82 | def step(self, isX, move): 83 | if(isX): 84 | ch = 'X' 85 | else: 86 | ch = '0' 87 | if(self.board[move-1]!=' '): # try to over write 88 | return -5, True 89 | 90 | self.board[move-1]= ch 91 | reward,done = self.evaluate(ch) 92 | return reward, done 93 | 94 | 95 | #draw move on window 96 | def drawMove(self, pos,isX): 97 | row=int((pos-1)/3) 98 | col=(pos-1)%3 99 | 100 | centerX = ((col) * 75) + 32 101 | centerY = ((row) * 75) + 32 102 | 103 | reward, done= self.step(isX,pos) #next step 104 | if(reward==-5): #overlap 105 | #print('Invalid move') 106 | font = pygame.font.Font(None, 24) 107 | text = font.render('Invalid move!', 1, (10, 10, 10)) 108 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 109 | self.surface.blit(text, (10, 230)) 110 | 111 | return reward, done 112 | 113 | if (isX): #playerX so draw x 114 | font = pygame.font.Font(None, 24) 115 | text = font.render('X', 1, (10, 10, 10)) 116 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 117 | self.surface.blit(text, (centerX, centerY)) 118 | self.board[pos-1] ='X' 119 | 120 | if(self.humman and reward==1): #if playerX is humman and won, display humman won 121 | #print('Humman won! in X') 122 | text = font.render('Humman won!', 1, (10, 10, 10)) 123 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 124 | self.surface.blit(text, (10, 230)) 125 | 126 | 127 | elif (self.computer and reward == 1):#if playerX is computer and won, display computer won 128 | #print('computer won! in X') 129 | text = font.render('computer won!', 1, (10, 10, 10)) 130 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 131 | self.surface.blit(text, (10, 230)) 132 | 133 | 134 | 135 | 136 | else: #playerO so draw O 137 | font = pygame.font.Font(None, 24) 138 | text = font.render('O', 1, (10, 10, 10)) 139 | 140 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 141 | self.surface.blit(text, (centerX, centerY)) 142 | self.board[pos-1] = '0' 143 | 144 | if (not self.humman and reward == 1): #if playerO is humman and won, display humman won 145 | #print('Humman won! in O') 146 | text = font.render('Humman won!', 1, (10, 10, 10)) 147 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 148 | self.surface.blit(text, (10, 230)) 149 | 150 | 151 | elif (not self.computer and reward == 1): #if playerO is computer and won, display computer won 152 | #print('computer won! in O') 153 | text = font.render('computer won!', 1, (10, 10, 10)) 154 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 155 | self.surface.blit(text, (10, 230)) 156 | 157 | 158 | 159 | if (reward == 0.5): # draw, then display draw 160 | #print('Draw Game! in O') 161 | font = pygame.font.Font(None, 24) 162 | text = font.render('Draw Game!', 1, (10, 10, 10)) 163 | self.surface.fill((250, 250, 250), (0, 300, 300, 25)) 164 | self.surface.blit(text, (10, 230)) 165 | return reward, done 166 | 167 | return reward,done 168 | 169 | # mouseClick position 170 | def mouseClick(self): 171 | (mouseX, mouseY) = pygame.mouse.get_pos() 172 | if (mouseY < 75): 173 | row = 0 174 | elif (mouseY < 150): 175 | row = 1 176 | else: 177 | row = 2 178 | 179 | if (mouseX < 75): 180 | col = 0 181 | elif (mouseX < 150): 182 | col = 1 183 | else: 184 | col = 2 185 | return row * 3 + col + 1 186 | 187 | 188 | #update state 189 | def updateState(self,isX): 190 | pos=self.mouseClick() 191 | reward,done = self.drawMove(pos,isX) 192 | return reward, done 193 | 194 | #show display 195 | def showboard(self): 196 | self.ttt.blit(self.surface, (0, 0)) 197 | pygame.display.flip() 198 | 199 | 200 | #begin training 201 | def startTraining(self,player1,player2): 202 | if(isinstance(player1,Qlearning) and isinstance(player2, Qlearning)): 203 | self.training = True 204 | self.player1=player1 205 | self.player2=player2 206 | 207 | #tarin function 208 | def train(self,iterations): 209 | if(self.training): 210 | for i in range(iterations): 211 | print("trainining", i) 212 | self.player1.game_begin() 213 | self.player2.game_begin() 214 | self.reset() 215 | done = False 216 | isX = random.choice([True, False]) 217 | while not done: 218 | if isX: 219 | move = self.player1.epslion_greedy(self.board, self.possible_moves()) 220 | else: 221 | move = self.player2.epslion_greedy(self.board, self.possible_moves()) 222 | 223 | 224 | reward, done = self.step(isX, move) 225 | 226 | if (reward == 1): # won 227 | if (isX): 228 | self.player1.updateQ(reward, self.board, self.possible_moves()) 229 | self.player2.updateQ(-1 * reward, self.board, self.possible_moves()) 230 | else: 231 | self.player1.updateQ(-1 * reward, self.board, self.possible_moves()) 232 | self.player2.updateQ(reward, self.board, self.possible_moves()) 233 | 234 | elif (reward == 0.5): # draw 235 | self.player1.updateQ(reward, self.board, self.possible_moves()) 236 | self.player2.updateQ(reward, self.board, self.possible_moves()) 237 | 238 | 239 | elif (reward == -5): # illegal move 240 | if (isX): 241 | self.player1.updateQ(reward, self.board, self.possible_moves()) 242 | else: 243 | self.player2.updateQ(reward, self.board, self.possible_moves()) 244 | 245 | elif (reward == 0): 246 | if (isX): # update opposite 247 | self.player2.updateQ(reward, self.board, self.possible_moves()) 248 | else: 249 | self.player1.updateQ(reward, self.board, self.possible_moves()) 250 | 251 | isX = not isX # 252 | 253 | #save Qtables 254 | def saveStates(self): 255 | self.player1.saveQtable("player1states") 256 | self.player2.saveQtable("player2states") 257 | 258 | 259 | #start game human vs AI or human vs random 260 | def startGame(self, playerX, playerO): 261 | if (isinstance(playerX, Humanplayer)): 262 | self.humman, self.computer = True, False 263 | if (isinstance(playerO, Qlearning)): #if AI 264 | self.ai = playerO 265 | self.ai.loadQtable("player2states") # load saved Q table 266 | self.ai.epsilon = 0 #set eps to 0 so always choose greedy step 267 | self.isAI = True 268 | elif (isinstance(playerO, Randomplayer)): #if random 269 | self.ai = playerO 270 | self.isAI = False 271 | 272 | elif (isinstance(playerO, Humanplayer)): 273 | self.humman, self.computer = False, True 274 | if (isinstance(playerX, Qlearning)): #if AI 275 | self.ai = playerX 276 | self.ai.loadQtable("player1states") # load saved Q table 277 | self.ai.epsilon = 0 #set eps to 0 so always choose greedy step 278 | self.isAI = True 279 | elif(isinstance(playerX, Randomplayer)):#if random 280 | self.ai=playerX 281 | self.isAI = False 282 | 283 | 284 | def render(self): 285 | running = 1 286 | done = False 287 | pygame.event.clear() 288 | while (running == 1): 289 | if (self.humanTurn): #humman click 290 | print("Human player turn") 291 | event = pygame.event.wait() 292 | while event.type != pygame.MOUSEBUTTONDOWN: 293 | event = pygame.event.wait() 294 | self.showboard() 295 | if event.type == pygame.QUIT: 296 | running = 0 297 | print("pressed quit") 298 | break 299 | 300 | reward, done = self.updateState(self.humman) #if random 301 | self.showboard() 302 | if (done): #if done reset 303 | time.sleep(1) 304 | self.reset() 305 | else: #AI or random turn 306 | if(self.isAI): 307 | moves = self.ai.epslion_greedy(self.board, self.possible_moves()) 308 | reward, done = self.drawMove(moves, self.computer) 309 | print("computer's AI player turn") 310 | self.showboard() 311 | else: #random player 312 | moves = self.ai.move(self.possible_moves()) #random player 313 | reward, done = self.drawMove(moves, self.computer) 314 | print("computer's random player turn") 315 | self.showboard() 316 | 317 | if (done): #if done reset 318 | time.sleep(1) 319 | self.reset() 320 | 321 | self.humanTurn = not self.humanTurn 322 | 323 | 324 | 325 | 326 | -------------------------------------------------------------------------------- /Section_6-TicTactoe/Play Dumb Agent.py: -------------------------------------------------------------------------------- 1 | from Game import TicTacToe, Humanplayer, Randomplayer 2 | from QLearning import Qlearning 3 | 4 | game = TicTacToe() #game instance 5 | player1=Humanplayer() #human player 6 | player2=Randomplayer() #agent 7 | game.startGame(player1,player2)#player1 is X, player2 is 0 8 | game.reset() #reset 9 | game.render() # render display 10 | 11 | -------------------------------------------------------------------------------- /Section_6-TicTactoe/Play Q-Learning.py: -------------------------------------------------------------------------------- 1 | from Game import TicTacToe, Humanplayer, Randomplayer 2 | from QLearning import Qlearning 3 | 4 | game = TicTacToe() #game instance 5 | player1=Humanplayer() #human player 6 | player2=Qlearning() #agent 7 | game.startGame(player1,player2)#player1 is X, player2 is 0 8 | game.reset() #reset 9 | game.render() # render display -------------------------------------------------------------------------------- /Section_6-TicTactoe/QLearning.py: -------------------------------------------------------------------------------- 1 | import random 2 | import pickle 3 | 4 | class Qlearning: 5 | def __init__(self,epsilon=0.2, alpha=0.3, gamma=0.9): 6 | self.epsilon=epsilon 7 | self.alpha=alpha 8 | self.gamma=gamma 9 | self.Q = {} #Q table 10 | self.last_board=None 11 | self.q_last=0.0 12 | self.state_action_last=None 13 | 14 | def game_begin(self): 15 | self.last_board = None 16 | self.q_last = 0.0 17 | self.state_action_last = None 18 | 19 | 20 | def epslion_greedy(self, state, possible_moves): #esplion greedy algorithm 21 | #return action 22 | self.last_board = tuple(state) 23 | if(random.random() < self.epsilon): 24 | move = random.choice(possible_moves) ##action 25 | self.state_action_last=(self.last_board,move) 26 | self.q_last=self.getQ(self.last_board,move) 27 | return move 28 | else: #greedy strategy 29 | Q_list=[] 30 | for action in possible_moves: 31 | Q_list.append(self.getQ(self.last_board,action)) 32 | maxQ=max(Q_list) 33 | 34 | if Q_list.count(maxQ) > 1: 35 | # more than 1 best option; choose among them randomly 36 | best_options = [i for i in range(len(possible_moves)) if Q_list[i] == maxQ] 37 | i = random.choice(best_options) 38 | else: 39 | i = Q_list.index(maxQ) 40 | self.state_action_last = (self.last_board, possible_moves[i]) 41 | self.q_last = self.getQ(self.last_board, possible_moves[i]) 42 | return possible_moves[i] 43 | 44 | 45 | def getQ(self, state, action): #get Q states 46 | if(self.Q.get((state,action))) is None: 47 | self.Q[(state,action)] = 1.0 48 | return self.Q.get((state,action)) 49 | 50 | def updateQ(self, reward, state, possible_moves): # update Q states using Qleanning 51 | q_list=[] 52 | for moves in possible_moves: 53 | q_list.append(self.getQ(tuple(state), moves)) 54 | if q_list: 55 | max_q_next = max(q_list) 56 | else: 57 | max_q_next=0.0 58 | self.Q[self.state_action_last] = self.q_last + self.alpha * ((reward + self.gamma*max_q_next) - self.q_last) 59 | 60 | def saveQtable(self,file_name): #save table 61 | with open(file_name, 'wb') as handle: 62 | pickle.dump(self.Q, handle, protocol=pickle.HIGHEST_PROTOCOL) 63 | 64 | def loadQtable(self,file_name): # load table 65 | with open(file_name, 'rb') as handle: 66 | self.Q = pickle.load(handle) 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | -------------------------------------------------------------------------------- /Section_6-TicTactoe/README.md: -------------------------------------------------------------------------------- 1 | # Tic-Tac-Toe-Reinforcement-learning 2 | Agent learns to play Tic-Tac-Toe using Reinforcement-learning (Q-learning). The agent was trained by playing against itself. Human can also play against trained Agent. 3 | 4 | ![Alt text](https://github.com/Rohithkvsp/Tic-Tac-Toe-Reinforcement-learning/blob/master/Game.jpg)
5 | Requirements:
6 | python 3.5.2 and pygame 7 | 8 | Run Play.py to play game.
9 | ``` 10 | py -3 Play.py 11 | ``` 12 | Run Train.py to train the agent.
13 | ``` 14 | py -3 Train.py 15 | ``` 16 | 17 | Training:
18 | It took 200,000 iterations to master the game. 19 | ``` 20 | game = TicTacToe(True) #game instance, True means training 21 | player1= Qlearning() #player1 learning agent 22 | player2 =Qlearning() #player2 learning agent 23 | game.startTraining(player1,player2) #start training 24 | game.train(200000) #train for 200,000 iterations 25 | game.saveStates() #save Qtable 26 | ``` 27 | 28 | Playing
29 | 30 | Human player vs AI agent 31 | ``` 32 | game = TicTacToe() #game instance 33 | player1=Humanplayer() #human player 34 | player2=Qlearning() #agent 35 | game.startGame(player1,player2)#player1 is X, player2 is 0 36 | game.reset() #reset 37 | game.render() # render display 38 | ``` 39 | Random player instead of AI agent 40 | 41 | ``` 42 | #change player1 or player2 to Randomplayer() 43 | player2 =Randomplayer() 44 | ``` 45 | -------------------------------------------------------------------------------- /Section_6-TicTactoe/Readme.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/Section_6-TicTactoe/Readme.txt -------------------------------------------------------------------------------- /Section_6-TicTactoe/Train.py: -------------------------------------------------------------------------------- 1 | from Game import TicTacToe 2 | from QLearning import Qlearning 3 | 4 | game = TicTacToe(True) #game instance, True means training 5 | player1= Qlearning() #player1 learning agent 6 | player2 =Qlearning() #player2 learning agent 7 | game.startTraining(player1,player2) #start training 8 | game.train(200000) #train for 200,000 iterations 9 | game.saveStates() #save Qtable -------------------------------------------------------------------------------- /Section_6-TicTactoe/player1states: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/Section_6-TicTactoe/player1states -------------------------------------------------------------------------------- /Section_6-TicTactoe/player2states: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/robbarto2/AIML-Algorithms-Training/4d8daf7a378f58ff7b3e47504b7c3761b36c103a/Section_6-TicTactoe/player2states -------------------------------------------------------------------------------- /Section_7-Word-Embedding-3D.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | from mpl_toolkits.mplot3d import Axes3D 4 | from sklearn.decomposition import PCA 5 | from gensim.models import Word2Vec 6 | 7 | # Load the trained Word2Vec model 8 | model = Word2Vec.load("Section-7-word2vec_model.bin") 9 | 10 | # Get word embeddings for all words in the vocabulary 11 | words = list(model.wv.key_to_index.keys()) 12 | embeddings = [model.wv[word] for word in words] 13 | 14 | # Convert embeddings list to a NumPy array 15 | embeddings = np.array(embeddings) 16 | 17 | # Perform PCA to reduce the dimensionality to 3 18 | pca = PCA(n_components=3) 19 | embeddings_3d = pca.fit_transform(embeddings) 20 | 21 | # Plot the word embeddings in 3D space 22 | fig = plt.figure(figsize=(10, 8)) 23 | ax = fig.add_subplot(111, projection='3d') 24 | 25 | # Scatter plot for each word's 3D representation 26 | ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2], marker='o', color='b') 27 | 28 | # Annotate each point with the corresponding word 29 | for i, word in enumerate(words): 30 | ax.text(embeddings_3d[i, 0], embeddings_3d[i, 1], embeddings_3d[i, 2], word, fontsize=8) 31 | 32 | ax.set_xlabel('PCA Component 1') 33 | ax.set_ylabel('PCA Component 2') 34 | ax.set_zlabel('PCA Component 3') 35 | ax.set_title('3D Visualization of Word Embeddings using Word2Vec') 36 | 37 | plt.show() 38 | -------------------------------------------------------------------------------- /Section_7-Word-Embeddings.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "raw", 5 | "id": "49fb0a57-515c-4df6-9c86-de6f2b03a854", 6 | "metadata": {}, 7 | "source": [ 8 | "Load the libraries" 9 | ] 10 | }, 11 | { 12 | "cell_type": "code", 13 | "execution_count": 1, 14 | "id": "78052be8-e64d-4f97-a819-109a78a455a6", 15 | "metadata": {}, 16 | "outputs": [ 17 | { 18 | "name": "stderr", 19 | "output_type": "stream", 20 | "text": [ 21 | "/Users/robbarto/Library/Python/3.9/lib/python/site-packages/urllib3/__init__.py:35: NotOpenSSLWarning: urllib3 v2 only supports OpenSSL 1.1.1+, currently the 'ssl' module is compiled with 'LibreSSL 2.8.3'. See: https://github.com/urllib3/urllib3/issues/3020\n", 22 | " warnings.warn(\n", 23 | "/Users/robbarto/Library/Python/3.9/lib/python/site-packages/paramiko/pkey.py:100: CryptographyDeprecationWarning: TripleDES has been moved to cryptography.hazmat.decrepit.ciphers.algorithms.TripleDES and will be removed from this module in 48.0.0.\n", 24 | " \"cipher\": algorithms.TripleDES,\n", 25 | "/Users/robbarto/Library/Python/3.9/lib/python/site-packages/paramiko/transport.py:259: CryptographyDeprecationWarning: TripleDES has been moved to cryptography.hazmat.decrepit.ciphers.algorithms.TripleDES and will be removed from this module in 48.0.0.\n", 26 | " \"class\": algorithms.TripleDES,\n" 27 | ] 28 | }, 29 | { 30 | "data": { 31 | "image/png": "", 32 | "text/plain": [ 33 | "
" 34 | ] 35 | }, 36 | "metadata": {}, 37 | "output_type": "display_data" 38 | } 39 | ], 40 | "source": [ 41 | "import numpy as np\n", 42 | "import matplotlib.pyplot as plt\n", 43 | "from mpl_toolkits.mplot3d import Axes3D\n", 44 | "from sklearn.decomposition import PCA\n", 45 | "from gensim.models import Word2Vec" 46 | ] 47 | }, 48 | { 49 | "cell_type": "raw", 50 | "id": "db2bdbba-036f-4365-b856-3ffb1289bb2b", 51 | "metadata": {}, 52 | "source": [ 53 | "Load the trained Word2Vec model, then get the word embeddings for all words.\n" 54 | ] 55 | }, 56 | { 57 | "cell_type": "code", 58 | "execution_count": null, 59 | "id": "00d68634-5789-43b5-b2fd-ff64fb2a5469", 60 | "metadata": {}, 61 | "outputs": [], 62 | "source": [ 63 | "model = Word2Vec.load(\"Section-7-word2vec_model.bin\")\n", 64 | "\n", 65 | "# Get word embeddings for all words in the vocabulary\n", 66 | "words = list(model.wv.key_to_index.keys())\n", 67 | "embeddings = [model.wv[word] for word in words]\n", 68 | "\n", 69 | "# Convert embeddings list to a NumPy array\n", 70 | "embeddings = np.array(embeddings)" 71 | ] 72 | }, 73 | { 74 | "cell_type": "raw", 75 | "id": "b1e0915c-bdfc-4362-9149-b462018f5924", 76 | "metadata": {}, 77 | "source": [ 78 | "Perform PCA to reduce the dimensionality to 3" 79 | ] 80 | }, 81 | { 82 | "cell_type": "code", 83 | "execution_count": null, 84 | "id": "d50b65a6-a353-4707-b19a-7b9aa6e032a0", 85 | "metadata": {}, 86 | "outputs": [], 87 | "source": [ 88 | "pca = PCA(n_components=3)\n", 89 | "embeddings_3d = pca.fit_transform(embeddings)" 90 | ] 91 | }, 92 | { 93 | "cell_type": "raw", 94 | "id": "93ddd2f9-2327-462c-8fb7-6322a5bee4ac", 95 | "metadata": {}, 96 | "source": [ 97 | "Plot the word embeddings in 3D space" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "c2b45bb2-3083-4243-aaf5-1c3a3832c0ca", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "fig = plt.figure(figsize=(10, 8))\n", 108 | "ax = fig.add_subplot(111, projection='3d')\n", 109 | "\n", 110 | "# Scatter plot for each word's 3D representation\n", 111 | "ax.scatter(embeddings_3d[:, 0], embeddings_3d[:, 1], embeddings_3d[:, 2], marker='o', color='b')\n", 112 | "\n", 113 | "# Annotate each point with the corresponding word\n", 114 | "for i, word in enumerate(words):\n", 115 | " ax.text(embeddings_3d[i, 0], embeddings_3d[i, 1], embeddings_3d[i, 2], word, fontsize=8)\n", 116 | "\n", 117 | "ax.set_xlabel('PCA Component 1')\n", 118 | "ax.set_ylabel('PCA Component 2')\n", 119 | "ax.set_zlabel('PCA Component 3')\n", 120 | "ax.set_title('3D Visualization of Word Embeddings using Word2Vec')\n", 121 | "\n", 122 | "plt.show()" 123 | ] 124 | }, 125 | { 126 | "cell_type": "code", 127 | "execution_count": null, 128 | "id": "15423ce3-393d-4d5f-8f17-8b039d58cc44", 129 | "metadata": {}, 130 | "outputs": [], 131 | "source": [] 132 | } 133 | ], 134 | "metadata": { 135 | "kernelspec": { 136 | "display_name": "Python 3 (ipykernel)", 137 | "language": "python", 138 | "name": "python3" 139 | }, 140 | "language_info": { 141 | "codemirror_mode": { 142 | "name": "ipython", 143 | "version": 3 144 | }, 145 | "file_extension": ".py", 146 | "mimetype": "text/x-python", 147 | "name": "python", 148 | "nbconvert_exporter": "python", 149 | "pygments_lexer": "ipython3", 150 | "version": "3.9.6" 151 | } 152 | }, 153 | "nbformat": 4, 154 | "nbformat_minor": 5 155 | } 156 | -------------------------------------------------------------------------------- /Section_7-Word2Vec.py: -------------------------------------------------------------------------------- 1 | from gensim.models import Word2Vec 2 | 3 | # Sample corpus (list of sentences) 4 | corpus = [ 5 | ["The", "early", "bird", "gets", "the", "worm"], 6 | ["Success", "requires", "hard", "work"] 7 | ] 8 | 9 | # Train Word2Vec model on the corpus 10 | model = Word2Vec(sentences=corpus, vector_size=100, window=5, min_count=1, sg=0) 11 | model.save("Section-7-word2vec_model.bin") 12 | --------------------------------------------------------------------------------