├── README.md ├── boston_housing ├── README.md ├── bj_housing.csv ├── boston_housing.ipynb ├── boston_housing.zip ├── housing.csv └── visuals.py ├── creating_customer_segments ├── README.md ├── customer_segments.ipynb ├── customer_segments.zip ├── customers.csv └── visuals.py ├── digit_recognition ├── README.md ├── digit_recognition.ipynb ├── model.png └── project_description.md ├── dog_vs_cat ├── README.md ├── dog_vs_cat (4).ipynb ├── file (1).csv ├── file (10).csv ├── file (11).csv ├── proposal.pdf └── report2.pdf ├── finding_donors ├── README.md ├── census.csv ├── finding_donors.ipynb ├── finding_donors.zip ├── project_description.md └── visuals.py ├── smartcab ├── README.md ├── agent.py ├── images │ ├── car-black.png │ ├── car-blue.png │ ├── car-cyan.png │ ├── car-green.png │ ├── car-magenta.png │ ├── car-orange.png │ ├── car-red.png │ ├── car-white.png │ ├── car-yellow.png │ ├── east-west.png │ ├── logo.png │ └── north-south.png ├── logs │ ├── sim_default-learning.csv │ ├── sim_default-learning.txt │ ├── sim_improved-learning.csv │ ├── sim_improved-learning.txt │ └── sim_no-learning.csv ├── project_description.md ├── smartcab.ipynb ├── smartcab.zip ├── smartcab │ ├── agent.py │ ├── environment.py │ ├── environment.pyc │ ├── planner.py │ ├── planner.pyc │ ├── simulator.py │ └── simulator.pyc ├── smartcab2.zip ├── smartcab3.zip └── visuals.py ├── student_intervention ├── README.md ├── student-data.csv ├── student_intervention.ipynb ├── student_intervention.zip └── student_intervention2.zip └── titanic_survival_exploration ├── README.md ├── titanic_data.csv ├── titanic_survival_exploration.ipynb ├── titanic_survival_exploration.zip └── titanic_visualizations.py /README.md: -------------------------------------------------------------------------------- 1 | # machine learning & deep learning project 2 | -------------------------------------------------------------------------------- /boston_housing/README.md: -------------------------------------------------------------------------------- 1 | # 项目1:模型评估与验证 2 | ## 波士顿房价预测 3 | 4 | ### 准备工作 5 | 6 | 这个项目需要安装**Python 2.7**和以下的Python函数库: 7 | 8 | - [NumPy](http://www.numpy.org/) 9 | - [matplotlib](http://matplotlib.org/) 10 | - [scikit-learn](http://scikit-learn.org/stable/) 11 | 12 | 你还需要安装一个软件,以运行和编辑[ipynb](http://jupyter.org/)文件。 13 | 14 | 优达学城推荐学生安装 [Anaconda](https://www.continuum.io/downloads),这是一个常用的Python集成编译环境,且已包含了本项目中所需的全部函数库。我们在P0项目中也有讲解[如何搭建学习环境](https://github.com/udacity/machine-learning/blob/master/projects_cn/titanic_survival_exploration/README.md)。 15 | 16 | ### 编码 17 | 18 | 代码的模版已经在`boston_housing.ipynb`文件中给出。你还会用到`visuals.py`和名为`housing.csv`的数据文件来完成这个项目。我们已经为你提供了一部分代码,但还有些功能需要你来实现才能以完成这个项目。 19 | 20 | ### 运行 21 | 22 | 在终端或命令行窗口中,选定`boston_housing/`的目录下(包含此README文件),运行下方的命令: 23 | 24 | ```jupyter notebook boston_housing.ipynb``` 25 | 26 | 这样就能够启动jupyter notebook软件,并在你的浏览器中打开文件。 27 | 28 | ### Data 29 | 30 | 经过编辑的波士顿房价数据集有490个数据点,每个点有三个特征。这个数据集编辑自[加州大学欧文分校机器学习数据集库](https://archive.ics.uci.edu/ml/datasets/Housing). 31 | 32 | **特征** 33 | 34 | 1. `RM`: 住宅平均房间数量 35 | 2. `LSTAT`: 区域中被认为是低收入阶层的比率 36 | 3. `PTRATIO`: 镇上学生与教师数量比例 37 | 38 | **目标变量** 39 | 40 | 4. `MEDV`: 房屋的中值价格 -------------------------------------------------------------------------------- /boston_housing/boston_housing.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/boston_housing/boston_housing.zip -------------------------------------------------------------------------------- /boston_housing/housing.csv: -------------------------------------------------------------------------------- 1 | RM,LSTAT,PTRATIO,MEDV 2 | 6.575,4.98,15.3,504000.0 3 | 6.421,9.14,17.8,453600.0 4 | 7.185,4.03,17.8,728700.0 5 | 6.998,2.94,18.7,701400.0 6 | 7.147,5.33,18.7,760200.0 7 | 6.43,5.21,18.7,602700.0 8 | 6.012,12.43,15.2,480900.0 9 | 6.172,19.15,15.2,569100.0 10 | 5.631,29.93,15.2,346500.0 11 | 6.004,17.1,15.2,396900.0 12 | 6.377,20.45,15.2,315000.0 13 | 6.009,13.27,15.2,396900.0 14 | 5.889,15.71,15.2,455700.0 15 | 5.949,8.26,21.0,428400.0 16 | 6.096,10.26,21.0,382200.0 17 | 5.834,8.47,21.0,417900.0 18 | 5.935,6.58,21.0,485100.0 19 | 5.99,14.67,21.0,367500.0 20 | 5.456,11.69,21.0,424200.0 21 | 5.727,11.28,21.0,382200.0 22 | 5.57,21.02,21.0,285600.0 23 | 5.965,13.83,21.0,411600.0 24 | 6.142,18.72,21.0,319200.0 25 | 5.813,19.88,21.0,304500.0 26 | 5.924,16.3,21.0,327600.0 27 | 5.599,16.51,21.0,291900.0 28 | 5.813,14.81,21.0,348600.0 29 | 6.047,17.28,21.0,310800.0 30 | 6.495,12.8,21.0,386400.0 31 | 6.674,11.98,21.0,441000.0 32 | 5.713,22.6,21.0,266700.0 33 | 6.072,13.04,21.0,304500.0 34 | 5.95,27.71,21.0,277200.0 35 | 5.701,18.35,21.0,275100.0 36 | 6.096,20.34,21.0,283500.0 37 | 5.933,9.68,19.2,396900.0 38 | 5.841,11.41,19.2,420000.0 39 | 5.85,8.77,19.2,441000.0 40 | 5.966,10.13,19.2,518700.0 41 | 6.595,4.32,18.3,646800.0 42 | 7.024,1.98,18.3,732900.0 43 | 6.77,4.84,17.9,558600.0 44 | 6.169,5.81,17.9,531300.0 45 | 6.211,7.44,17.9,518700.0 46 | 6.069,9.55,17.9,445200.0 47 | 5.682,10.21,17.9,405300.0 48 | 5.786,14.15,17.9,420000.0 49 | 6.03,18.8,17.9,348600.0 50 | 5.399,30.81,17.9,302400.0 51 | 5.602,16.2,17.9,407400.0 52 | 5.963,13.45,16.8,413700.0 53 | 6.115,9.43,16.8,430500.0 54 | 6.511,5.28,16.8,525000.0 55 | 5.998,8.43,16.8,491400.0 56 | 5.888,14.8,21.1,396900.0 57 | 7.249,4.81,17.9,743400.0 58 | 6.383,5.77,17.3,518700.0 59 | 6.816,3.95,15.1,663600.0 60 | 6.145,6.86,19.7,489300.0 61 | 5.927,9.22,19.7,411600.0 62 | 5.741,13.15,19.7,392700.0 63 | 5.966,14.44,19.7,336000.0 64 | 6.456,6.73,19.7,466200.0 65 | 6.762,9.5,19.7,525000.0 66 | 7.104,8.05,18.6,693000.0 67 | 6.29,4.67,16.1,493500.0 68 | 5.787,10.24,16.1,407400.0 69 | 5.878,8.1,18.9,462000.0 70 | 5.594,13.09,18.9,365400.0 71 | 5.885,8.79,18.9,438900.0 72 | 6.417,6.72,19.2,508200.0 73 | 5.961,9.88,19.2,455700.0 74 | 6.065,5.52,19.2,478800.0 75 | 6.245,7.54,19.2,491400.0 76 | 6.273,6.78,18.7,506100.0 77 | 6.286,8.94,18.7,449400.0 78 | 6.279,11.97,18.7,420000.0 79 | 6.14,10.27,18.7,436800.0 80 | 6.232,12.34,18.7,445200.0 81 | 5.874,9.1,18.7,426300.0 82 | 6.727,5.29,19.0,588000.0 83 | 6.619,7.22,19.0,501900.0 84 | 6.302,6.72,19.0,520800.0 85 | 6.167,7.51,19.0,480900.0 86 | 6.389,9.62,18.5,501900.0 87 | 6.63,6.53,18.5,558600.0 88 | 6.015,12.86,18.5,472500.0 89 | 6.121,8.44,18.5,466200.0 90 | 7.007,5.5,17.8,495600.0 91 | 7.079,5.7,17.8,602700.0 92 | 6.417,8.81,17.8,474600.0 93 | 6.405,8.2,17.8,462000.0 94 | 6.442,8.16,18.2,480900.0 95 | 6.211,6.21,18.2,525000.0 96 | 6.249,10.59,18.2,432600.0 97 | 6.625,6.65,18.0,596400.0 98 | 6.163,11.34,18.0,449400.0 99 | 8.069,4.21,18.0,812700.0 100 | 7.82,3.57,18.0,919800.0 101 | 7.416,6.19,18.0,697200.0 102 | 6.727,9.42,20.9,577500.0 103 | 6.781,7.67,20.9,556500.0 104 | 6.405,10.63,20.9,390600.0 105 | 6.137,13.44,20.9,405300.0 106 | 6.167,12.33,20.9,422100.0 107 | 5.851,16.47,20.9,409500.0 108 | 5.836,18.66,20.9,409500.0 109 | 6.127,14.09,20.9,428400.0 110 | 6.474,12.27,20.9,415800.0 111 | 6.229,15.55,20.9,407400.0 112 | 6.195,13.0,20.9,455700.0 113 | 6.715,10.16,17.8,478800.0 114 | 5.913,16.21,17.8,394800.0 115 | 6.092,17.09,17.8,392700.0 116 | 6.254,10.45,17.8,388500.0 117 | 5.928,15.76,17.8,384300.0 118 | 6.176,12.04,17.8,445200.0 119 | 6.021,10.3,17.8,403200.0 120 | 5.872,15.37,17.8,428400.0 121 | 5.731,13.61,17.8,405300.0 122 | 5.87,14.37,19.1,462000.0 123 | 6.004,14.27,19.1,426300.0 124 | 5.961,17.93,19.1,430500.0 125 | 5.856,25.41,19.1,363300.0 126 | 5.879,17.58,19.1,394800.0 127 | 5.986,14.81,19.1,449400.0 128 | 5.613,27.26,19.1,329700.0 129 | 5.693,17.19,21.2,340200.0 130 | 6.431,15.39,21.2,378000.0 131 | 5.637,18.34,21.2,300300.0 132 | 6.458,12.6,21.2,403200.0 133 | 6.326,12.26,21.2,411600.0 134 | 6.372,11.12,21.2,483000.0 135 | 5.822,15.03,21.2,386400.0 136 | 5.757,17.31,21.2,327600.0 137 | 6.335,16.96,21.2,380100.0 138 | 5.942,16.9,21.2,365400.0 139 | 6.454,14.59,21.2,359100.0 140 | 5.857,21.32,21.2,279300.0 141 | 6.151,18.46,21.2,373800.0 142 | 6.174,24.16,21.2,294000.0 143 | 5.019,34.41,21.2,302400.0 144 | 5.403,26.82,14.7,281400.0 145 | 5.468,26.42,14.7,327600.0 146 | 4.903,29.29,14.7,247800.0 147 | 6.13,27.8,14.7,289800.0 148 | 5.628,16.65,14.7,327600.0 149 | 4.926,29.53,14.7,306600.0 150 | 5.186,28.32,14.7,373800.0 151 | 5.597,21.45,14.7,323400.0 152 | 6.122,14.1,14.7,451500.0 153 | 5.404,13.28,14.7,411600.0 154 | 5.012,12.12,14.7,321300.0 155 | 5.709,15.79,14.7,407400.0 156 | 6.129,15.12,14.7,357000.0 157 | 6.152,15.02,14.7,327600.0 158 | 5.272,16.14,14.7,275100.0 159 | 6.943,4.59,14.7,867300.0 160 | 6.066,6.43,14.7,510300.0 161 | 6.51,7.39,14.7,489300.0 162 | 6.25,5.5,14.7,567000.0 163 | 5.854,11.64,14.7,476700.0 164 | 6.101,9.81,14.7,525000.0 165 | 5.877,12.14,14.7,499800.0 166 | 6.319,11.1,14.7,499800.0 167 | 6.402,11.32,14.7,468300.0 168 | 5.875,14.43,14.7,365400.0 169 | 5.88,12.03,14.7,401100.0 170 | 5.572,14.69,16.6,485100.0 171 | 6.416,9.04,16.6,495600.0 172 | 5.859,9.64,16.6,474600.0 173 | 6.546,5.33,16.6,617400.0 174 | 6.02,10.11,16.6,487200.0 175 | 6.315,6.29,16.6,516600.0 176 | 6.86,6.92,16.6,627900.0 177 | 6.98,5.04,17.8,781200.0 178 | 7.765,7.56,17.8,835800.0 179 | 6.144,9.45,17.8,760200.0 180 | 7.155,4.82,17.8,795900.0 181 | 6.563,5.68,17.8,682500.0 182 | 5.604,13.98,17.8,554400.0 183 | 6.153,13.15,17.8,621600.0 184 | 6.782,6.68,15.2,672000.0 185 | 6.556,4.56,15.2,625800.0 186 | 7.185,5.39,15.2,732900.0 187 | 6.951,5.1,15.2,777000.0 188 | 6.739,4.69,15.2,640500.0 189 | 7.178,2.87,15.2,764400.0 190 | 6.8,5.03,15.6,653100.0 191 | 6.604,4.38,15.6,611100.0 192 | 7.287,4.08,12.6,699300.0 193 | 7.107,8.61,12.6,636300.0 194 | 7.274,6.62,12.6,726600.0 195 | 6.975,4.56,17.0,732900.0 196 | 7.135,4.45,17.0,690900.0 197 | 6.162,7.43,14.7,506100.0 198 | 7.61,3.11,14.7,888300.0 199 | 7.853,3.81,14.7,1018500.0 200 | 5.891,10.87,18.6,474600.0 201 | 6.326,10.97,18.6,512400.0 202 | 5.783,18.06,18.6,472500.0 203 | 6.064,14.66,18.6,512400.0 204 | 5.344,23.09,18.6,420000.0 205 | 5.96,17.27,18.6,455700.0 206 | 5.404,23.98,18.6,405300.0 207 | 5.807,16.03,18.6,470400.0 208 | 6.375,9.38,18.6,590100.0 209 | 5.412,29.55,18.6,497700.0 210 | 6.182,9.47,18.6,525000.0 211 | 5.888,13.51,16.4,489300.0 212 | 6.642,9.69,16.4,602700.0 213 | 5.951,17.92,16.4,451500.0 214 | 6.373,10.5,16.4,483000.0 215 | 6.951,9.71,17.4,560700.0 216 | 6.164,21.46,17.4,455700.0 217 | 6.879,9.93,17.4,577500.0 218 | 6.618,7.6,17.4,632100.0 219 | 8.266,4.14,17.4,940800.0 220 | 8.04,3.13,17.4,789600.0 221 | 7.163,6.36,17.4,663600.0 222 | 7.686,3.92,17.4,980700.0 223 | 6.552,3.76,17.4,661500.0 224 | 5.981,11.65,17.4,510300.0 225 | 7.412,5.25,17.4,665700.0 226 | 8.337,2.47,17.4,875700.0 227 | 8.247,3.95,17.4,1014300.0 228 | 6.726,8.05,17.4,609000.0 229 | 6.086,10.88,17.4,504000.0 230 | 6.631,9.54,17.4,527100.0 231 | 7.358,4.73,17.4,661500.0 232 | 6.481,6.36,16.6,497700.0 233 | 6.606,7.37,16.6,489300.0 234 | 6.897,11.38,16.6,462000.0 235 | 6.095,12.4,16.6,422100.0 236 | 6.358,11.22,16.6,466200.0 237 | 6.393,5.19,16.6,497700.0 238 | 5.593,12.5,19.1,369600.0 239 | 5.605,18.46,19.1,388500.0 240 | 6.108,9.16,19.1,510300.0 241 | 6.226,10.15,19.1,430500.0 242 | 6.433,9.52,19.1,514500.0 243 | 6.718,6.56,19.1,550200.0 244 | 6.487,5.9,19.1,512400.0 245 | 6.438,3.59,19.1,520800.0 246 | 6.957,3.53,19.1,621600.0 247 | 8.259,3.54,19.1,898800.0 248 | 6.108,6.57,16.4,459900.0 249 | 5.876,9.25,16.4,438900.0 250 | 7.454,3.11,15.9,924000.0 251 | 7.333,7.79,13.0,756000.0 252 | 6.842,6.9,13.0,632100.0 253 | 7.203,9.59,13.0,709800.0 254 | 7.52,7.26,13.0,905100.0 255 | 8.398,5.91,13.0,1024800.0 256 | 7.327,11.25,13.0,651000.0 257 | 7.206,8.1,13.0,766500.0 258 | 5.56,10.45,13.0,478800.0 259 | 7.014,14.79,13.0,644700.0 260 | 7.47,3.16,13.0,913500.0 261 | 5.92,13.65,18.6,434700.0 262 | 5.856,13.0,18.6,443100.0 263 | 6.24,6.59,18.6,529200.0 264 | 6.538,7.73,18.6,512400.0 265 | 7.691,6.58,18.6,739200.0 266 | 6.758,3.53,17.6,680400.0 267 | 6.854,2.98,17.6,672000.0 268 | 7.267,6.05,17.6,697200.0 269 | 6.826,4.16,17.6,695100.0 270 | 6.482,7.19,17.6,611100.0 271 | 6.812,4.85,14.9,737100.0 272 | 7.82,3.76,14.9,953400.0 273 | 6.968,4.59,14.9,743400.0 274 | 7.645,3.01,14.9,966000.0 275 | 7.088,7.85,15.3,676200.0 276 | 6.453,8.23,15.3,462000.0 277 | 6.23,12.93,18.2,422100.0 278 | 6.209,7.14,16.6,487200.0 279 | 6.315,7.6,16.6,468300.0 280 | 6.565,9.51,16.6,520800.0 281 | 6.861,3.33,19.2,598500.0 282 | 7.148,3.56,19.2,783300.0 283 | 6.63,4.7,19.2,585900.0 284 | 6.127,8.58,16.0,501900.0 285 | 6.009,10.4,16.0,455700.0 286 | 6.678,6.27,16.0,600600.0 287 | 6.549,7.39,16.0,569100.0 288 | 5.79,15.84,16.0,426300.0 289 | 6.345,4.97,14.8,472500.0 290 | 7.041,4.74,14.8,609000.0 291 | 6.871,6.07,14.8,520800.0 292 | 6.59,9.5,16.1,462000.0 293 | 6.495,8.67,16.1,554400.0 294 | 6.982,4.86,16.1,695100.0 295 | 7.236,6.93,18.4,758100.0 296 | 6.616,8.93,18.4,596400.0 297 | 7.42,6.47,18.4,701400.0 298 | 6.849,7.53,18.4,592200.0 299 | 6.635,4.54,18.4,478800.0 300 | 5.972,9.97,18.4,426300.0 301 | 4.973,12.64,18.4,338100.0 302 | 6.122,5.98,18.4,464100.0 303 | 6.023,11.72,18.4,407400.0 304 | 6.266,7.9,18.4,453600.0 305 | 6.567,9.28,18.4,499800.0 306 | 5.705,11.5,18.4,340200.0 307 | 5.914,18.33,18.4,373800.0 308 | 5.782,15.94,18.4,415800.0 309 | 6.382,10.36,18.4,485100.0 310 | 6.113,12.73,18.4,441000.0 311 | 6.426,7.2,19.6,499800.0 312 | 6.376,6.87,19.6,485100.0 313 | 6.041,7.7,19.6,428400.0 314 | 5.708,11.74,19.6,388500.0 315 | 6.415,6.12,19.6,525000.0 316 | 6.431,5.08,19.6,516600.0 317 | 6.312,6.15,19.6,483000.0 318 | 6.083,12.79,19.6,466200.0 319 | 5.868,9.97,16.9,405300.0 320 | 6.333,7.34,16.9,474600.0 321 | 6.144,9.09,16.9,415800.0 322 | 5.706,12.43,16.9,359100.0 323 | 6.031,7.83,16.9,407400.0 324 | 6.316,5.68,20.2,466200.0 325 | 6.31,6.75,20.2,434700.0 326 | 6.037,8.01,20.2,443100.0 327 | 5.869,9.8,20.2,409500.0 328 | 5.895,10.56,20.2,388500.0 329 | 6.059,8.51,20.2,432600.0 330 | 5.985,9.74,20.2,399000.0 331 | 5.968,9.29,20.2,392700.0 332 | 7.241,5.49,15.5,686700.0 333 | 6.54,8.65,15.9,346500.0 334 | 6.696,7.18,17.6,501900.0 335 | 6.874,4.61,17.6,655200.0 336 | 6.014,10.53,18.8,367500.0 337 | 5.898,12.67,18.8,361200.0 338 | 6.516,6.36,17.9,485100.0 339 | 6.635,5.99,17.0,514500.0 340 | 6.939,5.89,19.7,558600.0 341 | 6.49,5.98,19.7,480900.0 342 | 6.579,5.49,18.3,506100.0 343 | 5.884,7.79,18.3,390600.0 344 | 6.728,4.5,17.0,632100.0 345 | 5.663,8.05,22.0,382200.0 346 | 5.936,5.57,22.0,432600.0 347 | 6.212,17.6,20.2,373800.0 348 | 6.395,13.27,20.2,455700.0 349 | 6.127,11.48,20.2,476700.0 350 | 6.112,12.67,20.2,474600.0 351 | 6.398,7.79,20.2,525000.0 352 | 6.251,14.19,20.2,417900.0 353 | 5.362,10.19,20.2,436800.0 354 | 5.803,14.64,20.2,352800.0 355 | 3.561,7.12,20.2,577500.0 356 | 4.963,14.0,20.2,459900.0 357 | 3.863,13.33,20.2,485100.0 358 | 4.906,34.77,20.2,289800.0 359 | 4.138,37.97,20.2,289800.0 360 | 7.313,13.44,20.2,315000.0 361 | 6.649,23.24,20.2,291900.0 362 | 6.794,21.24,20.2,279300.0 363 | 6.38,23.69,20.2,275100.0 364 | 6.223,21.78,20.2,214200.0 365 | 6.968,17.21,20.2,218400.0 366 | 6.545,21.08,20.2,228900.0 367 | 5.536,23.6,20.2,237300.0 368 | 5.52,24.56,20.2,258300.0 369 | 4.368,30.63,20.2,184800.0 370 | 5.277,30.81,20.2,151200.0 371 | 4.652,28.28,20.2,220500.0 372 | 5.0,31.99,20.2,155400.0 373 | 4.88,30.62,20.2,214200.0 374 | 5.39,20.85,20.2,241500.0 375 | 5.713,17.11,20.2,317100.0 376 | 6.051,18.76,20.2,487200.0 377 | 5.036,25.68,20.2,203700.0 378 | 6.193,15.17,20.2,289800.0 379 | 5.887,16.35,20.2,266700.0 380 | 6.471,17.12,20.2,275100.0 381 | 6.405,19.37,20.2,262500.0 382 | 5.747,19.92,20.2,178500.0 383 | 5.453,30.59,20.2,105000.0 384 | 5.852,29.97,20.2,132300.0 385 | 5.987,26.77,20.2,117600.0 386 | 6.343,20.32,20.2,151200.0 387 | 6.404,20.31,20.2,254100.0 388 | 5.349,19.77,20.2,174300.0 389 | 5.531,27.38,20.2,178500.0 390 | 5.683,22.98,20.2,105000.0 391 | 4.138,23.34,20.2,249900.0 392 | 5.608,12.13,20.2,585900.0 393 | 5.617,26.4,20.2,361200.0 394 | 6.852,19.78,20.2,577500.0 395 | 5.757,10.11,20.2,315000.0 396 | 6.657,21.22,20.2,361200.0 397 | 4.628,34.37,20.2,375900.0 398 | 5.155,20.08,20.2,342300.0 399 | 4.519,36.98,20.2,147000.0 400 | 6.434,29.05,20.2,151200.0 401 | 6.782,25.79,20.2,157500.0 402 | 5.304,26.64,20.2,218400.0 403 | 5.957,20.62,20.2,184800.0 404 | 6.824,22.74,20.2,176400.0 405 | 6.411,15.02,20.2,350700.0 406 | 6.006,15.7,20.2,298200.0 407 | 5.648,14.1,20.2,436800.0 408 | 6.103,23.29,20.2,281400.0 409 | 5.565,17.16,20.2,245700.0 410 | 5.896,24.39,20.2,174300.0 411 | 5.837,15.69,20.2,214200.0 412 | 6.202,14.52,20.2,228900.0 413 | 6.193,21.52,20.2,231000.0 414 | 6.38,24.08,20.2,199500.0 415 | 6.348,17.64,20.2,304500.0 416 | 6.833,19.69,20.2,296100.0 417 | 6.425,12.03,20.2,338100.0 418 | 6.436,16.22,20.2,300300.0 419 | 6.208,15.17,20.2,245700.0 420 | 6.629,23.27,20.2,281400.0 421 | 6.461,18.05,20.2,201600.0 422 | 6.152,26.45,20.2,182700.0 423 | 5.935,34.02,20.2,176400.0 424 | 5.627,22.88,20.2,268800.0 425 | 5.818,22.11,20.2,220500.0 426 | 6.406,19.52,20.2,359100.0 427 | 6.219,16.59,20.2,386400.0 428 | 6.485,18.85,20.2,323400.0 429 | 5.854,23.79,20.2,226800.0 430 | 6.459,23.98,20.2,247800.0 431 | 6.341,17.79,20.2,312900.0 432 | 6.251,16.44,20.2,264600.0 433 | 6.185,18.13,20.2,296100.0 434 | 6.417,19.31,20.2,273000.0 435 | 6.749,17.44,20.2,281400.0 436 | 6.655,17.73,20.2,319200.0 437 | 6.297,17.27,20.2,338100.0 438 | 7.393,16.74,20.2,373800.0 439 | 6.728,18.71,20.2,312900.0 440 | 6.525,18.13,20.2,296100.0 441 | 5.976,19.01,20.2,266700.0 442 | 5.936,16.94,20.2,283500.0 443 | 6.301,16.23,20.2,312900.0 444 | 6.081,14.7,20.2,420000.0 445 | 6.701,16.42,20.2,344400.0 446 | 6.376,14.65,20.2,371700.0 447 | 6.317,13.99,20.2,409500.0 448 | 6.513,10.29,20.2,424200.0 449 | 6.209,13.22,20.2,449400.0 450 | 5.759,14.13,20.2,417900.0 451 | 5.952,17.15,20.2,399000.0 452 | 6.003,21.32,20.2,401100.0 453 | 5.926,18.13,20.2,401100.0 454 | 5.713,14.76,20.2,422100.0 455 | 6.167,16.29,20.2,417900.0 456 | 6.229,12.87,20.2,411600.0 457 | 6.437,14.36,20.2,487200.0 458 | 6.98,11.66,20.2,625800.0 459 | 5.427,18.14,20.2,289800.0 460 | 6.162,24.1,20.2,279300.0 461 | 6.484,18.68,20.2,350700.0 462 | 5.304,24.91,20.2,252000.0 463 | 6.185,18.03,20.2,306600.0 464 | 6.229,13.11,20.2,449400.0 465 | 6.242,10.74,20.2,483000.0 466 | 6.75,7.74,20.2,497700.0 467 | 7.061,7.01,20.2,525000.0 468 | 5.762,10.42,20.2,457800.0 469 | 5.871,13.34,20.2,432600.0 470 | 6.312,10.58,20.2,445200.0 471 | 6.114,14.98,20.2,401100.0 472 | 5.905,11.45,20.2,432600.0 473 | 5.454,18.06,20.1,319200.0 474 | 5.414,23.97,20.1,147000.0 475 | 5.093,29.68,20.1,170100.0 476 | 5.983,18.07,20.1,285600.0 477 | 5.983,13.35,20.1,422100.0 478 | 5.707,12.01,19.2,457800.0 479 | 5.926,13.59,19.2,514500.0 480 | 5.67,17.6,19.2,485100.0 481 | 5.39,21.14,19.2,413700.0 482 | 5.794,14.1,19.2,384300.0 483 | 6.019,12.92,19.2,445200.0 484 | 5.569,15.1,19.2,367500.0 485 | 6.027,14.33,19.2,352800.0 486 | 6.593,9.67,21.0,470400.0 487 | 6.12,9.08,21.0,432600.0 488 | 6.976,5.64,21.0,501900.0 489 | 6.794,6.48,21.0,462000.0 490 | 6.03,7.88,21.0,249900.0 491 | -------------------------------------------------------------------------------- /boston_housing/visuals.py: -------------------------------------------------------------------------------- 1 | ########################################### 2 | # Suppress matplotlib user warnings 3 | # Necessary for newer version of matplotlib 4 | import warnings 5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib") 6 | ########################################### 7 | 8 | import matplotlib.pyplot as pl 9 | import numpy as np 10 | import sklearn.learning_curve as curves 11 | from sklearn.tree import DecisionTreeRegressor 12 | from sklearn.cross_validation import ShuffleSplit, train_test_split 13 | 14 | def ModelLearning(X, y): 15 | """ Calculates the performance of several models with varying sizes of training data. 16 | The learning and testing scores for each model are then plotted. """ 17 | 18 | # Create 10 cross-validation sets for training and testing 19 | cv = ShuffleSplit(X.shape[0], n_iter = 10, test_size = 0.2, random_state = 0) 20 | 21 | # Generate the training set sizes increasing by 50 22 | train_sizes = np.rint(np.linspace(1, X.shape[0]*0.8 - 1, 9)).astype(int) 23 | 24 | # Create the figure window 25 | fig = pl.figure(figsize=(10,7)) 26 | 27 | # Create three different models based on max_depth 28 | for k, depth in enumerate([1,3,6,10]): 29 | 30 | # Create a Decision tree regressor at max_depth = depth 31 | regressor = DecisionTreeRegressor(max_depth = depth) 32 | 33 | # Calculate the training and testing scores 34 | sizes, train_scores, test_scores = curves.learning_curve(regressor, X, y, \ 35 | cv = cv, train_sizes = train_sizes, scoring = 'r2') 36 | 37 | # Find the mean and standard deviation for smoothing 38 | train_std = np.std(train_scores, axis = 1) 39 | train_mean = np.mean(train_scores, axis = 1) 40 | test_std = np.std(test_scores, axis = 1) 41 | test_mean = np.mean(test_scores, axis = 1) 42 | 43 | # Subplot the learning curve 44 | ax = fig.add_subplot(2, 2, k+1) 45 | ax.plot(sizes, train_mean, 'o-', color = 'r', label = 'Training Score') 46 | ax.plot(sizes, test_mean, 'o-', color = 'g', label = 'Testing Score') 47 | ax.fill_between(sizes, train_mean - train_std, \ 48 | train_mean + train_std, alpha = 0.15, color = 'r') 49 | ax.fill_between(sizes, test_mean - test_std, \ 50 | test_mean + test_std, alpha = 0.15, color = 'g') 51 | 52 | # Labels 53 | ax.set_title('max_depth = %s'%(depth)) 54 | ax.set_xlabel('Number of Training Points') 55 | ax.set_ylabel('Score') 56 | ax.set_xlim([0, X.shape[0]*0.8]) 57 | ax.set_ylim([-0.05, 1.05]) 58 | 59 | # Visual aesthetics 60 | ax.legend(bbox_to_anchor=(1.05, 2.05), loc='lower left', borderaxespad = 0.) 61 | fig.suptitle('Decision Tree Regressor Learning Performances', fontsize = 16, y = 1.03) 62 | fig.tight_layout() 63 | fig.show() 64 | 65 | 66 | def ModelComplexity(X, y): 67 | """ Calculates the performance of the model as model complexity increases. 68 | The learning and testing errors rates are then plotted. """ 69 | 70 | # Create 10 cross-validation sets for training and testing 71 | cv = ShuffleSplit(X.shape[0], n_iter = 10, test_size = 0.2, random_state = 0) 72 | 73 | # Vary the max_depth parameter from 1 to 10 74 | max_depth = np.arange(1,11) 75 | 76 | # Calculate the training and testing scores 77 | train_scores, test_scores = curves.validation_curve(DecisionTreeRegressor(), X, y, \ 78 | param_name = "max_depth", param_range = max_depth, cv = cv, scoring = 'r2') 79 | 80 | # Find the mean and standard deviation for smoothing 81 | train_mean = np.mean(train_scores, axis=1) 82 | train_std = np.std(train_scores, axis=1) 83 | test_mean = np.mean(test_scores, axis=1) 84 | test_std = np.std(test_scores, axis=1) 85 | 86 | # Plot the validation curve 87 | pl.figure(figsize=(7, 5)) 88 | pl.title('Decision Tree Regressor Complexity Performance') 89 | pl.plot(max_depth, train_mean, 'o-', color = 'r', label = 'Training Score') 90 | pl.plot(max_depth, test_mean, 'o-', color = 'g', label = 'Validation Score') 91 | pl.fill_between(max_depth, train_mean - train_std, \ 92 | train_mean + train_std, alpha = 0.15, color = 'r') 93 | pl.fill_between(max_depth, test_mean - test_std, \ 94 | test_mean + test_std, alpha = 0.15, color = 'g') 95 | 96 | # Visual aesthetics 97 | pl.legend(loc = 'lower right') 98 | pl.xlabel('Maximum Depth') 99 | pl.ylabel('Score') 100 | pl.ylim([-0.05,1.05]) 101 | pl.show() 102 | 103 | 104 | def PredictTrials(X, y, fitter, data): 105 | """ Performs trials of fitting and predicting data. """ 106 | 107 | # Store the predicted prices 108 | prices = [] 109 | 110 | for k in range(10): 111 | # Split the data 112 | X_train, X_test, y_train, y_test = train_test_split(X, y, \ 113 | test_size = 0.2, random_state = k) 114 | 115 | # Fit the data 116 | reg = fitter(X_train, y_train) 117 | 118 | # Make a prediction 119 | pred = reg.predict([data[0]])[0] 120 | prices.append(pred) 121 | 122 | # Result 123 | print "Trial {}: ${:,.2f}".format(k+1, pred) 124 | 125 | # Display price range 126 | print "\nRange in prices: ${:,.2f}".format(max(prices) - min(prices)) -------------------------------------------------------------------------------- /creating_customer_segments/README.md: -------------------------------------------------------------------------------- 1 | # 项目 3: 非监督学习 2 | ## 创建用户细分 3 | 4 | ### 安装 5 | 6 | 这个项目要求使用 **Python 2.7** 并且需要安装下面这些python包: 7 | 8 | - [NumPy](http://www.numpy.org/) 9 | - [Pandas](http://pandas.pydata.org) 10 | - [scikit-learn](http://scikit-learn.org/stable/) 11 | 12 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://jupyter.org/)。 13 | 14 | 优达学城推荐学生安装 [Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。 15 | 16 | ### 代码 17 | 18 | 初始代码包含在 `customer_segments.ipynb` 这个notebook文件中。这里面有一些代码已经实现好来帮助你开始项目,但是为了完成项目,你还需要实现附加的功能。 19 | 20 | ### 运行 21 | 22 | 在命令行中,确保当前目录为 `customer_segments.ipynb` 文件夹的最顶层(目录包含本 README 文件),运行下列命令: 23 | 24 | ```jupyter notebook customer_segments.ipynb``` 25 | 26 | ​这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。 27 | 28 | ## 数据 29 | 30 | ​这个项目的数据包含在 `customers.csv` 文件中。你能在[UCI 机器学习信息库](https://archive.ics.uci.edu/ml/datasets/Wholesale+customers)页面中找到更多信息。 31 | -------------------------------------------------------------------------------- /creating_customer_segments/customer_segments.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/creating_customer_segments/customer_segments.zip -------------------------------------------------------------------------------- /creating_customer_segments/customers.csv: -------------------------------------------------------------------------------- 1 | Channel,Region,Fresh,Milk,Grocery,Frozen,Detergents_Paper,Delicatessen 2 | 2,3,12669,9656,7561,214,2674,1338 3 | 2,3,7057,9810,9568,1762,3293,1776 4 | 2,3,6353,8808,7684,2405,3516,7844 5 | 1,3,13265,1196,4221,6404,507,1788 6 | 2,3,22615,5410,7198,3915,1777,5185 7 | 2,3,9413,8259,5126,666,1795,1451 8 | 2,3,12126,3199,6975,480,3140,545 9 | 2,3,7579,4956,9426,1669,3321,2566 10 | 1,3,5963,3648,6192,425,1716,750 11 | 2,3,6006,11093,18881,1159,7425,2098 12 | 2,3,3366,5403,12974,4400,5977,1744 13 | 2,3,13146,1124,4523,1420,549,497 14 | 2,3,31714,12319,11757,287,3881,2931 15 | 2,3,21217,6208,14982,3095,6707,602 16 | 2,3,24653,9465,12091,294,5058,2168 17 | 1,3,10253,1114,3821,397,964,412 18 | 2,3,1020,8816,12121,134,4508,1080 19 | 1,3,5876,6157,2933,839,370,4478 20 | 2,3,18601,6327,10099,2205,2767,3181 21 | 1,3,7780,2495,9464,669,2518,501 22 | 2,3,17546,4519,4602,1066,2259,2124 23 | 1,3,5567,871,2010,3383,375,569 24 | 1,3,31276,1917,4469,9408,2381,4334 25 | 2,3,26373,36423,22019,5154,4337,16523 26 | 2,3,22647,9776,13792,2915,4482,5778 27 | 2,3,16165,4230,7595,201,4003,57 28 | 1,3,9898,961,2861,3151,242,833 29 | 1,3,14276,803,3045,485,100,518 30 | 2,3,4113,20484,25957,1158,8604,5206 31 | 1,3,43088,2100,2609,1200,1107,823 32 | 1,3,18815,3610,11107,1148,2134,2963 33 | 1,3,2612,4339,3133,2088,820,985 34 | 1,3,21632,1318,2886,266,918,405 35 | 1,3,29729,4786,7326,6130,361,1083 36 | 1,3,1502,1979,2262,425,483,395 37 | 2,3,688,5491,11091,833,4239,436 38 | 1,3,29955,4362,5428,1729,862,4626 39 | 2,3,15168,10556,12477,1920,6506,714 40 | 2,3,4591,15729,16709,33,6956,433 41 | 1,3,56159,555,902,10002,212,2916 42 | 1,3,24025,4332,4757,9510,1145,5864 43 | 1,3,19176,3065,5956,2033,2575,2802 44 | 2,3,10850,7555,14961,188,6899,46 45 | 2,3,630,11095,23998,787,9529,72 46 | 2,3,9670,7027,10471,541,4618,65 47 | 2,3,5181,22044,21531,1740,7353,4985 48 | 2,3,3103,14069,21955,1668,6792,1452 49 | 2,3,44466,54259,55571,7782,24171,6465 50 | 2,3,11519,6152,10868,584,5121,1476 51 | 2,3,4967,21412,28921,1798,13583,1163 52 | 1,3,6269,1095,1980,3860,609,2162 53 | 1,3,3347,4051,6996,239,1538,301 54 | 2,3,40721,3916,5876,532,2587,1278 55 | 2,3,491,10473,11532,744,5611,224 56 | 1,3,27329,1449,1947,2436,204,1333 57 | 1,3,5264,3683,5005,1057,2024,1130 58 | 2,3,4098,29892,26866,2616,17740,1340 59 | 2,3,5417,9933,10487,38,7572,1282 60 | 1,3,13779,1970,1648,596,227,436 61 | 1,3,6137,5360,8040,129,3084,1603 62 | 2,3,8590,3045,7854,96,4095,225 63 | 2,3,35942,38369,59598,3254,26701,2017 64 | 2,3,7823,6245,6544,4154,4074,964 65 | 2,3,9396,11601,15775,2896,7677,1295 66 | 1,3,4760,1227,3250,3724,1247,1145 67 | 2,3,85,20959,45828,36,24231,1423 68 | 1,3,9,1534,7417,175,3468,27 69 | 2,3,19913,6759,13462,1256,5141,834 70 | 1,3,2446,7260,3993,5870,788,3095 71 | 1,3,8352,2820,1293,779,656,144 72 | 1,3,16705,2037,3202,10643,116,1365 73 | 1,3,18291,1266,21042,5373,4173,14472 74 | 1,3,4420,5139,2661,8872,1321,181 75 | 2,3,19899,5332,8713,8132,764,648 76 | 2,3,8190,6343,9794,1285,1901,1780 77 | 1,3,20398,1137,3,4407,3,975 78 | 1,3,717,3587,6532,7530,529,894 79 | 2,3,12205,12697,28540,869,12034,1009 80 | 1,3,10766,1175,2067,2096,301,167 81 | 1,3,1640,3259,3655,868,1202,1653 82 | 1,3,7005,829,3009,430,610,529 83 | 2,3,219,9540,14403,283,7818,156 84 | 2,3,10362,9232,11009,737,3537,2342 85 | 1,3,20874,1563,1783,2320,550,772 86 | 2,3,11867,3327,4814,1178,3837,120 87 | 2,3,16117,46197,92780,1026,40827,2944 88 | 2,3,22925,73498,32114,987,20070,903 89 | 1,3,43265,5025,8117,6312,1579,14351 90 | 1,3,7864,542,4042,9735,165,46 91 | 1,3,24904,3836,5330,3443,454,3178 92 | 1,3,11405,596,1638,3347,69,360 93 | 1,3,12754,2762,2530,8693,627,1117 94 | 2,3,9198,27472,32034,3232,18906,5130 95 | 1,3,11314,3090,2062,35009,71,2698 96 | 2,3,5626,12220,11323,206,5038,244 97 | 1,3,3,2920,6252,440,223,709 98 | 2,3,23,2616,8118,145,3874,217 99 | 1,3,403,254,610,774,54,63 100 | 1,3,503,112,778,895,56,132 101 | 1,3,9658,2182,1909,5639,215,323 102 | 2,3,11594,7779,12144,3252,8035,3029 103 | 2,3,1420,10810,16267,1593,6766,1838 104 | 2,3,2932,6459,7677,2561,4573,1386 105 | 1,3,56082,3504,8906,18028,1480,2498 106 | 1,3,14100,2132,3445,1336,1491,548 107 | 1,3,15587,1014,3970,910,139,1378 108 | 2,3,1454,6337,10704,133,6830,1831 109 | 2,3,8797,10646,14886,2471,8969,1438 110 | 2,3,1531,8397,6981,247,2505,1236 111 | 2,3,1406,16729,28986,673,836,3 112 | 1,3,11818,1648,1694,2276,169,1647 113 | 2,3,12579,11114,17569,805,6457,1519 114 | 1,3,19046,2770,2469,8853,483,2708 115 | 1,3,14438,2295,1733,3220,585,1561 116 | 1,3,18044,1080,2000,2555,118,1266 117 | 1,3,11134,793,2988,2715,276,610 118 | 1,3,11173,2521,3355,1517,310,222 119 | 1,3,6990,3880,5380,1647,319,1160 120 | 1,3,20049,1891,2362,5343,411,933 121 | 1,3,8258,2344,2147,3896,266,635 122 | 1,3,17160,1200,3412,2417,174,1136 123 | 1,3,4020,3234,1498,2395,264,255 124 | 1,3,12212,201,245,1991,25,860 125 | 2,3,11170,10769,8814,2194,1976,143 126 | 1,3,36050,1642,2961,4787,500,1621 127 | 1,3,76237,3473,7102,16538,778,918 128 | 1,3,19219,1840,1658,8195,349,483 129 | 2,3,21465,7243,10685,880,2386,2749 130 | 1,3,140,8847,3823,142,1062,3 131 | 1,3,42312,926,1510,1718,410,1819 132 | 1,3,7149,2428,699,6316,395,911 133 | 1,3,2101,589,314,346,70,310 134 | 1,3,14903,2032,2479,576,955,328 135 | 1,3,9434,1042,1235,436,256,396 136 | 1,3,7388,1882,2174,720,47,537 137 | 1,3,6300,1289,2591,1170,199,326 138 | 1,3,4625,8579,7030,4575,2447,1542 139 | 1,3,3087,8080,8282,661,721,36 140 | 1,3,13537,4257,5034,155,249,3271 141 | 1,3,5387,4979,3343,825,637,929 142 | 1,3,17623,4280,7305,2279,960,2616 143 | 1,3,30379,13252,5189,321,51,1450 144 | 1,3,37036,7152,8253,2995,20,3 145 | 1,3,10405,1596,1096,8425,399,318 146 | 1,3,18827,3677,1988,118,516,201 147 | 2,3,22039,8384,34792,42,12591,4430 148 | 1,3,7769,1936,2177,926,73,520 149 | 1,3,9203,3373,2707,1286,1082,526 150 | 1,3,5924,584,542,4052,283,434 151 | 1,3,31812,1433,1651,800,113,1440 152 | 1,3,16225,1825,1765,853,170,1067 153 | 1,3,1289,3328,2022,531,255,1774 154 | 1,3,18840,1371,3135,3001,352,184 155 | 1,3,3463,9250,2368,779,302,1627 156 | 1,3,622,55,137,75,7,8 157 | 2,3,1989,10690,19460,233,11577,2153 158 | 2,3,3830,5291,14855,317,6694,3182 159 | 1,3,17773,1366,2474,3378,811,418 160 | 2,3,2861,6570,9618,930,4004,1682 161 | 2,3,355,7704,14682,398,8077,303 162 | 2,3,1725,3651,12822,824,4424,2157 163 | 1,3,12434,540,283,1092,3,2233 164 | 1,3,15177,2024,3810,2665,232,610 165 | 2,3,5531,15726,26870,2367,13726,446 166 | 2,3,5224,7603,8584,2540,3674,238 167 | 2,3,15615,12653,19858,4425,7108,2379 168 | 2,3,4822,6721,9170,993,4973,3637 169 | 1,3,2926,3195,3268,405,1680,693 170 | 1,3,5809,735,803,1393,79,429 171 | 1,3,5414,717,2155,2399,69,750 172 | 2,3,260,8675,13430,1116,7015,323 173 | 2,3,200,25862,19816,651,8773,6250 174 | 1,3,955,5479,6536,333,2840,707 175 | 2,3,514,7677,19805,937,9836,716 176 | 1,3,286,1208,5241,2515,153,1442 177 | 2,3,2343,7845,11874,52,4196,1697 178 | 1,3,45640,6958,6536,7368,1532,230 179 | 1,3,12759,7330,4533,1752,20,2631 180 | 1,3,11002,7075,4945,1152,120,395 181 | 1,3,3157,4888,2500,4477,273,2165 182 | 1,3,12356,6036,8887,402,1382,2794 183 | 1,3,112151,29627,18148,16745,4948,8550 184 | 1,3,694,8533,10518,443,6907,156 185 | 1,3,36847,43950,20170,36534,239,47943 186 | 1,3,327,918,4710,74,334,11 187 | 1,3,8170,6448,1139,2181,58,247 188 | 1,3,3009,521,854,3470,949,727 189 | 1,3,2438,8002,9819,6269,3459,3 190 | 2,3,8040,7639,11687,2758,6839,404 191 | 2,3,834,11577,11522,275,4027,1856 192 | 1,3,16936,6250,1981,7332,118,64 193 | 1,3,13624,295,1381,890,43,84 194 | 1,3,5509,1461,2251,547,187,409 195 | 2,3,180,3485,20292,959,5618,666 196 | 1,3,7107,1012,2974,806,355,1142 197 | 1,3,17023,5139,5230,7888,330,1755 198 | 1,1,30624,7209,4897,18711,763,2876 199 | 2,1,2427,7097,10391,1127,4314,1468 200 | 1,1,11686,2154,6824,3527,592,697 201 | 1,1,9670,2280,2112,520,402,347 202 | 2,1,3067,13240,23127,3941,9959,731 203 | 2,1,4484,14399,24708,3549,14235,1681 204 | 1,1,25203,11487,9490,5065,284,6854 205 | 1,1,583,685,2216,469,954,18 206 | 1,1,1956,891,5226,1383,5,1328 207 | 2,1,1107,11711,23596,955,9265,710 208 | 1,1,6373,780,950,878,288,285 209 | 2,1,2541,4737,6089,2946,5316,120 210 | 1,1,1537,3748,5838,1859,3381,806 211 | 2,1,5550,12729,16767,864,12420,797 212 | 1,1,18567,1895,1393,1801,244,2100 213 | 2,1,12119,28326,39694,4736,19410,2870 214 | 1,1,7291,1012,2062,1291,240,1775 215 | 1,1,3317,6602,6861,1329,3961,1215 216 | 2,1,2362,6551,11364,913,5957,791 217 | 1,1,2806,10765,15538,1374,5828,2388 218 | 2,1,2532,16599,36486,179,13308,674 219 | 1,1,18044,1475,2046,2532,130,1158 220 | 2,1,18,7504,15205,1285,4797,6372 221 | 1,1,4155,367,1390,2306,86,130 222 | 1,1,14755,899,1382,1765,56,749 223 | 1,1,5396,7503,10646,91,4167,239 224 | 1,1,5041,1115,2856,7496,256,375 225 | 2,1,2790,2527,5265,5612,788,1360 226 | 1,1,7274,659,1499,784,70,659 227 | 1,1,12680,3243,4157,660,761,786 228 | 2,1,20782,5921,9212,1759,2568,1553 229 | 1,1,4042,2204,1563,2286,263,689 230 | 1,1,1869,577,572,950,4762,203 231 | 1,1,8656,2746,2501,6845,694,980 232 | 2,1,11072,5989,5615,8321,955,2137 233 | 1,1,2344,10678,3828,1439,1566,490 234 | 1,1,25962,1780,3838,638,284,834 235 | 1,1,964,4984,3316,937,409,7 236 | 1,1,15603,2703,3833,4260,325,2563 237 | 1,1,1838,6380,2824,1218,1216,295 238 | 1,1,8635,820,3047,2312,415,225 239 | 1,1,18692,3838,593,4634,28,1215 240 | 1,1,7363,475,585,1112,72,216 241 | 1,1,47493,2567,3779,5243,828,2253 242 | 1,1,22096,3575,7041,11422,343,2564 243 | 1,1,24929,1801,2475,2216,412,1047 244 | 1,1,18226,659,2914,3752,586,578 245 | 1,1,11210,3576,5119,561,1682,2398 246 | 1,1,6202,7775,10817,1183,3143,1970 247 | 2,1,3062,6154,13916,230,8933,2784 248 | 1,1,8885,2428,1777,1777,430,610 249 | 1,1,13569,346,489,2077,44,659 250 | 1,1,15671,5279,2406,559,562,572 251 | 1,1,8040,3795,2070,6340,918,291 252 | 1,1,3191,1993,1799,1730,234,710 253 | 2,1,6134,23133,33586,6746,18594,5121 254 | 1,1,6623,1860,4740,7683,205,1693 255 | 1,1,29526,7961,16966,432,363,1391 256 | 1,1,10379,17972,4748,4686,1547,3265 257 | 1,1,31614,489,1495,3242,111,615 258 | 1,1,11092,5008,5249,453,392,373 259 | 1,1,8475,1931,1883,5004,3593,987 260 | 1,1,56083,4563,2124,6422,730,3321 261 | 1,1,53205,4959,7336,3012,967,818 262 | 1,1,9193,4885,2157,327,780,548 263 | 1,1,7858,1110,1094,6818,49,287 264 | 1,1,23257,1372,1677,982,429,655 265 | 1,1,2153,1115,6684,4324,2894,411 266 | 2,1,1073,9679,15445,61,5980,1265 267 | 1,1,5909,23527,13699,10155,830,3636 268 | 2,1,572,9763,22182,2221,4882,2563 269 | 1,1,20893,1222,2576,3975,737,3628 270 | 2,1,11908,8053,19847,1069,6374,698 271 | 1,1,15218,258,1138,2516,333,204 272 | 1,1,4720,1032,975,5500,197,56 273 | 1,1,2083,5007,1563,1120,147,1550 274 | 1,1,514,8323,6869,529,93,1040 275 | 1,3,36817,3045,1493,4802,210,1824 276 | 1,3,894,1703,1841,744,759,1153 277 | 1,3,680,1610,223,862,96,379 278 | 1,3,27901,3749,6964,4479,603,2503 279 | 1,3,9061,829,683,16919,621,139 280 | 1,3,11693,2317,2543,5845,274,1409 281 | 2,3,17360,6200,9694,1293,3620,1721 282 | 1,3,3366,2884,2431,977,167,1104 283 | 2,3,12238,7108,6235,1093,2328,2079 284 | 1,3,49063,3965,4252,5970,1041,1404 285 | 1,3,25767,3613,2013,10303,314,1384 286 | 1,3,68951,4411,12609,8692,751,2406 287 | 1,3,40254,640,3600,1042,436,18 288 | 1,3,7149,2247,1242,1619,1226,128 289 | 1,3,15354,2102,2828,8366,386,1027 290 | 1,3,16260,594,1296,848,445,258 291 | 1,3,42786,286,471,1388,32,22 292 | 1,3,2708,2160,2642,502,965,1522 293 | 1,3,6022,3354,3261,2507,212,686 294 | 1,3,2838,3086,4329,3838,825,1060 295 | 2,2,3996,11103,12469,902,5952,741 296 | 1,2,21273,2013,6550,909,811,1854 297 | 2,2,7588,1897,5234,417,2208,254 298 | 1,2,19087,1304,3643,3045,710,898 299 | 2,2,8090,3199,6986,1455,3712,531 300 | 2,2,6758,4560,9965,934,4538,1037 301 | 1,2,444,879,2060,264,290,259 302 | 2,2,16448,6243,6360,824,2662,2005 303 | 2,2,5283,13316,20399,1809,8752,172 304 | 2,2,2886,5302,9785,364,6236,555 305 | 2,2,2599,3688,13829,492,10069,59 306 | 2,2,161,7460,24773,617,11783,2410 307 | 2,2,243,12939,8852,799,3909,211 308 | 2,2,6468,12867,21570,1840,7558,1543 309 | 1,2,17327,2374,2842,1149,351,925 310 | 1,2,6987,1020,3007,416,257,656 311 | 2,2,918,20655,13567,1465,6846,806 312 | 1,2,7034,1492,2405,12569,299,1117 313 | 1,2,29635,2335,8280,3046,371,117 314 | 2,2,2137,3737,19172,1274,17120,142 315 | 1,2,9784,925,2405,4447,183,297 316 | 1,2,10617,1795,7647,1483,857,1233 317 | 2,2,1479,14982,11924,662,3891,3508 318 | 1,2,7127,1375,2201,2679,83,1059 319 | 1,2,1182,3088,6114,978,821,1637 320 | 1,2,11800,2713,3558,2121,706,51 321 | 2,2,9759,25071,17645,1128,12408,1625 322 | 1,2,1774,3696,2280,514,275,834 323 | 1,2,9155,1897,5167,2714,228,1113 324 | 1,2,15881,713,3315,3703,1470,229 325 | 1,2,13360,944,11593,915,1679,573 326 | 1,2,25977,3587,2464,2369,140,1092 327 | 1,2,32717,16784,13626,60869,1272,5609 328 | 1,2,4414,1610,1431,3498,387,834 329 | 1,2,542,899,1664,414,88,522 330 | 1,2,16933,2209,3389,7849,210,1534 331 | 1,2,5113,1486,4583,5127,492,739 332 | 1,2,9790,1786,5109,3570,182,1043 333 | 2,2,11223,14881,26839,1234,9606,1102 334 | 1,2,22321,3216,1447,2208,178,2602 335 | 2,2,8565,4980,67298,131,38102,1215 336 | 2,2,16823,928,2743,11559,332,3486 337 | 2,2,27082,6817,10790,1365,4111,2139 338 | 1,2,13970,1511,1330,650,146,778 339 | 1,2,9351,1347,2611,8170,442,868 340 | 1,2,3,333,7021,15601,15,550 341 | 1,2,2617,1188,5332,9584,573,1942 342 | 2,3,381,4025,9670,388,7271,1371 343 | 2,3,2320,5763,11238,767,5162,2158 344 | 1,3,255,5758,5923,349,4595,1328 345 | 2,3,1689,6964,26316,1456,15469,37 346 | 1,3,3043,1172,1763,2234,217,379 347 | 1,3,1198,2602,8335,402,3843,303 348 | 2,3,2771,6939,15541,2693,6600,1115 349 | 2,3,27380,7184,12311,2809,4621,1022 350 | 1,3,3428,2380,2028,1341,1184,665 351 | 2,3,5981,14641,20521,2005,12218,445 352 | 1,3,3521,1099,1997,1796,173,995 353 | 2,3,1210,10044,22294,1741,12638,3137 354 | 1,3,608,1106,1533,830,90,195 355 | 2,3,117,6264,21203,228,8682,1111 356 | 1,3,14039,7393,2548,6386,1333,2341 357 | 1,3,190,727,2012,245,184,127 358 | 1,3,22686,134,218,3157,9,548 359 | 2,3,37,1275,22272,137,6747,110 360 | 1,3,759,18664,1660,6114,536,4100 361 | 1,3,796,5878,2109,340,232,776 362 | 1,3,19746,2872,2006,2601,468,503 363 | 1,3,4734,607,864,1206,159,405 364 | 1,3,2121,1601,2453,560,179,712 365 | 1,3,4627,997,4438,191,1335,314 366 | 1,3,2615,873,1524,1103,514,468 367 | 2,3,4692,6128,8025,1619,4515,3105 368 | 1,3,9561,2217,1664,1173,222,447 369 | 1,3,3477,894,534,1457,252,342 370 | 1,3,22335,1196,2406,2046,101,558 371 | 1,3,6211,337,683,1089,41,296 372 | 2,3,39679,3944,4955,1364,523,2235 373 | 1,3,20105,1887,1939,8164,716,790 374 | 1,3,3884,3801,1641,876,397,4829 375 | 2,3,15076,6257,7398,1504,1916,3113 376 | 1,3,6338,2256,1668,1492,311,686 377 | 1,3,5841,1450,1162,597,476,70 378 | 2,3,3136,8630,13586,5641,4666,1426 379 | 1,3,38793,3154,2648,1034,96,1242 380 | 1,3,3225,3294,1902,282,68,1114 381 | 2,3,4048,5164,10391,130,813,179 382 | 1,3,28257,944,2146,3881,600,270 383 | 1,3,17770,4591,1617,9927,246,532 384 | 1,3,34454,7435,8469,2540,1711,2893 385 | 1,3,1821,1364,3450,4006,397,361 386 | 1,3,10683,21858,15400,3635,282,5120 387 | 1,3,11635,922,1614,2583,192,1068 388 | 1,3,1206,3620,2857,1945,353,967 389 | 1,3,20918,1916,1573,1960,231,961 390 | 1,3,9785,848,1172,1677,200,406 391 | 1,3,9385,1530,1422,3019,227,684 392 | 1,3,3352,1181,1328,5502,311,1000 393 | 1,3,2647,2761,2313,907,95,1827 394 | 1,3,518,4180,3600,659,122,654 395 | 1,3,23632,6730,3842,8620,385,819 396 | 1,3,12377,865,3204,1398,149,452 397 | 1,3,9602,1316,1263,2921,841,290 398 | 2,3,4515,11991,9345,2644,3378,2213 399 | 1,3,11535,1666,1428,6838,64,743 400 | 1,3,11442,1032,582,5390,74,247 401 | 1,3,9612,577,935,1601,469,375 402 | 1,3,4446,906,1238,3576,153,1014 403 | 1,3,27167,2801,2128,13223,92,1902 404 | 1,3,26539,4753,5091,220,10,340 405 | 1,3,25606,11006,4604,127,632,288 406 | 1,3,18073,4613,3444,4324,914,715 407 | 1,3,6884,1046,1167,2069,593,378 408 | 1,3,25066,5010,5026,9806,1092,960 409 | 2,3,7362,12844,18683,2854,7883,553 410 | 2,3,8257,3880,6407,1646,2730,344 411 | 1,3,8708,3634,6100,2349,2123,5137 412 | 1,3,6633,2096,4563,1389,1860,1892 413 | 1,3,2126,3289,3281,1535,235,4365 414 | 1,3,97,3605,12400,98,2970,62 415 | 1,3,4983,4859,6633,17866,912,2435 416 | 1,3,5969,1990,3417,5679,1135,290 417 | 2,3,7842,6046,8552,1691,3540,1874 418 | 2,3,4389,10940,10908,848,6728,993 419 | 1,3,5065,5499,11055,364,3485,1063 420 | 2,3,660,8494,18622,133,6740,776 421 | 1,3,8861,3783,2223,633,1580,1521 422 | 1,3,4456,5266,13227,25,6818,1393 423 | 2,3,17063,4847,9053,1031,3415,1784 424 | 1,3,26400,1377,4172,830,948,1218 425 | 2,3,17565,3686,4657,1059,1803,668 426 | 2,3,16980,2884,12232,874,3213,249 427 | 1,3,11243,2408,2593,15348,108,1886 428 | 1,3,13134,9347,14316,3141,5079,1894 429 | 1,3,31012,16687,5429,15082,439,1163 430 | 1,3,3047,5970,4910,2198,850,317 431 | 1,3,8607,1750,3580,47,84,2501 432 | 1,3,3097,4230,16483,575,241,2080 433 | 1,3,8533,5506,5160,13486,1377,1498 434 | 1,3,21117,1162,4754,269,1328,395 435 | 1,3,1982,3218,1493,1541,356,1449 436 | 1,3,16731,3922,7994,688,2371,838 437 | 1,3,29703,12051,16027,13135,182,2204 438 | 1,3,39228,1431,764,4510,93,2346 439 | 2,3,14531,15488,30243,437,14841,1867 440 | 1,3,10290,1981,2232,1038,168,2125 441 | 1,3,2787,1698,2510,65,477,52 442 | -------------------------------------------------------------------------------- /creating_customer_segments/visuals.py: -------------------------------------------------------------------------------- 1 | ########################################### 2 | # Suppress matplotlib user warnings 3 | # Necessary for newer version of matplotlib 4 | import warnings 5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib") 6 | # 7 | # Display inline matplotlib plots with IPython 8 | from IPython import get_ipython 9 | get_ipython().run_line_magic('matplotlib', 'inline') 10 | ########################################### 11 | 12 | import matplotlib.pyplot as plt 13 | import matplotlib.cm as cm 14 | import pandas as pd 15 | import numpy as np 16 | 17 | def pca_results(good_data, pca): 18 | ''' 19 | Create a DataFrame of the PCA results 20 | Includes dimension feature weights and explained variance 21 | Visualizes the PCA results 22 | ''' 23 | 24 | # Dimension indexing 25 | dimensions = dimensions = ['Dimension {}'.format(i) for i in range(1,len(pca.components_)+1)] 26 | 27 | # PCA components 28 | components = pd.DataFrame(np.round(pca.components_, 4), columns = good_data.keys()) 29 | components.index = dimensions 30 | 31 | # PCA explained variance 32 | ratios = pca.explained_variance_ratio_.reshape(len(pca.components_), 1) 33 | variance_ratios = pd.DataFrame(np.round(ratios, 4), columns = ['Explained Variance']) 34 | variance_ratios.index = dimensions 35 | 36 | # Create a bar plot visualization 37 | fig, ax = plt.subplots(figsize = (14,8)) 38 | 39 | # Plot the feature weights as a function of the components 40 | components.plot(ax = ax, kind = 'bar'); 41 | ax.set_ylabel("Feature Weights") 42 | ax.set_xticklabels(dimensions, rotation=0) 43 | 44 | 45 | # Display the explained variance ratios 46 | for i, ev in enumerate(pca.explained_variance_ratio_): 47 | ax.text(i-0.40, ax.get_ylim()[1] + 0.05, "Explained Variance\n %.4f"%(ev)) 48 | 49 | # Return a concatenated DataFrame 50 | return pd.concat([variance_ratios, components], axis = 1) 51 | 52 | def cluster_results(reduced_data, preds, centers, pca_samples): 53 | ''' 54 | Visualizes the PCA-reduced cluster data in two dimensions 55 | Adds cues for cluster centers and student-selected sample data 56 | ''' 57 | 58 | predictions = pd.DataFrame(preds, columns = ['Cluster']) 59 | plot_data = pd.concat([predictions, reduced_data], axis = 1) 60 | 61 | # Generate the cluster plot 62 | fig, ax = plt.subplots(figsize = (14,8)) 63 | 64 | # Color map 65 | cmap = cm.get_cmap('gist_rainbow') 66 | 67 | # Color the points based on assigned cluster 68 | for i, cluster in plot_data.groupby('Cluster'): 69 | cluster.plot(ax = ax, kind = 'scatter', x = 'Dimension 1', y = 'Dimension 2', \ 70 | color = cmap((i)*1.0/(len(centers)-1)), label = 'Cluster %i'%(i), s=30); 71 | 72 | # Plot centers with indicators 73 | for i, c in enumerate(centers): 74 | ax.scatter(x = c[0], y = c[1], color = 'white', edgecolors = 'black', \ 75 | alpha = 1, linewidth = 2, marker = 'o', s=200); 76 | ax.scatter(x = c[0], y = c[1], marker='$%d$'%(i), alpha = 1, s=100); 77 | 78 | # Plot transformed sample points 79 | ax.scatter(x = pca_samples[:,0], y = pca_samples[:,1], \ 80 | s = 150, linewidth = 4, color = 'black', marker = 'x'); 81 | 82 | # Set plot title 83 | ax.set_title("Cluster Learning on PCA-Reduced Data - Centroids Marked by Number\nTransformed Sample Data Marked by Black Cross"); 84 | 85 | 86 | def biplot(good_data, reduced_data, pca): 87 | ''' 88 | Produce a biplot that shows a scatterplot of the reduced 89 | data and the projections of the original features. 90 | 91 | good_data: original data, before transformation. 92 | Needs to be a pandas dataframe with valid column names 93 | reduced_data: the reduced data (the first two dimensions are plotted) 94 | pca: pca object that contains the components_ attribute 95 | 96 | return: a matplotlib AxesSubplot object (for any additional customization) 97 | 98 | This procedure is inspired by the script: 99 | https://github.com/teddyroland/python-biplot 100 | ''' 101 | 102 | fig, ax = plt.subplots(figsize = (14,8)) 103 | # scatterplot of the reduced data 104 | ax.scatter(x=reduced_data.loc[:, 'Dimension 1'], y=reduced_data.loc[:, 'Dimension 2'], 105 | facecolors='b', edgecolors='b', s=70, alpha=0.5) 106 | 107 | feature_vectors = pca.components_.T 108 | 109 | # we use scaling factors to make the arrows easier to see 110 | arrow_size, text_pos = 7.0, 8.0, 111 | 112 | # projections of the original features 113 | for i, v in enumerate(feature_vectors): 114 | ax.arrow(0, 0, arrow_size*v[0], arrow_size*v[1], 115 | head_width=0.2, head_length=0.2, linewidth=2, color='red') 116 | ax.text(v[0]*text_pos, v[1]*text_pos, good_data.columns[i], color='black', 117 | ha='center', va='center', fontsize=18) 118 | 119 | ax.set_xlabel("Dimension 1", fontsize=14) 120 | ax.set_ylabel("Dimension 2", fontsize=14) 121 | ax.set_title("PC plane with original feature projections.", fontsize=16); 122 | return ax 123 | 124 | 125 | def channel_results(reduced_data, outliers, pca_samples): 126 | ''' 127 | Visualizes the PCA-reduced cluster data in two dimensions using the full dataset 128 | Data is labeled by "Channel" and cues added for student-selected sample data 129 | ''' 130 | 131 | # Check that the dataset is loadable 132 | try: 133 | full_data = pd.read_csv("customers.csv") 134 | except: 135 | print "Dataset could not be loaded. Is the file missing?" 136 | return False 137 | 138 | # Create the Channel DataFrame 139 | channel = pd.DataFrame(full_data['Channel'], columns = ['Channel']) 140 | channel = channel.drop(channel.index[outliers]).reset_index(drop = True) 141 | labeled = pd.concat([reduced_data, channel], axis = 1) 142 | 143 | # Generate the cluster plot 144 | fig, ax = plt.subplots(figsize = (14,8)) 145 | 146 | # Color map 147 | cmap = cm.get_cmap('gist_rainbow') 148 | 149 | # Color the points based on assigned Channel 150 | labels = ['Hotel/Restaurant/Cafe', 'Retailer'] 151 | grouped = labeled.groupby('Channel') 152 | for i, channel in grouped: 153 | channel.plot(ax = ax, kind = 'scatter', x = 'Dimension 1', y = 'Dimension 2', \ 154 | color = cmap((i-1)*1.0/2), label = labels[i-1], s=30); 155 | 156 | # Plot transformed sample points 157 | for i, sample in enumerate(pca_samples): 158 | ax.scatter(x = sample[0], y = sample[1], \ 159 | s = 200, linewidth = 3, color = 'black', marker = 'o', facecolors = 'none'); 160 | ax.scatter(x = sample[0]+0.25, y = sample[1]+0.3, marker='$%d$'%(i), alpha = 1, s=125); 161 | 162 | # Set plot title 163 | ax.set_title("PCA-Reduced Data Labeled by 'Channel'\nTransformed Sample Data Circled"); -------------------------------------------------------------------------------- /digit_recognition/README.md: -------------------------------------------------------------------------------- 1 | # 机器学习工程师纳米学位 2 | # 深度学习 3 | ## 项目:搭建一个数字识别项目 4 | 5 | ### 安装 6 | 7 | 这个项目要求使用 **Python 2.7** 并且需要安装下面这些python包: 8 | 9 | - [NumPy](http://www.numpy.org/) 10 | - [SciPy](https://www.scipy.org/) 11 | - [scikit-learn](http://scikit-learn.org/0.17/install.html) (v0.17) 12 | - [TensorFlow](http://tensorflow.org) 13 | 14 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://ipython.org/notebook.html). 15 | 16 | 除了上面提到的,对于那些希望额外使用图像处理软件的,你可能需要安装下面的某一款软件: 17 | - [PyGame](http://pygame.org/) 18 | - 对于安装PyGame有帮助的链接: 19 | - [Getting Started](https://www.pygame.org/wiki/GettingStarted) 20 | - [PyGame Information](http://www.pygame.org/wiki/info) 21 | - [Google Group](https://groups.google.com/forum/#!forum/pygame-mirror-on-google-groups) 22 | - [PyGame subreddit](https://www.reddit.com/r/pygame/) 23 | - [OpenCV](http://opencv.org/) 24 | 25 | 对于那些希望选择额外选择将应用部署成安卓应用的: 26 | - Android SDK & NDK (查看这个[README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/README.md)) 27 | 28 | 如果你还没有安装Python,优达学城推荐学生安装[Anaconda](http://continuum.io/downloads)这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件,请确认你安装的是Python 2.7而不是Python 3.x。然后`pygame`和`OpenCV`可以通过下列命令安装: 29 | 30 | Mac: 31 | ```bash 32 | conda install -c https://conda.anaconda.org/quasiben pygame 33 | conda install -c menpo opencv=2.4.11 34 | ``` 35 | 36 | Windows & Linux: 37 | ```bash 38 | conda install -c https://conda.anaconda.org/tlatorre pygame 39 | conda install -c menpo opencv=2.4.11 40 | ``` 41 | 42 | ### 代码 43 | 44 | 初始代码包含在`digit_recognition.ipynb`这个notebook文件中。这里面没有提供给你代码,为了完成项目,你需要在notebook中实现基本的功能并回答关于你的实现和结果的问题。 45 | 46 | ### 运行 47 | 48 | 在命令行中,确保当前目录为 `digit_recognition/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令: 49 | ```bash 50 | ipython notebook digit_recognition.ipynb 51 | ``` 52 | 或者 53 | ```bash 54 | jupyter notebook digit_recognition.ipynb 55 | ``` 56 | 57 | 这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。 58 | 59 | 60 | ### 数据 61 | 62 | 因为这个项目没有直接提供任何的代码,你要自己下载并使用[街景房屋门牌号(SVHN)数据集](http://ufldl.stanford.edu/housenumbers/),同时你还需要[notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html)数据集或者是[MNIST](http://yann.lecun.com/exdb/mnist/)数据集。如果你已经完成了课程内容,那么你已经有**notMINIST**数据集了。 63 | -------------------------------------------------------------------------------- /digit_recognition/model.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/digit_recognition/model.png -------------------------------------------------------------------------------- /digit_recognition/project_description.md: -------------------------------------------------------------------------------- 1 | # 内容:深度学习 2 | ## 项目:搭建一个数字识别项目 3 | 4 | ## 项目概述 5 | 6 | 在这个项目中,你将使用你学到的关于深度神经网络和卷积神经网络的知识,建立一个实时相机应用或者是程序,它能够从提供的图片中实时打印出他观察到的数字。首先你将设计并实现一个能够识别数字序列的深度学习模型架构。然后,你将训练这个模型,使得它能够从类似[街景房屋门牌号(SVHN) dataset](http://ufldl.stanford.edu/housenumbers/)这种现实图片中识别出数字序列。模型训练好之后,你将使用一个实时相机应用(可选)或者是在新捕获的图片上建立应用以测试你的模型。最后,一旦你获得了有意义的结果,你将优化你的实现,并且*定位图片上的数字*,以及在新捕获的图像上测试定位效果。 7 | 8 | ## 软件需求 9 | 这个项目需要安装下面这些软件和python包: 10 | 11 | - [Python 2.7](https://www.python.org/download/releases/2.7/) 12 | - [NumPy](http://www.numpy.org/) 13 | - [SciPy](https://www.scipy.org/) 14 | - [scikit-learn](http://scikit-learn.org/0.17/install.html) (v0.17) 15 | - [TensorFlow](http://tensorflow.org) 16 | 17 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://ipython.org/notebook.html). 18 | 19 | 除了上面提到的,对于那些希望额外使用图像处理软件的,你可能需要安装下面的某一款软件: 20 | - [PyGame](http://pygame.org/) 21 | - 对于安装PyGame有帮助的链接: 22 | - [Getting Started](https://www.pygame.org/wiki/GettingStarted) 23 | - [PyGame Information](http://www.pygame.org/wiki/info) 24 | - [Google Group](https://groups.google.com/forum/#!forum/pygame-mirror-on-google-groups) 25 | - [PyGame subreddit](https://www.reddit.com/r/pygame/) 26 | - [OpenCV](http://opencv.org/) 27 | 28 | 对于那些希望额外选择将应用部署成安卓应用的: 29 | - Android SDK & NDK (查看这个[README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/README.md)) 30 | 31 | 如果你还没有安装Python,优达学城推荐学生安装[Anaconda](http://continuum.io/downloads)这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件,请确认你安装的是Python 2.7而不是Python 3.x。然后`pygame`和`OpenCV`可以通过下列命令安装: 32 | 33 | 34 | **opencv** 35 | `conda install -c menpo opencv=2.4.11` 36 | 37 | **PyGame:** 38 | Mac: `conda install -c https://conda.anaconda.org/quasiben pygame` 39 | Windows: `conda install -c https://conda.anaconda.org/tlatorre pygame` 40 | Linux: `conda install -c https://conda.anaconda.org/prkrekel pygame` 41 | 42 | ## 开始项目 43 | 44 | 对于这次作业,你可以在**课程资源**部分找到可下载的`digit_recognition.zip`压缩包,它包含了项目所需要的文件。*你也可以访问我们的[Machine Learning projects GitHub](https://github.com/udacity/machine-learning)以获取这个纳米学位的所有项目文件* 45 | 46 | - `digit_recognition.ipynb`:这个文件是你将要修改的主要文件。 47 | 48 | 另外,你要自己下载并使用[街景房屋门牌号(SVHN)数据集](http://ufldl.stanford.edu/housenumbers/),同时你还需要[notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html)数据集或者是[MNIST](http://yann.lecun.com/exdb/mnist/)数据集。如果你已经完成了课程内容,那么你已经有**notMINIST**数据集了。 49 | 50 | 在终端或者是命令行中,导航到包含项目文件的文件夹中,然后使用命令`jupyter notebook digit_recognition.ipynb`,在浏览器窗口或者是标签页中打开你的notebook。或者你也可以使用命令`jupyter notebook`或者`ipython notebook`,然后再打开的浏览器窗口中导航到notebook文件。为了完成这个项目你需要跟随notebook中的指引,并回答提出的每一个问题。和项目文件一起我们还提供了一个**README**文件,它包含了一些关于这个项目的额外的必要信息或者是指引。 51 | 52 | ## 任务 53 | 54 | ### 项目报告 55 | 作为你提交的`digit_recognition.ipynb`的一部分,你需要回答关于你的实现的一些问题。在完成下面的任务的同时,你需要包含关于每一个问题(下面用*斜体*表示)的全面的详细的回答。 56 | 57 | ### 步骤 1: 设计并测试一个模型架构 58 | 设计并实现一个能够识别数字序列的深度学习模型。你可以通过连接[notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html)或者是[MNIST](http://yann.lecun.com/exdb/mnist/)的字符来合成数据来训练这个模型。为了产生用于测试的合成数字序列,你可以进行如下的设置:比如,你可以限制一个数据序列最多五个数字,并在你的深度网络上使用五个分类器。同时,你有必要准备一个额外的“空白”的字符,以处理相对较短的数字序列。 59 | 60 | 在思考这个问题的时候有很多方面可以考虑: 61 | - 你的模型可以基于深度神经网络或者是卷积神经网络。 62 | - 你可以尝试是否在softmax分类器间共享权值。 63 | - 你还可以在深度神经网络中使用循环网络来替换其中的分类层,并且将数字序列里的数字一个一个地输出。 64 | 65 | 这里有一个[发表的关于这个问题的基准模型的论文](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/42241.pdf)([视频](https://www.youtube.com/watch?v=vGPI_JvLoN0))的例子。 66 | 67 | ***问题*** 68 | _你为解决这个问题采取了什么方法?_ 69 | 70 | ***问题*** 71 | _你最终的模型架构是什么样的?(什么类型的模型,层数,大小, 连接性等)_ 72 | 73 | ***问题*** 74 | _你是如何训练你的模型的?你是如何合成数据集的?_ 请同时包括你创建的合成数据中的一些示例图像。 75 | 76 | ### 步骤 2: 在真实数据集上训练一个模型 77 | 78 | 一旦你确定好了一个好的模型架构,你就可以开始在真实的数据上训练你的模型了。特别地,[街景房屋门牌号(SVHN)](http://ufldl.stanford.edu/housenumbers/)数据集是一个大规模的,从谷歌街景数据中采集的门牌号数据。在这个更有挑战性的数据集(这里数字不是整齐排列的,并且会有各种倾斜、字体和颜色)上训练,可能意味着你必须做一些超参数探索以获得一个表现良好的模型。 79 | 80 | ***问题*** 81 | _描述如何为模型准备训练和测试数据。 模型在真实数据集上表现怎么样?_ 82 | 83 | ***问题*** 84 | _你(在模型上)做了什么改变?如果做了一些改变,那么你得到一个“好的”结果了妈?有没有任何你探索的导致结果更糟?_ 85 | 86 | ***问题*** 87 | _当你在真实数据集做测试的时候你的初始结果和最终结果是什么?你认为你的模型在正确分类数字这个任务上上做的足够好吗?_ 88 | 89 | ### 步骤 3: 在新抓取的图片上测试模型 90 | 91 | 在你周围拍摄几张数字的图片(至少五张),然后用你的分类器来预测产生结果。或者(可选),你可以尝试使用OpenCV / SimpleCV / Pygame从网络摄像头捕获实时图像,并通过你的分类器分析这些图像。 92 | 93 | ***问题*** 94 | _选择在你周围拍摄的五张候选图片,并提供在报告中。它们中的某些图片是否有一些特殊的性质,可能会导致分类困难?_ 95 | 96 | ***问题*** 97 | _与在现实数据集上的测试结果相比,你的模型能够在捕获的图片或实时相机流上表现同样良好吗?_ 98 | 99 | ***问题*** 100 | _如果必要的话,请提供关于你是如何建立一个使得你的模型能够加载和分类新获取图像的接口的。_ 101 | 102 | ### 步骤 4: 探索一种提升模型的方式 103 | 104 | 一旦你基本的分类器训练好了,你就可以做很多事情。一个例子是:(在分类的同时)还能够定位图像上数字的位置。SVHN数据集提供边界框,你可以调试以训练一个定位器。训练一个关于坐标与边框的回归损失函数,然后测试它。 105 | 106 | ***问题*** 107 | _你的模型在真实数据的测试集上定位数字表现的怎么样?包含位置信息之后你的分类结果有变化吗?_ 108 | 109 | ***问题*** 110 | 在你在**步骤3**所捕获的图像上测试你的定位功能。模型是否准确计算出你找到的图像中的数字的边界框?如果你没有使用图形界面,您可能需要手动探索边界框。_提供一个在捕获的图像上创建边界框的示例_ 111 | 112 | ## 可选步骤 5:为模型封装一个应用或者是程序 113 | 114 | 为了让你的项目更进一步。如果你有兴趣,可以构建一个安卓应用程序,或者是一个更鲁棒的Python程序。这些程序能够和输入的图像交互,显示分类的数字甚至边界框。比如,你可以尝试通过将你的答案叠加在图像上像[Word Lens](https://en.wikipedia.org/wiki/Word_Lens)应用程序那样来构建一个增强现实应用程序。 115 | 116 | 如何在安卓上的相机应用程序中加载TensorFlow的模型的示例代码在[TensorFlow Android demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android)中,你可以再这个基础上做一些简单的修改。 117 | 118 | 119 | 如果你决定探索这条可选路径,请务必记录你的接口和实现,以及你找到的重要结果。你可以通过[点击这个链接](https://review.udacity.com/#!/rubrics/413/view)看到将被用来评价你的工作的相关条目。 120 | 121 | ## 提交项目 122 | 123 | ### 评价 124 | 你的项目会由Udacity项目评审师根据 **搭建一个数字识别项目量规**进行评审。请注意仔细阅读这份量规并在提交前进行全面的自我评价。这份量规中涉及的所有条目必须全部被标记成*meeting specifications*你才能通过。 125 | 126 | ### 需要提交的文件 127 | 当你准备好提交你的项目的时候,请收集以下的文件,并把他们压缩进单个压缩包中上传。或者你也可以在你的GitHub Repo中的一个名叫`digit_recognition`的文件夹中提供以下文件以方便检查: 128 | - 回答了所有问题并且所有的代码单元被执行并显示了输出结果的`digit_recognition.ipynb`文件。 129 | - 一个从项目的notebook文件中导出的命名为**report.html**的**HTML**文件。这个文件*必须*提供。 130 | - 任何用于这个项目的除SVHN, notMNIST或者MNIST以外的数据集或者是图像 131 | 132 | - 对于可选的图像识别软件部分,你需要提供任何相关的Python文件,以保证能够运行你的代码。 133 | - 对于可选的安卓应用部分,你需要为如何获取和使用这个应用编写文档,这部分应该提供一个命名为**documentation.pdf**的pdf报告。 134 | 135 | 一旦你收集好了这些文件,并阅读了项目量规,请进入项目提交页面。 136 | 137 | ### 我准备好了! 138 | 当你准备好提交项目的时候,点击页面底部的**提交项目**按钮。 139 | 140 | 如果你提交项目中遇到任何问题或者是希望检查你的提交的进度,请给**machine-support@udacity.com**发邮件,或者你可以访问论坛. 141 | 142 | ### 然后? 143 | 当你的项目评审师给你回复之后你会马上收到一封通知邮件。在等待的同时你也可以开始准备下一个项目,或者学习相关的课程。 144 | -------------------------------------------------------------------------------- /dog_vs_cat/README.md: -------------------------------------------------------------------------------- 1 | # 猫狗大战 2 | # 本项目在aws udacity-carND上训练 3 | # 一个模型训练时间约75分钟 4 | # 本次使用的数据为kaggle上猫狗大战的数据 5 | # 本次数据出来和模型训练的代码都在dog_vs_cat.ipynb中 6 | # file.csv文件为提交到kaggle评价的数据 7 | 8 | # model.h5 9 | https://drive.google.com/open?id=0B_jeca3axBM9dkdheXJaWHoxbE0 10 | # model_2.h5 11 | https://drive.google.com/open?id=0B_jeca3axBM9V2lrRXBPX0pYQTA 12 | # model_3.h5 13 | https://drive.google.com/open?id=0B_jeca3axBM9ZkdoTFItZllBQ1k -------------------------------------------------------------------------------- /dog_vs_cat/proposal.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/dog_vs_cat/proposal.pdf -------------------------------------------------------------------------------- /dog_vs_cat/report2.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/dog_vs_cat/report2.pdf -------------------------------------------------------------------------------- /finding_donors/README.md: -------------------------------------------------------------------------------- 1 | # 机器学习纳米学位 2 | # 监督学习 3 | ## 项目: 为CharityML寻找捐献者 4 | ### 安装 5 | 6 | 这个项目要求使用 Python 2.7 并且需要安装下面这些python包: 7 | 8 | - [Python 2.7](https://www.python.org/download/releases/2.7/) 9 | - [NumPy](http://www.numpy.org/) 10 | - [Pandas](http://pandas.pydata.org/) 11 | - [scikit-learn](http://scikit-learn.org/stable/) 12 | - [matplotlib](http://matplotlib.org/) 13 | 14 | 你同样需要安装好相应软件使之能够运行 [iPython Notebook](http://ipython.org/notebook.html) 15 | 16 | 优达学城推荐学生安装[Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。 17 | 18 | ### 代码 19 | 20 | 初始代码包含在`finding_donors.ipynb`这个notebook文件中。你还会用到`visuals.py`和名为`census.csv`的数据文件来完成这个项目。我们已经为你提供了一部分代码,但还有些功能需要你来实现才能以完成这个项目。 21 | 这里面有一些代码已经实现好来帮助你开始项目,但是为了完成项目,你还需要实现附加的功能。 22 | 注意包含在`visuals.py`中的代码设计成一个外部导入的功能,而不是打算学生去修改。如果你对notebook中创建的可视化感兴趣,你也可以去查看这些代码。 23 | 24 | 25 | ### 运行 26 | 在命令行中,确保当前目录为 `finding_donors/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令: 27 | 28 | ```bash 29 | jupyter notebook finding_donors.ipynb 30 | ``` 31 | 32 | ​这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。 33 | 34 | ### 数据 35 | 36 | 修改的人口普查数据集含有将近32,000个数据点,每一个数据点含有13个特征。这个数据集是Ron Kohavi的论文*"Scaling Up the Accuracy of Naive-Bayes Classifiers: a Decision-Tree Hybrid",*中数据集的一个修改版本。你能够在[这里](https://www.aaai.org/Papers/KDD/1996/KDD96-033.pdf)找到论文,在[UCI的网站](https://archive.ics.uci.edu/ml/datasets/Census+Income)找到原始数据集。 37 | 38 | **特征** 39 | 40 | - `age`: 一个整数,表示被调查者的年龄。 41 | - `workclass`: 一个类别变量表示被调查者的通常劳动类型,允许的值有 {Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked} 42 | - `education_level`: 一个类别变量表示教育程度,允许的值有 {Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool} 43 | - `education-num`: 一个整数表示在学校学习了多少年 44 | - `marital-status`: 一个类别变量,允许的值有 {Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse} 45 | - `occupation`: 一个类别变量表示一般的职业领域,允许的值有 {Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces} 46 | - `relationship`: 一个类别变量表示家庭情况,允许的值有 {Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried} 47 | - `race`: 一个类别变量表示人种,允许的值有 {White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black} 48 | - `sex`: 一个类别变量表示性别,允许的值有 {Female, Male} 49 | - `capital-gain`: 连续值。 50 | - `capital-loss`: 连续值。 51 | - `hours-per-week`: 连续值。 52 | - `native-country`: 一个类别变量表示原始的国家,允许的值有 {United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands} 53 | 54 | **目标变量** 55 | 56 | - `income`: 一个类别变量,表示收入属于那个类别,允许的值有 {<=50K, >50K} -------------------------------------------------------------------------------- /finding_donors/finding_donors.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/finding_donors/finding_donors.zip -------------------------------------------------------------------------------- /finding_donors/project_description.md: -------------------------------------------------------------------------------- 1 | # 内容: 监督学习 2 | ## 项目:为CharityML寻找捐献者 3 | 4 | ## 项目概况 5 | 在这个项目中,你将使用监督技术和分析能力对美国人口普查数据进行分析,以帮助CharityML(一个虚拟的慈善机构)识别最有可能向他们捐款的人,你将首先探索数据以了解人口普查数据是如何记录的。接下来,你将使用一系列的转换和预处理技术以将数据整理成能用的形式。然后,你将在这个数据上评价你选择的几个算法,然后考虑哪一个是最合适的。之后,你将优化你现在为CharityML选择的模型。最后,你将探索选择的模型和它的预测能力。 6 | 7 | ## 项目亮点 8 | 这个项目设计成帮助你熟悉在sklearn中能够使用的多个监督学习算法,并提供一个评价模型在某种类型的数据上表现的方法。在机器学习中准确理解在什么时候什么地方应该选择什么算法和不应该选择什么算法是十分重要的。 9 | 10 | 完成这个项目你将学会以下内容: 11 | - 知道什么时候应该使用预处理以及如何做预处理。 12 | - 如何为问题设置一个基准。 13 | - 判断在一个特定的数据集上几个监督学习算法的表现如何。 14 | - 调查候选的解决方案模型是否足够解决问题。 15 | 16 | ## 软件要求 17 | 18 | 这个项目要求使用 Python 2.7 并且需要安装下面这些python包: 19 | 20 | - [Python 2.7](https://www.python.org/download/releases/2.7/) 21 | - [NumPy](http://www.numpy.org/) 22 | - [Pandas](http://pandas.pydata.org/) 23 | - [scikit-learn](http://scikit-learn.org/stable/) 24 | - [matplotlib](http://matplotlib.org/) 25 | 26 | 你同样需要安装好相应软件使之能够运行 [iPython Notebook](http://ipython.org/notebook.html) 27 | 28 | 优达学城推荐学生安装[Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。请注意你安装的是2.7而不是3.X 29 | 30 | ## 开始项目 31 | 32 | 对于这个项目,你能够在**Resources**部分找到一个能下载的`find_donors.zip`。*你也可以访问我们的[机器学习项目GitHub](https://github.com/udacity/machine-learning)获取我们纳米学位中的所有项目* 33 | 34 | 这个项目包含以下文件: 35 | 36 | - `find_donors.ipynb`: 这是你需要工作的主要的文件。 37 | - `census.csv`: 项目使用的数据集,你将需要在notebook中载入这个数据集。 38 | - `visuals.py`: 一个实现了可视化功能的Python代码。不要修改它。 39 | 40 | 在终端或命令提示符中,导航到包含项目文件的文件夹,使用命令`jupyter notebook finding_donors.ipynb`以在一个浏览器窗口或一个标签页打开notebook文件。或者你也可以使用命令`jupyter notebook`或者`ipython notebook`然后在打开的网页中导航到需要的文件夹。跟随notebook中的指引,回答每一个问题以成功完成项目。在这个项目中我们也提供了一个**README**文件,其中也包含了你在这个项目中需要了解的信息或者指引。 41 | 42 | ## 提交项目 43 | 44 | ### 评价 45 | 你的项目会由Udacity项目评审师根据**为CharityML寻找捐献者项目量规**进行评审。请注意仔细阅读这份量规并在提交前进行全面的自我评价。这份量规中涉及的所有条目必须全部被标记成*meeting specifications*你才能通过。 46 | 47 | ### 需要提交的文件 48 | 当你准备好提交你的项目的时候,请收集以下的文件,并把他们压缩进单个压缩包中上传。或者你也可以在你的GitHub Repo中的一个名叫`finding_donors`的文件夹中提供以下文件以方便检查: 49 | - 回答了所有问题并且所有的代码单元被执行并显示了输出结果的`finding_donors.ipynb`文件。 50 | - 一个从项目的notebook文件中导出的命名为**report.html**的**HTML**文件。这个文件*必须*提供。 51 | 52 | 一旦你收集好了这些文件,并阅读了项目量规,请进入项目提交页面。 53 | 54 | ### 我准备好了! 55 | 当你准备好提交项目的时候,点击页面底部的**提交项目**按钮。 56 | 57 | 如果你提交项目中遇到任何问题或者是希望检查你的提交的进度,请给**machine-support@udacity.com**发邮件,或者你可以访问论坛. 58 | 59 | ### 然后? 60 | 当你的项目评审师给你回复之后你会马上收到一封通知邮件。在等待的同时你也可以开始准备下一个项目,或者学习相关的课程。 -------------------------------------------------------------------------------- /finding_donors/visuals.py: -------------------------------------------------------------------------------- 1 | ########################################### 2 | # Suppress matplotlib user warnings 3 | # Necessary for newer version of matplotlib 4 | import warnings 5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib") 6 | # 7 | # Display inline matplotlib plots with IPython 8 | from IPython import get_ipython 9 | get_ipython().run_line_magic('matplotlib', 'inline') 10 | ########################################### 11 | 12 | import matplotlib.pyplot as pl 13 | import matplotlib.patches as mpatches 14 | import numpy as np 15 | import pandas as pd 16 | from time import time 17 | from sklearn.metrics import f1_score, accuracy_score 18 | 19 | 20 | def distribution(data, transformed = False): 21 | """ 22 | Visualization code for displaying skewed distributions of features 23 | """ 24 | 25 | # Create figure 26 | fig = pl.figure(figsize = (11,5)); 27 | 28 | # Skewed feature plotting 29 | for i, feature in enumerate(['capital-gain','capital-loss']): 30 | ax = fig.add_subplot(1, 2, i+1) 31 | ax.hist(data[feature], bins = 25, color = '#00A0A0') 32 | ax.set_title("'%s' Feature Distribution"%(feature), fontsize = 14) 33 | ax.set_xlabel("Value") 34 | ax.set_ylabel("Number of Records") 35 | ax.set_ylim((0, 2000)) 36 | ax.set_yticks([0, 500, 1000, 1500, 2000]) 37 | ax.set_yticklabels([0, 500, 1000, 1500, ">2000"]) 38 | 39 | # Plot aesthetics 40 | if transformed: 41 | fig.suptitle("Log-transformed Distributions of Continuous Census Data Features", \ 42 | fontsize = 16, y = 1.03) 43 | else: 44 | fig.suptitle("Skewed Distributions of Continuous Census Data Features", \ 45 | fontsize = 16, y = 1.03) 46 | 47 | fig.tight_layout() 48 | fig.show() 49 | 50 | 51 | def evaluate(results, accuracy, f1): 52 | """ 53 | Visualization code to display results of various learners. 54 | 55 | inputs: 56 | - learners: a list of supervised learners 57 | - stats: a list of dictionaries of the statistic results from 'train_predict()' 58 | - accuracy: The score for the naive predictor 59 | - f1: The score for the naive predictor 60 | """ 61 | 62 | # Create figure 63 | fig, ax = pl.subplots(2, 3, figsize = (11,7)) 64 | 65 | # Constants 66 | bar_width = 0.3 67 | colors = ['#A00000','#00A0A0','#00A000'] 68 | 69 | # Super loop to plot four panels of data 70 | for k, learner in enumerate(results.keys()): 71 | for j, metric in enumerate(['train_time', 'acc_train', 'f_train', 'pred_time', 'acc_test', 'f_test']): 72 | for i in np.arange(3): 73 | 74 | # Creative plot code 75 | ax[j/3, j%3].bar(i+k*bar_width, results[learner][i][metric], width = bar_width, color = colors[k]) 76 | ax[j/3, j%3].set_xticks([0.45, 1.45, 2.45]) 77 | ax[j/3, j%3].set_xticklabels(["1%", "10%", "100%"]) 78 | ax[j/3, j%3].set_xlabel("Training Set Size") 79 | ax[j/3, j%3].set_xlim((-0.1, 3.0)) 80 | 81 | # Add unique y-labels 82 | ax[0, 0].set_ylabel("Time (in seconds)") 83 | ax[0, 1].set_ylabel("Accuracy Score") 84 | ax[0, 2].set_ylabel("F-score") 85 | ax[1, 0].set_ylabel("Time (in seconds)") 86 | ax[1, 1].set_ylabel("Accuracy Score") 87 | ax[1, 2].set_ylabel("F-score") 88 | 89 | # Add titles 90 | ax[0, 0].set_title("Model Training") 91 | ax[0, 1].set_title("Accuracy Score on Training Subset") 92 | ax[0, 2].set_title("F-score on Training Subset") 93 | ax[1, 0].set_title("Model Predicting") 94 | ax[1, 1].set_title("Accuracy Score on Testing Set") 95 | ax[1, 2].set_title("F-score on Testing Set") 96 | 97 | # Add horizontal lines for naive predictors 98 | ax[0, 1].axhline(y = accuracy, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed') 99 | ax[1, 1].axhline(y = accuracy, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed') 100 | ax[0, 2].axhline(y = f1, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed') 101 | ax[1, 2].axhline(y = f1, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed') 102 | 103 | # Set y-limits for score panels 104 | ax[0, 1].set_ylim((0, 1)) 105 | ax[0, 2].set_ylim((0, 1)) 106 | ax[1, 1].set_ylim((0, 1)) 107 | ax[1, 2].set_ylim((0, 1)) 108 | 109 | # Create patches for the legend 110 | patches = [] 111 | for i, learner in enumerate(results.keys()): 112 | patches.append(mpatches.Patch(color = colors[i], label = learner)) 113 | pl.legend(handles = patches, bbox_to_anchor = (-.80, 2.53), \ 114 | loc = 'upper center', borderaxespad = 0., ncol = 3, fontsize = 'x-large') 115 | 116 | # Aesthetics 117 | pl.suptitle("Performance Metrics for Three Supervised Learning Models", fontsize = 16, y = 1.10) 118 | pl.tight_layout() 119 | pl.show() 120 | 121 | 122 | def feature_plot(importances, X_train, y_train): 123 | 124 | # Display the five most important features 125 | indices = np.argsort(importances)[::-1] 126 | columns = X_train.columns.values[indices[:5]] 127 | values = importances[indices][:5] 128 | 129 | # Creat the plot 130 | fig = pl.figure(figsize = (9,5)) 131 | pl.title("Normalized Weights for First Five Most Predictive Features", fontsize = 16) 132 | pl.bar(np.arange(5), values, width = 0.6, align="center", color = '#00A000', \ 133 | label = "Feature Weight") 134 | pl.bar(np.arange(5) - 0.3, np.cumsum(values), width = 0.2, align = "center", color = '#00A0A0', \ 135 | label = "Cumulative Feature Weight") 136 | pl.xticks(np.arange(5), columns) 137 | pl.xlim((-0.5, 4.5)) 138 | pl.ylabel("Weight", fontsize = 12) 139 | pl.xlabel("Feature", fontsize = 12) 140 | 141 | pl.legend(loc = 'upper center') 142 | pl.tight_layout() 143 | pl.show() 144 | -------------------------------------------------------------------------------- /smartcab/README.md: -------------------------------------------------------------------------------- 1 | # Machine Learning Engineer Nanodegree 2 | # Reinforcement Learning 3 | ## Project: Train a Smartcab How to Drive 4 | 5 | ### Install 6 | 7 | This project requires **Python 2.7** with the [pygame](https://www.pygame.org/wiki/GettingStarted 8 | ) library installed 9 | 10 | ### Code 11 | 12 | Template code is provided in the `smartcab/agent.py` python file. Additional supporting python code can be found in `smartcab/enviroment.py`, `smartcab/planner.py`, and `smartcab/simulator.py`. Supporting images for the graphical user interface can be found in the `images` folder. While some code has already been implemented to get you started, you will need to implement additional functionality for the `LearningAgent` class in `agent.py` when requested to successfully complete the project. 13 | 14 | ### Run 15 | 16 | In a terminal or command window, navigate to the top-level project directory `smartcab/` (that contains this README) and run one of the following commands: 17 | 18 | ```python smartcab/agent.py``` 19 | ```python -m smartcab.agent``` 20 | 21 | This will run the `agent.py` file and execute your agent code. 22 | -------------------------------------------------------------------------------- /smartcab/agent.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/agent.py -------------------------------------------------------------------------------- /smartcab/images/car-black.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-black.png -------------------------------------------------------------------------------- /smartcab/images/car-blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-blue.png -------------------------------------------------------------------------------- /smartcab/images/car-cyan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-cyan.png -------------------------------------------------------------------------------- /smartcab/images/car-green.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-green.png -------------------------------------------------------------------------------- /smartcab/images/car-magenta.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-magenta.png -------------------------------------------------------------------------------- /smartcab/images/car-orange.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-orange.png -------------------------------------------------------------------------------- /smartcab/images/car-red.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-red.png -------------------------------------------------------------------------------- /smartcab/images/car-white.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-white.png -------------------------------------------------------------------------------- /smartcab/images/car-yellow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-yellow.png -------------------------------------------------------------------------------- /smartcab/images/east-west.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/east-west.png -------------------------------------------------------------------------------- /smartcab/images/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/logo.png -------------------------------------------------------------------------------- /smartcab/images/north-south.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/north-south.png -------------------------------------------------------------------------------- /smartcab/logs/sim_default-learning.csv: -------------------------------------------------------------------------------- 1 | trial,testing,parameters,initial_deadline,final_deadline,net_reward,actions,success 2 | 1,False,"{'a': 0.5, 'e': 0.95}",30,0,-141.4152219254666,"{0: 21, 1: 2, 2: 3, 3: 2, 4: 2}",0 3 | 2,False,"{'a': 0.5, 'e': 0.8999999999999999}",30,0,-118.74810356680284,"{0: 20, 1: 3, 2: 5, 3: 0, 4: 2}",0 4 | 3,False,"{'a': 0.5, 'e': 0.8499999999999999}",25,0,-134.30063983462142,"{0: 14, 1: 2, 2: 6, 3: 2, 4: 1}",0 5 | 4,False,"{'a': 0.5, 'e': 0.7999999999999998}",25,0,-122.97600732528981,"{0: 17, 1: 2, 2: 3, 3: 1, 4: 2}",0 6 | 5,False,"{'a': 0.5, 'e': 0.7499999999999998}",20,2,-9.639186715051562,"{0: 15, 1: 0, 2: 3, 3: 0, 4: 0}",1 7 | 6,False,"{'a': 0.5, 'e': 0.6999999999999997}",20,12,2.7898590683150184,"{0: 7, 1: 0, 2: 1, 3: 0, 4: 0}",1 8 | 7,False,"{'a': 0.5, 'e': 0.6499999999999997}",20,0,-32.30763679982498,"{0: 14, 1: 2, 2: 4, 3: 0, 4: 0}",0 9 | 8,False,"{'a': 0.5, 'e': 0.5999999999999996}",20,0,-80.37939015975518,"{0: 13, 1: 2, 2: 4, 3: 0, 4: 1}",0 10 | 9,False,"{'a': 0.5, 'e': 0.5499999999999996}",20,0,-61.96798385294584,"{0: 13, 1: 4, 2: 2, 3: 0, 4: 1}",0 11 | 10,False,"{'a': 0.5, 'e': 0.4999999999999996}",25,0,-84.82140927085712,"{0: 19, 1: 1, 2: 3, 3: 0, 4: 2}",0 12 | 11,False,"{'a': 0.5, 'e': 0.4499999999999996}",25,0,-25.103553666871605,"{0: 19, 1: 3, 2: 3, 3: 0, 4: 0}",1 13 | 12,False,"{'a': 0.5, 'e': 0.39999999999999963}",20,0,-28.16709246761365,"{0: 17, 1: 1, 2: 1, 3: 0, 4: 1}",0 14 | 13,False,"{'a': 0.5, 'e': 0.34999999999999964}",20,5,-29.19961200493926,"{0: 13, 1: 0, 2: 1, 3: 0, 4: 1}",1 15 | 14,False,"{'a': 0.5, 'e': 0.29999999999999966}",20,0,-2.918593524608614,"{0: 16, 1: 3, 2: 1, 3: 0, 4: 0}",0 16 | 15,False,"{'a': 0.5, 'e': 0.24999999999999967}",25,12,21.289498513974756,"{0: 13, 1: 0, 2: 0, 3: 0, 4: 0}",1 17 | 16,False,"{'a': 0.5, 'e': 0.19999999999999968}",20,0,-50.21845680672953,"{0: 18, 1: 0, 2: 0, 3: 0, 4: 2}",0 18 | 17,False,"{'a': 0.5, 'e': 0.1499999999999997}",20,16,10.304394010758255,"{0: 4, 1: 0, 2: 0, 3: 0, 4: 0}",1 19 | 18,False,"{'a': 0.5, 'e': 0.09999999999999969}",20,10,20.80318001076089,"{0: 10, 1: 0, 2: 0, 3: 0, 4: 0}",1 20 | 19,False,"{'a': 0.5, 'e': 0.049999999999999684}",25,7,12.021487368019777,"{0: 17, 1: 0, 2: 0, 3: 1, 4: 0}",1 21 | 20,False,"{'a': 0.5, 'e': -3.191891195797325e-16}",25,13,22.44349872235844,"{0: 12, 1: 0, 2: 0, 3: 0, 4: 0}",1 22 | 1,True,"{'a': 0, 'e': 0}",25,13,23.591695430775612,"{0: 12, 1: 0, 2: 0, 3: 0, 4: 0}",1 23 | 2,True,"{'a': 0, 'e': 0}",30,12,31.876211003456852,"{0: 18, 1: 0, 2: 0, 3: 0, 4: 0}",1 24 | 3,True,"{'a': 0, 'e': 0}",25,0,-25.997937197119185,"{0: 23, 1: 0, 2: 0, 3: 1, 4: 1}",0 25 | 4,True,"{'a': 0, 'e': 0}",20,9,19.688356652330302,"{0: 11, 1: 0, 2: 0, 3: 0, 4: 0}",1 26 | 5,True,"{'a': 0, 'e': 0}",25,12,25.61534719374939,"{0: 13, 1: 0, 2: 0, 3: 0, 4: 0}",1 27 | 6,True,"{'a': 0, 'e': 0}",35,12,44.99543335703007,"{0: 23, 1: 0, 2: 0, 3: 0, 4: 0}",1 28 | 7,True,"{'a': 0, 'e': 0}",25,15,18.10519012016503,"{0: 10, 1: 0, 2: 0, 3: 0, 4: 0}",1 29 | 8,True,"{'a': 0, 'e': 0}",25,9,6.303381484449328,"{0: 15, 1: 0, 2: 0, 3: 1, 4: 0}",1 30 | 9,True,"{'a': 0, 'e': 0}",25,15,22.30262088856527,"{0: 10, 1: 0, 2: 0, 3: 0, 4: 0}",1 31 | 10,True,"{'a': 0, 'e': 0}",25,3,31.601188162330747,"{0: 22, 1: 0, 2: 0, 3: 0, 4: 0}",1 32 | -------------------------------------------------------------------------------- /smartcab/logs/sim_default-learning.txt: -------------------------------------------------------------------------------- 1 | /----------------------------------------- 2 | | State-action rewards from Q-Learning 3 | \----------------------------------------- 4 | 5 | ('right', 'green', False, False) 6 | -- forward : 0.90 7 | -- None : -3.63 8 | -- right : 1.72 9 | -- left : 0.49 10 | 11 | ('left', 'red', False, False) 12 | -- forward : 0.00 13 | -- None : 1.83 14 | -- right : 1.09 15 | -- left : -10.26 16 | 17 | ('right', 'green', True, False) 18 | -- forward : 0.28 19 | -- None : -3.54 20 | -- right : 1.72 21 | -- left : -9.57 22 | 23 | ('left', 'red', True, False) 24 | -- forward : -8.90 25 | -- None : 1.10 26 | -- right : 0.36 27 | -- left : 0.00 28 | 29 | ('left', 'green', True, True) 30 | -- forward : 0.00 31 | -- None : 0.00 32 | -- right : 0.06 33 | -- left : 0.00 34 | 35 | ('forward', 'green', True, True) 36 | -- forward : 0.00 37 | -- None : 0.00 38 | -- right : 0.15 39 | -- left : 0.00 40 | 41 | ('forward', 'red', False, True) 42 | -- forward : 0.00 43 | -- None : 1.73 44 | -- right : -9.64 45 | -- left : -29.85 46 | 47 | ('right', 'green', False, True) 48 | -- forward : 0.70 49 | -- None : 0.00 50 | -- right : 0.00 51 | -- left : 0.05 52 | 53 | ('right', 'red', True, True) 54 | -- forward : 0.00 55 | -- None : 2.08 56 | -- right : 0.00 57 | -- left : -19.85 58 | 59 | ('left', 'red', False, True) 60 | -- forward : -34.97 61 | -- None : 2.05 62 | -- right : -4.45 63 | -- left : -29.98 64 | 65 | ('left', 'green', False, False) 66 | -- forward : 0.59 67 | -- None : -3.26 68 | -- right : 1.20 69 | -- left : 2.29 70 | 71 | ('left', 'green', True, False) 72 | -- forward : 0.00 73 | -- None : -5.45 74 | -- right : 0.96 75 | -- left : 0.00 76 | 77 | ('forward', 'red', True, False) 78 | -- forward : -5.24 79 | -- None : 1.51 80 | -- right : 0.90 81 | -- left : 0.00 82 | 83 | ('forward', 'red', False, False) 84 | -- forward : -9.56 85 | -- None : 1.76 86 | -- right : 1.21 87 | -- left : -9.67 88 | 89 | ('forward', 'green', False, True) 90 | -- forward : 1.04 91 | -- None : -2.25 92 | -- right : 0.52 93 | -- left : 0.53 94 | 95 | ('right', 'red', False, False) 96 | -- forward : -9.61 97 | -- None : 2.06 98 | -- right : 0.96 99 | -- left : -10.42 100 | 101 | ('right', 'red', True, False) 102 | -- forward : -5.00 103 | -- None : 0.00 104 | -- right : 1.60 105 | -- left : -4.65 106 | 107 | ('right', 'green', True, True) 108 | -- forward : 0.00 109 | -- None : 0.00 110 | -- right : 0.00 111 | -- left : -9.93 112 | 113 | ('left', 'green', False, True) 114 | -- forward : 0.26 115 | -- None : -3.80 116 | -- right : 0.28 117 | -- left : 0.95 118 | 119 | ('left', 'red', True, True) 120 | -- forward : -19.51 121 | -- None : 0.00 122 | -- right : 0.00 123 | -- left : 0.00 124 | 125 | ('forward', 'red', True, True) 126 | -- forward : -19.91 127 | -- None : 0.00 128 | -- right : 0.00 129 | -- left : -30.35 130 | 131 | ('forward', 'green', True, False) 132 | -- forward : 1.63 133 | -- None : -4.39 134 | -- right : 0.00 135 | -- left : -9.75 136 | 137 | ('forward', 'green', False, False) 138 | -- forward : 1.53 139 | -- None : -4.92 140 | -- right : 0.31 141 | -- left : 1.06 142 | 143 | ('right', 'red', False, True) 144 | -- forward : -34.78 145 | -- None : 0.63 146 | -- right : -9.45 147 | -- left : 0.00 148 | 149 | -------------------------------------------------------------------------------- /smartcab/logs/sim_no-learning.csv: -------------------------------------------------------------------------------- 1 | trial,testing,parameters,initial_deadline,final_deadline,net_reward,actions,success 2 | 1,False,"{'a': 0.5, 'e': 1.0}",20,0,-93.26123797920842,"{0: 12, 1: 2, 2: 4, 3: 1, 4: 1}",0 3 | 2,False,"{'a': 0.5, 'e': 1.0}",20,0,-93.77395441962415,"{0: 13, 1: 4, 2: 1, 3: 0, 4: 2}",0 4 | 3,False,"{'a': 0.5, 'e': 1.0}",20,0,-113.51560440630828,"{0: 14, 1: 0, 2: 3, 3: 1, 4: 2}",0 5 | 4,False,"{'a': 0.5, 'e': 1.0}",25,7,-18.289133567400253,"{0: 15, 1: 1, 2: 1, 3: 1, 4: 0}",1 6 | 5,False,"{'a': 0.5, 'e': 1.0}",20,0,-92.99727611241221,"{0: 12, 1: 2, 2: 4, 3: 1, 4: 1}",0 7 | 6,False,"{'a': 0.5, 'e': 1.0}",30,0,-217.88590661448606,"{0: 14, 1: 3, 2: 10, 3: 0, 4: 3}",0 8 | 7,False,"{'a': 0.5, 'e': 1.0}",25,0,-84.28899817200508,"{0: 16, 1: 3, 2: 5, 3: 0, 4: 1}",0 9 | 8,False,"{'a': 0.5, 'e': 1.0}",20,0,-54.38409527191148,"{0: 11, 1: 5, 2: 4, 3: 0, 4: 0}",0 10 | 9,False,"{'a': 0.5, 'e': 1.0}",30,0,-160.13202618839966,"{0: 20, 1: 3, 2: 3, 3: 1, 4: 3}",0 11 | 10,False,"{'a': 0.5, 'e': 1.0}",25,0,-78.39060580203846,"{0: 20, 1: 1, 2: 1, 3: 2, 4: 1}",0 12 | 11,False,"{'a': 0.5, 'e': 1.0}",20,0,-32.28564143160174,"{0: 15, 1: 1, 2: 4, 3: 0, 4: 0}",0 13 | 12,False,"{'a': 0.5, 'e': 1.0}",30,0,-251.8830454840801,"{0: 18, 1: 1, 2: 4, 3: 3, 4: 4}",0 14 | 13,False,"{'a': 0.5, 'e': 1.0}",20,0,-101.42265402380453,"{0: 11, 1: 1, 2: 7, 3: 0, 4: 1}",0 15 | 14,False,"{'a': 0.5, 'e': 1.0}",30,2,-244.1098376976274,"{0: 14, 1: 2, 2: 7, 3: 1, 4: 4}",1 16 | 15,False,"{'a': 0.5, 'e': 1.0}",20,0,-78.52448372238419,"{0: 14, 1: 1, 2: 3, 3: 1, 4: 1}",0 17 | 16,False,"{'a': 0.5, 'e': 1.0}",25,0,-112.07939409851832,"{0: 17, 1: 1, 2: 5, 3: 0, 4: 2}",0 18 | 17,False,"{'a': 0.5, 'e': 1.0}",20,0,-68.47204830674838,"{0: 9, 1: 8, 2: 2, 3: 1, 4: 0}",0 19 | 18,False,"{'a': 0.5, 'e': 1.0}",25,0,-138.69640533540485,"{0: 16, 1: 0, 2: 6, 3: 1, 4: 2}",0 20 | 19,False,"{'a': 0.5, 'e': 1.0}",25,17,3.0168561933500504,"{0: 7, 1: 0, 2: 1, 3: 0, 4: 0}",1 21 | 20,False,"{'a': 0.5, 'e': 1.0}",25,0,-121.22725948341478,"{0: 17, 1: 2, 2: 3, 3: 1, 4: 2}",0 22 | 1,True,"{'a': 0.5, 'e': 1.0}",20,0,-18.31492248877889,"{0: 15, 1: 3, 2: 2, 3: 0, 4: 0}",0 23 | 2,True,"{'a': 0.5, 'e': 1.0}",20,0,-123.63704697875012,"{0: 9, 1: 2, 2: 8, 3: 0, 4: 1}",0 24 | 3,True,"{'a': 0.5, 'e': 1.0}",20,0,-58.23968261471943,"{0: 14, 1: 2, 2: 3, 3: 0, 4: 1}",0 25 | 4,True,"{'a': 0.5, 'e': 1.0}",20,0,-44.78428957901549,"{0: 16, 1: 2, 2: 1, 3: 0, 4: 1}",0 26 | 5,True,"{'a': 0.5, 'e': 1.0}",25,0,-56.78857487544803,"{0: 18, 1: 0, 2: 6, 3: 1, 4: 0}",0 27 | 6,True,"{'a': 0.5, 'e': 1.0}",30,0,-234.07867260270754,"{0: 16, 1: 3, 2: 6, 3: 1, 4: 4}",0 28 | 7,True,"{'a': 0.5, 'e': 1.0}",25,0,-143.1088671029654,"{0: 13, 1: 4, 2: 6, 3: 0, 4: 2}",0 29 | 8,True,"{'a': 0.5, 'e': 1.0}",20,0,-182.78659171952157,"{0: 9, 1: 3, 2: 4, 3: 1, 4: 3}",0 30 | 9,True,"{'a': 0.5, 'e': 1.0}",25,5,-69.11794536993459,"{0: 14, 1: 0, 2: 5, 3: 0, 4: 1}",1 31 | 10,True,"{'a': 0.5, 'e': 1.0}",20,5,-35.64511240661868,"{0: 13, 1: 0, 2: 1, 3: 0, 4: 1}",1 32 | -------------------------------------------------------------------------------- /smartcab/project_description.md: -------------------------------------------------------------------------------- 1 | # Content: Reinforcement Learning 2 | ## Project: Train a Smartcab How to Drive 3 | 4 | ## Project Overview 5 | 6 | In this project you will apply reinforcement learning techniques for a self-driving agent in a simplified world to aid it in effectively reaching its destinations in the allotted time. You will first investigate the environment the agent operates in by constructing a very basic driving implementation. Once your agent is successful at operating within the environment, you will then identify each possible state the agent can be in when considering such things as traffic lights and oncoming traffic at each intersection. With states identified, you will then implement a Q-Learning algorithm for the self-driving agent to guide the agent towards its destination within the allotted time. Finally, you will improve upon the Q-Learning algorithm to find the best configuration of learning and exploration factors to ensure the self-driving agent is reaching its destinations with consistently positive results. 7 | 8 | ## Description 9 | In the not-so-distant future, taxicab companies across the United States no longer employ human drivers to operate their fleet of vehicles. Instead, the taxicabs are operated by self-driving agents, known as *smartcabs*, to transport people from one location to another within the cities those companies operate. In major metropolitan areas, such as Chicago, New York City, and San Francisco, an increasing number of people have come to depend on *smartcabs* to get to where they need to go as safely and reliably as possible. Although *smartcabs* have become the transport of choice, concerns have arose that a self-driving agent might not be as safe or reliable as human drivers, particularly when considering city traffic lights and other vehicles. To alleviate these concerns, your task as an employee for a national taxicab company is to use reinforcement learning techniques to construct a demonstration of a *smartcab* operating in real-time to prove that both safety and reliability can be achieved. 10 | 11 | ## Software Requirements 12 | This project uses the following software and Python libraries: 13 | 14 | - [Python 2.7](https://www.python.org/download/releases/2.7/) 15 | - [NumPy](http://www.numpy.org/) 16 | - [pandas](http://pandas.pydata.org/) 17 | - [matplotlib](http://matplotlib.org/) 18 | - [PyGame](http://pygame.org/) 19 | 20 | If you do not have Python installed yet, it is highly recommended that you install the [Anaconda](http://continuum.io/downloads) distribution of Python, which already has the above packages and more included. Make sure that you select the Python 2.7 installer and not the Python 3.x installer. `pygame` can then be installed using one of the following commands: 21 | 22 | Mac: `conda install -c https://conda.anaconda.org/quasiben pygame` 23 | Windows: `conda install -c https://conda.anaconda.org/tlatorre pygame` 24 | Linux: `conda install -c https://conda.anaconda.org/prkrekel pygame` 25 | 26 | ## Fixing Common PyGame Problems 27 | 28 | The PyGame library can in some cases require a bit of troubleshooting to work correctly for this project. While the PyGame aspect of the project is not required for a successful submission (you can complete the project without a visual simulation, although it is more difficult), it is very helpful to have it working! If you encounter an issue with PyGame, first see these helpful links below that are developed by communities of users working with the library: 29 | - [Getting Started](https://www.pygame.org/wiki/GettingStarted) 30 | - [PyGame Information](http://www.pygame.org/wiki/info) 31 | - [Google Group](https://groups.google.com/forum/#!forum/pygame-mirror-on-google-groups) 32 | - [PyGame subreddit](https://www.reddit.com/r/pygame/) 33 | 34 | ### Problems most often reported by students 35 | _"PyGame won't install on my machine; there was an issue with the installation."_ 36 | **Solution:** As has been recommended for previous projects, Udacity suggests that you are using the Anaconda distribution of Python, which can then allow you to install PyGame through the `conda`-specific command. 37 | 38 | _"I'm seeing a black screen when running the code; output says that it can't load car images."_ 39 | **Solution:** The code will not operate correctly unless it is run from the top-level directory for `smartcab`. The top-level directory is the one that contains the **README** and the project notebook. 40 | 41 | If you continue to have problems with the project code in regards to PyGame, you can also [use the discussion forums](https://discussions.udacity.com/c/nd009-reinforcement-learning) to find posts from students that encountered issues that you may be experiencing. Additionally, you can seek help from a swath of students in the [MLND Student Slack Community](http://mlnd.slack.com). 42 | 43 | ## Starting the Project 44 | 45 | For this assignment, you can find the `smartcab` folder containing the necessary project files on the [Machine Learning projects GitHub](https://github.com/udacity/machine-learning), under the `projects` folder. You may download all of the files for projects we'll use in this Nanodegree program directly from this repo. Please make sure that you use the most recent version of project files when completing a project! 46 | 47 | This project contains three directories: 48 | 49 | - `/logs/`: This folder will contain all log files that are given from the simulation when specific prerequisites are met. 50 | - `/images/`: This folder contains various images of cars to be used in the graphical user interface. You will not need to modify or create any files in this directory. 51 | - `/smartcab/`: This folder contains the Python scripts that create the environment, graphical user interface, the simulation, and the agents. You will not need to modify or create any files in this directory except for `agent.py`. 52 | 53 | It also contains two files: 54 | - `smartcab.ipynb`: This is the main file where you will answer questions and provide an analysis for your work. 55 | -`visuals.py`: This Python script provides supplementary visualizations for the analysis. Do not modify. 56 | 57 | Finally, in `/smartcab/` are the following four files: 58 | - **Modify:** 59 | - `agent.py`: This is the main Python file where you will be performing your work on the project. 60 | - **Do not modify:** 61 | - `environment.py`: This Python file will create the *smartcab* environment. 62 | - `planner.py`: This Python file creates a high-level planner for the agent to follow towards a set goal. 63 | - `simulation.py`: This Python file creates the simulation and graphical user interface. 64 | 65 | ### Running the Code 66 | In a terminal or command window, navigate to the top-level project directory `smartcab/` (that contains the two project directories) and run one of the following commands: 67 | 68 | `python smartcab/agent.py` or 69 | `python -m smartcab.agent` 70 | 71 | This will run the `agent.py` file and execute your implemented agent code into the environment. Additionally, use the command `jupyter notebook smartcab.ipynb` from this same directory to open up a browser window or tab to work with your analysis notebook. Alternatively, you can use the command `jupyter notebook` or `ipython notebook` and navigate to the notebook file in the browser window that opens. Follow the instructions in the notebook and answer each question presented to successfully complete the implementation necessary for your `agent.py` agent file. A **README** file has also been provided with the project files which may contain additional necessary information or instruction for the project. 72 | 73 | ## Definitions 74 | 75 | ### Environment 76 | The *smartcab* operates in an ideal, grid-like city (similar to New York City), with roads going in the North-South and East-West directions. Other vehicles will certainly be present on the road, but there will be no pedestrians to be concerned with. At each intersection there is a traffic light that either allows traffic in the North-South direction or the East-West direction. U.S. Right-of-Way rules apply: 77 | - On a green light, a left turn is permitted if there is no oncoming traffic making a right turn or coming straight through the intersection. 78 | - On a red light, a right turn is permitted if no oncoming traffic is approaching from your left through the intersection. 79 | To understand how to correctly yield to oncoming traffic when turning left, you may refer to [this official drivers? education video](https://www.youtube.com/watch?v=TW0Eq2Q-9Ac), or [this passionate exposition](https://www.youtube.com/watch?v=0EdkxI6NeuA). 80 | 81 | ### Inputs and Outputs 82 | Assume that the *smartcab* is assigned a route plan based on the passengers? starting location and destination. The route is split at each intersection into waypoints, and you may assume that the *smartcab*, at any instant, is at some intersection in the world. Therefore, the next waypoint to the destination, assuming the destination has not already been reached, is one intersection away in one direction (North, South, East, or West). The *smartcab* has only an egocentric view of the intersection it is at: It can determine the state of the traffic light for its direction of movement, and whether there is a vehicle at the intersection for each of the oncoming directions. For each action, the *smartcab* may either idle at the intersection, or drive to the next intersection to the left, right, or ahead of it. Finally, each trip has a time to reach the destination which decreases for each action taken (the passengers want to get there quickly). If the allotted time becomes zero before reaching the destination, the trip has failed. 83 | 84 | ### Rewards and Goal 85 | The *smartcab* will receive positive or negative rewards based on the action it as taken. Expectedly, the *smartcab* will receive a small positive reward when making a good action, and a varying amount of negative reward dependent on the severity of the traffic violation it would have committed. Based on the rewards and penalties the *smartcab* receives, the self-driving agent implementation should learn an optimal policy for driving on the city roads while obeying traffic rules, avoiding accidents, and reaching passengers? destinations in the allotted time. 86 | 87 | ## Submitting the Project 88 | 89 | ### Evaluation 90 | Your project will be reviewed by a Udacity reviewer against the **Train a Smartcab to Drive project rubric**. Be sure to review this rubric thoroughly and self-evaluate your project before submission. All criteria found in the rubric must be *meeting specifications* for you to pass. 91 | 92 | ### Submission Files 93 | When you are ready to submit your project, collect the following files and compress them into a single archive for upload. Alternatively, you may supply the following files on your GitHub Repo in a folder named `smartcab` for ease of access: 94 | - The `agent.py` Python file with all code implemented as required in the instructed tasks. 95 | - The `/logs/` folder which should contain **five** log files that were produced from your simulation and used in the analysis. 96 | - The `smartcab.ipynb` notebook file with all questions answered and all visualization cells executed and displaying results. 97 | - An **HTML** export of the project notebook with the name **report.html**. This file *must* be present for your project to be evaluated. 98 | 99 | Once you have collected these files and reviewed the project rubric, proceed to the project submission page. 100 | 101 | ### I'm Ready! 102 | When you're ready to submit your project, click on the **Submit Project** button at the bottom of the page. 103 | 104 | If you are having any problems submitting your project or wish to check on the status of your submission, please email us at **machine-support@udacity.com** or visit us in the discussion forums. 105 | 106 | ### What's Next? 107 | You will get an email as soon as your reviewer has feedback for you. In the meantime, review your next project and feel free to get started on it or the courses supporting it! -------------------------------------------------------------------------------- /smartcab/smartcab.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab.zip -------------------------------------------------------------------------------- /smartcab/smartcab/agent.py: -------------------------------------------------------------------------------- 1 | import random 2 | import math 3 | from environment import Agent, Environment 4 | from planner import RoutePlanner 5 | from simulator import Simulator 6 | 7 | class LearningAgent(Agent): 8 | """ An agent that learns to drive in the Smartcab world. 9 | This is the object you will be modifying. """ 10 | 11 | def __init__(self, env, learning=False, epsilon=1.0, alpha=0.5): 12 | super(LearningAgent, self).__init__(env) # Set the agent in the evironment 13 | self.planner = RoutePlanner(self.env, self) # Create a route planner 14 | self.valid_actions = self.env.valid_actions # The set of valid actions 15 | 16 | # Set parameters of the learning agent 17 | self.learning = learning # Whether the agent is expected to learn 18 | self.Q = dict() # Create a Q-table which will be a dictionary of tuples 19 | self.epsilon = epsilon # Random exploration factor 20 | self.alpha = alpha # Learning factor 21 | 22 | ########### 23 | ## TO DO ## 24 | ########### 25 | # Set any additional class parameters as needed 26 | 27 | 28 | def reset(self, destination=None, testing=False): 29 | """ The reset function is called at the beginning of each trial. 30 | 'testing' is set to True if testing trials are being used 31 | once training trials have completed. """ 32 | 33 | # Select the destination as the new location to route to 34 | self.planner.route_to(destination) 35 | 36 | ########### 37 | ## TO DO ## 38 | ########### 39 | # Update epsilon using a decay function of your choice 40 | # Update additional class parameters as needed 41 | # If 'testing' is True, set epsilon and alpha to 0 42 | 43 | return None 44 | 45 | def build_state(self): 46 | """ The build_state function is called when the agent requests data from the 47 | environment. The next waypoint, the intersection inputs, and the deadline 48 | are all features available to the agent. """ 49 | 50 | # Collect data about the environment 51 | waypoint = self.planner.next_waypoint() # The next waypoint 52 | inputs = self.env.sense(self) # Visual input - intersection light and traffic 53 | deadline = self.env.get_deadline(self) # Remaining deadline 54 | 55 | ########### 56 | ## TO DO ## 57 | ########### 58 | # Set 'state' as a tuple of relevant data for the agent 59 | state = None 60 | 61 | return state 62 | 63 | 64 | def get_maxQ(self, state): 65 | """ The get_max_Q function is called when the agent is asked to find the 66 | maximum Q-value of all actions based on the 'state' the smartcab is in. """ 67 | 68 | ########### 69 | ## TO DO ## 70 | ########### 71 | # Calculate the maximum Q-value of all actions for a given state 72 | 73 | maxQ = None 74 | 75 | return maxQ 76 | 77 | 78 | def createQ(self, state): 79 | """ The createQ function is called when a state is generated by the agent. """ 80 | 81 | ########### 82 | ## TO DO ## 83 | ########### 84 | # When learning, check if the 'state' is not in the Q-table 85 | # If it is not, create a new dictionary for that state 86 | # Then, for each action available, set the initial Q-value to 0.0 87 | 88 | return 89 | 90 | 91 | def choose_action(self, state): 92 | """ The choose_action function is called when the agent is asked to choose 93 | which action to take, based on the 'state' the smartcab is in. """ 94 | 95 | # Set the agent state and default action 96 | self.state = state 97 | self.next_waypoint = self.planner.next_waypoint() 98 | action = None 99 | 100 | ########### 101 | ## TO DO ## 102 | ########### 103 | # When not learning, choose a random action 104 | # When learning, choose a random action with 'epsilon' probability 105 | # Otherwise, choose an action with the highest Q-value for the current state 106 | 107 | return action 108 | 109 | 110 | def learn(self, state, action, reward): 111 | """ The learn function is called after the agent completes an action and 112 | receives an award. This function does not consider future rewards 113 | when conducting learning. """ 114 | 115 | ########### 116 | ## TO DO ## 117 | ########### 118 | # When learning, implement the value iteration update rule 119 | # Use only the learning rate 'alpha' (do not use the discount factor 'gamma') 120 | 121 | return 122 | 123 | 124 | def update(self): 125 | """ The update function is called when a time step is completed in the 126 | environment for a given trial. This function will build the agent 127 | state, choose an action, receive a reward, and learn if enabled. """ 128 | 129 | state = self.build_state() # Get current state 130 | self.createQ(state) # Create 'state' in Q-table 131 | action = self.choose_action(state) # Choose an action 132 | reward = self.env.act(self, action) # Receive a reward 133 | self.learn(state, action, reward) # Q-learn 134 | 135 | return 136 | 137 | 138 | def run(): 139 | """ Driving function for running the simulation. 140 | Press ESC to close the simulation, or [SPACE] to pause the simulation. """ 141 | 142 | ############## 143 | # Create the environment 144 | # Flags: 145 | # verbose - set to True to display additional output from the simulation 146 | # num_dummies - discrete number of dummy agents in the environment, default is 100 147 | # grid_size - discrete number of intersections (columns, rows), default is (8, 6) 148 | env = Environment() 149 | 150 | ############## 151 | # Create the driving agent 152 | # Flags: 153 | # learning - set to True to force the driving agent to use Q-learning 154 | # * epsilon - continuous value for the exploration factor, default is 1 155 | # * alpha - continuous value for the learning rate, default is 0.5 156 | agent = env.create_agent(LearningAgent) 157 | 158 | ############## 159 | # Follow the driving agent 160 | # Flags: 161 | # enforce_deadline - set to True to enforce a deadline metric 162 | env.set_primary_agent(agent) 163 | 164 | ############## 165 | # Create the simulation 166 | # Flags: 167 | # update_delay - continuous time (in seconds) between actions, default is 2.0 seconds 168 | # display - set to False to disable the GUI if PyGame is enabled 169 | # log_metrics - set to True to log trial and simulation results to /logs 170 | # optimized - set to True to change the default log file name 171 | sim = Simulator(env) 172 | 173 | ############## 174 | # Run the simulator 175 | # Flags: 176 | # tolerance - epsilon tolerance before beginning testing, default is 0.05 177 | # n_test - discrete number of testing trials to perform, default is 0 178 | sim.run() 179 | 180 | 181 | if __name__ == '__main__': 182 | run() -------------------------------------------------------------------------------- /smartcab/smartcab/environment.py: -------------------------------------------------------------------------------- 1 | import time 2 | import random 3 | import math 4 | from collections import OrderedDict 5 | from simulator import Simulator 6 | 7 | 8 | class TrafficLight(object): 9 | """A traffic light that switches periodically.""" 10 | 11 | valid_states = [True, False] # True = NS open; False = EW open 12 | 13 | def __init__(self, state=None, period=None): 14 | self.state = state if state is not None else random.choice(self.valid_states) 15 | self.period = period if period is not None else random.choice([2, 3, 4, 5]) 16 | self.last_updated = 0 17 | 18 | def reset(self): 19 | self.last_updated = 0 20 | 21 | def update(self, t): 22 | if t - self.last_updated >= self.period: 23 | self.state = not self.state # Assuming state is boolean 24 | self.last_updated = t 25 | 26 | 27 | class Environment(object): 28 | """Environment within which all agents operate.""" 29 | 30 | valid_actions = [None, 'forward', 'left', 'right'] 31 | valid_inputs = {'light': TrafficLight.valid_states, 'oncoming': valid_actions, 'left': valid_actions, 'right': valid_actions} 32 | valid_headings = [(1, 0), (0, -1), (-1, 0), (0, 1)] # E, N, W, S 33 | hard_time_limit = -100 # Set a hard time limit even if deadline is not enforced. 34 | 35 | def __init__(self, verbose=False, num_dummies=100, grid_size = (8, 6)): 36 | self.num_dummies = num_dummies # Number of dummy driver agents in the environment 37 | self.verbose = verbose # If debug output should be given 38 | 39 | # Initialize simulation variables 40 | self.done = False 41 | self.t = 0 42 | self.agent_states = OrderedDict() 43 | self.step_data = {} 44 | self.success = None 45 | 46 | # Road network 47 | self.grid_size = grid_size # (columns, rows) 48 | self.bounds = (1, 2, self.grid_size[0], self.grid_size[1] + 1) 49 | self.block_size = 100 50 | self.hang = 0.6 51 | self.intersections = OrderedDict() 52 | self.roads = [] 53 | for x in xrange(self.bounds[0], self.bounds[2] + 1): 54 | for y in xrange(self.bounds[1], self.bounds[3] + 1): 55 | self.intersections[(x, y)] = TrafficLight() # A traffic light at each intersection 56 | 57 | for a in self.intersections: 58 | for b in self.intersections: 59 | if a == b: 60 | continue 61 | if (abs(a[0] - b[0]) + abs(a[1] - b[1])) == 1: # L1 distance = 1 62 | self.roads.append((a, b)) 63 | 64 | # Add environment boundaries 65 | for x in xrange(self.bounds[0], self.bounds[2] + 1): 66 | self.roads.append(((x, self.bounds[1] - self.hang), (x, self.bounds[1]))) 67 | self.roads.append(((x, self.bounds[3] + self.hang), (x, self.bounds[3]))) 68 | for y in xrange(self.bounds[1], self.bounds[3] + 1): 69 | self.roads.append(((self.bounds[0] - self.hang, y), (self.bounds[0], y))) 70 | self.roads.append(((self.bounds[2] + self.hang, y), (self.bounds[2], y))) 71 | 72 | # Create dummy agents 73 | for i in xrange(self.num_dummies): 74 | self.create_agent(DummyAgent) 75 | 76 | # Primary agent and associated parameters 77 | self.primary_agent = None # to be set explicitly 78 | self.enforce_deadline = False 79 | 80 | # Trial data (updated at the end of each trial) 81 | self.trial_data = { 82 | 'testing': False, # if the trial is for testing a learned policy 83 | 'initial_distance': 0, # L1 distance from start to destination 84 | 'initial_deadline': 0, # given deadline (time steps) to start with 85 | 'net_reward': 0.0, # total reward earned in current trial 86 | 'final_deadline': None, # deadline value (time remaining) at the end 87 | 'actions': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}, # violations and accidents 88 | 'success': 0 # whether the agent reached the destination in time 89 | } 90 | 91 | def create_agent(self, agent_class, *args, **kwargs): 92 | """ When called, create_agent creates an agent in the environment. """ 93 | 94 | agent = agent_class(self, *args, **kwargs) 95 | self.agent_states[agent] = {'location': random.choice(self.intersections.keys()), 'heading': (0, 1)} 96 | return agent 97 | 98 | def set_primary_agent(self, agent, enforce_deadline=False): 99 | """ When called, set_primary_agent sets 'agent' as the primary agent. 100 | The primary agent is the smartcab that is followed in the environment. """ 101 | 102 | self.primary_agent = agent 103 | agent.primary_agent = True 104 | self.enforce_deadline = enforce_deadline 105 | 106 | def reset(self, testing=False): 107 | """ This function is called at the beginning of a new trial. """ 108 | 109 | self.done = False 110 | self.t = 0 111 | 112 | # Reset status text 113 | self.step_data = {} 114 | 115 | # Reset traffic lights 116 | for traffic_light in self.intersections.itervalues(): 117 | traffic_light.reset() 118 | 119 | # Pick a start and a destination 120 | start = random.choice(self.intersections.keys()) 121 | destination = random.choice(self.intersections.keys()) 122 | 123 | # Ensure starting location and destination are not too close 124 | while self.compute_dist(start, destination) < 4: 125 | start = random.choice(self.intersections.keys()) 126 | destination = random.choice(self.intersections.keys()) 127 | 128 | start_heading = random.choice(self.valid_headings) 129 | distance = self.compute_dist(start, destination) 130 | deadline = distance * 5 # 5 time steps per intersection away 131 | if(self.verbose == True): # Debugging 132 | print "Environment.reset(): Trial set up with start = {}, destination = {}, deadline = {}".format(start, destination, deadline) 133 | 134 | # Create a map of all possible initial positions 135 | positions = dict() 136 | for location in self.intersections: 137 | positions[location] = list() 138 | for heading in self.valid_headings: 139 | positions[location].append(heading) 140 | 141 | # Initialize agent(s) 142 | for agent in self.agent_states.iterkeys(): 143 | 144 | if agent is self.primary_agent: 145 | self.agent_states[agent] = { 146 | 'location': start, 147 | 'heading': start_heading, 148 | 'destination': destination, 149 | 'deadline': deadline 150 | } 151 | # For dummy agents, make them choose one of the available 152 | # intersections and headings still in 'positions' 153 | else: 154 | intersection = random.choice(positions.keys()) 155 | heading = random.choice(positions[intersection]) 156 | self.agent_states[agent] = { 157 | 'location': intersection, 158 | 'heading': heading, 159 | 'destination': None, 160 | 'deadline': None 161 | } 162 | # Now delete the taken location and heading from 'positions' 163 | positions[intersection] = list(set(positions[intersection]) - set([heading])) 164 | if positions[intersection] == list(): # No headings available for intersection 165 | del positions[intersection] # Delete the intersection altogether 166 | 167 | 168 | agent.reset(destination=(destination if agent is self.primary_agent else None), testing=testing) 169 | if agent is self.primary_agent: 170 | # Reset metrics for this trial (step data will be set during the step) 171 | self.trial_data['testing'] = testing 172 | self.trial_data['initial_deadline'] = deadline 173 | self.trial_data['final_deadline'] = deadline 174 | self.trial_data['net_reward'] = 0.0 175 | self.trial_data['actions'] = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0} 176 | self.trial_data['parameters'] = {'e': agent.epsilon, 'a': agent.alpha} 177 | self.trial_data['success'] = 0 178 | 179 | def step(self): 180 | """ This function is called when a time step is taken turing a trial. """ 181 | 182 | # Pretty print to terminal 183 | print "" 184 | print "/-------------------" 185 | print "| Step {} Results".format(self.t) 186 | print "\-------------------" 187 | print "" 188 | 189 | if(self.verbose == True): # Debugging 190 | print "Environment.step(): t = {}".format(self.t) 191 | 192 | # Update agents, primary first 193 | if self.primary_agent is not None: 194 | self.primary_agent.update() 195 | 196 | for agent in self.agent_states.iterkeys(): 197 | if agent is not self.primary_agent: 198 | agent.update() 199 | 200 | # Update traffic lights 201 | for intersection, traffic_light in self.intersections.iteritems(): 202 | traffic_light.update(self.t) 203 | 204 | if self.primary_agent is not None: 205 | # Agent has taken an action: reduce the deadline by 1 206 | agent_deadline = self.agent_states[self.primary_agent]['deadline'] - 1 207 | self.agent_states[self.primary_agent]['deadline'] = agent_deadline 208 | 209 | if agent_deadline <= self.hard_time_limit: 210 | self.done = True 211 | self.success = False 212 | if self.verbose: # Debugging 213 | print "Environment.step(): Primary agent hit hard time limit ({})! Trial aborted.".format(self.hard_time_limit) 214 | elif self.enforce_deadline and agent_deadline <= 0: 215 | self.done = True 216 | self.success = False 217 | if self.verbose: # Debugging 218 | print "Environment.step(): Primary agent ran out of time! Trial aborted." 219 | 220 | self.t += 1 221 | 222 | def sense(self, agent): 223 | """ This function is called when information is requested about the sensor 224 | inputs from an 'agent' in the environment. """ 225 | 226 | assert agent in self.agent_states, "Unknown agent!" 227 | 228 | state = self.agent_states[agent] 229 | location = state['location'] 230 | heading = state['heading'] 231 | light = 'green' if (self.intersections[location].state and heading[1] != 0) or ((not self.intersections[location].state) and heading[0] != 0) else 'red' 232 | 233 | # Populate oncoming, left, right 234 | oncoming = None 235 | left = None 236 | right = None 237 | for other_agent, other_state in self.agent_states.iteritems(): 238 | if agent == other_agent or location != other_state['location'] or (heading[0] == other_state['heading'][0] and heading[1] == other_state['heading'][1]): 239 | continue 240 | # For dummy agents, ignore the primary agent 241 | # This is because the primary agent is not required to follow the waypoint 242 | if other_agent == self.primary_agent: 243 | continue 244 | other_heading = other_agent.get_next_waypoint() 245 | if (heading[0] * other_state['heading'][0] + heading[1] * other_state['heading'][1]) == -1: 246 | if oncoming != 'left': # we don't want to override oncoming == 'left' 247 | oncoming = other_heading 248 | elif (heading[1] == other_state['heading'][0] and -heading[0] == other_state['heading'][1]): 249 | if right != 'forward' and right != 'left': # we don't want to override right == 'forward or 'left' 250 | right = other_heading 251 | else: 252 | if left != 'forward': # we don't want to override left == 'forward' 253 | left = other_heading 254 | 255 | return {'light': light, 'oncoming': oncoming, 'left': left, 'right': right} 256 | 257 | def get_deadline(self, agent): 258 | """ Returns the deadline remaining for an agent. """ 259 | 260 | return self.agent_states[agent]['deadline'] if agent is self.primary_agent else None 261 | 262 | def act(self, agent, action): 263 | """ Consider an action and perform the action if it is legal. 264 | Receive a reward for the agent based on traffic laws. """ 265 | 266 | assert agent in self.agent_states, "Unknown agent!" 267 | assert action in self.valid_actions, "Invalid action!" 268 | 269 | state = self.agent_states[agent] 270 | location = state['location'] 271 | heading = state['heading'] 272 | light = 'green' if (self.intersections[location].state and heading[1] != 0) or ((not self.intersections[location].state) and heading[0] != 0) else 'red' 273 | inputs = self.sense(agent) 274 | 275 | # Assess whether the agent can move based on the action chosen. 276 | # Either the action is okay to perform, or falls under 4 types of violations: 277 | # 0: Action okay 278 | # 1: Minor traffic violation 279 | # 2: Major traffic violation 280 | # 3: Minor traffic violation causing an accident 281 | # 4: Major traffic violation causing an accident 282 | violation = 0 283 | 284 | # Reward scheme 285 | # First initialize reward uniformly random from [-1, 1] 286 | reward = 2 * random.random() - 1 287 | 288 | # Create a penalty factor as a function of remaining deadline 289 | # Scales reward multiplicatively from [0, 1] 290 | fnc = self.t * 1.0 / (self.t + state['deadline']) if agent.primary_agent else 0.0 291 | gradient = 10 292 | 293 | # No penalty given to an agent that has no enforced deadline 294 | penalty = 0 295 | 296 | # If the deadline is enforced, give a penalty based on time remaining 297 | if self.enforce_deadline: 298 | penalty = (math.pow(gradient, fnc) - 1) / (gradient - 1) 299 | 300 | # Agent wants to drive forward: 301 | if action == 'forward': 302 | if light != 'green': # Running red light 303 | violation = 2 # Major violation 304 | if inputs['left'] == 'forward' or inputs['right'] == 'forward': # Cross traffic 305 | violation = 4 # Accident 306 | 307 | # Agent wants to drive left: 308 | elif action == 'left': 309 | if light != 'green': # Running a red light 310 | violation = 2 # Major violation 311 | if inputs['left'] == 'forward' or inputs['right'] == 'forward': # Cross traffic 312 | violation = 4 # Accident 313 | elif inputs['oncoming'] == 'right': # Oncoming car turning right 314 | violation = 4 # Accident 315 | else: # Green light 316 | if inputs['oncoming'] == 'right' or inputs['oncoming'] == 'forward': # Incoming traffic 317 | violation = 3 # Accident 318 | else: # Valid move! 319 | heading = (heading[1], -heading[0]) 320 | 321 | # Agent wants to drive right: 322 | elif action == 'right': 323 | if light != 'green' and inputs['left'] == 'forward': # Cross traffic 324 | violation = 3 # Accident 325 | else: # Valid move! 326 | heading = (-heading[1], heading[0]) 327 | 328 | # Agent wants to perform no action: 329 | elif action == None: 330 | if light == 'green' and inputs['oncoming'] != 'left': # No oncoming traffic 331 | violation = 1 # Minor violation 332 | 333 | 334 | # Did the agent attempt a valid move? 335 | if violation == 0: 336 | if action == agent.get_next_waypoint(): # Was it the correct action? 337 | reward += 2 - penalty # (2, 1) 338 | elif action == None and light != 'green': # Was the agent stuck at a red light? 339 | reward += 2 - penalty # (2, 1) 340 | else: # Valid but incorrect 341 | reward += 1 - penalty # (1, 0) 342 | 343 | # Move the agent 344 | if action is not None: 345 | location = ((location[0] + heading[0] - self.bounds[0]) % (self.bounds[2] - self.bounds[0] + 1) + self.bounds[0], 346 | (location[1] + heading[1] - self.bounds[1]) % (self.bounds[3] - self.bounds[1] + 1) + self.bounds[1]) # wrap-around 347 | state['location'] = location 348 | state['heading'] = heading 349 | # Agent attempted invalid move 350 | else: 351 | if violation == 1: # Minor violation 352 | reward += -5 353 | elif violation == 2: # Major violation 354 | reward += -10 355 | elif violation == 3: # Minor accident 356 | reward += -20 357 | elif violation == 4: # Major accident 358 | reward += -40 359 | 360 | # Did agent reach the goal after a valid move? 361 | if agent is self.primary_agent: 362 | if state['location'] == state['destination']: 363 | # Did agent get to destination before deadline? 364 | if state['deadline'] >= 0: 365 | self.trial_data['success'] = 1 366 | 367 | # Stop the trial 368 | self.done = True 369 | self.success = True 370 | 371 | if(self.verbose == True): # Debugging 372 | print "Environment.act(): Primary agent has reached destination!" 373 | 374 | if(self.verbose == True): # Debugging 375 | print "Environment.act() [POST]: location: {}, heading: {}, action: {}, reward: {}".format(location, heading, action, reward) 376 | 377 | # Update metrics 378 | self.step_data['t'] = self.t 379 | self.step_data['violation'] = violation 380 | self.step_data['state'] = agent.get_state() 381 | self.step_data['deadline'] = state['deadline'] 382 | self.step_data['waypoint'] = agent.get_next_waypoint() 383 | self.step_data['inputs'] = inputs 384 | self.step_data['light'] = light 385 | self.step_data['action'] = action 386 | self.step_data['reward'] = reward 387 | 388 | self.trial_data['final_deadline'] = state['deadline'] - 1 389 | self.trial_data['net_reward'] += reward 390 | self.trial_data['actions'][violation] += 1 391 | 392 | if(self.verbose == True): # Debugging 393 | print "Environment.act(): Step data: {}".format(self.step_data) 394 | return reward 395 | 396 | def compute_dist(self, a, b): 397 | """ Compute the Manhattan (L1) distance of a spherical world. """ 398 | 399 | dx1 = abs(b[0] - a[0]) 400 | dx2 = abs(self.grid_size[0] - dx1) 401 | dx = dx1 if dx1 < dx2 else dx2 402 | 403 | dy1 = abs(b[1] - a[1]) 404 | dy2 = abs(self. 405 | grid_size[1] - dy1) 406 | dy = dy1 if dy1 < dy2 else dy2 407 | 408 | return dx + dy 409 | 410 | 411 | class Agent(object): 412 | """Base class for all agents.""" 413 | 414 | def __init__(self, env): 415 | self.env = env 416 | self.state = None 417 | self.next_waypoint = None 418 | self.color = 'white' 419 | self.primary_agent = False 420 | 421 | def reset(self, destination=None, testing=False): 422 | pass 423 | 424 | def update(self): 425 | pass 426 | 427 | def get_state(self): 428 | return self.state 429 | 430 | def get_next_waypoint(self): 431 | return self.next_waypoint 432 | 433 | 434 | class DummyAgent(Agent): 435 | color_choices = ['cyan', 'red', 'blue', 'green', 'orange', 'magenta', 'yellow'] 436 | 437 | def __init__(self, env): 438 | super(DummyAgent, self).__init__(env) # sets self.env = env, state = None, next_waypoint = None, and a default color 439 | self.next_waypoint = random.choice(Environment.valid_actions[1:]) 440 | self.color = random.choice(self.color_choices) 441 | 442 | def update(self): 443 | """ Update a DummyAgent to move randomly under legal traffic laws. """ 444 | 445 | inputs = self.env.sense(self) 446 | 447 | # Check if the chosen waypoint is safe to move to. 448 | action_okay = True 449 | if self.next_waypoint == 'right': 450 | if inputs['light'] == 'red' and inputs['left'] == 'forward': 451 | action_okay = False 452 | elif self.next_waypoint == 'forward': 453 | if inputs['light'] == 'red': 454 | action_okay = False 455 | elif self.next_waypoint == 'left': 456 | if inputs['light'] == 'red' or (inputs['oncoming'] == 'forward' or inputs['oncoming'] == 'right'): 457 | action_okay = False 458 | 459 | # Move to the next waypoint and choose a new one. 460 | action = None 461 | if action_okay: 462 | action = self.next_waypoint 463 | self.next_waypoint = random.choice(Environment.valid_actions[1:]) 464 | reward = self.env.act(self, action) -------------------------------------------------------------------------------- /smartcab/smartcab/environment.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab/environment.pyc -------------------------------------------------------------------------------- /smartcab/smartcab/planner.py: -------------------------------------------------------------------------------- 1 | import random 2 | 3 | class RoutePlanner(object): 4 | """ Complex route planner that is meant for a perpendicular grid network. """ 5 | 6 | def __init__(self, env, agent): 7 | self.env = env 8 | self.agent = agent 9 | self.destination = None 10 | 11 | def route_to(self, destination=None): 12 | """ Select the destination if one is provided, otherwise choose a random intersection. """ 13 | 14 | self.destination = destination if destination is not None else random.choice(self.env.intersections.keys()) 15 | 16 | def next_waypoint(self): 17 | """ Creates the next waypoint based on current heading, location, 18 | intended destination and L1 distance from destination. """ 19 | 20 | # Collect global location details 21 | bounds = self.env.grid_size 22 | location = self.env.agent_states[self.agent]['location'] 23 | heading = self.env.agent_states[self.agent]['heading'] 24 | 25 | delta_a = (self.destination[0] - location[0], self.destination[1] - location[1]) 26 | delta_b = (bounds[0] + delta_a[0] if delta_a[0] <= 0 else delta_a[0] - bounds[0], \ 27 | bounds[1] + delta_a[1] if delta_a[1] <= 0 else delta_a[1] - bounds[1]) 28 | 29 | # Calculate true difference in location based on world-wrap 30 | # This will pre-determine the need for U-turns from improper headings 31 | dx = delta_a[0] if abs(delta_a[0]) < abs(delta_b[0]) else delta_b[0] 32 | dy = delta_a[1] if abs(delta_a[1]) < abs(delta_b[1]) else delta_b[1] 33 | 34 | # First check if destination is at location 35 | if dx == 0 and dy == 0: 36 | return None 37 | 38 | # Next check if destination is cardinally East or West of location 39 | elif dx != 0: 40 | 41 | if dx * heading[0] > 0: # Heading the correct East or West direction 42 | return 'forward' 43 | elif dx * heading[0] < 0 and heading[0] < 0: # Heading West, destination East 44 | if dy > 0: # Destination also to the South 45 | return 'left' 46 | else: 47 | return 'right' 48 | elif dx * heading[0] < 0 and heading[0] > 0: # Heading East, destination West 49 | if dy < 0: # Destination also to the North 50 | return 'left' 51 | else: 52 | return 'right' 53 | elif dx * heading[1] > 0: # Heading North destination West; Heading South destination East 54 | return 'left' 55 | else: 56 | return 'right' 57 | 58 | # Finally, check if destination is cardinally North or South of location 59 | elif dy != 0: 60 | 61 | if dy * heading[1] > 0: # Heading the correct North or South direction 62 | return 'forward' 63 | elif dy * heading[1] < 0 and heading[1] < 0: # Heading North, destination South 64 | if dx < 0: # Destination also to the West 65 | return 'left' 66 | else: 67 | return 'right' 68 | elif dy * heading[1] < 0 and heading[1] > 0: # Heading South, destination North 69 | if dx > 0: # Destination also to the East 70 | return 'left' 71 | else: 72 | return 'right' 73 | elif dy * heading[0] > 0: # Heading West destination North; Heading East destination South 74 | return 'right' 75 | else: 76 | return 'left' -------------------------------------------------------------------------------- /smartcab/smartcab/planner.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab/planner.pyc -------------------------------------------------------------------------------- /smartcab/smartcab/simulator.py: -------------------------------------------------------------------------------- 1 | ########################################### 2 | # Suppress matplotlib user warnings 3 | # Necessary for newer version of matplotlib 4 | import warnings 5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib") 6 | ########################################### 7 | 8 | import os 9 | import time 10 | import random 11 | import importlib 12 | import csv 13 | 14 | class Simulator(object): 15 | """Simulates agents in a dynamic smartcab environment. 16 | 17 | Uses PyGame to display GUI, if available. 18 | """ 19 | 20 | colors = { 21 | 'black' : ( 0, 0, 0), 22 | 'white' : (255, 255, 255), 23 | 'red' : (255, 0, 0), 24 | 'green' : ( 0, 255, 0), 25 | 'dgreen' : ( 0, 228, 0), 26 | 'blue' : ( 0, 0, 255), 27 | 'cyan' : ( 0, 200, 200), 28 | 'magenta' : (200, 0, 200), 29 | 'yellow' : (255, 255, 0), 30 | 'mustard' : (200, 200, 0), 31 | 'orange' : (255, 128, 0), 32 | 'maroon' : (200, 0, 0), 33 | 'crimson' : (128, 0, 0), 34 | 'gray' : (155, 155, 155) 35 | } 36 | 37 | def __init__(self, env, size=None, update_delay=2.0, display=True, log_metrics=False, optimized=False): 38 | self.env = env 39 | self.size = size if size is not None else ((self.env.grid_size[0] + 1) * self.env.block_size, (self.env.grid_size[1] + 2) * self.env.block_size) 40 | self.width, self.height = self.size 41 | self.road_width = 44 42 | 43 | self.bg_color = self.colors['gray'] 44 | self.road_color = self.colors['black'] 45 | self.line_color = self.colors['mustard'] 46 | self.boundary = self.colors['black'] 47 | self.stop_color = self.colors['crimson'] 48 | 49 | self.quit = False 50 | self.start_time = None 51 | self.current_time = 0.0 52 | self.last_updated = 0.0 53 | self.update_delay = update_delay # duration between each step (in seconds) 54 | 55 | self.display = display 56 | if self.display: 57 | try: 58 | self.pygame = importlib.import_module('pygame') 59 | self.pygame.init() 60 | self.screen = self.pygame.display.set_mode(self.size) 61 | self._logo = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "logo.png")), (self.road_width, self.road_width)) 62 | 63 | self._ew = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "east-west.png")), (self.road_width, self.road_width)) 64 | self._ns = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "north-south.png")), (self.road_width, self.road_width)) 65 | 66 | self.frame_delay = max(1, int(self.update_delay * 1000)) # delay between GUI frames in ms (min: 1) 67 | self.agent_sprite_size = (32, 32) 68 | self.primary_agent_sprite_size = (42, 42) 69 | self.agent_circle_radius = 20 # radius of circle, when using simple representation 70 | for agent in self.env.agent_states: 71 | if agent.color == 'white': 72 | agent._sprite = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "car-{}.png".format(agent.color))), self.primary_agent_sprite_size) 73 | else: 74 | agent._sprite = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "car-{}.png".format(agent.color))), self.agent_sprite_size) 75 | agent._sprite_size = (agent._sprite.get_width(), agent._sprite.get_height()) 76 | 77 | self.font = self.pygame.font.Font(None, 20) 78 | self.paused = False 79 | except ImportError as e: 80 | self.display = False 81 | print "Simulator.__init__(): Unable to import pygame; display disabled.\n{}: {}".format(e.__class__.__name__, e) 82 | except Exception as e: 83 | self.display = False 84 | print "Simulator.__init__(): Error initializing GUI objects; display disabled.\n{}: {}".format(e.__class__.__name__, e) 85 | 86 | # Setup metrics to report 87 | self.log_metrics = log_metrics 88 | self.optimized = optimized 89 | 90 | if self.log_metrics: 91 | a = self.env.primary_agent 92 | 93 | # Set log files 94 | if a.learning: 95 | if self.optimized: # Whether the user is optimizing the parameters and decay functions 96 | self.log_filename = os.path.join("logs", "sim_improved-learning.csv") 97 | self.table_filename = os.path.join("logs","sim_improved-learning.txt") 98 | else: 99 | self.log_filename = os.path.join("logs", "sim_default-learning.csv") 100 | self.table_filename = os.path.join("logs","sim_default-learning.txt") 101 | 102 | self.table_file = open(self.table_filename, 'wb') 103 | else: 104 | self.log_filename = os.path.join("logs", "sim_no-learning.csv") 105 | 106 | self.log_fields = ['trial', 'testing', 'parameters', 'initial_deadline', 'final_deadline', 'net_reward', 'actions', 'success'] 107 | self.log_file = open(self.log_filename, 'wb') 108 | self.log_writer = csv.DictWriter(self.log_file, fieldnames=self.log_fields) 109 | self.log_writer.writeheader() 110 | 111 | def run(self, tolerance=0.05, n_test=0): 112 | """ Run a simulation of the environment. 113 | 114 | 'tolerance' is the minimum epsilon necessary to begin testing (if enabled) 115 | 'n_test' is the number of testing trials simulated 116 | 117 | Note that the minimum number of training trials is always 20. """ 118 | 119 | self.quit = False 120 | 121 | # Get the primary agent 122 | a = self.env.primary_agent 123 | 124 | total_trials = 1 125 | testing = False 126 | trial = 1 127 | 128 | while True: 129 | 130 | # Flip testing switch 131 | if not testing: 132 | if total_trials > 20: # Must complete minimum 20 training trials 133 | if a.learning: 134 | if a.epsilon < tolerance: # assumes epsilon decays to 0 135 | testing = True 136 | trial = 1 137 | else: 138 | testing = True 139 | trial = 1 140 | 141 | # Break if we've reached the limit of testing trials 142 | else: 143 | if trial > n_test: 144 | break 145 | 146 | # Pretty print to terminal 147 | print 148 | print "/-------------------------" 149 | if testing: 150 | print "| Testing trial {}".format(trial) 151 | else: 152 | print "| Training trial {}".format(trial) 153 | 154 | print "\-------------------------" 155 | print 156 | 157 | self.env.reset(testing) 158 | self.current_time = 0.0 159 | self.last_updated = 0.0 160 | self.start_time = time.time() 161 | while True: 162 | try: 163 | # Update current time 164 | self.current_time = time.time() - self.start_time 165 | 166 | # Handle GUI events 167 | if self.display: 168 | for event in self.pygame.event.get(): 169 | if event.type == self.pygame.QUIT: 170 | self.quit = True 171 | elif event.type == self.pygame.KEYDOWN: 172 | if event.key == 27: # Esc 173 | self.quit = True 174 | elif event.unicode == u' ': 175 | self.paused = True 176 | 177 | if self.paused: 178 | self.pause() 179 | 180 | # Update environment 181 | if self.current_time - self.last_updated >= self.update_delay: 182 | self.env.step() 183 | self.last_updated = self.current_time 184 | 185 | # Render text 186 | self.render_text(trial, testing) 187 | 188 | # Render GUI and sleep 189 | if self.display: 190 | self.render(trial, testing) 191 | self.pygame.time.wait(self.frame_delay) 192 | 193 | except KeyboardInterrupt: 194 | self.quit = True 195 | finally: 196 | if self.quit or self.env.done: 197 | break 198 | 199 | if self.quit: 200 | break 201 | 202 | # Collect metrics from trial 203 | if self.log_metrics: 204 | self.log_writer.writerow({ 205 | 'trial': trial, 206 | 'testing': self.env.trial_data['testing'], 207 | 'parameters': self.env.trial_data['parameters'], 208 | 'initial_deadline': self.env.trial_data['initial_deadline'], 209 | 'final_deadline': self.env.trial_data['final_deadline'], 210 | 'net_reward': self.env.trial_data['net_reward'], 211 | 'actions': self.env.trial_data['actions'], 212 | 'success': self.env.trial_data['success'] 213 | }) 214 | 215 | # Trial finished 216 | if self.env.success == True: 217 | print "\nTrial Completed!" 218 | print "Agent reached the destination." 219 | else: 220 | print "\nTrial Aborted!" 221 | print "Agent did not reach the destination." 222 | 223 | # Increment 224 | total_trials = total_trials + 1 225 | trial = trial + 1 226 | 227 | # Clean up 228 | if self.log_metrics: 229 | 230 | if a.learning: 231 | f = self.table_file 232 | 233 | f.write("/-----------------------------------------\n") 234 | f.write("| State-action rewards from Q-Learning\n") 235 | f.write("\-----------------------------------------\n\n") 236 | 237 | for state in a.Q: 238 | f.write("{}\n".format(state)) 239 | for action, reward in a.Q[state].iteritems(): 240 | f.write(" -- {} : {:.2f}\n".format(action, reward)) 241 | f.write("\n") 242 | self.table_file.close() 243 | 244 | self.log_file.close() 245 | 246 | print "\nSimulation ended. . . " 247 | 248 | # Report final metrics 249 | if self.display: 250 | self.pygame.display.quit() # shut down pygame 251 | 252 | def render_text(self, trial, testing=False): 253 | """ This is the non-GUI render display of the simulation. 254 | Simulated trial data will be rendered in the terminal/command prompt. """ 255 | 256 | status = self.env.step_data 257 | if status and status['waypoint'] is not None: # Continuing the trial 258 | 259 | # Previous State 260 | if status['state']: 261 | print "Agent previous state: {}".format(status['state']) 262 | else: 263 | print "!! Agent state not been updated!" 264 | 265 | # Result 266 | if status['violation'] == 0: # Legal 267 | if status['waypoint'] == status['action']: # Followed waypoint 268 | print "Agent followed the waypoint {}. (rewarded {:.2f})".format(status['action'], status['reward']) 269 | elif status['action'] == None: 270 | if status['light'] == 'red': # Stuck at red light 271 | print "Agent properly idled at a red light. (rewarded {:.2f})".format(status['reward']) 272 | else: 273 | print "Agent idled at a green light with oncoming traffic. (rewarded {:.2f})".format(status['reward']) 274 | else: # Did not follow waypoint 275 | print "Agent drove {} instead of {}. (rewarded {:.2f})".format(status['action'], status['waypoint'], status['reward']) 276 | else: # Illegal 277 | if status['violation'] == 1: # Minor violation 278 | print "Agent idled at a green light with no oncoming traffic. (rewarded {:.2f})".format(status['reward']) 279 | elif status['violation'] == 2: # Major violation 280 | print "Agent attempted driving {} through a red light. (rewarded {:.2f})".format(status['action'], status['reward']) 281 | elif status['violation'] == 3: # Minor accident 282 | print "Agent attempted driving {} through traffic and cause a minor accident. (rewarded {:.2f})".format(status['action'], status['reward']) 283 | elif status['violation'] == 4: # Major accident 284 | print "Agent attempted driving {} through a red light with traffic and cause a major accident. (rewarded {:.2f})".format(status['action'], status['reward']) 285 | 286 | # Time Remaining 287 | if self.env.enforce_deadline: 288 | time = (status['deadline'] - 1) * 100.0 / (status['t'] + status['deadline']) 289 | print "{:.0f}% of time remaining to reach destination.".format(time) 290 | else: 291 | print "Agent not enforced to meet deadline." 292 | 293 | # Starting new trial 294 | else: 295 | a = self.env.primary_agent 296 | print "Simulating trial. . . " 297 | if a.learning: 298 | print "epsilon = {:.4f}; alpha = {:.4f}".format(a.epsilon, a.alpha) 299 | else: 300 | print "Agent not set to learn." 301 | 302 | 303 | def render(self, trial, testing=False): 304 | """ This is the GUI render display of the simulation. 305 | Supplementary trial data can be found from render_text. """ 306 | 307 | # Reset the screen. 308 | self.screen.fill(self.bg_color) 309 | 310 | # Draw elements 311 | # * Static elements 312 | 313 | # Boundary 314 | self.pygame.draw.rect(self.screen, self.boundary, ((self.env.bounds[0] - self.env.hang)*self.env.block_size, (self.env.bounds[1]-self.env.hang)*self.env.block_size, (self.env.bounds[2] + self.env.hang/3)*self.env.block_size, (self.env.bounds[3] - 1 + self.env.hang/3)*self.env.block_size), 4) 315 | 316 | for road in self.env.roads: 317 | # Road 318 | self.pygame.draw.line(self.screen, self.road_color, (road[0][0] * self.env.block_size, road[0][1] * self.env.block_size), (road[1][0] * self.env.block_size, road[1][1] * self.env.block_size), self.road_width) 319 | # Center line 320 | self.pygame.draw.line(self.screen, self.line_color, (road[0][0] * self.env.block_size, road[0][1] * self.env.block_size), (road[1][0] * self.env.block_size, road[1][1] * self.env.block_size), 2) 321 | 322 | for intersection, traffic_light in self.env.intersections.iteritems(): 323 | self.pygame.draw.circle(self.screen, self.road_color, (intersection[0] * self.env.block_size, intersection[1] * self.env.block_size), self.road_width/2) 324 | 325 | if traffic_light.state: # North-South is open 326 | self.screen.blit(self._ns, 327 | self.pygame.rect.Rect(intersection[0]*self.env.block_size - self.road_width/2, intersection[1]*self.env.block_size - self.road_width/2, intersection[0]*self.env.block_size + self.road_width, intersection[1]*self.env.block_size + self.road_width/2)) 328 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size - self.road_width/2), (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size + self.road_width/2), 2) 329 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size + self.road_width/2 + 1, intersection[1] * self.env.block_size - self.road_width/2), (intersection[0] * self.env.block_size + self.road_width/2 + 1, intersection[1] * self.env.block_size + self.road_width/2), 2) 330 | else: 331 | self.screen.blit(self._ew, 332 | self.pygame.rect.Rect(intersection[0]*self.env.block_size - self.road_width/2, intersection[1]*self.env.block_size - self.road_width/2, intersection[0]*self.env.block_size + self.road_width, intersection[1]*self.env.block_size + self.road_width/2)) 333 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size - self.road_width/2), (intersection[0] * self.env.block_size + self.road_width/2, intersection[1] * self.env.block_size - self.road_width/2), 2) 334 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size + self.road_width/2, intersection[1] * self.env.block_size + self.road_width/2 + 1), (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size + self.road_width/2 + 1), 2) 335 | 336 | # * Dynamic elements 337 | self.font = self.pygame.font.Font(None, 20) 338 | for agent, state in self.env.agent_states.iteritems(): 339 | # Compute precise agent location here (back from the intersection some) 340 | agent_offset = (2 * state['heading'][0] * self.agent_circle_radius + self.agent_circle_radius * state['heading'][1] * 0.5, \ 341 | 2 * state['heading'][1] * self.agent_circle_radius - self.agent_circle_radius * state['heading'][0] * 0.5) 342 | 343 | 344 | agent_pos = (state['location'][0] * self.env.block_size - agent_offset[0], state['location'][1] * self.env.block_size - agent_offset[1]) 345 | agent_color = self.colors[agent.color] 346 | 347 | if hasattr(agent, '_sprite') and agent._sprite is not None: 348 | # Draw agent sprite (image), properly rotated 349 | rotated_sprite = agent._sprite if state['heading'] == (1, 0) else self.pygame.transform.rotate(agent._sprite, 180 if state['heading'][0] == -1 else state['heading'][1] * -90) 350 | self.screen.blit(rotated_sprite, 351 | self.pygame.rect.Rect(agent_pos[0] - agent._sprite_size[0] / 2, agent_pos[1] - agent._sprite_size[1] / 2, 352 | agent._sprite_size[0], agent._sprite_size[1])) 353 | else: 354 | # Draw simple agent (circle with a short line segment poking out to indicate heading) 355 | self.pygame.draw.circle(self.screen, agent_color, agent_pos, self.agent_circle_radius) 356 | self.pygame.draw.line(self.screen, agent_color, agent_pos, state['location'], self.road_width) 357 | 358 | 359 | if state['destination'] is not None: 360 | self.screen.blit(self._logo, 361 | self.pygame.rect.Rect(state['destination'][0] * self.env.block_size - self.road_width/2, \ 362 | state['destination'][1]*self.env.block_size - self.road_width/2, \ 363 | state['destination'][0]*self.env.block_size + self.road_width/2, \ 364 | state['destination'][1]*self.env.block_size + self.road_width/2)) 365 | 366 | # * Overlays 367 | self.font = self.pygame.font.Font(None, 50) 368 | if testing: 369 | self.screen.blit(self.font.render("Testing Trial %s"%(trial), True, self.colors['black'], self.bg_color), (10, 10)) 370 | else: 371 | self.screen.blit(self.font.render("Training Trial %s"%(trial), True, self.colors['black'], self.bg_color), (10, 10)) 372 | 373 | self.font = self.pygame.font.Font(None, 30) 374 | 375 | # Status text about each step 376 | status = self.env.step_data 377 | if status: 378 | 379 | # Previous State 380 | if status['state']: 381 | self.screen.blit(self.font.render("Previous State: {}".format(status['state']), True, self.colors['white'], self.bg_color), (350, 10)) 382 | if not status['state']: 383 | self.screen.blit(self.font.render("!! Agent state not updated!", True, self.colors['maroon'], self.bg_color), (350, 10)) 384 | 385 | # Action 386 | if status['violation'] == 0: # Legal 387 | if status['action'] == None: 388 | self.screen.blit(self.font.render("No action taken. (rewarded {:.2f})".format(status['reward']), True, self.colors['dgreen'], self.bg_color), (350, 40)) 389 | else: 390 | self.screen.blit(self.font.render("Agent drove {}. (rewarded {:.2f})".format(status['action'], status['reward']), True, self.colors['dgreen'], self.bg_color), (350, 40)) 391 | else: # Illegal 392 | if status['action'] == None: 393 | self.screen.blit(self.font.render("No action taken. (rewarded {:.2f})".format(status['reward']), True, self.colors['maroon'], self.bg_color), (350, 40)) 394 | else: 395 | self.screen.blit(self.font.render("{} attempted (rewarded {:.2f})".format(status['action'], status['reward']), True, self.colors['maroon'], self.bg_color), (350, 40)) 396 | 397 | # Result 398 | if status['violation'] == 0: # Legal 399 | if status['waypoint'] == status['action']: # Followed waypoint 400 | self.screen.blit(self.font.render("Agent followed the waypoint!", True, self.colors['dgreen'], self.bg_color), (350, 70)) 401 | elif status['action'] == None: 402 | if status['light'] == 'red': # Stuck at a red light 403 | self.screen.blit(self.font.render("Agent idled at a red light!", True, self.colors['dgreen'], self.bg_color), (350, 70)) 404 | else: 405 | self.screen.blit(self.font.render("Agent idled at a green light with oncoming traffic.", True, self.colors['mustard'], self.bg_color), (350, 70)) 406 | else: # Did not follow waypoint 407 | self.screen.blit(self.font.render("Agent did not follow the waypoint.", True, self.colors['mustard'], self.bg_color), (350, 70)) 408 | else: # Illegal 409 | if status['violation'] == 1: # Minor violation 410 | self.screen.blit(self.font.render("There was a green light with no oncoming traffic.", True, self.colors['maroon'], self.bg_color), (350, 70)) 411 | elif status['violation'] == 2: # Major violation 412 | self.screen.blit(self.font.render("There was a red light with no traffic.", True, self.colors['maroon'], self.bg_color), (350, 70)) 413 | elif status['violation'] == 3: # Minor accident 414 | self.screen.blit(self.font.render("There was traffic with right-of-way.", True, self.colors['maroon'], self.bg_color), (350, 70)) 415 | elif status['violation'] == 4: # Major accident 416 | self.screen.blit(self.font.render("There was a red light with traffic.", True, self.colors['maroon'], self.bg_color), (350, 70)) 417 | 418 | # Time Remaining 419 | if self.env.enforce_deadline: 420 | time = (status['deadline'] - 1) * 100.0 / (status['t'] + status['deadline']) 421 | self.screen.blit(self.font.render("{:.0f}% of time remaining to reach destination.".format(time), True, self.colors['black'], self.bg_color), (350, 100)) 422 | else: 423 | self.screen.blit(self.font.render("Agent not enforced to meet deadline.", True, self.colors['black'], self.bg_color), (350, 100)) 424 | 425 | # Denote whether a trial was a success or failure 426 | if (state['destination'] != state['location'] and state['deadline'] > 0) or (self.env.enforce_deadline is not True and state['destination'] != state['location']): 427 | self.font = self.pygame.font.Font(None, 40) 428 | if self.env.success == True: 429 | self.screen.blit(self.font.render("Previous Trial: Success", True, self.colors['dgreen'], self.bg_color), (10, 50)) 430 | if self.env.success == False: 431 | self.screen.blit(self.font.render("Previous Trial: Failure", True, self.colors['maroon'], self.bg_color), (10, 50)) 432 | 433 | if self.env.primary_agent.learning: 434 | self.font = self.pygame.font.Font(None, 22) 435 | self.screen.blit(self.font.render("epsilon = {:.4f}".format(self.env.primary_agent.epsilon), True, self.colors['black'], self.bg_color), (10, 80)) 436 | self.screen.blit(self.font.render("alpha = {:.4f}".format(self.env.primary_agent.alpha), True, self.colors['black'], self.bg_color), (10, 95)) 437 | 438 | # Reset status text 439 | else: 440 | self.pygame.rect.Rect(350, 10, self.width, 200) 441 | self.font = self.pygame.font.Font(None, 40) 442 | self.screen.blit(self.font.render("Simulating trial. . .", True, self.colors['white'], self.bg_color), (400, 60)) 443 | 444 | 445 | # Flip buffers 446 | self.pygame.display.flip() 447 | 448 | def pause(self): 449 | """ When the GUI is enabled, this function will pause the simulation. """ 450 | 451 | abs_pause_time = time.time() 452 | self.font = self.pygame.font.Font(None, 30) 453 | pause_text = "Simulation Paused. Press any key to continue. . ." 454 | self.screen.blit(self.font.render(pause_text, True, self.colors['red'], self.bg_color), (400, self.height - 30)) 455 | self.pygame.display.flip() 456 | print pause_text 457 | while self.paused: 458 | for event in self.pygame.event.get(): 459 | if event.type == self.pygame.KEYDOWN: 460 | self.paused = False 461 | self.pygame.time.wait(self.frame_delay) 462 | self.screen.blit(self.font.render(pause_text, True, self.bg_color, self.bg_color), (400, self.height - 30)) 463 | self.start_time += (time.time() - abs_pause_time) 464 | -------------------------------------------------------------------------------- /smartcab/smartcab/simulator.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab/simulator.pyc -------------------------------------------------------------------------------- /smartcab/smartcab2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab2.zip -------------------------------------------------------------------------------- /smartcab/smartcab3.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab3.zip -------------------------------------------------------------------------------- /smartcab/visuals.py: -------------------------------------------------------------------------------- 1 | ########################################### 2 | # Suppress matplotlib user warnings 3 | # Necessary for newer version of matplotlib 4 | import warnings 5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib") 6 | ########################################### 7 | # 8 | # Display inline matplotlib plots with IPython 9 | from IPython import get_ipython 10 | get_ipython().run_line_magic('matplotlib', 'inline') 11 | ########################################### 12 | 13 | import matplotlib.pyplot as plt 14 | import numpy as np 15 | import pandas as pd 16 | import os 17 | import ast 18 | 19 | 20 | def calculate_safety(data): 21 | """ Calculates the safety rating of the smartcab during testing. """ 22 | 23 | good_ratio = data['good_actions'].sum() * 1.0 / \ 24 | (data['initial_deadline'] - data['final_deadline']).sum() 25 | 26 | if good_ratio == 1: # Perfect driving 27 | return ("A+", "green") 28 | else: # Imperfect driving 29 | if data['actions'].apply(lambda x: ast.literal_eval(x)[4]).sum() > 0: # Major accident 30 | return ("F", "red") 31 | elif data['actions'].apply(lambda x: ast.literal_eval(x)[3]).sum() > 0: # Minor accident 32 | return ("D", "#EEC700") 33 | elif data['actions'].apply(lambda x: ast.literal_eval(x)[2]).sum() > 0: # Major violation 34 | return ("C", "#EEC700") 35 | else: # Minor violation 36 | minor = data['actions'].apply(lambda x: ast.literal_eval(x)[1]).sum() 37 | if minor >= len(data)/2: # Minor violation in at least half of the trials 38 | return ("B", "green") 39 | else: 40 | return ("A", "green") 41 | 42 | 43 | def calculate_reliability(data): 44 | """ Calculates the reliability rating of the smartcab during testing. """ 45 | 46 | success_ratio = data['success'].sum() * 1.0 / len(data) 47 | 48 | if success_ratio == 1: # Always meets deadline 49 | return ("A+", "green") 50 | else: 51 | if success_ratio >= 0.90: 52 | return ("A", "green") 53 | elif success_ratio >= 0.80: 54 | return ("B", "green") 55 | elif success_ratio >= 0.70: 56 | return ("C", "#EEC700") 57 | elif success_ratio >= 0.60: 58 | return ("D", "#EEC700") 59 | else: 60 | return ("F", "red") 61 | 62 | 63 | def plot_trials(csv): 64 | """ Plots the data from logged metrics during a simulation.""" 65 | 66 | data = pd.read_csv(os.path.join("logs", csv)) 67 | 68 | if len(data) < 10: 69 | print "Not enough data collected to create a visualization." 70 | print "At least 20 trials are required." 71 | return 72 | 73 | # Create additional features 74 | data['average_reward'] = (data['net_reward'] / (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean() 75 | data['reliability_rate'] = (data['success']*100).rolling(window=10, center=False).mean() # compute avg. net reward with window=10 76 | data['good_actions'] = data['actions'].apply(lambda x: ast.literal_eval(x)[0]) 77 | data['good'] = (data['good_actions'] * 1.0 / \ 78 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean() 79 | data['minor'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[1]) * 1.0 / \ 80 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean() 81 | data['major'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[2]) * 1.0 / \ 82 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean() 83 | data['minor_acc'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[3]) * 1.0 / \ 84 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean() 85 | data['major_acc'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[4]) * 1.0 / \ 86 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean() 87 | data['epsilon'] = data['parameters'].apply(lambda x: ast.literal_eval(x)['e']) 88 | data['alpha'] = data['parameters'].apply(lambda x: ast.literal_eval(x)['a']) 89 | 90 | 91 | # Create training and testing subsets 92 | training_data = data[data['testing'] == False] 93 | testing_data = data[data['testing'] == True] 94 | 95 | plt.figure(figsize=(12,8)) 96 | 97 | 98 | ############### 99 | ### Average step reward plot 100 | ############### 101 | 102 | ax = plt.subplot2grid((6,6), (0,3), colspan=3, rowspan=2) 103 | ax.set_title("10-Trial Rolling Average Reward per Action") 104 | ax.set_ylabel("Reward per Action") 105 | ax.set_xlabel("Trial Number") 106 | ax.set_xlim((10, len(training_data))) 107 | 108 | # Create plot-specific data 109 | step = training_data[['trial','average_reward']].dropna() 110 | 111 | ax.axhline(xmin = 0, xmax = 1, y = 0, color = 'black', linestyle = 'dashed') 112 | ax.plot(step['trial'], step['average_reward']) 113 | 114 | 115 | ############### 116 | ### Parameters Plot 117 | ############### 118 | 119 | ax = plt.subplot2grid((6,6), (2,3), colspan=3, rowspan=2) 120 | 121 | # Check whether the agent was expected to learn 122 | if csv != 'sim_no-learning.csv': 123 | ax.set_ylabel("Parameter Value") 124 | ax.set_xlabel("Trial Number") 125 | ax.set_xlim((1, len(training_data))) 126 | ax.set_ylim((0, 1.05)) 127 | 128 | ax.plot(training_data['trial'], training_data['epsilon'], color='blue', label='Exploration factor') 129 | ax.plot(training_data['trial'], training_data['alpha'], color='green', label='Learning factor') 130 | 131 | ax.legend(bbox_to_anchor=(0.5,1.19), fancybox=True, ncol=2, loc='upper center', fontsize=10) 132 | 133 | else: 134 | ax.axis('off') 135 | ax.text(0.52, 0.30, "Simulation completed\nwith learning disabled.", fontsize=24, ha='center', style='italic') 136 | 137 | 138 | ############### 139 | ### Bad Actions Plot 140 | ############### 141 | 142 | actions = training_data[['trial','good', 'minor','major','minor_acc','major_acc']].dropna() 143 | maximum = (1 - actions['good']).values.max() 144 | 145 | ax = plt.subplot2grid((6,6), (0,0), colspan=3, rowspan=4) 146 | ax.set_title("10-Trial Rolling Relative Frequency of Bad Actions") 147 | ax.set_ylabel("Relative Frequency") 148 | ax.set_xlabel("Trial Number") 149 | 150 | ax.set_ylim((0, maximum + 0.01)) 151 | ax.set_xlim((10, len(training_data))) 152 | 153 | ax.set_yticks(np.linspace(0, maximum+0.01, 10)) 154 | 155 | ax.plot(actions['trial'], (1 - actions['good']), color='black', label='Total Bad Actions', linestyle='dotted', linewidth=3) 156 | ax.plot(actions['trial'], actions['minor'], color='orange', label='Minor Violation', linestyle='dashed') 157 | ax.plot(actions['trial'], actions['major'], color='orange', label='Major Violation', linewidth=2) 158 | ax.plot(actions['trial'], actions['minor_acc'], color='red', label='Minor Accident', linestyle='dashed') 159 | ax.plot(actions['trial'], actions['major_acc'], color='red', label='Major Accident', linewidth=2) 160 | 161 | ax.legend(loc='upper right', fancybox=True, fontsize=10) 162 | 163 | 164 | ############### 165 | ### Rolling Success-Rate plot 166 | ############### 167 | 168 | ax = plt.subplot2grid((6,6), (4,0), colspan=4, rowspan=2) 169 | ax.set_title("10-Trial Rolling Rate of Reliability") 170 | ax.set_ylabel("Rate of Reliability") 171 | ax.set_xlabel("Trial Number") 172 | ax.set_xlim((10, len(training_data))) 173 | ax.set_ylim((-5, 105)) 174 | ax.set_yticks(np.arange(0, 101, 20)) 175 | ax.set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%']) 176 | 177 | # Create plot-specific data 178 | trial = training_data.dropna()['trial'] 179 | rate = training_data.dropna()['reliability_rate'] 180 | 181 | # Rolling success rate 182 | ax.plot(trial, rate, label="Reliability Rate", color='blue') 183 | 184 | 185 | ############### 186 | ### Test results 187 | ############### 188 | 189 | ax = plt.subplot2grid((6,6), (4,4), colspan=2, rowspan=2) 190 | ax.axis('off') 191 | 192 | if len(testing_data) > 0: 193 | safety_rating, safety_color = calculate_safety(testing_data) 194 | reliability_rating, reliability_color = calculate_reliability(testing_data) 195 | 196 | # Write success rate 197 | ax.text(0.40, .9, "{} testing trials simulated.".format(len(testing_data)), fontsize=14, ha='center') 198 | ax.text(0.40, 0.7, "Safety Rating:", fontsize=16, ha='center') 199 | ax.text(0.40, 0.42, "{}".format(safety_rating), fontsize=40, ha='center', color=safety_color) 200 | ax.text(0.40, 0.27, "Reliability Rating:", fontsize=16, ha='center') 201 | ax.text(0.40, 0, "{}".format(reliability_rating), fontsize=40, ha='center', color=reliability_color) 202 | 203 | else: 204 | ax.text(0.36, 0.30, "Simulation completed\nwith testing disabled.", fontsize=20, ha='center', style='italic') 205 | 206 | plt.tight_layout() 207 | plt.show() 208 | -------------------------------------------------------------------------------- /student_intervention/README.md: -------------------------------------------------------------------------------- 1 | # 项目 2: 监督学习 2 | ## 搭建一个学生干预系统 3 | 4 | ### 安装 5 | 6 | 这个项目要求使用 **Python 2.7** 并且需要安装下面这些python包: 7 | 8 | - [NumPy](http://www.numpy.org/) 9 | - [pandas](http://pandas.pydata.org) 10 | - [scikit-learn](http://scikit-learn.org/stable/) 11 | 12 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://jupyter.org/) 13 | 14 | 优达学城推荐学生安装[Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。 15 | 16 | 17 | ### 代码 18 | 19 | 初始代码包含在 `student_intervention.ipynb` 这个notebook文件中。这里面有一些代码已经实现好来帮助你开始项目,但是为了完成项目,你还需要实现附加的功能。 20 | 21 | ### 运行 22 | 23 | 在命令行中,确保当前目录为 `student_intervention/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令: 24 | 25 | ```jupyter notebook student_intervention.ipynb``` 26 | 27 | ​这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。 28 | 29 | ## 数据 30 | 31 | ​这个项目的数据包含在 `student-data.csv` 文件中。这个数据集包含以下属性: ​ 32 | 33 | - `school` : 学生的学校(二元特征:值为“GP”或者是“MS”) 34 | - `sex` : 学生的性别(二元特征:“F”表示女性 或者是 “M”表示男性) 35 | - `age` : 学生的年龄(数值特征:从15到22) 36 | - `address`: 学生的家庭住址类型(二元特征:“U”表示城市 或者是 “R”表示农村) 37 | - `famsize`: 家庭大小(二元特征:“LE3”表示小于等于3 或者 “GT3”表示大于3) 38 | - `Pstatus`: 父母共同生活状态(二元特征:“T”表示共同生活 或者是 “A”表示分居) 39 | - `Medu`: 母亲的教育程度 (数值特征:0 - 未受教育, 1 - 小学教育(4年级), 2 - 5年级到9年级, 3 - 中学教育 或者 4 - 更高等级教育) 40 | - `Fedu`: 父亲的教育程度 (数值特征:0 - 未受教育, 1 - 小学教育(4年级), 2 - 5年级到9年级, 3 - 中学教育 或者 4 - 更高等级教育) 41 | - `Mjob` : 母亲的工作 (常量特征: "teacher", "health" 表示和健康看护相关的工作, "services" 表示公务员(比如:行政人员或者警察), "at_home"表示在家, "other"表示其他) 42 | - `Fjob` : 父亲的工作 (常量特征: "teacher", "health" 表示和健康看护相关的工作, "services" 表示公务员(比如:行政人员或者警察), "at_home"表示在家, "other"表示其他) 43 | - `reason` : 选择这所学校的原因 (常量特征:"home"表示离家近, "reputation"表示学校声誉, "course"表示课程偏好 或者 "other"表示其他) 44 | - `guardian` : 学生的监护人 (常量特征:"mother"表示母亲, "father"表示父亲 或者 "other"表示其他) 45 | - `traveltime` : 到学校需要的时间 (数值特征: 1 - 小于15分钟., 2 - 15到30分钟., 3 - 30分钟到1小时, 4 - 大于1小时) 46 | - `studytime`: 每周学习时间 (数值特征: 1 - 小于2个小时, 2 - 2到5个小时, 3 - 5到10个小时, 4 - 大于10个小时) 47 | - `failures`:过去考试失败的次数 (数值特征: n 如果 1<=n<3, 其他 4) 48 | - `schoolsup` : 额外的教育支持 (二元特征: yes 或者 no) 49 | - `famsup` : 家庭教育支持 (二元特征: yes 或者 no) 50 | - `paid` : 和课程有关的其他付费课堂 (数学或者葡萄牙语) (二值特征: yes 或者 no) 51 | - `activities` : 课外活动 (二元特征: yes 或者 no) 52 | - `nursery` : 参加托儿所 (二元特征: yes 或者 no) 53 | - `higher` : 希望得到高等教育(二元特征: yes 或者 no) 54 | - `internet` : 在家是否能够访问网络 (二元特征: yes 或者 no) 55 | - `romantic` : 有没有谈恋爱 (二元特征: yes 或者 no) 56 | - `famrel` : 与家人关系的好坏 (数值特征: 从 1 - 非常差 到 5 - 非常好) 57 | - `freetime` : 放学后的空闲时间(数值特征: 从 1 - 非常少 到 5 - 非常多) 58 | - `goout` : 和朋友出去(数值特征: 从 1 - 非常少 到 5 - 非常多) 59 | - `Dalc` : 工作日饮酒量(数值特征:从 1 - 非常少 到 5 - 非常多) 60 | - `Walc` : 周末饮酒量(数值特征:从 1 - 非常少 到 5 - 非常多) 61 | - `health` : 当前健康状况 (数值特征: 从 1 - 非常差 到 5 - 非常好) 62 | - `absences` :在学校的缺席次数 (数值特征: 从 0 到 93) 63 | - `passed` : 学生是否通过最终的考试 (二元特征: yes 或者 no) 64 | -------------------------------------------------------------------------------- /student_intervention/student_intervention.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# 机器学习工程师纳米学位\n", 8 | "## 监督学习\n", 9 | "## 项目 2: 搭建一个学生干预系统" 10 | ] 11 | }, 12 | { 13 | "cell_type": "markdown", 14 | "metadata": {}, 15 | "source": [ 16 | "欢迎来到机器学习工程师纳米学位的第二个项目!在此文件中,有些示例代码已经提供给你,但你还需要实现更多的功能让项目成功运行。除非有明确要求,你无须修改任何已给出的代码。以**'练习'**开始的标题表示接下来的代码部分中有你必须要实现的功能。每一部分都会有详细的指导,需要实现的部分也会在注释中以**'TODO'**标出。请仔细阅读所有的提示!\n", 17 | "\n", 18 | "除了实现代码外,你还**必须**回答一些与项目和你的实现有关的问题。每一个需要你回答的问题都会以**'问题 X'**为标题。请仔细阅读每个问题,并且在问题后的**'回答'**文字框中写出完整的答案。我们将根据你对问题的回答和撰写代码所实现的功能来对你提交的项目进行评分。\n", 19 | "\n", 20 | ">**提示:**Code 和 Markdown 区域可通过 **Shift + Enter** 快捷键运行。此外,Markdown可以通过双击进入编辑模式。" 21 | ] 22 | }, 23 | { 24 | "cell_type": "markdown", 25 | "metadata": {}, 26 | "source": [ 27 | "### 问题 1 - 分类 vs. 回归\n", 28 | "*在这个项目中你的任务是找出那些如果不给予帮助,最终可能无法毕业的学生。你觉得这个问题是哪种类型的监督学习问题,是分类问题还是回归问题?为什么?*" 29 | ] 30 | }, 31 | { 32 | "cell_type": "markdown", 33 | "metadata": {}, 34 | "source": [ 35 | "**答案: **分类问题。因为该项目的目标是找到可能无法毕业的学生,输出的数据是学生能否毕业,也就是把学生分类为毕业的和不能毕业的,结果是不连续的,所以属于分类问题。" 36 | ] 37 | }, 38 | { 39 | "cell_type": "markdown", 40 | "metadata": {}, 41 | "source": [ 42 | "## 分析数据\n", 43 | "运行下面区域的代码以载入学生数据集,以及一些此项目所需的Python库。注意数据集的最后一列`'passed'`是我们的预测的目标(表示学生是毕业了还是没有毕业),其他的列是每个学生的属性。" 44 | ] 45 | }, 46 | { 47 | "cell_type": "code", 48 | "execution_count": 4, 49 | "metadata": { 50 | "collapsed": false 51 | }, 52 | "outputs": [ 53 | { 54 | "name": "stdout", 55 | "output_type": "stream", 56 | "text": [ 57 | "Student data read successfully!\n" 58 | ] 59 | } 60 | ], 61 | "source": [ 62 | "# 载入所需要的库\n", 63 | "import numpy as np\n", 64 | "import pandas as pd\n", 65 | "from time import time\n", 66 | "from sklearn.metrics import f1_score\n", 67 | "\n", 68 | "# 载入学生数据集\n", 69 | "student_data = pd.read_csv(\"student-data.csv\")\n", 70 | "print \"Student data read successfully!\"" 71 | ] 72 | }, 73 | { 74 | "cell_type": "markdown", 75 | "metadata": {}, 76 | "source": [ 77 | "### 练习: 分析数据\n", 78 | "我们首先通过调查数据,以确定有多少学生的信息,并了解这些学生的毕业率。在下面的代码单元中,你需要完成如下的运算:\n", 79 | "- 学生的总数, `n_students`。\n", 80 | "- 每个学生的特征总数, `n_features`。\n", 81 | "- 毕业的学生的数量, `n_passed`。\n", 82 | "- 未毕业的学生的数量, `n_failed`。\n", 83 | "- 班级的毕业率, `grad_rate`, 用百分数表示(%)。\n" 84 | ] 85 | }, 86 | { 87 | "cell_type": "code", 88 | "execution_count": 33, 89 | "metadata": { 90 | "collapsed": false 91 | }, 92 | "outputs": [ 93 | { 94 | "name": "stdout", 95 | "output_type": "stream", 96 | "text": [ 97 | "Total number of students: 395\n", 98 | "Number of features: 30\n", 99 | "Number of students who passed: 265\n", 100 | "Number of students who failed: 130\n", 101 | "Graduation rate of the class: 67.09%\n" 102 | ] 103 | } 104 | ], 105 | "source": [ 106 | "# TODO: 计算学生的数量\n", 107 | "n_students = len(student_data)\n", 108 | "\n", 109 | "# TODO: 计算特征数量\n", 110 | "n_features = len(student_data.columns)-1\n", 111 | "\n", 112 | "# TODO: 计算通过的学生数\n", 113 | "n_passed = sum(student_data['passed']=='yes')\n", 114 | " \n", 115 | "# TODO: 计算未通过的学生数\n", 116 | "n_failed = n_students-n_passed\n", 117 | "\n", 118 | "# TODO: 计算通过率\n", 119 | "grad_rate = float(n_passed/n_students)*100\n", 120 | "\n", 121 | "# 输出结果\n", 122 | "print \"Total number of students: {}\".format(n_students)\n", 123 | "print \"Number of features: {}\".format(n_features)\n", 124 | "print \"Number of students who passed: {}\".format(n_passed)\n", 125 | "print \"Number of students who failed: {}\".format(n_failed)\n", 126 | "print \"Graduation rate of the class: {:.2f}%\".format(grad_rate)" 127 | ] 128 | }, 129 | { 130 | "cell_type": "markdown", 131 | "metadata": {}, 132 | "source": [ 133 | "## 数据准备\n", 134 | "在这个部分中,我们将要为建模、训练和测试准备数据\n", 135 | "### 识别特征和目标列\n", 136 | "你获取的数据中通常都会包含一些非数字的特征,这会导致一些问题,因为大多数的机器学习算法都会期望输入数字特征进行计算。\n", 137 | "\n", 138 | "运行下面的代码单元将学生数据分成特征和目标列看一看他们中是否有非数字特征。" 139 | ] 140 | }, 141 | { 142 | "cell_type": "code", 143 | "execution_count": 25, 144 | "metadata": { 145 | "collapsed": false 146 | }, 147 | "outputs": [ 148 | { 149 | "name": "stdout", 150 | "output_type": "stream", 151 | "text": [ 152 | "Feature columns:\n", 153 | "['school', 'sex', 'age', 'address', 'famsize', 'Pstatus', 'Medu', 'Fedu', 'Mjob', 'Fjob', 'reason', 'guardian', 'traveltime', 'studytime', 'failures', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic', 'famrel', 'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences']\n", 154 | "\n", 155 | "Target column: passed\n", 156 | "\n", 157 | "Feature values:\n", 158 | " school sex age address famsize Pstatus Medu Fedu Mjob Fjob \\\n", 159 | "0 GP F 18 U GT3 A 4 4 at_home teacher \n", 160 | "1 GP F 17 U GT3 T 1 1 at_home other \n", 161 | "2 GP F 15 U LE3 T 1 1 at_home other \n", 162 | "3 GP F 15 U GT3 T 4 2 health services \n", 163 | "4 GP F 16 U GT3 T 3 3 other other \n", 164 | "\n", 165 | " ... higher internet romantic famrel freetime goout Dalc Walc health \\\n", 166 | "0 ... yes no no 4 3 4 1 1 3 \n", 167 | "1 ... yes yes no 5 3 3 1 1 3 \n", 168 | "2 ... yes yes no 4 3 2 2 3 3 \n", 169 | "3 ... yes yes yes 3 2 2 1 1 5 \n", 170 | "4 ... yes no no 4 3 2 1 2 5 \n", 171 | "\n", 172 | " absences \n", 173 | "0 6 \n", 174 | "1 4 \n", 175 | "2 10 \n", 176 | "3 2 \n", 177 | "4 4 \n", 178 | "\n", 179 | "[5 rows x 30 columns]\n" 180 | ] 181 | } 182 | ], 183 | "source": [ 184 | "# 提取特征列\n", 185 | "feature_cols = list(student_data.columns[:-1])\n", 186 | "\n", 187 | "# 提取目标列 ‘passed’\n", 188 | "target_col = student_data.columns[-1] \n", 189 | "\n", 190 | "# 显示列的列表\n", 191 | "print \"Feature columns:\\n{}\".format(feature_cols)\n", 192 | "print \"\\nTarget column: {}\".format(target_col)\n", 193 | "\n", 194 | "# 将数据分割成特征数据和目标数据(即X_all 和 y_all)\n", 195 | "X_all = student_data[feature_cols]\n", 196 | "y_all = student_data[target_col]\n", 197 | "\n", 198 | "# 通过打印前5行显示特征信息\n", 199 | "print \"\\nFeature values:\"\n", 200 | "print X_all.head()" 201 | ] 202 | }, 203 | { 204 | "cell_type": "markdown", 205 | "metadata": {}, 206 | "source": [ 207 | "### 预处理特征列\n", 208 | "\n", 209 | "正如你所见,我们这里有几个非数值的列需要做一定的转换!它们中很多是简单的`yes`/`no`,比如`internet`。这些可以合理地转化为`1`/`0`(二元值,binary)值。\n", 210 | "\n", 211 | "其他的列,如`Mjob`和`Fjob`,有两个以上的值,被称为_分类变量(categorical variables)_。处理这样的列的推荐方法是创建和可能值一样多的列(如:`Fjob_teacher`,`Fjob_other`,`Fjob_services`等),然后将其中一个的值设为`1`另外的设为`0`。\n", 212 | "\n", 213 | "这些创建的列有时候叫做 _虚拟变量(dummy variables)_,我们将用[`pandas.get_dummies()`](http://pandas.pydata.org/pandas-docs/stable/generated/pandas.get_dummies.html?highlight=get_dummies#pandas.get_dummies)函数来完成这个转换。运行下面代码单元的代码来完成这里讨论的预处理步骤。" 214 | ] 215 | }, 216 | { 217 | "cell_type": "code", 218 | "execution_count": 26, 219 | "metadata": { 220 | "collapsed": false 221 | }, 222 | "outputs": [ 223 | { 224 | "name": "stdout", 225 | "output_type": "stream", 226 | "text": [ 227 | "Processed feature columns (48 total features):\n", 228 | "['school_GP', 'school_MS', 'sex_F', 'sex_M', 'age', 'address_R', 'address_U', 'famsize_GT3', 'famsize_LE3', 'Pstatus_A', 'Pstatus_T', 'Medu', 'Fedu', 'Mjob_at_home', 'Mjob_health', 'Mjob_other', 'Mjob_services', 'Mjob_teacher', 'Fjob_at_home', 'Fjob_health', 'Fjob_other', 'Fjob_services', 'Fjob_teacher', 'reason_course', 'reason_home', 'reason_other', 'reason_reputation', 'guardian_father', 'guardian_mother', 'guardian_other', 'traveltime', 'studytime', 'failures', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic', 'famrel', 'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences']\n" 229 | ] 230 | } 231 | ], 232 | "source": [ 233 | "def preprocess_features(X):\n", 234 | " ''' 预处理学生数据,将非数字的二元特征转化成二元值(0或1),将分类的变量转换成虚拟变量\n", 235 | " '''\n", 236 | " \n", 237 | " # 初始化一个用于输出的DataFrame\n", 238 | " output = pd.DataFrame(index = X.index)\n", 239 | "\n", 240 | " # 查看数据的每一个特征列\n", 241 | " for col, col_data in X.iteritems():\n", 242 | " \n", 243 | " # 如果数据是非数字类型,将所有的yes/no替换成1/0\n", 244 | " if col_data.dtype == object:\n", 245 | " col_data = col_data.replace(['yes', 'no'], [1, 0])\n", 246 | "\n", 247 | " # 如果数据类型是类别的(categorical),将它转换成虚拟变量\n", 248 | " if col_data.dtype == object:\n", 249 | " # 例子: 'school' => 'school_GP' and 'school_MS'\n", 250 | " col_data = pd.get_dummies(col_data, prefix = col) \n", 251 | " \n", 252 | " # 收集转换后的列\n", 253 | " output = output.join(col_data)\n", 254 | " \n", 255 | " return output\n", 256 | "\n", 257 | "X_all = preprocess_features(X_all)\n", 258 | "print \"Processed feature columns ({} total features):\\n{}\".format(len(X_all.columns), list(X_all.columns))" 259 | ] 260 | }, 261 | { 262 | "cell_type": "markdown", 263 | "metadata": {}, 264 | "source": [ 265 | "### 实现: 将数据分成训练集和测试集\n", 266 | "现在我们已经将所有的 _分类的(categorical)_ 特征转换成数值了。下一步我们将把数据(包括特征和对应的标签数据)分割成训练集和测试集。在下面的代码单元中,你需要完成下列功能:\n", 267 | "- 随机混洗切分数据(`X_all`, `y_all`) 为训练子集和测试子集。\n", 268 | " - 使用300个数据点作为训练集(约76%),使用95个数据点作为测试集(约24%)。\n", 269 | " - 如果可能的话,为你使用的函数设置一个`random_state`。\n", 270 | " - 将结果存储在`X_train`, `X_test`, `y_train`和 `y_test`中。" 271 | ] 272 | }, 273 | { 274 | "cell_type": "code", 275 | "execution_count": 44, 276 | "metadata": { 277 | "collapsed": false 278 | }, 279 | "outputs": [ 280 | { 281 | "name": "stdout", 282 | "output_type": "stream", 283 | "text": [ 284 | "Training set has 300 samples.\n", 285 | "Testing set has 95 samples.\n" 286 | ] 287 | } 288 | ], 289 | "source": [ 290 | "from sklearn.cross_validation import train_test_split\n", 291 | "\n", 292 | "# TODO:设置训练集的数量\n", 293 | "num_train = 300\n", 294 | "\n", 295 | "# TODO:设置测试集的数量\n", 296 | "num_test = X_all.shape[0] - num_train\n", 297 | "\n", 298 | "# TODO:把数据集混洗和分割成上面定义的训练集和测试集\n", 299 | "X_train, X_test, y_train, y_test = train_test_split(X_all, y_all,test_size=0.24, random_state=42)\n", 300 | "\n", 301 | "# 显示分割的结果\n", 302 | "print \"Training set has {} samples.\".format(X_train.shape[0])\n", 303 | "print \"Testing set has {} samples.\".format(X_test.shape[0])" 304 | ] 305 | }, 306 | { 307 | "cell_type": "markdown", 308 | "metadata": {}, 309 | "source": [ 310 | "## 训练和评价模型\n", 311 | "在这个部分,你将选择3个适合这个问题并且在`scikit-learn`中已有的监督学习的模型。首先你需要说明你选择这三个模型的原因,包括这些数据集有哪些特点,每个模型的优点和缺点各是什么。然后,你需要将这些模型用不同大小的训练集(100个数据点,200个数据点,300个数据点)进行训练,并用F1的值来衡量。你需要制作三个表,每个表要显示训练集大小,训练时间,预测时间,训练集上的F1值和测试集上的F1值(每个模型一个表)。\n", 312 | "\n", 313 | "**这是目前** [`scikit-learn`](http://scikit-learn.org/stable/supervised_learning.html) **里有的监督学习模型,你可以从中选择:**\n", 314 | "- Gaussian Naive Bayes (GaussianNB) 朴素贝叶斯\n", 315 | "- Decision Trees 决策树\n", 316 | "- Ensemble Methods (Bagging, AdaBoost, Random Forest, Gradient Boosting)\n", 317 | "- K-Nearest Neighbors (KNeighbors)\n", 318 | "- Stochastic Gradient Descent (SGDC)\n", 319 | "- Support Vector Machines (SVM) 向量模型机\n", 320 | "- Logistic Regression 逻辑回归" 321 | ] 322 | }, 323 | { 324 | "cell_type": "markdown", 325 | "metadata": {}, 326 | "source": [ 327 | "### 问题 2 - 应用模型\n", 328 | "*列出三个适合这个问题的监督学习算法模型。每一个你选择的模型:*\n", 329 | "\n", 330 | "- 描述一个该模型在真实世界的一个应用场景。(你需要为此做点研究,并给出你的引用出处)\n", 331 | "- 这个模型的优势是什么?他什么情况下表现最好?\n", 332 | "- 这个模型的缺点是什么?什么条件下它表现很差?\n", 333 | "- 根据我们当前数据集的特点,为什么这个模型适合这个问题。" 334 | ] 335 | }, 336 | { 337 | "cell_type": "markdown", 338 | "metadata": {}, 339 | "source": [ 340 | "**回答: **我选择的是GaussianNB, Decision Trees, AdaBoost这三种算法。\n", 341 | "- GaussianNB: 真实应用场景:过滤垃圾邮件(来自Machine Learning in Action | by Peter Harrington 内容4.6) \n", 342 | "- 优势是在数据较少的情况下仍然有效,可以处理多类别问题。对于文档分类它的表现最好\n", 343 | "- 缺点是对于输入数据的准备方式较为敏感。当输入数据是由多个单词组成且意义明显不同的短语时,它的表现很差。\n", 344 | "- 对于我们的数据集,它的特征数很多(30个),而GaussianNB适合处理大量特征数的数据。\n", 345 | "\n", 346 | "\n", 347 | "- Decision Tree: 真实应用场景:预测隐形眼镜类型(来自Machine Learning in Action | by Peter Harrington 内容3.4) \n", 348 | "- 优势是计算复杂度不高,输出结果易于理解,对中间值的确是不敏感,可以处理不相关特征数据。对于简单的布尔型数据它的表现最好\n", 349 | "- 缺点是可能出现过拟合问题。当过于依赖数据或参数设置不好时,它的表现很差。\n", 350 | "- 我们的数据中有大量布尔型特征,适合用Decision Tree,而且它的一些特征(如nursery)对于我们的目标(passed)可能相关程度并不高,而Decision Tree适合处理不相关特征数的数据。\n", 351 | "\n", 352 | "\n", 353 | "- AdaBoost: 真实应用场景:预测患有疝病的马是否存活(来自Machine Learning in Action | by Peter Harrington 内容7.6) \n", 354 | "- 优势是泛化错误低,易编码,可以应用在大部分分类器上,无参数调整。对于基于错误提升分类器性能它的表现最好\n", 355 | "- 缺点是对离群点敏感。当输入数据有不少极端值时,它的表现很差。\n", 356 | "- 我们的数据集特征很多,较为复杂,在后续迭代中,出现错误的数据权重可能增大,而针对这种错误的调节能力正是AdaBoost的长处。" 357 | ] 358 | }, 359 | { 360 | "cell_type": "markdown", 361 | "metadata": {}, 362 | "source": [ 363 | "### 准备\n", 364 | "运行下面的代码单元以初始化三个帮助函数,这三个函数将能够帮你训练和测试你上面所选择的三个监督学习算法。这些函数是:\n", 365 | "- `train_classifier` - 输入一个分类器和训练集,用数据来训练这个分类器。\n", 366 | "- `predict_labels` - 输入一个训练好的分类器、特征以及一个目标标签,这个函数将帮你做预测并给出F1的值.\n", 367 | "- `train_predict` - 输入一个分类器以及训练集和测试集,它可以运行`train_clasifier`和`predict_labels`.\n", 368 | " - 这个函数将分别输出训练集的F1值和测试集的F1值" 369 | ] 370 | }, 371 | { 372 | "cell_type": "code", 373 | "execution_count": 45, 374 | "metadata": { 375 | "collapsed": false 376 | }, 377 | "outputs": [], 378 | "source": [ 379 | "def train_classifier(clf, X_train, y_train):\n", 380 | " ''' 用训练集训练分类器 '''\n", 381 | " \n", 382 | " # 开始计时,训练分类器,然后停止计时\n", 383 | " start = time()\n", 384 | " clf.fit(X_train, y_train)\n", 385 | " end = time()\n", 386 | " \n", 387 | " # Print the results\n", 388 | " print \"Trained model in {:.4f} seconds\".format(end - start)\n", 389 | "\n", 390 | " \n", 391 | "def predict_labels(clf, features, target):\n", 392 | " ''' 用训练好的分类器做预测并输出F1值'''\n", 393 | " \n", 394 | " # 开始计时,作出预测,然后停止计时\n", 395 | " start = time()\n", 396 | " y_pred = clf.predict(features)\n", 397 | " end = time()\n", 398 | " \n", 399 | " # 输出并返回结果\n", 400 | " print \"Made predictions in {:.4f} seconds.\".format(end - start)\n", 401 | " return f1_score(target.values, y_pred, pos_label='yes')\n", 402 | "\n", 403 | "\n", 404 | "def train_predict(clf, X_train, y_train, X_test, y_test):\n", 405 | " ''' 用一个分类器训练和预测,并输出F1值 '''\n", 406 | " \n", 407 | " # 输出分类器名称和训练集大小\n", 408 | " print \"Training a {} using a training set size of {}. . .\".format(clf.__class__.__name__, len(X_train))\n", 409 | " \n", 410 | " # 训练一个分类器\n", 411 | " train_classifier(clf, X_train, y_train)\n", 412 | " \n", 413 | " # 输出训练和测试的预测结果\n", 414 | " print \"F1 score for training set: {:.4f}.\".format(predict_labels(clf, X_train, y_train))\n", 415 | " print \"F1 score for test set: {:.4f}.\".format(predict_labels(clf, X_test, y_test))" 416 | ] 417 | }, 418 | { 419 | "cell_type": "markdown", 420 | "metadata": {}, 421 | "source": [ 422 | "### 练习: 模型评价指标\n", 423 | "借助于上面定义的函数,你现在需要导入三个你选择的监督学习模型,然后为每一个模型运行`train_predict`函数。请记住,对于每一个模型你需要在不同大小的训练集(100,200和300)上进行训练和测试。所以,你在下面应该会有9个不同的输出(每个模型都有训练集大小不同的三个输出)。在接下来的代码单元中,你将需要实现以下功能:\n", 424 | "- 引入三个你在上面讨论过的监督式学习算法模型。\n", 425 | "- 初始化三个模型并将它们存储在`clf_A`, `clf_B` 和 `clf_C`中。\n", 426 | " - 如果可能对每一个模型都设置一个`random_state`。\n", 427 | " - **注意:** 这里先使用每一个模型的默认参数,在接下来的部分中你将需要对某一个模型的参数进行调整。\n", 428 | "- 创建不同大小的训练集用来训练每一个模型。\n", 429 | " - *不要再混洗和再分割数据!新的训练集要取自`X_train`和`y_train`.*\n", 430 | "- 对于每一个模型要用不同大小的训练集来训练它,然后在测试集上做测试(总共需要9次训练测试) \n", 431 | "**注意:** 在下面的代码单元后面我们提供了三个表用来存储你的结果。" 432 | ] 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 46, 437 | "metadata": { 438 | "collapsed": false 439 | }, 440 | "outputs": [ 441 | { 442 | "name": "stdout", 443 | "output_type": "stream", 444 | "text": [ 445 | "Training a GaussianNB using a training set size of 100. . .\n", 446 | "Trained model in 0.0020 seconds\n", 447 | "Made predictions in 0.0000 seconds.\n", 448 | "F1 score for training set: 0.8467.\n", 449 | "Made predictions in 0.0010 seconds.\n", 450 | "F1 score for test set: 0.8029.\n", 451 | "Training a GaussianNB using a training set size of 200. . .\n", 452 | "Trained model in 0.0020 seconds\n", 453 | "Made predictions in 0.0010 seconds.\n", 454 | "F1 score for training set: 0.8406.\n", 455 | "Made predictions in 0.0010 seconds.\n", 456 | "F1 score for test set: 0.7244.\n", 457 | "Training a GaussianNB using a training set size of 300. . .\n", 458 | "Trained model in 0.0030 seconds\n", 459 | "Made predictions in 0.0020 seconds.\n", 460 | "F1 score for training set: 0.8038.\n", 461 | "Made predictions in 0.0010 seconds.\n", 462 | "F1 score for test set: 0.7634.\n", 463 | "Training a DecisionTreeClassifier using a training set size of 100. . .\n", 464 | "Trained model in 0.0020 seconds\n", 465 | "Made predictions in 0.0000 seconds.\n", 466 | "F1 score for training set: 1.0000.\n", 467 | "Made predictions in 0.0010 seconds.\n", 468 | "F1 score for test set: 0.6552.\n", 469 | "Training a DecisionTreeClassifier using a training set size of 200. . .\n", 470 | "Trained model in 0.0030 seconds\n", 471 | "Made predictions in 0.0000 seconds.\n", 472 | "F1 score for training set: 1.0000.\n", 473 | "Made predictions in 0.0000 seconds.\n", 474 | "F1 score for test set: 0.7500.\n", 475 | "Training a DecisionTreeClassifier using a training set size of 300. . .\n", 476 | "Trained model in 0.0040 seconds\n", 477 | "Made predictions in 0.0000 seconds.\n", 478 | "F1 score for training set: 1.0000.\n", 479 | "Made predictions in 0.0000 seconds.\n", 480 | "F1 score for test set: 0.6613.\n", 481 | "Training a AdaBoostClassifier using a training set size of 100. . .\n", 482 | "Trained model in 0.1430 seconds\n", 483 | "Made predictions in 0.0310 seconds.\n", 484 | "F1 score for training set: 0.9481.\n", 485 | "Made predictions in 0.0160 seconds.\n", 486 | "F1 score for test set: 0.7669.\n", 487 | "Training a AdaBoostClassifier using a training set size of 200. . .\n", 488 | "Trained model in 0.2020 seconds\n", 489 | "Made predictions in 0.0080 seconds.\n", 490 | "F1 score for training set: 0.8927.\n", 491 | "Made predictions in 0.0070 seconds.\n", 492 | "F1 score for test set: 0.8281.\n", 493 | "Training a AdaBoostClassifier using a training set size of 300. . .\n", 494 | "Trained model in 0.1660 seconds\n", 495 | "Made predictions in 0.0000 seconds.\n", 496 | "F1 score for training set: 0.8637.\n", 497 | "Made predictions in 0.0150 seconds.\n", 498 | "F1 score for test set: 0.7820.\n" 499 | ] 500 | } 501 | ], 502 | "source": [ 503 | "# TODO:从sklearn中引入三个监督学习模型\n", 504 | "from sklearn.naive_bayes import GaussianNB\n", 505 | "from sklearn.tree import DecisionTreeClassifier\n", 506 | "from sklearn.ensemble import AdaBoostClassifier\n", 507 | "\n", 508 | "# TODO:初始化三个模型\n", 509 | "clf_A = GaussianNB()\n", 510 | "clf_B = DecisionTreeClassifier(random_state=42)\n", 511 | "clf_C = AdaBoostClassifier(random_state=42)\n", 512 | "\n", 513 | "# TODO:设置训练集大小\n", 514 | "X_train_100 = X_train[0:100]\n", 515 | "y_train_100 = y_train[0:100]\n", 516 | "\n", 517 | "X_train_200 = X_train[100:300]\n", 518 | "y_train_200 = y_train[100:300]\n", 519 | "\n", 520 | "X_train_300 = X_train\n", 521 | "y_train_300 = y_train\n", 522 | "\n", 523 | "# TODO:对每一个分类器和每一个训练集大小运行'train_predict' \n", 524 | "for clf in [clf_A, clf_B, clf_C]:\n", 525 | " for Size in [100, 200, 300]:\n", 526 | " train_predict(clf, X_train[:Size], y_train[:Size], X_test, y_test)" 527 | ] 528 | }, 529 | { 530 | "cell_type": "markdown", 531 | "metadata": {}, 532 | "source": [ 533 | "### 结果表格\n", 534 | "编辑下面的表格看看在[Markdown](https://github.com/adam-p/markdown-here/wiki/Markdown-Cheatsheet#tables)中如何设计一个表格。你需要把上面的结果记录在表格中。" 535 | ] 536 | }, 537 | { 538 | "cell_type": "markdown", 539 | "metadata": {}, 540 | "source": [ 541 | "** 分类器 1 - GaussianNB** \n", 542 | "\n", 543 | "| 训练集大小 | 训练时间 | 预测时间 (测试) | F1值 (训练) | F1值 (测试) |\n", 544 | "| :---------------: | :---------------------: | :--------------------: | :--------------: | :-------------: |\n", 545 | "| 100 | 0.0010 | 0.0010 | 0.8467 | 0.8029 |\n", 546 | "| 200 | 0.0020 | 0.0010 | 0.8102 | 0.7258 |\n", 547 | "| 300 | 0.0020 | 0.0010 | 0.8038 | 0.7634 |\n", 548 | "\n", 549 | "** 分类器 2 - DecisionTree** \n", 550 | "\n", 551 | "| 训练集大小 | 训练时间 | 预测时间 (测试) | F1值 (训练) | F1值 (测试) |\n", 552 | "| :---------------: | :---------------------: | :--------------------: | :--------------: | :-------------: |\n", 553 | "| 100 | 0.0010 | 0.0010 | 1 | 0.6552 |\n", 554 | "| 200 | 0.0030 | 0.0010 | 1 | 0.7031 |\n", 555 | "| 300 | 0.0470 | 0.0010 | 1 | 0.6631 |\n", 556 | "\n", 557 | "** 分类器 3 - AdaBoost** \n", 558 | "\n", 559 | "| 训练集大小 | 训练时间 | 预测时间 (测试) | F1值 (训练) | F1值 (测试) |\n", 560 | "| :---------------: | :---------------------: | :--------------------: | :--------------: | :-------------: |\n", 561 | "| 100 | 0.2270 | 0.0080 | 0.9481 | 0.7669 |\n", 562 | "| 200 | 0.1970 | 0.0230 | 0.8836 | 0.7344 |\n", 563 | "| 300 | 0.2610 | 0.0100 | 0.8637 | 0.7820 |" 564 | ] 565 | }, 566 | { 567 | "cell_type": "markdown", 568 | "metadata": {}, 569 | "source": [ 570 | "## 选择最佳模型\n", 571 | "在最后这一部分中,你将从三个监督学习模型中选择一个用在学生数据上的最佳模型。然后你将在最佳模型上用全部的训练集(`X_train`和`y_train`)运行一个网格搜索算法,在这个过程中,你要至少调整一个参数以提高模型的F1值(相比于没有调参的模型的分值有所提高)。 " 572 | ] 573 | }, 574 | { 575 | "cell_type": "markdown", 576 | "metadata": {}, 577 | "source": [ 578 | "### 问题 3 - 选择最佳模型\n", 579 | "*给予你上面做的实验,用一到两段话,向(学校)监事会解释你将选择哪个模型作为最佳的模型。哪个模型在现有的数据,有限的资源、开支和模型表现综合来看是最好的选择?*" 580 | ] 581 | }, 582 | { 583 | "cell_type": "markdown", 584 | "metadata": {}, 585 | "source": [ 586 | "**回答: **你好,既然你们让我来构建学生干预系统,那我就要对你们负责,所以我选择了3个模型来进行预测,分别是GaussianNB, DecisionTree, Adaboost.基于我对模型的评估,我认为AdaBoost是最好的模型。\n", 587 | "\n", 588 | "我的理由是AdaBoost的F1得分在三个模型中是最好的(虽然DecisionTree训练得分为1,可测试得分太低,出现过拟合),虽然它的训练时间较长,但考虑到它的预测时间短,也就是查询时间短,我们一旦把模型训练出来(考虑到数据集大小也花不了多少时间),之后的主要任务就只有查询了,并不会对过多消耗资源和开支,所以我还是使用AdaBoost" 589 | ] 590 | }, 591 | { 592 | "cell_type": "markdown", 593 | "metadata": {}, 594 | "source": [ 595 | "### 问题 4 - 用通俗的语言解释模型\n", 596 | "*用一到两段话,向(学校)监事会用外行也听得懂的话来解释最终模型是如何工作的。你需要解释所选模型的主要特点。例如,这个模型是怎样被训练的,它又是如何做出预测的。避免使用高级的数学或技术术语,不要使用公式或特定的算法名词。*" 597 | ] 598 | }, 599 | { 600 | "cell_type": "markdown", 601 | "metadata": {}, 602 | "source": [ 603 | "让我们以自己的数据为例,根据我们的数据集,让我们假设我们有10名老师,他们根据学生的相关情况,比如性别,年龄,家庭情况,来判断学生能否通过考试,第一次预测,10位老师每人一票投票他们是否能通过,然后第二次预测的时候,我们考虑到第一次投票时每位老师的准确度,可能一些老师侧重于家庭情况,一些老师觉得男生肯定比女生通过的多,然后对第一次预测时准确度高的老师,我们给他更多的票数,给他5张,而对准确度低的老师我们就不让他再参与投票了,这个就叫加权,这样反复进行多次预测,最终得到最好的预测组合,这就是AdaBoost了" 604 | ] 605 | }, 606 | { 607 | "cell_type": "markdown", 608 | "metadata": {}, 609 | "source": [ 610 | "### 练习: 模型调参\n", 611 | "细调选择的模型的参数。使用网格搜索(`GridSearchCV`)来至少调整模型的重要参数(至少调整一个),这个参数至少需给出并尝试3个不同的值。你要使用整个训练集来完成这个过程。在接下来的代码单元中,你需要实现以下功能:\n", 612 | "- 导入 [`sklearn.grid_search.gridSearchCV`](http://scikit-learn.org/stable/modules/generated/sklearn.grid_search.GridSearchCV.html) 和 [`sklearn.metrics.make_scorer`](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html).\n", 613 | "- 创建一个对于这个模型你希望调整参数的字典。\n", 614 | " - 例如: `parameters = {'parameter' : [list of values]}`。\n", 615 | "- 初始化你选择的分类器,并将其存储在`clf`中。\n", 616 | "- 使用`make_scorer` 创建F1评分函数并将其存储在`f1_scorer`中。\n", 617 | " - 需正确设定参数`pos_label`的值!\n", 618 | "- 在分类器`clf`上用`f1_scorer` 作为评价函数运行网格搜索,并将结果存储在`grid_obj`中。\n", 619 | "- 用训练集(`X_train`, `y_train`)训练grid search object,并将结果存储在`grid_obj`中。" 620 | ] 621 | }, 622 | { 623 | "cell_type": "code", 624 | "execution_count": 47, 625 | "metadata": { 626 | "collapsed": false 627 | }, 628 | "outputs": [ 629 | { 630 | "name": "stdout", 631 | "output_type": "stream", 632 | "text": [ 633 | "Made predictions in 0.0000 seconds.\n", 634 | "Tuned model has a training F1 score of 0.8299.\n", 635 | "Made predictions in 0.0000 seconds.\n", 636 | "Tuned model has a testing F1 score of 0.8000.\n" 637 | ] 638 | } 639 | ], 640 | "source": [ 641 | "# TODO: 导入 'GridSearchCV' 和 'make_scorer'\n", 642 | "\n", 643 | "from sklearn.grid_search import GridSearchCV\n", 644 | "from sklearn.metrics import make_scorer\n", 645 | "\n", 646 | "\n", 647 | "# TODO:创建你希望调整的参数列表\n", 648 | "parameters ={'n_estimators':[1,25,50,75,100]}\n", 649 | "\n", 650 | "# TODO:初始化分类器\n", 651 | "clf = AdaBoostClassifier()\n", 652 | "\n", 653 | "# TODO:用'make_scorer'创建一个f1评分函数\n", 654 | "f1_scorer = make_scorer(f1_score,pos_label='yes')\n", 655 | "\n", 656 | "# TODO:在分类器上使用f1_scorer作为评分函数运行网格搜索\n", 657 | "grid_obj = GridSearchCV(clf,parameters,f1_scorer)\n", 658 | "\n", 659 | "# TODO: Fit the grid search object to the training data and find the optimal parameters\n", 660 | "# TODO:用训练集训练grid search object来寻找最佳参数\n", 661 | "grid_obj = grid_obj.fit(X_train,y_train)\n", 662 | "\n", 663 | "# Get the estimator\n", 664 | "# 得到预测的结果\n", 665 | "clf = grid_obj.best_estimator_\n", 666 | "\n", 667 | "# Report the final F1 score for training and testing after parameter tuning\n", 668 | "# 输出经过调参之后的训练集和测试集的F1值\n", 669 | "print \"Tuned model has a training F1 score of {:.4f}.\".format(predict_labels(clf, X_train, y_train))\n", 670 | "print \"Tuned model has a testing F1 score of {:.4f}.\".format(predict_labels(clf, X_test, y_test))" 671 | ] 672 | }, 673 | { 674 | "cell_type": "markdown", 675 | "metadata": {}, 676 | "source": [ 677 | "### 问题 5 - 最终的 F1 值\n", 678 | "*最终模型的训练和测试的F1值是多少?这个值相比于没有调整过参数的模型怎么样?*" 679 | ] 680 | }, 681 | { 682 | "cell_type": "markdown", 683 | "metadata": {}, 684 | "source": [ 685 | "**回答: **最终模型的训练和测试的f1得分为0.83和0.80,相比之前的0.86和0.78,虽然训练得分下降了。但测试得分提升了,表现还是可以的" 686 | ] 687 | }, 688 | { 689 | "cell_type": "markdown", 690 | "metadata": {}, 691 | "source": [ 692 | "> **注意**: 当你写完了所有的代码,并且回答了所有的问题。你就可以把你的 iPython Notebook 导出成 HTML 文件。你可以在菜单栏,这样导出**File -> Download as -> HTML (.html)**把这个 HTML 和这个 iPython notebook 一起做为你的作业提交。 " 693 | ] 694 | } 695 | ], 696 | "metadata": { 697 | "anaconda-cloud": {}, 698 | "kernelspec": { 699 | "display_name": "Python [default]", 700 | "language": "python", 701 | "name": "python2" 702 | }, 703 | "language_info": { 704 | "codemirror_mode": { 705 | "name": "ipython", 706 | "version": 2 707 | }, 708 | "file_extension": ".py", 709 | "mimetype": "text/x-python", 710 | "name": "python", 711 | "nbconvert_exporter": "python", 712 | "pygments_lexer": "ipython2", 713 | "version": "2.7.12" 714 | } 715 | }, 716 | "nbformat": 4, 717 | "nbformat_minor": 0 718 | } 719 | -------------------------------------------------------------------------------- /student_intervention/student_intervention.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/student_intervention/student_intervention.zip -------------------------------------------------------------------------------- /student_intervention/student_intervention2.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/student_intervention/student_intervention2.zip -------------------------------------------------------------------------------- /titanic_survival_exploration/README.md: -------------------------------------------------------------------------------- 1 | # 项目 0: 入门与基础 2 | ## 预测泰坦尼克号乘客幸存率 3 | 4 | ### 安装要求 5 | 这个项目要求使用 **Python 2.7** 以及安装下列python库 6 | 7 | - [NumPy](http://www.numpy.org/) 8 | - [Pandas](http://pandas.pydata.org) 9 | - [matplotlib](http://matplotlib.org/) 10 | - [scikit-learn](http://scikit-learn.org/stable/) 11 | ​ 12 | 13 | 你还需要安装和运行 [Jupyter Notebook](http://jupyter.readthedocs.io/en/latest/install.html#optional-for-experienced-python-developers-installing-jupyter-with-pip)。 14 | 15 | 16 | 优达学城推荐学生安装 [Anaconda](https://www.continuum.io/downloads),一个包含了项目需要的所有库和软件的 Python 发行版本。[这里](https://classroom.udacity.com/nanodegrees/nd002/parts/0021345403/modules/317671873575460/lessons/5430778793/concepts/54140889150923)介绍了如何安装Anaconda。 17 | 18 | 如果你使用macOS系统并且对命令行比较熟悉,可以安装[homebrew](http://brew.sh/),以及brew版python 19 | 20 | ```bash 21 | $ brew install python 22 | ``` 23 | 24 | 再用下列命令安装所需要的python库 25 | 26 | ```bash 27 | $ pip install numpy pandas matplotlib scikit-learn scipy jupyter 28 | ``` 29 | 30 | ### 代码 31 | ​ 32 | 事例代码在 `titanic_survival_exploration_cn.ipynb` 文件中,辅助代码在 `titanic_visualizations.py` 文件中。尽管已经提供了一些代码帮助你上手,你还是需要补充些代码使得项目要求的功能能够成功实现。 33 | 34 | ### 运行 35 | ​ 36 | 在命令行中,确保当前目录为 `titanic_survival_exploration/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令: 37 | 38 | ```bash 39 | $ jupyter notebook titanic_survival_exploration.ipynb 40 | ``` 41 | ​ 42 | 这会启动 Jupyter Notebook 把项目文件打开在你的浏览器中。 43 | 44 | 对jupyter不熟悉的同学可以看一下这两个链接: 45 | 46 | - [Jupyter使用视频教程](http://cn-static.udacity.com/mlnd/how_to_use_jupyter.mp4) 47 | - [为什么使用jupyter?](https://www.zhihu.com/question/37490497) 48 | ​ 49 | ​ 50 | ​ 51 | ​ 52 | ​ 53 | ​ 54 | ​ 55 | ​ 56 | ​ 57 | ​ 58 | ​ 59 | ​ 60 | ​ 61 | ​ 62 | 63 | ### 数据 64 | ​ 65 | 这个项目的数据包含在 `titanic_data.csv` 文件中。文件包含下列特征: 66 | ​ 67 | - **Survived**:是否存活(0代表否,1代表是) 68 | - **Pclass**:社会阶级(1代表上层阶级,2代表中层阶级,3代表底层阶级) 69 | - **Name**:船上乘客的名字 70 | - **Sex**:船上乘客的性别 71 | - **Age**:船上乘客的年龄(可能存在 `NaN`) 72 | - **SibSp**:乘客在船上的兄弟姐妹和配偶的数量 73 | - **Parch**:乘客在船上的父母以及小孩的数量 74 | - **Ticket**:乘客船票的编号 75 | - **Fare**:乘客为船票支付的费用 76 | - **Cabin**:乘客所在船舱的编号(可能存在 `NaN`) 77 | - **Embarked**:乘客上船的港口(C 代表从 Cherbourg 登船,Q 代表从 Queenstown 登船,S 代表从 Southampton 登船) 78 | -------------------------------------------------------------------------------- /titanic_survival_exploration/titanic_survival_exploration.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/titanic_survival_exploration/titanic_survival_exploration.zip -------------------------------------------------------------------------------- /titanic_survival_exploration/titanic_visualizations.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import pandas as pd 3 | import matplotlib.pyplot as plt 4 | 5 | def filter_data(data, condition): 6 | """ 7 | Remove elements that do not match the condition provided. 8 | Takes a data list as input and returns a filtered list. 9 | Conditions should be a list of strings of the following format: 10 | ' ' 11 | where the following operations are valid: >, <, >=, <=, ==, != 12 | 13 | Example: ["Sex == 'male'", 'Age < 18'] 14 | """ 15 | 16 | field, op, value = condition.split(" ") 17 | 18 | # convert value into number or strip excess quotes if string 19 | try: 20 | value = float(value) 21 | except: 22 | value = value.strip("\'\"") 23 | 24 | # get booleans for filtering 25 | if op == ">": 26 | matches = data[field] > value 27 | elif op == "<": 28 | matches = data[field] < value 29 | elif op == ">=": 30 | matches = data[field] >= value 31 | elif op == "<=": 32 | matches = data[field] <= value 33 | elif op == "==": 34 | matches = data[field] == value 35 | elif op == "!=": 36 | matches = data[field] != value 37 | else: # catch invalid operation codes 38 | raise Exception("Invalid comparison operator. Only >, <, >=, <=, ==, != allowed.") 39 | 40 | # filter data and outcomes 41 | data = data[matches].reset_index(drop = True) 42 | return data 43 | 44 | def survival_stats(data, outcomes, key, filters = []): 45 | """ 46 | Print out selected statistics regarding survival, given a feature of 47 | interest and any number of filters (including no filters) 48 | """ 49 | 50 | # Check that the key exists 51 | if key not in data.columns.values : 52 | print "'{}' is not a feature of the Titanic data. Did you spell something wrong?".format(key) 53 | return False 54 | 55 | # Return the function before visualizing if 'Cabin' or 'Ticket' 56 | # is selected: too many unique categories to display 57 | if(key == 'Cabin' or key == 'PassengerId' or key == 'Ticket'): 58 | print "'{}' has too many unique categories to display! Try a different feature.".format(key) 59 | return False 60 | 61 | # Merge data and outcomes into single dataframe 62 | all_data = pd.concat([data, outcomes], axis = 1) 63 | 64 | # Apply filters to data 65 | for condition in filters: 66 | all_data = filter_data(all_data, condition) 67 | 68 | # Create outcomes DataFrame 69 | all_data = all_data[[key, 'Survived']] 70 | 71 | # Create plotting figure 72 | plt.figure(figsize=(8,6)) 73 | 74 | # 'Numerical' features 75 | if(key == 'Age' or key == 'Fare'): 76 | 77 | # Remove NaN values from Age data 78 | all_data = all_data[~np.isnan(all_data[key])] 79 | 80 | # Divide the range of data into bins and count survival rates 81 | min_value = all_data[key].min() 82 | max_value = all_data[key].max() 83 | value_range = max_value - min_value 84 | 85 | # 'Fares' has larger range of values than 'Age' so create more bins 86 | if(key == 'Fare'): 87 | bins = np.arange(0, all_data['Fare'].max() + 20, 20) 88 | if(key == 'Age'): 89 | bins = np.arange(0, all_data['Age'].max() + 10, 10) 90 | 91 | # Overlay each bin's survival rates 92 | nonsurv_vals = all_data[all_data['Survived'] == 0][key].reset_index(drop = True) 93 | surv_vals = all_data[all_data['Survived'] == 1][key].reset_index(drop = True) 94 | plt.hist(nonsurv_vals, bins = bins, alpha = 0.6, 95 | color = 'red', label = 'Did not survive') 96 | plt.hist(surv_vals, bins = bins, alpha = 0.6, 97 | color = 'green', label = 'Survived') 98 | 99 | # Add legend to plot 100 | plt.xlim(0, bins.max()) 101 | plt.legend(framealpha = 0.8) 102 | 103 | # 'Categorical' features 104 | else: 105 | 106 | # Set the various categories 107 | if(key == 'Pclass'): 108 | values = np.arange(1,4) 109 | if(key == 'Parch' or key == 'SibSp'): 110 | values = np.arange(0,np.max(data[key]) + 1) 111 | if(key == 'Embarked'): 112 | values = ['C', 'Q', 'S'] 113 | if(key == 'Sex'): 114 | values = ['male', 'female'] 115 | 116 | # Create DataFrame containing categories and count of each 117 | frame = pd.DataFrame(index = np.arange(len(values)), columns=(key,'Survived','NSurvived')) 118 | for i, value in enumerate(values): 119 | frame.loc[i] = [value, \ 120 | len(all_data[(all_data['Survived'] == 1) & (all_data[key] == value)]), \ 121 | len(all_data[(all_data['Survived'] == 0) & (all_data[key] == value)])] 122 | 123 | # Set the width of each bar 124 | bar_width = 0.4 125 | 126 | # Display each category's survival rates 127 | for i in np.arange(len(frame)): 128 | nonsurv_bar = plt.bar(i-bar_width, frame.loc[i]['NSurvived'], width = bar_width, color = 'r') 129 | surv_bar = plt.bar(i, frame.loc[i]['Survived'], width = bar_width, color = 'g') 130 | 131 | plt.xticks(np.arange(len(frame)), values) 132 | plt.legend((nonsurv_bar[0], surv_bar[0]),('Did not survive', 'Survived'), framealpha = 0.8) 133 | 134 | # Common attributes for plot formatting 135 | plt.xlabel(key) 136 | plt.ylabel('Number of Passengers') 137 | plt.title('Passenger Survival Statistics With \'%s\' Feature'%(key)) 138 | plt.show() 139 | 140 | # Report number of passengers with missing values 141 | if sum(pd.isnull(all_data[key])): 142 | nan_outcomes = all_data[pd.isnull(all_data[key])]['Survived'] 143 | print "Passengers with missing '{}' values: {} ({} survived, {} did not survive)".format( \ 144 | key, len(nan_outcomes), sum(nan_outcomes == 1), sum(nan_outcomes == 0)) 145 | 146 | --------------------------------------------------------------------------------