├── README.md
├── boston_housing
├── README.md
├── bj_housing.csv
├── boston_housing.ipynb
├── boston_housing.zip
├── housing.csv
└── visuals.py
├── creating_customer_segments
├── README.md
├── customer_segments.ipynb
├── customer_segments.zip
├── customers.csv
└── visuals.py
├── digit_recognition
├── README.md
├── digit_recognition.ipynb
├── model.png
└── project_description.md
├── dog_vs_cat
├── README.md
├── dog_vs_cat (4).ipynb
├── file (1).csv
├── file (10).csv
├── file (11).csv
├── proposal.pdf
└── report2.pdf
├── finding_donors
├── README.md
├── census.csv
├── finding_donors.ipynb
├── finding_donors.zip
├── project_description.md
└── visuals.py
├── smartcab
├── README.md
├── agent.py
├── images
│ ├── car-black.png
│ ├── car-blue.png
│ ├── car-cyan.png
│ ├── car-green.png
│ ├── car-magenta.png
│ ├── car-orange.png
│ ├── car-red.png
│ ├── car-white.png
│ ├── car-yellow.png
│ ├── east-west.png
│ ├── logo.png
│ └── north-south.png
├── logs
│ ├── sim_default-learning.csv
│ ├── sim_default-learning.txt
│ ├── sim_improved-learning.csv
│ ├── sim_improved-learning.txt
│ └── sim_no-learning.csv
├── project_description.md
├── smartcab.ipynb
├── smartcab.zip
├── smartcab
│ ├── agent.py
│ ├── environment.py
│ ├── environment.pyc
│ ├── planner.py
│ ├── planner.pyc
│ ├── simulator.py
│ └── simulator.pyc
├── smartcab2.zip
├── smartcab3.zip
└── visuals.py
├── student_intervention
├── README.md
├── student-data.csv
├── student_intervention.ipynb
├── student_intervention.zip
└── student_intervention2.zip
└── titanic_survival_exploration
├── README.md
├── titanic_data.csv
├── titanic_survival_exploration.ipynb
├── titanic_survival_exploration.zip
└── titanic_visualizations.py
/README.md:
--------------------------------------------------------------------------------
1 | # machine learning & deep learning project
2 |
--------------------------------------------------------------------------------
/boston_housing/README.md:
--------------------------------------------------------------------------------
1 | # 项目1:模型评估与验证
2 | ## 波士顿房价预测
3 |
4 | ### 准备工作
5 |
6 | 这个项目需要安装**Python 2.7**和以下的Python函数库:
7 |
8 | - [NumPy](http://www.numpy.org/)
9 | - [matplotlib](http://matplotlib.org/)
10 | - [scikit-learn](http://scikit-learn.org/stable/)
11 |
12 | 你还需要安装一个软件,以运行和编辑[ipynb](http://jupyter.org/)文件。
13 |
14 | 优达学城推荐学生安装 [Anaconda](https://www.continuum.io/downloads),这是一个常用的Python集成编译环境,且已包含了本项目中所需的全部函数库。我们在P0项目中也有讲解[如何搭建学习环境](https://github.com/udacity/machine-learning/blob/master/projects_cn/titanic_survival_exploration/README.md)。
15 |
16 | ### 编码
17 |
18 | 代码的模版已经在`boston_housing.ipynb`文件中给出。你还会用到`visuals.py`和名为`housing.csv`的数据文件来完成这个项目。我们已经为你提供了一部分代码,但还有些功能需要你来实现才能以完成这个项目。
19 |
20 | ### 运行
21 |
22 | 在终端或命令行窗口中,选定`boston_housing/`的目录下(包含此README文件),运行下方的命令:
23 |
24 | ```jupyter notebook boston_housing.ipynb```
25 |
26 | 这样就能够启动jupyter notebook软件,并在你的浏览器中打开文件。
27 |
28 | ### Data
29 |
30 | 经过编辑的波士顿房价数据集有490个数据点,每个点有三个特征。这个数据集编辑自[加州大学欧文分校机器学习数据集库](https://archive.ics.uci.edu/ml/datasets/Housing).
31 |
32 | **特征**
33 |
34 | 1. `RM`: 住宅平均房间数量
35 | 2. `LSTAT`: 区域中被认为是低收入阶层的比率
36 | 3. `PTRATIO`: 镇上学生与教师数量比例
37 |
38 | **目标变量**
39 |
40 | 4. `MEDV`: 房屋的中值价格
--------------------------------------------------------------------------------
/boston_housing/boston_housing.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/boston_housing/boston_housing.zip
--------------------------------------------------------------------------------
/boston_housing/housing.csv:
--------------------------------------------------------------------------------
1 | RM,LSTAT,PTRATIO,MEDV
2 | 6.575,4.98,15.3,504000.0
3 | 6.421,9.14,17.8,453600.0
4 | 7.185,4.03,17.8,728700.0
5 | 6.998,2.94,18.7,701400.0
6 | 7.147,5.33,18.7,760200.0
7 | 6.43,5.21,18.7,602700.0
8 | 6.012,12.43,15.2,480900.0
9 | 6.172,19.15,15.2,569100.0
10 | 5.631,29.93,15.2,346500.0
11 | 6.004,17.1,15.2,396900.0
12 | 6.377,20.45,15.2,315000.0
13 | 6.009,13.27,15.2,396900.0
14 | 5.889,15.71,15.2,455700.0
15 | 5.949,8.26,21.0,428400.0
16 | 6.096,10.26,21.0,382200.0
17 | 5.834,8.47,21.0,417900.0
18 | 5.935,6.58,21.0,485100.0
19 | 5.99,14.67,21.0,367500.0
20 | 5.456,11.69,21.0,424200.0
21 | 5.727,11.28,21.0,382200.0
22 | 5.57,21.02,21.0,285600.0
23 | 5.965,13.83,21.0,411600.0
24 | 6.142,18.72,21.0,319200.0
25 | 5.813,19.88,21.0,304500.0
26 | 5.924,16.3,21.0,327600.0
27 | 5.599,16.51,21.0,291900.0
28 | 5.813,14.81,21.0,348600.0
29 | 6.047,17.28,21.0,310800.0
30 | 6.495,12.8,21.0,386400.0
31 | 6.674,11.98,21.0,441000.0
32 | 5.713,22.6,21.0,266700.0
33 | 6.072,13.04,21.0,304500.0
34 | 5.95,27.71,21.0,277200.0
35 | 5.701,18.35,21.0,275100.0
36 | 6.096,20.34,21.0,283500.0
37 | 5.933,9.68,19.2,396900.0
38 | 5.841,11.41,19.2,420000.0
39 | 5.85,8.77,19.2,441000.0
40 | 5.966,10.13,19.2,518700.0
41 | 6.595,4.32,18.3,646800.0
42 | 7.024,1.98,18.3,732900.0
43 | 6.77,4.84,17.9,558600.0
44 | 6.169,5.81,17.9,531300.0
45 | 6.211,7.44,17.9,518700.0
46 | 6.069,9.55,17.9,445200.0
47 | 5.682,10.21,17.9,405300.0
48 | 5.786,14.15,17.9,420000.0
49 | 6.03,18.8,17.9,348600.0
50 | 5.399,30.81,17.9,302400.0
51 | 5.602,16.2,17.9,407400.0
52 | 5.963,13.45,16.8,413700.0
53 | 6.115,9.43,16.8,430500.0
54 | 6.511,5.28,16.8,525000.0
55 | 5.998,8.43,16.8,491400.0
56 | 5.888,14.8,21.1,396900.0
57 | 7.249,4.81,17.9,743400.0
58 | 6.383,5.77,17.3,518700.0
59 | 6.816,3.95,15.1,663600.0
60 | 6.145,6.86,19.7,489300.0
61 | 5.927,9.22,19.7,411600.0
62 | 5.741,13.15,19.7,392700.0
63 | 5.966,14.44,19.7,336000.0
64 | 6.456,6.73,19.7,466200.0
65 | 6.762,9.5,19.7,525000.0
66 | 7.104,8.05,18.6,693000.0
67 | 6.29,4.67,16.1,493500.0
68 | 5.787,10.24,16.1,407400.0
69 | 5.878,8.1,18.9,462000.0
70 | 5.594,13.09,18.9,365400.0
71 | 5.885,8.79,18.9,438900.0
72 | 6.417,6.72,19.2,508200.0
73 | 5.961,9.88,19.2,455700.0
74 | 6.065,5.52,19.2,478800.0
75 | 6.245,7.54,19.2,491400.0
76 | 6.273,6.78,18.7,506100.0
77 | 6.286,8.94,18.7,449400.0
78 | 6.279,11.97,18.7,420000.0
79 | 6.14,10.27,18.7,436800.0
80 | 6.232,12.34,18.7,445200.0
81 | 5.874,9.1,18.7,426300.0
82 | 6.727,5.29,19.0,588000.0
83 | 6.619,7.22,19.0,501900.0
84 | 6.302,6.72,19.0,520800.0
85 | 6.167,7.51,19.0,480900.0
86 | 6.389,9.62,18.5,501900.0
87 | 6.63,6.53,18.5,558600.0
88 | 6.015,12.86,18.5,472500.0
89 | 6.121,8.44,18.5,466200.0
90 | 7.007,5.5,17.8,495600.0
91 | 7.079,5.7,17.8,602700.0
92 | 6.417,8.81,17.8,474600.0
93 | 6.405,8.2,17.8,462000.0
94 | 6.442,8.16,18.2,480900.0
95 | 6.211,6.21,18.2,525000.0
96 | 6.249,10.59,18.2,432600.0
97 | 6.625,6.65,18.0,596400.0
98 | 6.163,11.34,18.0,449400.0
99 | 8.069,4.21,18.0,812700.0
100 | 7.82,3.57,18.0,919800.0
101 | 7.416,6.19,18.0,697200.0
102 | 6.727,9.42,20.9,577500.0
103 | 6.781,7.67,20.9,556500.0
104 | 6.405,10.63,20.9,390600.0
105 | 6.137,13.44,20.9,405300.0
106 | 6.167,12.33,20.9,422100.0
107 | 5.851,16.47,20.9,409500.0
108 | 5.836,18.66,20.9,409500.0
109 | 6.127,14.09,20.9,428400.0
110 | 6.474,12.27,20.9,415800.0
111 | 6.229,15.55,20.9,407400.0
112 | 6.195,13.0,20.9,455700.0
113 | 6.715,10.16,17.8,478800.0
114 | 5.913,16.21,17.8,394800.0
115 | 6.092,17.09,17.8,392700.0
116 | 6.254,10.45,17.8,388500.0
117 | 5.928,15.76,17.8,384300.0
118 | 6.176,12.04,17.8,445200.0
119 | 6.021,10.3,17.8,403200.0
120 | 5.872,15.37,17.8,428400.0
121 | 5.731,13.61,17.8,405300.0
122 | 5.87,14.37,19.1,462000.0
123 | 6.004,14.27,19.1,426300.0
124 | 5.961,17.93,19.1,430500.0
125 | 5.856,25.41,19.1,363300.0
126 | 5.879,17.58,19.1,394800.0
127 | 5.986,14.81,19.1,449400.0
128 | 5.613,27.26,19.1,329700.0
129 | 5.693,17.19,21.2,340200.0
130 | 6.431,15.39,21.2,378000.0
131 | 5.637,18.34,21.2,300300.0
132 | 6.458,12.6,21.2,403200.0
133 | 6.326,12.26,21.2,411600.0
134 | 6.372,11.12,21.2,483000.0
135 | 5.822,15.03,21.2,386400.0
136 | 5.757,17.31,21.2,327600.0
137 | 6.335,16.96,21.2,380100.0
138 | 5.942,16.9,21.2,365400.0
139 | 6.454,14.59,21.2,359100.0
140 | 5.857,21.32,21.2,279300.0
141 | 6.151,18.46,21.2,373800.0
142 | 6.174,24.16,21.2,294000.0
143 | 5.019,34.41,21.2,302400.0
144 | 5.403,26.82,14.7,281400.0
145 | 5.468,26.42,14.7,327600.0
146 | 4.903,29.29,14.7,247800.0
147 | 6.13,27.8,14.7,289800.0
148 | 5.628,16.65,14.7,327600.0
149 | 4.926,29.53,14.7,306600.0
150 | 5.186,28.32,14.7,373800.0
151 | 5.597,21.45,14.7,323400.0
152 | 6.122,14.1,14.7,451500.0
153 | 5.404,13.28,14.7,411600.0
154 | 5.012,12.12,14.7,321300.0
155 | 5.709,15.79,14.7,407400.0
156 | 6.129,15.12,14.7,357000.0
157 | 6.152,15.02,14.7,327600.0
158 | 5.272,16.14,14.7,275100.0
159 | 6.943,4.59,14.7,867300.0
160 | 6.066,6.43,14.7,510300.0
161 | 6.51,7.39,14.7,489300.0
162 | 6.25,5.5,14.7,567000.0
163 | 5.854,11.64,14.7,476700.0
164 | 6.101,9.81,14.7,525000.0
165 | 5.877,12.14,14.7,499800.0
166 | 6.319,11.1,14.7,499800.0
167 | 6.402,11.32,14.7,468300.0
168 | 5.875,14.43,14.7,365400.0
169 | 5.88,12.03,14.7,401100.0
170 | 5.572,14.69,16.6,485100.0
171 | 6.416,9.04,16.6,495600.0
172 | 5.859,9.64,16.6,474600.0
173 | 6.546,5.33,16.6,617400.0
174 | 6.02,10.11,16.6,487200.0
175 | 6.315,6.29,16.6,516600.0
176 | 6.86,6.92,16.6,627900.0
177 | 6.98,5.04,17.8,781200.0
178 | 7.765,7.56,17.8,835800.0
179 | 6.144,9.45,17.8,760200.0
180 | 7.155,4.82,17.8,795900.0
181 | 6.563,5.68,17.8,682500.0
182 | 5.604,13.98,17.8,554400.0
183 | 6.153,13.15,17.8,621600.0
184 | 6.782,6.68,15.2,672000.0
185 | 6.556,4.56,15.2,625800.0
186 | 7.185,5.39,15.2,732900.0
187 | 6.951,5.1,15.2,777000.0
188 | 6.739,4.69,15.2,640500.0
189 | 7.178,2.87,15.2,764400.0
190 | 6.8,5.03,15.6,653100.0
191 | 6.604,4.38,15.6,611100.0
192 | 7.287,4.08,12.6,699300.0
193 | 7.107,8.61,12.6,636300.0
194 | 7.274,6.62,12.6,726600.0
195 | 6.975,4.56,17.0,732900.0
196 | 7.135,4.45,17.0,690900.0
197 | 6.162,7.43,14.7,506100.0
198 | 7.61,3.11,14.7,888300.0
199 | 7.853,3.81,14.7,1018500.0
200 | 5.891,10.87,18.6,474600.0
201 | 6.326,10.97,18.6,512400.0
202 | 5.783,18.06,18.6,472500.0
203 | 6.064,14.66,18.6,512400.0
204 | 5.344,23.09,18.6,420000.0
205 | 5.96,17.27,18.6,455700.0
206 | 5.404,23.98,18.6,405300.0
207 | 5.807,16.03,18.6,470400.0
208 | 6.375,9.38,18.6,590100.0
209 | 5.412,29.55,18.6,497700.0
210 | 6.182,9.47,18.6,525000.0
211 | 5.888,13.51,16.4,489300.0
212 | 6.642,9.69,16.4,602700.0
213 | 5.951,17.92,16.4,451500.0
214 | 6.373,10.5,16.4,483000.0
215 | 6.951,9.71,17.4,560700.0
216 | 6.164,21.46,17.4,455700.0
217 | 6.879,9.93,17.4,577500.0
218 | 6.618,7.6,17.4,632100.0
219 | 8.266,4.14,17.4,940800.0
220 | 8.04,3.13,17.4,789600.0
221 | 7.163,6.36,17.4,663600.0
222 | 7.686,3.92,17.4,980700.0
223 | 6.552,3.76,17.4,661500.0
224 | 5.981,11.65,17.4,510300.0
225 | 7.412,5.25,17.4,665700.0
226 | 8.337,2.47,17.4,875700.0
227 | 8.247,3.95,17.4,1014300.0
228 | 6.726,8.05,17.4,609000.0
229 | 6.086,10.88,17.4,504000.0
230 | 6.631,9.54,17.4,527100.0
231 | 7.358,4.73,17.4,661500.0
232 | 6.481,6.36,16.6,497700.0
233 | 6.606,7.37,16.6,489300.0
234 | 6.897,11.38,16.6,462000.0
235 | 6.095,12.4,16.6,422100.0
236 | 6.358,11.22,16.6,466200.0
237 | 6.393,5.19,16.6,497700.0
238 | 5.593,12.5,19.1,369600.0
239 | 5.605,18.46,19.1,388500.0
240 | 6.108,9.16,19.1,510300.0
241 | 6.226,10.15,19.1,430500.0
242 | 6.433,9.52,19.1,514500.0
243 | 6.718,6.56,19.1,550200.0
244 | 6.487,5.9,19.1,512400.0
245 | 6.438,3.59,19.1,520800.0
246 | 6.957,3.53,19.1,621600.0
247 | 8.259,3.54,19.1,898800.0
248 | 6.108,6.57,16.4,459900.0
249 | 5.876,9.25,16.4,438900.0
250 | 7.454,3.11,15.9,924000.0
251 | 7.333,7.79,13.0,756000.0
252 | 6.842,6.9,13.0,632100.0
253 | 7.203,9.59,13.0,709800.0
254 | 7.52,7.26,13.0,905100.0
255 | 8.398,5.91,13.0,1024800.0
256 | 7.327,11.25,13.0,651000.0
257 | 7.206,8.1,13.0,766500.0
258 | 5.56,10.45,13.0,478800.0
259 | 7.014,14.79,13.0,644700.0
260 | 7.47,3.16,13.0,913500.0
261 | 5.92,13.65,18.6,434700.0
262 | 5.856,13.0,18.6,443100.0
263 | 6.24,6.59,18.6,529200.0
264 | 6.538,7.73,18.6,512400.0
265 | 7.691,6.58,18.6,739200.0
266 | 6.758,3.53,17.6,680400.0
267 | 6.854,2.98,17.6,672000.0
268 | 7.267,6.05,17.6,697200.0
269 | 6.826,4.16,17.6,695100.0
270 | 6.482,7.19,17.6,611100.0
271 | 6.812,4.85,14.9,737100.0
272 | 7.82,3.76,14.9,953400.0
273 | 6.968,4.59,14.9,743400.0
274 | 7.645,3.01,14.9,966000.0
275 | 7.088,7.85,15.3,676200.0
276 | 6.453,8.23,15.3,462000.0
277 | 6.23,12.93,18.2,422100.0
278 | 6.209,7.14,16.6,487200.0
279 | 6.315,7.6,16.6,468300.0
280 | 6.565,9.51,16.6,520800.0
281 | 6.861,3.33,19.2,598500.0
282 | 7.148,3.56,19.2,783300.0
283 | 6.63,4.7,19.2,585900.0
284 | 6.127,8.58,16.0,501900.0
285 | 6.009,10.4,16.0,455700.0
286 | 6.678,6.27,16.0,600600.0
287 | 6.549,7.39,16.0,569100.0
288 | 5.79,15.84,16.0,426300.0
289 | 6.345,4.97,14.8,472500.0
290 | 7.041,4.74,14.8,609000.0
291 | 6.871,6.07,14.8,520800.0
292 | 6.59,9.5,16.1,462000.0
293 | 6.495,8.67,16.1,554400.0
294 | 6.982,4.86,16.1,695100.0
295 | 7.236,6.93,18.4,758100.0
296 | 6.616,8.93,18.4,596400.0
297 | 7.42,6.47,18.4,701400.0
298 | 6.849,7.53,18.4,592200.0
299 | 6.635,4.54,18.4,478800.0
300 | 5.972,9.97,18.4,426300.0
301 | 4.973,12.64,18.4,338100.0
302 | 6.122,5.98,18.4,464100.0
303 | 6.023,11.72,18.4,407400.0
304 | 6.266,7.9,18.4,453600.0
305 | 6.567,9.28,18.4,499800.0
306 | 5.705,11.5,18.4,340200.0
307 | 5.914,18.33,18.4,373800.0
308 | 5.782,15.94,18.4,415800.0
309 | 6.382,10.36,18.4,485100.0
310 | 6.113,12.73,18.4,441000.0
311 | 6.426,7.2,19.6,499800.0
312 | 6.376,6.87,19.6,485100.0
313 | 6.041,7.7,19.6,428400.0
314 | 5.708,11.74,19.6,388500.0
315 | 6.415,6.12,19.6,525000.0
316 | 6.431,5.08,19.6,516600.0
317 | 6.312,6.15,19.6,483000.0
318 | 6.083,12.79,19.6,466200.0
319 | 5.868,9.97,16.9,405300.0
320 | 6.333,7.34,16.9,474600.0
321 | 6.144,9.09,16.9,415800.0
322 | 5.706,12.43,16.9,359100.0
323 | 6.031,7.83,16.9,407400.0
324 | 6.316,5.68,20.2,466200.0
325 | 6.31,6.75,20.2,434700.0
326 | 6.037,8.01,20.2,443100.0
327 | 5.869,9.8,20.2,409500.0
328 | 5.895,10.56,20.2,388500.0
329 | 6.059,8.51,20.2,432600.0
330 | 5.985,9.74,20.2,399000.0
331 | 5.968,9.29,20.2,392700.0
332 | 7.241,5.49,15.5,686700.0
333 | 6.54,8.65,15.9,346500.0
334 | 6.696,7.18,17.6,501900.0
335 | 6.874,4.61,17.6,655200.0
336 | 6.014,10.53,18.8,367500.0
337 | 5.898,12.67,18.8,361200.0
338 | 6.516,6.36,17.9,485100.0
339 | 6.635,5.99,17.0,514500.0
340 | 6.939,5.89,19.7,558600.0
341 | 6.49,5.98,19.7,480900.0
342 | 6.579,5.49,18.3,506100.0
343 | 5.884,7.79,18.3,390600.0
344 | 6.728,4.5,17.0,632100.0
345 | 5.663,8.05,22.0,382200.0
346 | 5.936,5.57,22.0,432600.0
347 | 6.212,17.6,20.2,373800.0
348 | 6.395,13.27,20.2,455700.0
349 | 6.127,11.48,20.2,476700.0
350 | 6.112,12.67,20.2,474600.0
351 | 6.398,7.79,20.2,525000.0
352 | 6.251,14.19,20.2,417900.0
353 | 5.362,10.19,20.2,436800.0
354 | 5.803,14.64,20.2,352800.0
355 | 3.561,7.12,20.2,577500.0
356 | 4.963,14.0,20.2,459900.0
357 | 3.863,13.33,20.2,485100.0
358 | 4.906,34.77,20.2,289800.0
359 | 4.138,37.97,20.2,289800.0
360 | 7.313,13.44,20.2,315000.0
361 | 6.649,23.24,20.2,291900.0
362 | 6.794,21.24,20.2,279300.0
363 | 6.38,23.69,20.2,275100.0
364 | 6.223,21.78,20.2,214200.0
365 | 6.968,17.21,20.2,218400.0
366 | 6.545,21.08,20.2,228900.0
367 | 5.536,23.6,20.2,237300.0
368 | 5.52,24.56,20.2,258300.0
369 | 4.368,30.63,20.2,184800.0
370 | 5.277,30.81,20.2,151200.0
371 | 4.652,28.28,20.2,220500.0
372 | 5.0,31.99,20.2,155400.0
373 | 4.88,30.62,20.2,214200.0
374 | 5.39,20.85,20.2,241500.0
375 | 5.713,17.11,20.2,317100.0
376 | 6.051,18.76,20.2,487200.0
377 | 5.036,25.68,20.2,203700.0
378 | 6.193,15.17,20.2,289800.0
379 | 5.887,16.35,20.2,266700.0
380 | 6.471,17.12,20.2,275100.0
381 | 6.405,19.37,20.2,262500.0
382 | 5.747,19.92,20.2,178500.0
383 | 5.453,30.59,20.2,105000.0
384 | 5.852,29.97,20.2,132300.0
385 | 5.987,26.77,20.2,117600.0
386 | 6.343,20.32,20.2,151200.0
387 | 6.404,20.31,20.2,254100.0
388 | 5.349,19.77,20.2,174300.0
389 | 5.531,27.38,20.2,178500.0
390 | 5.683,22.98,20.2,105000.0
391 | 4.138,23.34,20.2,249900.0
392 | 5.608,12.13,20.2,585900.0
393 | 5.617,26.4,20.2,361200.0
394 | 6.852,19.78,20.2,577500.0
395 | 5.757,10.11,20.2,315000.0
396 | 6.657,21.22,20.2,361200.0
397 | 4.628,34.37,20.2,375900.0
398 | 5.155,20.08,20.2,342300.0
399 | 4.519,36.98,20.2,147000.0
400 | 6.434,29.05,20.2,151200.0
401 | 6.782,25.79,20.2,157500.0
402 | 5.304,26.64,20.2,218400.0
403 | 5.957,20.62,20.2,184800.0
404 | 6.824,22.74,20.2,176400.0
405 | 6.411,15.02,20.2,350700.0
406 | 6.006,15.7,20.2,298200.0
407 | 5.648,14.1,20.2,436800.0
408 | 6.103,23.29,20.2,281400.0
409 | 5.565,17.16,20.2,245700.0
410 | 5.896,24.39,20.2,174300.0
411 | 5.837,15.69,20.2,214200.0
412 | 6.202,14.52,20.2,228900.0
413 | 6.193,21.52,20.2,231000.0
414 | 6.38,24.08,20.2,199500.0
415 | 6.348,17.64,20.2,304500.0
416 | 6.833,19.69,20.2,296100.0
417 | 6.425,12.03,20.2,338100.0
418 | 6.436,16.22,20.2,300300.0
419 | 6.208,15.17,20.2,245700.0
420 | 6.629,23.27,20.2,281400.0
421 | 6.461,18.05,20.2,201600.0
422 | 6.152,26.45,20.2,182700.0
423 | 5.935,34.02,20.2,176400.0
424 | 5.627,22.88,20.2,268800.0
425 | 5.818,22.11,20.2,220500.0
426 | 6.406,19.52,20.2,359100.0
427 | 6.219,16.59,20.2,386400.0
428 | 6.485,18.85,20.2,323400.0
429 | 5.854,23.79,20.2,226800.0
430 | 6.459,23.98,20.2,247800.0
431 | 6.341,17.79,20.2,312900.0
432 | 6.251,16.44,20.2,264600.0
433 | 6.185,18.13,20.2,296100.0
434 | 6.417,19.31,20.2,273000.0
435 | 6.749,17.44,20.2,281400.0
436 | 6.655,17.73,20.2,319200.0
437 | 6.297,17.27,20.2,338100.0
438 | 7.393,16.74,20.2,373800.0
439 | 6.728,18.71,20.2,312900.0
440 | 6.525,18.13,20.2,296100.0
441 | 5.976,19.01,20.2,266700.0
442 | 5.936,16.94,20.2,283500.0
443 | 6.301,16.23,20.2,312900.0
444 | 6.081,14.7,20.2,420000.0
445 | 6.701,16.42,20.2,344400.0
446 | 6.376,14.65,20.2,371700.0
447 | 6.317,13.99,20.2,409500.0
448 | 6.513,10.29,20.2,424200.0
449 | 6.209,13.22,20.2,449400.0
450 | 5.759,14.13,20.2,417900.0
451 | 5.952,17.15,20.2,399000.0
452 | 6.003,21.32,20.2,401100.0
453 | 5.926,18.13,20.2,401100.0
454 | 5.713,14.76,20.2,422100.0
455 | 6.167,16.29,20.2,417900.0
456 | 6.229,12.87,20.2,411600.0
457 | 6.437,14.36,20.2,487200.0
458 | 6.98,11.66,20.2,625800.0
459 | 5.427,18.14,20.2,289800.0
460 | 6.162,24.1,20.2,279300.0
461 | 6.484,18.68,20.2,350700.0
462 | 5.304,24.91,20.2,252000.0
463 | 6.185,18.03,20.2,306600.0
464 | 6.229,13.11,20.2,449400.0
465 | 6.242,10.74,20.2,483000.0
466 | 6.75,7.74,20.2,497700.0
467 | 7.061,7.01,20.2,525000.0
468 | 5.762,10.42,20.2,457800.0
469 | 5.871,13.34,20.2,432600.0
470 | 6.312,10.58,20.2,445200.0
471 | 6.114,14.98,20.2,401100.0
472 | 5.905,11.45,20.2,432600.0
473 | 5.454,18.06,20.1,319200.0
474 | 5.414,23.97,20.1,147000.0
475 | 5.093,29.68,20.1,170100.0
476 | 5.983,18.07,20.1,285600.0
477 | 5.983,13.35,20.1,422100.0
478 | 5.707,12.01,19.2,457800.0
479 | 5.926,13.59,19.2,514500.0
480 | 5.67,17.6,19.2,485100.0
481 | 5.39,21.14,19.2,413700.0
482 | 5.794,14.1,19.2,384300.0
483 | 6.019,12.92,19.2,445200.0
484 | 5.569,15.1,19.2,367500.0
485 | 6.027,14.33,19.2,352800.0
486 | 6.593,9.67,21.0,470400.0
487 | 6.12,9.08,21.0,432600.0
488 | 6.976,5.64,21.0,501900.0
489 | 6.794,6.48,21.0,462000.0
490 | 6.03,7.88,21.0,249900.0
491 |
--------------------------------------------------------------------------------
/boston_housing/visuals.py:
--------------------------------------------------------------------------------
1 | ###########################################
2 | # Suppress matplotlib user warnings
3 | # Necessary for newer version of matplotlib
4 | import warnings
5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
6 | ###########################################
7 |
8 | import matplotlib.pyplot as pl
9 | import numpy as np
10 | import sklearn.learning_curve as curves
11 | from sklearn.tree import DecisionTreeRegressor
12 | from sklearn.cross_validation import ShuffleSplit, train_test_split
13 |
14 | def ModelLearning(X, y):
15 | """ Calculates the performance of several models with varying sizes of training data.
16 | The learning and testing scores for each model are then plotted. """
17 |
18 | # Create 10 cross-validation sets for training and testing
19 | cv = ShuffleSplit(X.shape[0], n_iter = 10, test_size = 0.2, random_state = 0)
20 |
21 | # Generate the training set sizes increasing by 50
22 | train_sizes = np.rint(np.linspace(1, X.shape[0]*0.8 - 1, 9)).astype(int)
23 |
24 | # Create the figure window
25 | fig = pl.figure(figsize=(10,7))
26 |
27 | # Create three different models based on max_depth
28 | for k, depth in enumerate([1,3,6,10]):
29 |
30 | # Create a Decision tree regressor at max_depth = depth
31 | regressor = DecisionTreeRegressor(max_depth = depth)
32 |
33 | # Calculate the training and testing scores
34 | sizes, train_scores, test_scores = curves.learning_curve(regressor, X, y, \
35 | cv = cv, train_sizes = train_sizes, scoring = 'r2')
36 |
37 | # Find the mean and standard deviation for smoothing
38 | train_std = np.std(train_scores, axis = 1)
39 | train_mean = np.mean(train_scores, axis = 1)
40 | test_std = np.std(test_scores, axis = 1)
41 | test_mean = np.mean(test_scores, axis = 1)
42 |
43 | # Subplot the learning curve
44 | ax = fig.add_subplot(2, 2, k+1)
45 | ax.plot(sizes, train_mean, 'o-', color = 'r', label = 'Training Score')
46 | ax.plot(sizes, test_mean, 'o-', color = 'g', label = 'Testing Score')
47 | ax.fill_between(sizes, train_mean - train_std, \
48 | train_mean + train_std, alpha = 0.15, color = 'r')
49 | ax.fill_between(sizes, test_mean - test_std, \
50 | test_mean + test_std, alpha = 0.15, color = 'g')
51 |
52 | # Labels
53 | ax.set_title('max_depth = %s'%(depth))
54 | ax.set_xlabel('Number of Training Points')
55 | ax.set_ylabel('Score')
56 | ax.set_xlim([0, X.shape[0]*0.8])
57 | ax.set_ylim([-0.05, 1.05])
58 |
59 | # Visual aesthetics
60 | ax.legend(bbox_to_anchor=(1.05, 2.05), loc='lower left', borderaxespad = 0.)
61 | fig.suptitle('Decision Tree Regressor Learning Performances', fontsize = 16, y = 1.03)
62 | fig.tight_layout()
63 | fig.show()
64 |
65 |
66 | def ModelComplexity(X, y):
67 | """ Calculates the performance of the model as model complexity increases.
68 | The learning and testing errors rates are then plotted. """
69 |
70 | # Create 10 cross-validation sets for training and testing
71 | cv = ShuffleSplit(X.shape[0], n_iter = 10, test_size = 0.2, random_state = 0)
72 |
73 | # Vary the max_depth parameter from 1 to 10
74 | max_depth = np.arange(1,11)
75 |
76 | # Calculate the training and testing scores
77 | train_scores, test_scores = curves.validation_curve(DecisionTreeRegressor(), X, y, \
78 | param_name = "max_depth", param_range = max_depth, cv = cv, scoring = 'r2')
79 |
80 | # Find the mean and standard deviation for smoothing
81 | train_mean = np.mean(train_scores, axis=1)
82 | train_std = np.std(train_scores, axis=1)
83 | test_mean = np.mean(test_scores, axis=1)
84 | test_std = np.std(test_scores, axis=1)
85 |
86 | # Plot the validation curve
87 | pl.figure(figsize=(7, 5))
88 | pl.title('Decision Tree Regressor Complexity Performance')
89 | pl.plot(max_depth, train_mean, 'o-', color = 'r', label = 'Training Score')
90 | pl.plot(max_depth, test_mean, 'o-', color = 'g', label = 'Validation Score')
91 | pl.fill_between(max_depth, train_mean - train_std, \
92 | train_mean + train_std, alpha = 0.15, color = 'r')
93 | pl.fill_between(max_depth, test_mean - test_std, \
94 | test_mean + test_std, alpha = 0.15, color = 'g')
95 |
96 | # Visual aesthetics
97 | pl.legend(loc = 'lower right')
98 | pl.xlabel('Maximum Depth')
99 | pl.ylabel('Score')
100 | pl.ylim([-0.05,1.05])
101 | pl.show()
102 |
103 |
104 | def PredictTrials(X, y, fitter, data):
105 | """ Performs trials of fitting and predicting data. """
106 |
107 | # Store the predicted prices
108 | prices = []
109 |
110 | for k in range(10):
111 | # Split the data
112 | X_train, X_test, y_train, y_test = train_test_split(X, y, \
113 | test_size = 0.2, random_state = k)
114 |
115 | # Fit the data
116 | reg = fitter(X_train, y_train)
117 |
118 | # Make a prediction
119 | pred = reg.predict([data[0]])[0]
120 | prices.append(pred)
121 |
122 | # Result
123 | print "Trial {}: ${:,.2f}".format(k+1, pred)
124 |
125 | # Display price range
126 | print "\nRange in prices: ${:,.2f}".format(max(prices) - min(prices))
--------------------------------------------------------------------------------
/creating_customer_segments/README.md:
--------------------------------------------------------------------------------
1 | # 项目 3: 非监督学习
2 | ## 创建用户细分
3 |
4 | ### 安装
5 |
6 | 这个项目要求使用 **Python 2.7** 并且需要安装下面这些python包:
7 |
8 | - [NumPy](http://www.numpy.org/)
9 | - [Pandas](http://pandas.pydata.org)
10 | - [scikit-learn](http://scikit-learn.org/stable/)
11 |
12 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://jupyter.org/)。
13 |
14 | 优达学城推荐学生安装 [Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。
15 |
16 | ### 代码
17 |
18 | 初始代码包含在 `customer_segments.ipynb` 这个notebook文件中。这里面有一些代码已经实现好来帮助你开始项目,但是为了完成项目,你还需要实现附加的功能。
19 |
20 | ### 运行
21 |
22 | 在命令行中,确保当前目录为 `customer_segments.ipynb` 文件夹的最顶层(目录包含本 README 文件),运行下列命令:
23 |
24 | ```jupyter notebook customer_segments.ipynb```
25 |
26 | 这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。
27 |
28 | ## 数据
29 |
30 | 这个项目的数据包含在 `customers.csv` 文件中。你能在[UCI 机器学习信息库](https://archive.ics.uci.edu/ml/datasets/Wholesale+customers)页面中找到更多信息。
31 |
--------------------------------------------------------------------------------
/creating_customer_segments/customer_segments.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/creating_customer_segments/customer_segments.zip
--------------------------------------------------------------------------------
/creating_customer_segments/customers.csv:
--------------------------------------------------------------------------------
1 | Channel,Region,Fresh,Milk,Grocery,Frozen,Detergents_Paper,Delicatessen
2 | 2,3,12669,9656,7561,214,2674,1338
3 | 2,3,7057,9810,9568,1762,3293,1776
4 | 2,3,6353,8808,7684,2405,3516,7844
5 | 1,3,13265,1196,4221,6404,507,1788
6 | 2,3,22615,5410,7198,3915,1777,5185
7 | 2,3,9413,8259,5126,666,1795,1451
8 | 2,3,12126,3199,6975,480,3140,545
9 | 2,3,7579,4956,9426,1669,3321,2566
10 | 1,3,5963,3648,6192,425,1716,750
11 | 2,3,6006,11093,18881,1159,7425,2098
12 | 2,3,3366,5403,12974,4400,5977,1744
13 | 2,3,13146,1124,4523,1420,549,497
14 | 2,3,31714,12319,11757,287,3881,2931
15 | 2,3,21217,6208,14982,3095,6707,602
16 | 2,3,24653,9465,12091,294,5058,2168
17 | 1,3,10253,1114,3821,397,964,412
18 | 2,3,1020,8816,12121,134,4508,1080
19 | 1,3,5876,6157,2933,839,370,4478
20 | 2,3,18601,6327,10099,2205,2767,3181
21 | 1,3,7780,2495,9464,669,2518,501
22 | 2,3,17546,4519,4602,1066,2259,2124
23 | 1,3,5567,871,2010,3383,375,569
24 | 1,3,31276,1917,4469,9408,2381,4334
25 | 2,3,26373,36423,22019,5154,4337,16523
26 | 2,3,22647,9776,13792,2915,4482,5778
27 | 2,3,16165,4230,7595,201,4003,57
28 | 1,3,9898,961,2861,3151,242,833
29 | 1,3,14276,803,3045,485,100,518
30 | 2,3,4113,20484,25957,1158,8604,5206
31 | 1,3,43088,2100,2609,1200,1107,823
32 | 1,3,18815,3610,11107,1148,2134,2963
33 | 1,3,2612,4339,3133,2088,820,985
34 | 1,3,21632,1318,2886,266,918,405
35 | 1,3,29729,4786,7326,6130,361,1083
36 | 1,3,1502,1979,2262,425,483,395
37 | 2,3,688,5491,11091,833,4239,436
38 | 1,3,29955,4362,5428,1729,862,4626
39 | 2,3,15168,10556,12477,1920,6506,714
40 | 2,3,4591,15729,16709,33,6956,433
41 | 1,3,56159,555,902,10002,212,2916
42 | 1,3,24025,4332,4757,9510,1145,5864
43 | 1,3,19176,3065,5956,2033,2575,2802
44 | 2,3,10850,7555,14961,188,6899,46
45 | 2,3,630,11095,23998,787,9529,72
46 | 2,3,9670,7027,10471,541,4618,65
47 | 2,3,5181,22044,21531,1740,7353,4985
48 | 2,3,3103,14069,21955,1668,6792,1452
49 | 2,3,44466,54259,55571,7782,24171,6465
50 | 2,3,11519,6152,10868,584,5121,1476
51 | 2,3,4967,21412,28921,1798,13583,1163
52 | 1,3,6269,1095,1980,3860,609,2162
53 | 1,3,3347,4051,6996,239,1538,301
54 | 2,3,40721,3916,5876,532,2587,1278
55 | 2,3,491,10473,11532,744,5611,224
56 | 1,3,27329,1449,1947,2436,204,1333
57 | 1,3,5264,3683,5005,1057,2024,1130
58 | 2,3,4098,29892,26866,2616,17740,1340
59 | 2,3,5417,9933,10487,38,7572,1282
60 | 1,3,13779,1970,1648,596,227,436
61 | 1,3,6137,5360,8040,129,3084,1603
62 | 2,3,8590,3045,7854,96,4095,225
63 | 2,3,35942,38369,59598,3254,26701,2017
64 | 2,3,7823,6245,6544,4154,4074,964
65 | 2,3,9396,11601,15775,2896,7677,1295
66 | 1,3,4760,1227,3250,3724,1247,1145
67 | 2,3,85,20959,45828,36,24231,1423
68 | 1,3,9,1534,7417,175,3468,27
69 | 2,3,19913,6759,13462,1256,5141,834
70 | 1,3,2446,7260,3993,5870,788,3095
71 | 1,3,8352,2820,1293,779,656,144
72 | 1,3,16705,2037,3202,10643,116,1365
73 | 1,3,18291,1266,21042,5373,4173,14472
74 | 1,3,4420,5139,2661,8872,1321,181
75 | 2,3,19899,5332,8713,8132,764,648
76 | 2,3,8190,6343,9794,1285,1901,1780
77 | 1,3,20398,1137,3,4407,3,975
78 | 1,3,717,3587,6532,7530,529,894
79 | 2,3,12205,12697,28540,869,12034,1009
80 | 1,3,10766,1175,2067,2096,301,167
81 | 1,3,1640,3259,3655,868,1202,1653
82 | 1,3,7005,829,3009,430,610,529
83 | 2,3,219,9540,14403,283,7818,156
84 | 2,3,10362,9232,11009,737,3537,2342
85 | 1,3,20874,1563,1783,2320,550,772
86 | 2,3,11867,3327,4814,1178,3837,120
87 | 2,3,16117,46197,92780,1026,40827,2944
88 | 2,3,22925,73498,32114,987,20070,903
89 | 1,3,43265,5025,8117,6312,1579,14351
90 | 1,3,7864,542,4042,9735,165,46
91 | 1,3,24904,3836,5330,3443,454,3178
92 | 1,3,11405,596,1638,3347,69,360
93 | 1,3,12754,2762,2530,8693,627,1117
94 | 2,3,9198,27472,32034,3232,18906,5130
95 | 1,3,11314,3090,2062,35009,71,2698
96 | 2,3,5626,12220,11323,206,5038,244
97 | 1,3,3,2920,6252,440,223,709
98 | 2,3,23,2616,8118,145,3874,217
99 | 1,3,403,254,610,774,54,63
100 | 1,3,503,112,778,895,56,132
101 | 1,3,9658,2182,1909,5639,215,323
102 | 2,3,11594,7779,12144,3252,8035,3029
103 | 2,3,1420,10810,16267,1593,6766,1838
104 | 2,3,2932,6459,7677,2561,4573,1386
105 | 1,3,56082,3504,8906,18028,1480,2498
106 | 1,3,14100,2132,3445,1336,1491,548
107 | 1,3,15587,1014,3970,910,139,1378
108 | 2,3,1454,6337,10704,133,6830,1831
109 | 2,3,8797,10646,14886,2471,8969,1438
110 | 2,3,1531,8397,6981,247,2505,1236
111 | 2,3,1406,16729,28986,673,836,3
112 | 1,3,11818,1648,1694,2276,169,1647
113 | 2,3,12579,11114,17569,805,6457,1519
114 | 1,3,19046,2770,2469,8853,483,2708
115 | 1,3,14438,2295,1733,3220,585,1561
116 | 1,3,18044,1080,2000,2555,118,1266
117 | 1,3,11134,793,2988,2715,276,610
118 | 1,3,11173,2521,3355,1517,310,222
119 | 1,3,6990,3880,5380,1647,319,1160
120 | 1,3,20049,1891,2362,5343,411,933
121 | 1,3,8258,2344,2147,3896,266,635
122 | 1,3,17160,1200,3412,2417,174,1136
123 | 1,3,4020,3234,1498,2395,264,255
124 | 1,3,12212,201,245,1991,25,860
125 | 2,3,11170,10769,8814,2194,1976,143
126 | 1,3,36050,1642,2961,4787,500,1621
127 | 1,3,76237,3473,7102,16538,778,918
128 | 1,3,19219,1840,1658,8195,349,483
129 | 2,3,21465,7243,10685,880,2386,2749
130 | 1,3,140,8847,3823,142,1062,3
131 | 1,3,42312,926,1510,1718,410,1819
132 | 1,3,7149,2428,699,6316,395,911
133 | 1,3,2101,589,314,346,70,310
134 | 1,3,14903,2032,2479,576,955,328
135 | 1,3,9434,1042,1235,436,256,396
136 | 1,3,7388,1882,2174,720,47,537
137 | 1,3,6300,1289,2591,1170,199,326
138 | 1,3,4625,8579,7030,4575,2447,1542
139 | 1,3,3087,8080,8282,661,721,36
140 | 1,3,13537,4257,5034,155,249,3271
141 | 1,3,5387,4979,3343,825,637,929
142 | 1,3,17623,4280,7305,2279,960,2616
143 | 1,3,30379,13252,5189,321,51,1450
144 | 1,3,37036,7152,8253,2995,20,3
145 | 1,3,10405,1596,1096,8425,399,318
146 | 1,3,18827,3677,1988,118,516,201
147 | 2,3,22039,8384,34792,42,12591,4430
148 | 1,3,7769,1936,2177,926,73,520
149 | 1,3,9203,3373,2707,1286,1082,526
150 | 1,3,5924,584,542,4052,283,434
151 | 1,3,31812,1433,1651,800,113,1440
152 | 1,3,16225,1825,1765,853,170,1067
153 | 1,3,1289,3328,2022,531,255,1774
154 | 1,3,18840,1371,3135,3001,352,184
155 | 1,3,3463,9250,2368,779,302,1627
156 | 1,3,622,55,137,75,7,8
157 | 2,3,1989,10690,19460,233,11577,2153
158 | 2,3,3830,5291,14855,317,6694,3182
159 | 1,3,17773,1366,2474,3378,811,418
160 | 2,3,2861,6570,9618,930,4004,1682
161 | 2,3,355,7704,14682,398,8077,303
162 | 2,3,1725,3651,12822,824,4424,2157
163 | 1,3,12434,540,283,1092,3,2233
164 | 1,3,15177,2024,3810,2665,232,610
165 | 2,3,5531,15726,26870,2367,13726,446
166 | 2,3,5224,7603,8584,2540,3674,238
167 | 2,3,15615,12653,19858,4425,7108,2379
168 | 2,3,4822,6721,9170,993,4973,3637
169 | 1,3,2926,3195,3268,405,1680,693
170 | 1,3,5809,735,803,1393,79,429
171 | 1,3,5414,717,2155,2399,69,750
172 | 2,3,260,8675,13430,1116,7015,323
173 | 2,3,200,25862,19816,651,8773,6250
174 | 1,3,955,5479,6536,333,2840,707
175 | 2,3,514,7677,19805,937,9836,716
176 | 1,3,286,1208,5241,2515,153,1442
177 | 2,3,2343,7845,11874,52,4196,1697
178 | 1,3,45640,6958,6536,7368,1532,230
179 | 1,3,12759,7330,4533,1752,20,2631
180 | 1,3,11002,7075,4945,1152,120,395
181 | 1,3,3157,4888,2500,4477,273,2165
182 | 1,3,12356,6036,8887,402,1382,2794
183 | 1,3,112151,29627,18148,16745,4948,8550
184 | 1,3,694,8533,10518,443,6907,156
185 | 1,3,36847,43950,20170,36534,239,47943
186 | 1,3,327,918,4710,74,334,11
187 | 1,3,8170,6448,1139,2181,58,247
188 | 1,3,3009,521,854,3470,949,727
189 | 1,3,2438,8002,9819,6269,3459,3
190 | 2,3,8040,7639,11687,2758,6839,404
191 | 2,3,834,11577,11522,275,4027,1856
192 | 1,3,16936,6250,1981,7332,118,64
193 | 1,3,13624,295,1381,890,43,84
194 | 1,3,5509,1461,2251,547,187,409
195 | 2,3,180,3485,20292,959,5618,666
196 | 1,3,7107,1012,2974,806,355,1142
197 | 1,3,17023,5139,5230,7888,330,1755
198 | 1,1,30624,7209,4897,18711,763,2876
199 | 2,1,2427,7097,10391,1127,4314,1468
200 | 1,1,11686,2154,6824,3527,592,697
201 | 1,1,9670,2280,2112,520,402,347
202 | 2,1,3067,13240,23127,3941,9959,731
203 | 2,1,4484,14399,24708,3549,14235,1681
204 | 1,1,25203,11487,9490,5065,284,6854
205 | 1,1,583,685,2216,469,954,18
206 | 1,1,1956,891,5226,1383,5,1328
207 | 2,1,1107,11711,23596,955,9265,710
208 | 1,1,6373,780,950,878,288,285
209 | 2,1,2541,4737,6089,2946,5316,120
210 | 1,1,1537,3748,5838,1859,3381,806
211 | 2,1,5550,12729,16767,864,12420,797
212 | 1,1,18567,1895,1393,1801,244,2100
213 | 2,1,12119,28326,39694,4736,19410,2870
214 | 1,1,7291,1012,2062,1291,240,1775
215 | 1,1,3317,6602,6861,1329,3961,1215
216 | 2,1,2362,6551,11364,913,5957,791
217 | 1,1,2806,10765,15538,1374,5828,2388
218 | 2,1,2532,16599,36486,179,13308,674
219 | 1,1,18044,1475,2046,2532,130,1158
220 | 2,1,18,7504,15205,1285,4797,6372
221 | 1,1,4155,367,1390,2306,86,130
222 | 1,1,14755,899,1382,1765,56,749
223 | 1,1,5396,7503,10646,91,4167,239
224 | 1,1,5041,1115,2856,7496,256,375
225 | 2,1,2790,2527,5265,5612,788,1360
226 | 1,1,7274,659,1499,784,70,659
227 | 1,1,12680,3243,4157,660,761,786
228 | 2,1,20782,5921,9212,1759,2568,1553
229 | 1,1,4042,2204,1563,2286,263,689
230 | 1,1,1869,577,572,950,4762,203
231 | 1,1,8656,2746,2501,6845,694,980
232 | 2,1,11072,5989,5615,8321,955,2137
233 | 1,1,2344,10678,3828,1439,1566,490
234 | 1,1,25962,1780,3838,638,284,834
235 | 1,1,964,4984,3316,937,409,7
236 | 1,1,15603,2703,3833,4260,325,2563
237 | 1,1,1838,6380,2824,1218,1216,295
238 | 1,1,8635,820,3047,2312,415,225
239 | 1,1,18692,3838,593,4634,28,1215
240 | 1,1,7363,475,585,1112,72,216
241 | 1,1,47493,2567,3779,5243,828,2253
242 | 1,1,22096,3575,7041,11422,343,2564
243 | 1,1,24929,1801,2475,2216,412,1047
244 | 1,1,18226,659,2914,3752,586,578
245 | 1,1,11210,3576,5119,561,1682,2398
246 | 1,1,6202,7775,10817,1183,3143,1970
247 | 2,1,3062,6154,13916,230,8933,2784
248 | 1,1,8885,2428,1777,1777,430,610
249 | 1,1,13569,346,489,2077,44,659
250 | 1,1,15671,5279,2406,559,562,572
251 | 1,1,8040,3795,2070,6340,918,291
252 | 1,1,3191,1993,1799,1730,234,710
253 | 2,1,6134,23133,33586,6746,18594,5121
254 | 1,1,6623,1860,4740,7683,205,1693
255 | 1,1,29526,7961,16966,432,363,1391
256 | 1,1,10379,17972,4748,4686,1547,3265
257 | 1,1,31614,489,1495,3242,111,615
258 | 1,1,11092,5008,5249,453,392,373
259 | 1,1,8475,1931,1883,5004,3593,987
260 | 1,1,56083,4563,2124,6422,730,3321
261 | 1,1,53205,4959,7336,3012,967,818
262 | 1,1,9193,4885,2157,327,780,548
263 | 1,1,7858,1110,1094,6818,49,287
264 | 1,1,23257,1372,1677,982,429,655
265 | 1,1,2153,1115,6684,4324,2894,411
266 | 2,1,1073,9679,15445,61,5980,1265
267 | 1,1,5909,23527,13699,10155,830,3636
268 | 2,1,572,9763,22182,2221,4882,2563
269 | 1,1,20893,1222,2576,3975,737,3628
270 | 2,1,11908,8053,19847,1069,6374,698
271 | 1,1,15218,258,1138,2516,333,204
272 | 1,1,4720,1032,975,5500,197,56
273 | 1,1,2083,5007,1563,1120,147,1550
274 | 1,1,514,8323,6869,529,93,1040
275 | 1,3,36817,3045,1493,4802,210,1824
276 | 1,3,894,1703,1841,744,759,1153
277 | 1,3,680,1610,223,862,96,379
278 | 1,3,27901,3749,6964,4479,603,2503
279 | 1,3,9061,829,683,16919,621,139
280 | 1,3,11693,2317,2543,5845,274,1409
281 | 2,3,17360,6200,9694,1293,3620,1721
282 | 1,3,3366,2884,2431,977,167,1104
283 | 2,3,12238,7108,6235,1093,2328,2079
284 | 1,3,49063,3965,4252,5970,1041,1404
285 | 1,3,25767,3613,2013,10303,314,1384
286 | 1,3,68951,4411,12609,8692,751,2406
287 | 1,3,40254,640,3600,1042,436,18
288 | 1,3,7149,2247,1242,1619,1226,128
289 | 1,3,15354,2102,2828,8366,386,1027
290 | 1,3,16260,594,1296,848,445,258
291 | 1,3,42786,286,471,1388,32,22
292 | 1,3,2708,2160,2642,502,965,1522
293 | 1,3,6022,3354,3261,2507,212,686
294 | 1,3,2838,3086,4329,3838,825,1060
295 | 2,2,3996,11103,12469,902,5952,741
296 | 1,2,21273,2013,6550,909,811,1854
297 | 2,2,7588,1897,5234,417,2208,254
298 | 1,2,19087,1304,3643,3045,710,898
299 | 2,2,8090,3199,6986,1455,3712,531
300 | 2,2,6758,4560,9965,934,4538,1037
301 | 1,2,444,879,2060,264,290,259
302 | 2,2,16448,6243,6360,824,2662,2005
303 | 2,2,5283,13316,20399,1809,8752,172
304 | 2,2,2886,5302,9785,364,6236,555
305 | 2,2,2599,3688,13829,492,10069,59
306 | 2,2,161,7460,24773,617,11783,2410
307 | 2,2,243,12939,8852,799,3909,211
308 | 2,2,6468,12867,21570,1840,7558,1543
309 | 1,2,17327,2374,2842,1149,351,925
310 | 1,2,6987,1020,3007,416,257,656
311 | 2,2,918,20655,13567,1465,6846,806
312 | 1,2,7034,1492,2405,12569,299,1117
313 | 1,2,29635,2335,8280,3046,371,117
314 | 2,2,2137,3737,19172,1274,17120,142
315 | 1,2,9784,925,2405,4447,183,297
316 | 1,2,10617,1795,7647,1483,857,1233
317 | 2,2,1479,14982,11924,662,3891,3508
318 | 1,2,7127,1375,2201,2679,83,1059
319 | 1,2,1182,3088,6114,978,821,1637
320 | 1,2,11800,2713,3558,2121,706,51
321 | 2,2,9759,25071,17645,1128,12408,1625
322 | 1,2,1774,3696,2280,514,275,834
323 | 1,2,9155,1897,5167,2714,228,1113
324 | 1,2,15881,713,3315,3703,1470,229
325 | 1,2,13360,944,11593,915,1679,573
326 | 1,2,25977,3587,2464,2369,140,1092
327 | 1,2,32717,16784,13626,60869,1272,5609
328 | 1,2,4414,1610,1431,3498,387,834
329 | 1,2,542,899,1664,414,88,522
330 | 1,2,16933,2209,3389,7849,210,1534
331 | 1,2,5113,1486,4583,5127,492,739
332 | 1,2,9790,1786,5109,3570,182,1043
333 | 2,2,11223,14881,26839,1234,9606,1102
334 | 1,2,22321,3216,1447,2208,178,2602
335 | 2,2,8565,4980,67298,131,38102,1215
336 | 2,2,16823,928,2743,11559,332,3486
337 | 2,2,27082,6817,10790,1365,4111,2139
338 | 1,2,13970,1511,1330,650,146,778
339 | 1,2,9351,1347,2611,8170,442,868
340 | 1,2,3,333,7021,15601,15,550
341 | 1,2,2617,1188,5332,9584,573,1942
342 | 2,3,381,4025,9670,388,7271,1371
343 | 2,3,2320,5763,11238,767,5162,2158
344 | 1,3,255,5758,5923,349,4595,1328
345 | 2,3,1689,6964,26316,1456,15469,37
346 | 1,3,3043,1172,1763,2234,217,379
347 | 1,3,1198,2602,8335,402,3843,303
348 | 2,3,2771,6939,15541,2693,6600,1115
349 | 2,3,27380,7184,12311,2809,4621,1022
350 | 1,3,3428,2380,2028,1341,1184,665
351 | 2,3,5981,14641,20521,2005,12218,445
352 | 1,3,3521,1099,1997,1796,173,995
353 | 2,3,1210,10044,22294,1741,12638,3137
354 | 1,3,608,1106,1533,830,90,195
355 | 2,3,117,6264,21203,228,8682,1111
356 | 1,3,14039,7393,2548,6386,1333,2341
357 | 1,3,190,727,2012,245,184,127
358 | 1,3,22686,134,218,3157,9,548
359 | 2,3,37,1275,22272,137,6747,110
360 | 1,3,759,18664,1660,6114,536,4100
361 | 1,3,796,5878,2109,340,232,776
362 | 1,3,19746,2872,2006,2601,468,503
363 | 1,3,4734,607,864,1206,159,405
364 | 1,3,2121,1601,2453,560,179,712
365 | 1,3,4627,997,4438,191,1335,314
366 | 1,3,2615,873,1524,1103,514,468
367 | 2,3,4692,6128,8025,1619,4515,3105
368 | 1,3,9561,2217,1664,1173,222,447
369 | 1,3,3477,894,534,1457,252,342
370 | 1,3,22335,1196,2406,2046,101,558
371 | 1,3,6211,337,683,1089,41,296
372 | 2,3,39679,3944,4955,1364,523,2235
373 | 1,3,20105,1887,1939,8164,716,790
374 | 1,3,3884,3801,1641,876,397,4829
375 | 2,3,15076,6257,7398,1504,1916,3113
376 | 1,3,6338,2256,1668,1492,311,686
377 | 1,3,5841,1450,1162,597,476,70
378 | 2,3,3136,8630,13586,5641,4666,1426
379 | 1,3,38793,3154,2648,1034,96,1242
380 | 1,3,3225,3294,1902,282,68,1114
381 | 2,3,4048,5164,10391,130,813,179
382 | 1,3,28257,944,2146,3881,600,270
383 | 1,3,17770,4591,1617,9927,246,532
384 | 1,3,34454,7435,8469,2540,1711,2893
385 | 1,3,1821,1364,3450,4006,397,361
386 | 1,3,10683,21858,15400,3635,282,5120
387 | 1,3,11635,922,1614,2583,192,1068
388 | 1,3,1206,3620,2857,1945,353,967
389 | 1,3,20918,1916,1573,1960,231,961
390 | 1,3,9785,848,1172,1677,200,406
391 | 1,3,9385,1530,1422,3019,227,684
392 | 1,3,3352,1181,1328,5502,311,1000
393 | 1,3,2647,2761,2313,907,95,1827
394 | 1,3,518,4180,3600,659,122,654
395 | 1,3,23632,6730,3842,8620,385,819
396 | 1,3,12377,865,3204,1398,149,452
397 | 1,3,9602,1316,1263,2921,841,290
398 | 2,3,4515,11991,9345,2644,3378,2213
399 | 1,3,11535,1666,1428,6838,64,743
400 | 1,3,11442,1032,582,5390,74,247
401 | 1,3,9612,577,935,1601,469,375
402 | 1,3,4446,906,1238,3576,153,1014
403 | 1,3,27167,2801,2128,13223,92,1902
404 | 1,3,26539,4753,5091,220,10,340
405 | 1,3,25606,11006,4604,127,632,288
406 | 1,3,18073,4613,3444,4324,914,715
407 | 1,3,6884,1046,1167,2069,593,378
408 | 1,3,25066,5010,5026,9806,1092,960
409 | 2,3,7362,12844,18683,2854,7883,553
410 | 2,3,8257,3880,6407,1646,2730,344
411 | 1,3,8708,3634,6100,2349,2123,5137
412 | 1,3,6633,2096,4563,1389,1860,1892
413 | 1,3,2126,3289,3281,1535,235,4365
414 | 1,3,97,3605,12400,98,2970,62
415 | 1,3,4983,4859,6633,17866,912,2435
416 | 1,3,5969,1990,3417,5679,1135,290
417 | 2,3,7842,6046,8552,1691,3540,1874
418 | 2,3,4389,10940,10908,848,6728,993
419 | 1,3,5065,5499,11055,364,3485,1063
420 | 2,3,660,8494,18622,133,6740,776
421 | 1,3,8861,3783,2223,633,1580,1521
422 | 1,3,4456,5266,13227,25,6818,1393
423 | 2,3,17063,4847,9053,1031,3415,1784
424 | 1,3,26400,1377,4172,830,948,1218
425 | 2,3,17565,3686,4657,1059,1803,668
426 | 2,3,16980,2884,12232,874,3213,249
427 | 1,3,11243,2408,2593,15348,108,1886
428 | 1,3,13134,9347,14316,3141,5079,1894
429 | 1,3,31012,16687,5429,15082,439,1163
430 | 1,3,3047,5970,4910,2198,850,317
431 | 1,3,8607,1750,3580,47,84,2501
432 | 1,3,3097,4230,16483,575,241,2080
433 | 1,3,8533,5506,5160,13486,1377,1498
434 | 1,3,21117,1162,4754,269,1328,395
435 | 1,3,1982,3218,1493,1541,356,1449
436 | 1,3,16731,3922,7994,688,2371,838
437 | 1,3,29703,12051,16027,13135,182,2204
438 | 1,3,39228,1431,764,4510,93,2346
439 | 2,3,14531,15488,30243,437,14841,1867
440 | 1,3,10290,1981,2232,1038,168,2125
441 | 1,3,2787,1698,2510,65,477,52
442 |
--------------------------------------------------------------------------------
/creating_customer_segments/visuals.py:
--------------------------------------------------------------------------------
1 | ###########################################
2 | # Suppress matplotlib user warnings
3 | # Necessary for newer version of matplotlib
4 | import warnings
5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
6 | #
7 | # Display inline matplotlib plots with IPython
8 | from IPython import get_ipython
9 | get_ipython().run_line_magic('matplotlib', 'inline')
10 | ###########################################
11 |
12 | import matplotlib.pyplot as plt
13 | import matplotlib.cm as cm
14 | import pandas as pd
15 | import numpy as np
16 |
17 | def pca_results(good_data, pca):
18 | '''
19 | Create a DataFrame of the PCA results
20 | Includes dimension feature weights and explained variance
21 | Visualizes the PCA results
22 | '''
23 |
24 | # Dimension indexing
25 | dimensions = dimensions = ['Dimension {}'.format(i) for i in range(1,len(pca.components_)+1)]
26 |
27 | # PCA components
28 | components = pd.DataFrame(np.round(pca.components_, 4), columns = good_data.keys())
29 | components.index = dimensions
30 |
31 | # PCA explained variance
32 | ratios = pca.explained_variance_ratio_.reshape(len(pca.components_), 1)
33 | variance_ratios = pd.DataFrame(np.round(ratios, 4), columns = ['Explained Variance'])
34 | variance_ratios.index = dimensions
35 |
36 | # Create a bar plot visualization
37 | fig, ax = plt.subplots(figsize = (14,8))
38 |
39 | # Plot the feature weights as a function of the components
40 | components.plot(ax = ax, kind = 'bar');
41 | ax.set_ylabel("Feature Weights")
42 | ax.set_xticklabels(dimensions, rotation=0)
43 |
44 |
45 | # Display the explained variance ratios
46 | for i, ev in enumerate(pca.explained_variance_ratio_):
47 | ax.text(i-0.40, ax.get_ylim()[1] + 0.05, "Explained Variance\n %.4f"%(ev))
48 |
49 | # Return a concatenated DataFrame
50 | return pd.concat([variance_ratios, components], axis = 1)
51 |
52 | def cluster_results(reduced_data, preds, centers, pca_samples):
53 | '''
54 | Visualizes the PCA-reduced cluster data in two dimensions
55 | Adds cues for cluster centers and student-selected sample data
56 | '''
57 |
58 | predictions = pd.DataFrame(preds, columns = ['Cluster'])
59 | plot_data = pd.concat([predictions, reduced_data], axis = 1)
60 |
61 | # Generate the cluster plot
62 | fig, ax = plt.subplots(figsize = (14,8))
63 |
64 | # Color map
65 | cmap = cm.get_cmap('gist_rainbow')
66 |
67 | # Color the points based on assigned cluster
68 | for i, cluster in plot_data.groupby('Cluster'):
69 | cluster.plot(ax = ax, kind = 'scatter', x = 'Dimension 1', y = 'Dimension 2', \
70 | color = cmap((i)*1.0/(len(centers)-1)), label = 'Cluster %i'%(i), s=30);
71 |
72 | # Plot centers with indicators
73 | for i, c in enumerate(centers):
74 | ax.scatter(x = c[0], y = c[1], color = 'white', edgecolors = 'black', \
75 | alpha = 1, linewidth = 2, marker = 'o', s=200);
76 | ax.scatter(x = c[0], y = c[1], marker='$%d$'%(i), alpha = 1, s=100);
77 |
78 | # Plot transformed sample points
79 | ax.scatter(x = pca_samples[:,0], y = pca_samples[:,1], \
80 | s = 150, linewidth = 4, color = 'black', marker = 'x');
81 |
82 | # Set plot title
83 | ax.set_title("Cluster Learning on PCA-Reduced Data - Centroids Marked by Number\nTransformed Sample Data Marked by Black Cross");
84 |
85 |
86 | def biplot(good_data, reduced_data, pca):
87 | '''
88 | Produce a biplot that shows a scatterplot of the reduced
89 | data and the projections of the original features.
90 |
91 | good_data: original data, before transformation.
92 | Needs to be a pandas dataframe with valid column names
93 | reduced_data: the reduced data (the first two dimensions are plotted)
94 | pca: pca object that contains the components_ attribute
95 |
96 | return: a matplotlib AxesSubplot object (for any additional customization)
97 |
98 | This procedure is inspired by the script:
99 | https://github.com/teddyroland/python-biplot
100 | '''
101 |
102 | fig, ax = plt.subplots(figsize = (14,8))
103 | # scatterplot of the reduced data
104 | ax.scatter(x=reduced_data.loc[:, 'Dimension 1'], y=reduced_data.loc[:, 'Dimension 2'],
105 | facecolors='b', edgecolors='b', s=70, alpha=0.5)
106 |
107 | feature_vectors = pca.components_.T
108 |
109 | # we use scaling factors to make the arrows easier to see
110 | arrow_size, text_pos = 7.0, 8.0,
111 |
112 | # projections of the original features
113 | for i, v in enumerate(feature_vectors):
114 | ax.arrow(0, 0, arrow_size*v[0], arrow_size*v[1],
115 | head_width=0.2, head_length=0.2, linewidth=2, color='red')
116 | ax.text(v[0]*text_pos, v[1]*text_pos, good_data.columns[i], color='black',
117 | ha='center', va='center', fontsize=18)
118 |
119 | ax.set_xlabel("Dimension 1", fontsize=14)
120 | ax.set_ylabel("Dimension 2", fontsize=14)
121 | ax.set_title("PC plane with original feature projections.", fontsize=16);
122 | return ax
123 |
124 |
125 | def channel_results(reduced_data, outliers, pca_samples):
126 | '''
127 | Visualizes the PCA-reduced cluster data in two dimensions using the full dataset
128 | Data is labeled by "Channel" and cues added for student-selected sample data
129 | '''
130 |
131 | # Check that the dataset is loadable
132 | try:
133 | full_data = pd.read_csv("customers.csv")
134 | except:
135 | print "Dataset could not be loaded. Is the file missing?"
136 | return False
137 |
138 | # Create the Channel DataFrame
139 | channel = pd.DataFrame(full_data['Channel'], columns = ['Channel'])
140 | channel = channel.drop(channel.index[outliers]).reset_index(drop = True)
141 | labeled = pd.concat([reduced_data, channel], axis = 1)
142 |
143 | # Generate the cluster plot
144 | fig, ax = plt.subplots(figsize = (14,8))
145 |
146 | # Color map
147 | cmap = cm.get_cmap('gist_rainbow')
148 |
149 | # Color the points based on assigned Channel
150 | labels = ['Hotel/Restaurant/Cafe', 'Retailer']
151 | grouped = labeled.groupby('Channel')
152 | for i, channel in grouped:
153 | channel.plot(ax = ax, kind = 'scatter', x = 'Dimension 1', y = 'Dimension 2', \
154 | color = cmap((i-1)*1.0/2), label = labels[i-1], s=30);
155 |
156 | # Plot transformed sample points
157 | for i, sample in enumerate(pca_samples):
158 | ax.scatter(x = sample[0], y = sample[1], \
159 | s = 200, linewidth = 3, color = 'black', marker = 'o', facecolors = 'none');
160 | ax.scatter(x = sample[0]+0.25, y = sample[1]+0.3, marker='$%d$'%(i), alpha = 1, s=125);
161 |
162 | # Set plot title
163 | ax.set_title("PCA-Reduced Data Labeled by 'Channel'\nTransformed Sample Data Circled");
--------------------------------------------------------------------------------
/digit_recognition/README.md:
--------------------------------------------------------------------------------
1 | # 机器学习工程师纳米学位
2 | # 深度学习
3 | ## 项目:搭建一个数字识别项目
4 |
5 | ### 安装
6 |
7 | 这个项目要求使用 **Python 2.7** 并且需要安装下面这些python包:
8 |
9 | - [NumPy](http://www.numpy.org/)
10 | - [SciPy](https://www.scipy.org/)
11 | - [scikit-learn](http://scikit-learn.org/0.17/install.html) (v0.17)
12 | - [TensorFlow](http://tensorflow.org)
13 |
14 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://ipython.org/notebook.html).
15 |
16 | 除了上面提到的,对于那些希望额外使用图像处理软件的,你可能需要安装下面的某一款软件:
17 | - [PyGame](http://pygame.org/)
18 | - 对于安装PyGame有帮助的链接:
19 | - [Getting Started](https://www.pygame.org/wiki/GettingStarted)
20 | - [PyGame Information](http://www.pygame.org/wiki/info)
21 | - [Google Group](https://groups.google.com/forum/#!forum/pygame-mirror-on-google-groups)
22 | - [PyGame subreddit](https://www.reddit.com/r/pygame/)
23 | - [OpenCV](http://opencv.org/)
24 |
25 | 对于那些希望选择额外选择将应用部署成安卓应用的:
26 | - Android SDK & NDK (查看这个[README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/README.md))
27 |
28 | 如果你还没有安装Python,优达学城推荐学生安装[Anaconda](http://continuum.io/downloads)这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件,请确认你安装的是Python 2.7而不是Python 3.x。然后`pygame`和`OpenCV`可以通过下列命令安装:
29 |
30 | Mac:
31 | ```bash
32 | conda install -c https://conda.anaconda.org/quasiben pygame
33 | conda install -c menpo opencv=2.4.11
34 | ```
35 |
36 | Windows & Linux:
37 | ```bash
38 | conda install -c https://conda.anaconda.org/tlatorre pygame
39 | conda install -c menpo opencv=2.4.11
40 | ```
41 |
42 | ### 代码
43 |
44 | 初始代码包含在`digit_recognition.ipynb`这个notebook文件中。这里面没有提供给你代码,为了完成项目,你需要在notebook中实现基本的功能并回答关于你的实现和结果的问题。
45 |
46 | ### 运行
47 |
48 | 在命令行中,确保当前目录为 `digit_recognition/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令:
49 | ```bash
50 | ipython notebook digit_recognition.ipynb
51 | ```
52 | 或者
53 | ```bash
54 | jupyter notebook digit_recognition.ipynb
55 | ```
56 |
57 | 这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。
58 |
59 |
60 | ### 数据
61 |
62 | 因为这个项目没有直接提供任何的代码,你要自己下载并使用[街景房屋门牌号(SVHN)数据集](http://ufldl.stanford.edu/housenumbers/),同时你还需要[notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html)数据集或者是[MNIST](http://yann.lecun.com/exdb/mnist/)数据集。如果你已经完成了课程内容,那么你已经有**notMINIST**数据集了。
63 |
--------------------------------------------------------------------------------
/digit_recognition/model.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/digit_recognition/model.png
--------------------------------------------------------------------------------
/digit_recognition/project_description.md:
--------------------------------------------------------------------------------
1 | # 内容:深度学习
2 | ## 项目:搭建一个数字识别项目
3 |
4 | ## 项目概述
5 |
6 | 在这个项目中,你将使用你学到的关于深度神经网络和卷积神经网络的知识,建立一个实时相机应用或者是程序,它能够从提供的图片中实时打印出他观察到的数字。首先你将设计并实现一个能够识别数字序列的深度学习模型架构。然后,你将训练这个模型,使得它能够从类似[街景房屋门牌号(SVHN) dataset](http://ufldl.stanford.edu/housenumbers/)这种现实图片中识别出数字序列。模型训练好之后,你将使用一个实时相机应用(可选)或者是在新捕获的图片上建立应用以测试你的模型。最后,一旦你获得了有意义的结果,你将优化你的实现,并且*定位图片上的数字*,以及在新捕获的图像上测试定位效果。
7 |
8 | ## 软件需求
9 | 这个项目需要安装下面这些软件和python包:
10 |
11 | - [Python 2.7](https://www.python.org/download/releases/2.7/)
12 | - [NumPy](http://www.numpy.org/)
13 | - [SciPy](https://www.scipy.org/)
14 | - [scikit-learn](http://scikit-learn.org/0.17/install.html) (v0.17)
15 | - [TensorFlow](http://tensorflow.org)
16 |
17 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://ipython.org/notebook.html).
18 |
19 | 除了上面提到的,对于那些希望额外使用图像处理软件的,你可能需要安装下面的某一款软件:
20 | - [PyGame](http://pygame.org/)
21 | - 对于安装PyGame有帮助的链接:
22 | - [Getting Started](https://www.pygame.org/wiki/GettingStarted)
23 | - [PyGame Information](http://www.pygame.org/wiki/info)
24 | - [Google Group](https://groups.google.com/forum/#!forum/pygame-mirror-on-google-groups)
25 | - [PyGame subreddit](https://www.reddit.com/r/pygame/)
26 | - [OpenCV](http://opencv.org/)
27 |
28 | 对于那些希望额外选择将应用部署成安卓应用的:
29 | - Android SDK & NDK (查看这个[README](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/examples/android/README.md))
30 |
31 | 如果你还没有安装Python,优达学城推荐学生安装[Anaconda](http://continuum.io/downloads)这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件,请确认你安装的是Python 2.7而不是Python 3.x。然后`pygame`和`OpenCV`可以通过下列命令安装:
32 |
33 |
34 | **opencv**
35 | `conda install -c menpo opencv=2.4.11`
36 |
37 | **PyGame:**
38 | Mac: `conda install -c https://conda.anaconda.org/quasiben pygame`
39 | Windows: `conda install -c https://conda.anaconda.org/tlatorre pygame`
40 | Linux: `conda install -c https://conda.anaconda.org/prkrekel pygame`
41 |
42 | ## 开始项目
43 |
44 | 对于这次作业,你可以在**课程资源**部分找到可下载的`digit_recognition.zip`压缩包,它包含了项目所需要的文件。*你也可以访问我们的[Machine Learning projects GitHub](https://github.com/udacity/machine-learning)以获取这个纳米学位的所有项目文件*
45 |
46 | - `digit_recognition.ipynb`:这个文件是你将要修改的主要文件。
47 |
48 | 另外,你要自己下载并使用[街景房屋门牌号(SVHN)数据集](http://ufldl.stanford.edu/housenumbers/),同时你还需要[notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html)数据集或者是[MNIST](http://yann.lecun.com/exdb/mnist/)数据集。如果你已经完成了课程内容,那么你已经有**notMINIST**数据集了。
49 |
50 | 在终端或者是命令行中,导航到包含项目文件的文件夹中,然后使用命令`jupyter notebook digit_recognition.ipynb`,在浏览器窗口或者是标签页中打开你的notebook。或者你也可以使用命令`jupyter notebook`或者`ipython notebook`,然后再打开的浏览器窗口中导航到notebook文件。为了完成这个项目你需要跟随notebook中的指引,并回答提出的每一个问题。和项目文件一起我们还提供了一个**README**文件,它包含了一些关于这个项目的额外的必要信息或者是指引。
51 |
52 | ## 任务
53 |
54 | ### 项目报告
55 | 作为你提交的`digit_recognition.ipynb`的一部分,你需要回答关于你的实现的一些问题。在完成下面的任务的同时,你需要包含关于每一个问题(下面用*斜体*表示)的全面的详细的回答。
56 |
57 | ### 步骤 1: 设计并测试一个模型架构
58 | 设计并实现一个能够识别数字序列的深度学习模型。你可以通过连接[notMNIST](http://yaroslavvb.blogspot.com/2011/09/notmnist-dataset.html)或者是[MNIST](http://yann.lecun.com/exdb/mnist/)的字符来合成数据来训练这个模型。为了产生用于测试的合成数字序列,你可以进行如下的设置:比如,你可以限制一个数据序列最多五个数字,并在你的深度网络上使用五个分类器。同时,你有必要准备一个额外的“空白”的字符,以处理相对较短的数字序列。
59 |
60 | 在思考这个问题的时候有很多方面可以考虑:
61 | - 你的模型可以基于深度神经网络或者是卷积神经网络。
62 | - 你可以尝试是否在softmax分类器间共享权值。
63 | - 你还可以在深度神经网络中使用循环网络来替换其中的分类层,并且将数字序列里的数字一个一个地输出。
64 |
65 | 这里有一个[发表的关于这个问题的基准模型的论文](http://static.googleusercontent.com/media/research.google.com/en//pubs/archive/42241.pdf)([视频](https://www.youtube.com/watch?v=vGPI_JvLoN0))的例子。
66 |
67 | ***问题***
68 | _你为解决这个问题采取了什么方法?_
69 |
70 | ***问题***
71 | _你最终的模型架构是什么样的?(什么类型的模型,层数,大小, 连接性等)_
72 |
73 | ***问题***
74 | _你是如何训练你的模型的?你是如何合成数据集的?_ 请同时包括你创建的合成数据中的一些示例图像。
75 |
76 | ### 步骤 2: 在真实数据集上训练一个模型
77 |
78 | 一旦你确定好了一个好的模型架构,你就可以开始在真实的数据上训练你的模型了。特别地,[街景房屋门牌号(SVHN)](http://ufldl.stanford.edu/housenumbers/)数据集是一个大规模的,从谷歌街景数据中采集的门牌号数据。在这个更有挑战性的数据集(这里数字不是整齐排列的,并且会有各种倾斜、字体和颜色)上训练,可能意味着你必须做一些超参数探索以获得一个表现良好的模型。
79 |
80 | ***问题***
81 | _描述如何为模型准备训练和测试数据。 模型在真实数据集上表现怎么样?_
82 |
83 | ***问题***
84 | _你(在模型上)做了什么改变?如果做了一些改变,那么你得到一个“好的”结果了妈?有没有任何你探索的导致结果更糟?_
85 |
86 | ***问题***
87 | _当你在真实数据集做测试的时候你的初始结果和最终结果是什么?你认为你的模型在正确分类数字这个任务上上做的足够好吗?_
88 |
89 | ### 步骤 3: 在新抓取的图片上测试模型
90 |
91 | 在你周围拍摄几张数字的图片(至少五张),然后用你的分类器来预测产生结果。或者(可选),你可以尝试使用OpenCV / SimpleCV / Pygame从网络摄像头捕获实时图像,并通过你的分类器分析这些图像。
92 |
93 | ***问题***
94 | _选择在你周围拍摄的五张候选图片,并提供在报告中。它们中的某些图片是否有一些特殊的性质,可能会导致分类困难?_
95 |
96 | ***问题***
97 | _与在现实数据集上的测试结果相比,你的模型能够在捕获的图片或实时相机流上表现同样良好吗?_
98 |
99 | ***问题***
100 | _如果必要的话,请提供关于你是如何建立一个使得你的模型能够加载和分类新获取图像的接口的。_
101 |
102 | ### 步骤 4: 探索一种提升模型的方式
103 |
104 | 一旦你基本的分类器训练好了,你就可以做很多事情。一个例子是:(在分类的同时)还能够定位图像上数字的位置。SVHN数据集提供边界框,你可以调试以训练一个定位器。训练一个关于坐标与边框的回归损失函数,然后测试它。
105 |
106 | ***问题***
107 | _你的模型在真实数据的测试集上定位数字表现的怎么样?包含位置信息之后你的分类结果有变化吗?_
108 |
109 | ***问题***
110 | 在你在**步骤3**所捕获的图像上测试你的定位功能。模型是否准确计算出你找到的图像中的数字的边界框?如果你没有使用图形界面,您可能需要手动探索边界框。_提供一个在捕获的图像上创建边界框的示例_
111 |
112 | ## 可选步骤 5:为模型封装一个应用或者是程序
113 |
114 | 为了让你的项目更进一步。如果你有兴趣,可以构建一个安卓应用程序,或者是一个更鲁棒的Python程序。这些程序能够和输入的图像交互,显示分类的数字甚至边界框。比如,你可以尝试通过将你的答案叠加在图像上像[Word Lens](https://en.wikipedia.org/wiki/Word_Lens)应用程序那样来构建一个增强现实应用程序。
115 |
116 | 如何在安卓上的相机应用程序中加载TensorFlow的模型的示例代码在[TensorFlow Android demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/android)中,你可以再这个基础上做一些简单的修改。
117 |
118 |
119 | 如果你决定探索这条可选路径,请务必记录你的接口和实现,以及你找到的重要结果。你可以通过[点击这个链接](https://review.udacity.com/#!/rubrics/413/view)看到将被用来评价你的工作的相关条目。
120 |
121 | ## 提交项目
122 |
123 | ### 评价
124 | 你的项目会由Udacity项目评审师根据 **搭建一个数字识别项目量规**进行评审。请注意仔细阅读这份量规并在提交前进行全面的自我评价。这份量规中涉及的所有条目必须全部被标记成*meeting specifications*你才能通过。
125 |
126 | ### 需要提交的文件
127 | 当你准备好提交你的项目的时候,请收集以下的文件,并把他们压缩进单个压缩包中上传。或者你也可以在你的GitHub Repo中的一个名叫`digit_recognition`的文件夹中提供以下文件以方便检查:
128 | - 回答了所有问题并且所有的代码单元被执行并显示了输出结果的`digit_recognition.ipynb`文件。
129 | - 一个从项目的notebook文件中导出的命名为**report.html**的**HTML**文件。这个文件*必须*提供。
130 | - 任何用于这个项目的除SVHN, notMNIST或者MNIST以外的数据集或者是图像
131 |
132 | - 对于可选的图像识别软件部分,你需要提供任何相关的Python文件,以保证能够运行你的代码。
133 | - 对于可选的安卓应用部分,你需要为如何获取和使用这个应用编写文档,这部分应该提供一个命名为**documentation.pdf**的pdf报告。
134 |
135 | 一旦你收集好了这些文件,并阅读了项目量规,请进入项目提交页面。
136 |
137 | ### 我准备好了!
138 | 当你准备好提交项目的时候,点击页面底部的**提交项目**按钮。
139 |
140 | 如果你提交项目中遇到任何问题或者是希望检查你的提交的进度,请给**machine-support@udacity.com**发邮件,或者你可以访问论坛.
141 |
142 | ### 然后?
143 | 当你的项目评审师给你回复之后你会马上收到一封通知邮件。在等待的同时你也可以开始准备下一个项目,或者学习相关的课程。
144 |
--------------------------------------------------------------------------------
/dog_vs_cat/README.md:
--------------------------------------------------------------------------------
1 | # 猫狗大战
2 | # 本项目在aws udacity-carND上训练
3 | # 一个模型训练时间约75分钟
4 | # 本次使用的数据为kaggle上猫狗大战的数据
5 | # 本次数据出来和模型训练的代码都在dog_vs_cat.ipynb中
6 | # file.csv文件为提交到kaggle评价的数据
7 |
8 | # model.h5
9 | https://drive.google.com/open?id=0B_jeca3axBM9dkdheXJaWHoxbE0
10 | # model_2.h5
11 | https://drive.google.com/open?id=0B_jeca3axBM9V2lrRXBPX0pYQTA
12 | # model_3.h5
13 | https://drive.google.com/open?id=0B_jeca3axBM9ZkdoTFItZllBQ1k
--------------------------------------------------------------------------------
/dog_vs_cat/proposal.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/dog_vs_cat/proposal.pdf
--------------------------------------------------------------------------------
/dog_vs_cat/report2.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/dog_vs_cat/report2.pdf
--------------------------------------------------------------------------------
/finding_donors/README.md:
--------------------------------------------------------------------------------
1 | # 机器学习纳米学位
2 | # 监督学习
3 | ## 项目: 为CharityML寻找捐献者
4 | ### 安装
5 |
6 | 这个项目要求使用 Python 2.7 并且需要安装下面这些python包:
7 |
8 | - [Python 2.7](https://www.python.org/download/releases/2.7/)
9 | - [NumPy](http://www.numpy.org/)
10 | - [Pandas](http://pandas.pydata.org/)
11 | - [scikit-learn](http://scikit-learn.org/stable/)
12 | - [matplotlib](http://matplotlib.org/)
13 |
14 | 你同样需要安装好相应软件使之能够运行 [iPython Notebook](http://ipython.org/notebook.html)
15 |
16 | 优达学城推荐学生安装[Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。
17 |
18 | ### 代码
19 |
20 | 初始代码包含在`finding_donors.ipynb`这个notebook文件中。你还会用到`visuals.py`和名为`census.csv`的数据文件来完成这个项目。我们已经为你提供了一部分代码,但还有些功能需要你来实现才能以完成这个项目。
21 | 这里面有一些代码已经实现好来帮助你开始项目,但是为了完成项目,你还需要实现附加的功能。
22 | 注意包含在`visuals.py`中的代码设计成一个外部导入的功能,而不是打算学生去修改。如果你对notebook中创建的可视化感兴趣,你也可以去查看这些代码。
23 |
24 |
25 | ### 运行
26 | 在命令行中,确保当前目录为 `finding_donors/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令:
27 |
28 | ```bash
29 | jupyter notebook finding_donors.ipynb
30 | ```
31 |
32 | 这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。
33 |
34 | ### 数据
35 |
36 | 修改的人口普查数据集含有将近32,000个数据点,每一个数据点含有13个特征。这个数据集是Ron Kohavi的论文*"Scaling Up the Accuracy of Naive-Bayes Classifiers: a Decision-Tree Hybrid",*中数据集的一个修改版本。你能够在[这里](https://www.aaai.org/Papers/KDD/1996/KDD96-033.pdf)找到论文,在[UCI的网站](https://archive.ics.uci.edu/ml/datasets/Census+Income)找到原始数据集。
37 |
38 | **特征**
39 |
40 | - `age`: 一个整数,表示被调查者的年龄。
41 | - `workclass`: 一个类别变量表示被调查者的通常劳动类型,允许的值有 {Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked}
42 | - `education_level`: 一个类别变量表示教育程度,允许的值有 {Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool}
43 | - `education-num`: 一个整数表示在学校学习了多少年
44 | - `marital-status`: 一个类别变量,允许的值有 {Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse}
45 | - `occupation`: 一个类别变量表示一般的职业领域,允许的值有 {Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces}
46 | - `relationship`: 一个类别变量表示家庭情况,允许的值有 {Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried}
47 | - `race`: 一个类别变量表示人种,允许的值有 {White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black}
48 | - `sex`: 一个类别变量表示性别,允许的值有 {Female, Male}
49 | - `capital-gain`: 连续值。
50 | - `capital-loss`: 连续值。
51 | - `hours-per-week`: 连续值。
52 | - `native-country`: 一个类别变量表示原始的国家,允许的值有 {United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands}
53 |
54 | **目标变量**
55 |
56 | - `income`: 一个类别变量,表示收入属于那个类别,允许的值有 {<=50K, >50K}
--------------------------------------------------------------------------------
/finding_donors/finding_donors.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/finding_donors/finding_donors.zip
--------------------------------------------------------------------------------
/finding_donors/project_description.md:
--------------------------------------------------------------------------------
1 | # 内容: 监督学习
2 | ## 项目:为CharityML寻找捐献者
3 |
4 | ## 项目概况
5 | 在这个项目中,你将使用监督技术和分析能力对美国人口普查数据进行分析,以帮助CharityML(一个虚拟的慈善机构)识别最有可能向他们捐款的人,你将首先探索数据以了解人口普查数据是如何记录的。接下来,你将使用一系列的转换和预处理技术以将数据整理成能用的形式。然后,你将在这个数据上评价你选择的几个算法,然后考虑哪一个是最合适的。之后,你将优化你现在为CharityML选择的模型。最后,你将探索选择的模型和它的预测能力。
6 |
7 | ## 项目亮点
8 | 这个项目设计成帮助你熟悉在sklearn中能够使用的多个监督学习算法,并提供一个评价模型在某种类型的数据上表现的方法。在机器学习中准确理解在什么时候什么地方应该选择什么算法和不应该选择什么算法是十分重要的。
9 |
10 | 完成这个项目你将学会以下内容:
11 | - 知道什么时候应该使用预处理以及如何做预处理。
12 | - 如何为问题设置一个基准。
13 | - 判断在一个特定的数据集上几个监督学习算法的表现如何。
14 | - 调查候选的解决方案模型是否足够解决问题。
15 |
16 | ## 软件要求
17 |
18 | 这个项目要求使用 Python 2.7 并且需要安装下面这些python包:
19 |
20 | - [Python 2.7](https://www.python.org/download/releases/2.7/)
21 | - [NumPy](http://www.numpy.org/)
22 | - [Pandas](http://pandas.pydata.org/)
23 | - [scikit-learn](http://scikit-learn.org/stable/)
24 | - [matplotlib](http://matplotlib.org/)
25 |
26 | 你同样需要安装好相应软件使之能够运行 [iPython Notebook](http://ipython.org/notebook.html)
27 |
28 | 优达学城推荐学生安装[Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。请注意你安装的是2.7而不是3.X
29 |
30 | ## 开始项目
31 |
32 | 对于这个项目,你能够在**Resources**部分找到一个能下载的`find_donors.zip`。*你也可以访问我们的[机器学习项目GitHub](https://github.com/udacity/machine-learning)获取我们纳米学位中的所有项目*
33 |
34 | 这个项目包含以下文件:
35 |
36 | - `find_donors.ipynb`: 这是你需要工作的主要的文件。
37 | - `census.csv`: 项目使用的数据集,你将需要在notebook中载入这个数据集。
38 | - `visuals.py`: 一个实现了可视化功能的Python代码。不要修改它。
39 |
40 | 在终端或命令提示符中,导航到包含项目文件的文件夹,使用命令`jupyter notebook finding_donors.ipynb`以在一个浏览器窗口或一个标签页打开notebook文件。或者你也可以使用命令`jupyter notebook`或者`ipython notebook`然后在打开的网页中导航到需要的文件夹。跟随notebook中的指引,回答每一个问题以成功完成项目。在这个项目中我们也提供了一个**README**文件,其中也包含了你在这个项目中需要了解的信息或者指引。
41 |
42 | ## 提交项目
43 |
44 | ### 评价
45 | 你的项目会由Udacity项目评审师根据**为CharityML寻找捐献者项目量规**进行评审。请注意仔细阅读这份量规并在提交前进行全面的自我评价。这份量规中涉及的所有条目必须全部被标记成*meeting specifications*你才能通过。
46 |
47 | ### 需要提交的文件
48 | 当你准备好提交你的项目的时候,请收集以下的文件,并把他们压缩进单个压缩包中上传。或者你也可以在你的GitHub Repo中的一个名叫`finding_donors`的文件夹中提供以下文件以方便检查:
49 | - 回答了所有问题并且所有的代码单元被执行并显示了输出结果的`finding_donors.ipynb`文件。
50 | - 一个从项目的notebook文件中导出的命名为**report.html**的**HTML**文件。这个文件*必须*提供。
51 |
52 | 一旦你收集好了这些文件,并阅读了项目量规,请进入项目提交页面。
53 |
54 | ### 我准备好了!
55 | 当你准备好提交项目的时候,点击页面底部的**提交项目**按钮。
56 |
57 | 如果你提交项目中遇到任何问题或者是希望检查你的提交的进度,请给**machine-support@udacity.com**发邮件,或者你可以访问论坛.
58 |
59 | ### 然后?
60 | 当你的项目评审师给你回复之后你会马上收到一封通知邮件。在等待的同时你也可以开始准备下一个项目,或者学习相关的课程。
--------------------------------------------------------------------------------
/finding_donors/visuals.py:
--------------------------------------------------------------------------------
1 | ###########################################
2 | # Suppress matplotlib user warnings
3 | # Necessary for newer version of matplotlib
4 | import warnings
5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
6 | #
7 | # Display inline matplotlib plots with IPython
8 | from IPython import get_ipython
9 | get_ipython().run_line_magic('matplotlib', 'inline')
10 | ###########################################
11 |
12 | import matplotlib.pyplot as pl
13 | import matplotlib.patches as mpatches
14 | import numpy as np
15 | import pandas as pd
16 | from time import time
17 | from sklearn.metrics import f1_score, accuracy_score
18 |
19 |
20 | def distribution(data, transformed = False):
21 | """
22 | Visualization code for displaying skewed distributions of features
23 | """
24 |
25 | # Create figure
26 | fig = pl.figure(figsize = (11,5));
27 |
28 | # Skewed feature plotting
29 | for i, feature in enumerate(['capital-gain','capital-loss']):
30 | ax = fig.add_subplot(1, 2, i+1)
31 | ax.hist(data[feature], bins = 25, color = '#00A0A0')
32 | ax.set_title("'%s' Feature Distribution"%(feature), fontsize = 14)
33 | ax.set_xlabel("Value")
34 | ax.set_ylabel("Number of Records")
35 | ax.set_ylim((0, 2000))
36 | ax.set_yticks([0, 500, 1000, 1500, 2000])
37 | ax.set_yticklabels([0, 500, 1000, 1500, ">2000"])
38 |
39 | # Plot aesthetics
40 | if transformed:
41 | fig.suptitle("Log-transformed Distributions of Continuous Census Data Features", \
42 | fontsize = 16, y = 1.03)
43 | else:
44 | fig.suptitle("Skewed Distributions of Continuous Census Data Features", \
45 | fontsize = 16, y = 1.03)
46 |
47 | fig.tight_layout()
48 | fig.show()
49 |
50 |
51 | def evaluate(results, accuracy, f1):
52 | """
53 | Visualization code to display results of various learners.
54 |
55 | inputs:
56 | - learners: a list of supervised learners
57 | - stats: a list of dictionaries of the statistic results from 'train_predict()'
58 | - accuracy: The score for the naive predictor
59 | - f1: The score for the naive predictor
60 | """
61 |
62 | # Create figure
63 | fig, ax = pl.subplots(2, 3, figsize = (11,7))
64 |
65 | # Constants
66 | bar_width = 0.3
67 | colors = ['#A00000','#00A0A0','#00A000']
68 |
69 | # Super loop to plot four panels of data
70 | for k, learner in enumerate(results.keys()):
71 | for j, metric in enumerate(['train_time', 'acc_train', 'f_train', 'pred_time', 'acc_test', 'f_test']):
72 | for i in np.arange(3):
73 |
74 | # Creative plot code
75 | ax[j/3, j%3].bar(i+k*bar_width, results[learner][i][metric], width = bar_width, color = colors[k])
76 | ax[j/3, j%3].set_xticks([0.45, 1.45, 2.45])
77 | ax[j/3, j%3].set_xticklabels(["1%", "10%", "100%"])
78 | ax[j/3, j%3].set_xlabel("Training Set Size")
79 | ax[j/3, j%3].set_xlim((-0.1, 3.0))
80 |
81 | # Add unique y-labels
82 | ax[0, 0].set_ylabel("Time (in seconds)")
83 | ax[0, 1].set_ylabel("Accuracy Score")
84 | ax[0, 2].set_ylabel("F-score")
85 | ax[1, 0].set_ylabel("Time (in seconds)")
86 | ax[1, 1].set_ylabel("Accuracy Score")
87 | ax[1, 2].set_ylabel("F-score")
88 |
89 | # Add titles
90 | ax[0, 0].set_title("Model Training")
91 | ax[0, 1].set_title("Accuracy Score on Training Subset")
92 | ax[0, 2].set_title("F-score on Training Subset")
93 | ax[1, 0].set_title("Model Predicting")
94 | ax[1, 1].set_title("Accuracy Score on Testing Set")
95 | ax[1, 2].set_title("F-score on Testing Set")
96 |
97 | # Add horizontal lines for naive predictors
98 | ax[0, 1].axhline(y = accuracy, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
99 | ax[1, 1].axhline(y = accuracy, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
100 | ax[0, 2].axhline(y = f1, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
101 | ax[1, 2].axhline(y = f1, xmin = -0.1, xmax = 3.0, linewidth = 1, color = 'k', linestyle = 'dashed')
102 |
103 | # Set y-limits for score panels
104 | ax[0, 1].set_ylim((0, 1))
105 | ax[0, 2].set_ylim((0, 1))
106 | ax[1, 1].set_ylim((0, 1))
107 | ax[1, 2].set_ylim((0, 1))
108 |
109 | # Create patches for the legend
110 | patches = []
111 | for i, learner in enumerate(results.keys()):
112 | patches.append(mpatches.Patch(color = colors[i], label = learner))
113 | pl.legend(handles = patches, bbox_to_anchor = (-.80, 2.53), \
114 | loc = 'upper center', borderaxespad = 0., ncol = 3, fontsize = 'x-large')
115 |
116 | # Aesthetics
117 | pl.suptitle("Performance Metrics for Three Supervised Learning Models", fontsize = 16, y = 1.10)
118 | pl.tight_layout()
119 | pl.show()
120 |
121 |
122 | def feature_plot(importances, X_train, y_train):
123 |
124 | # Display the five most important features
125 | indices = np.argsort(importances)[::-1]
126 | columns = X_train.columns.values[indices[:5]]
127 | values = importances[indices][:5]
128 |
129 | # Creat the plot
130 | fig = pl.figure(figsize = (9,5))
131 | pl.title("Normalized Weights for First Five Most Predictive Features", fontsize = 16)
132 | pl.bar(np.arange(5), values, width = 0.6, align="center", color = '#00A000', \
133 | label = "Feature Weight")
134 | pl.bar(np.arange(5) - 0.3, np.cumsum(values), width = 0.2, align = "center", color = '#00A0A0', \
135 | label = "Cumulative Feature Weight")
136 | pl.xticks(np.arange(5), columns)
137 | pl.xlim((-0.5, 4.5))
138 | pl.ylabel("Weight", fontsize = 12)
139 | pl.xlabel("Feature", fontsize = 12)
140 |
141 | pl.legend(loc = 'upper center')
142 | pl.tight_layout()
143 | pl.show()
144 |
--------------------------------------------------------------------------------
/smartcab/README.md:
--------------------------------------------------------------------------------
1 | # Machine Learning Engineer Nanodegree
2 | # Reinforcement Learning
3 | ## Project: Train a Smartcab How to Drive
4 |
5 | ### Install
6 |
7 | This project requires **Python 2.7** with the [pygame](https://www.pygame.org/wiki/GettingStarted
8 | ) library installed
9 |
10 | ### Code
11 |
12 | Template code is provided in the `smartcab/agent.py` python file. Additional supporting python code can be found in `smartcab/enviroment.py`, `smartcab/planner.py`, and `smartcab/simulator.py`. Supporting images for the graphical user interface can be found in the `images` folder. While some code has already been implemented to get you started, you will need to implement additional functionality for the `LearningAgent` class in `agent.py` when requested to successfully complete the project.
13 |
14 | ### Run
15 |
16 | In a terminal or command window, navigate to the top-level project directory `smartcab/` (that contains this README) and run one of the following commands:
17 |
18 | ```python smartcab/agent.py```
19 | ```python -m smartcab.agent```
20 |
21 | This will run the `agent.py` file and execute your agent code.
22 |
--------------------------------------------------------------------------------
/smartcab/agent.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/agent.py
--------------------------------------------------------------------------------
/smartcab/images/car-black.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-black.png
--------------------------------------------------------------------------------
/smartcab/images/car-blue.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-blue.png
--------------------------------------------------------------------------------
/smartcab/images/car-cyan.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-cyan.png
--------------------------------------------------------------------------------
/smartcab/images/car-green.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-green.png
--------------------------------------------------------------------------------
/smartcab/images/car-magenta.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-magenta.png
--------------------------------------------------------------------------------
/smartcab/images/car-orange.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-orange.png
--------------------------------------------------------------------------------
/smartcab/images/car-red.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-red.png
--------------------------------------------------------------------------------
/smartcab/images/car-white.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-white.png
--------------------------------------------------------------------------------
/smartcab/images/car-yellow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/car-yellow.png
--------------------------------------------------------------------------------
/smartcab/images/east-west.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/east-west.png
--------------------------------------------------------------------------------
/smartcab/images/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/logo.png
--------------------------------------------------------------------------------
/smartcab/images/north-south.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/images/north-south.png
--------------------------------------------------------------------------------
/smartcab/logs/sim_default-learning.csv:
--------------------------------------------------------------------------------
1 | trial,testing,parameters,initial_deadline,final_deadline,net_reward,actions,success
2 | 1,False,"{'a': 0.5, 'e': 0.95}",30,0,-141.4152219254666,"{0: 21, 1: 2, 2: 3, 3: 2, 4: 2}",0
3 | 2,False,"{'a': 0.5, 'e': 0.8999999999999999}",30,0,-118.74810356680284,"{0: 20, 1: 3, 2: 5, 3: 0, 4: 2}",0
4 | 3,False,"{'a': 0.5, 'e': 0.8499999999999999}",25,0,-134.30063983462142,"{0: 14, 1: 2, 2: 6, 3: 2, 4: 1}",0
5 | 4,False,"{'a': 0.5, 'e': 0.7999999999999998}",25,0,-122.97600732528981,"{0: 17, 1: 2, 2: 3, 3: 1, 4: 2}",0
6 | 5,False,"{'a': 0.5, 'e': 0.7499999999999998}",20,2,-9.639186715051562,"{0: 15, 1: 0, 2: 3, 3: 0, 4: 0}",1
7 | 6,False,"{'a': 0.5, 'e': 0.6999999999999997}",20,12,2.7898590683150184,"{0: 7, 1: 0, 2: 1, 3: 0, 4: 0}",1
8 | 7,False,"{'a': 0.5, 'e': 0.6499999999999997}",20,0,-32.30763679982498,"{0: 14, 1: 2, 2: 4, 3: 0, 4: 0}",0
9 | 8,False,"{'a': 0.5, 'e': 0.5999999999999996}",20,0,-80.37939015975518,"{0: 13, 1: 2, 2: 4, 3: 0, 4: 1}",0
10 | 9,False,"{'a': 0.5, 'e': 0.5499999999999996}",20,0,-61.96798385294584,"{0: 13, 1: 4, 2: 2, 3: 0, 4: 1}",0
11 | 10,False,"{'a': 0.5, 'e': 0.4999999999999996}",25,0,-84.82140927085712,"{0: 19, 1: 1, 2: 3, 3: 0, 4: 2}",0
12 | 11,False,"{'a': 0.5, 'e': 0.4499999999999996}",25,0,-25.103553666871605,"{0: 19, 1: 3, 2: 3, 3: 0, 4: 0}",1
13 | 12,False,"{'a': 0.5, 'e': 0.39999999999999963}",20,0,-28.16709246761365,"{0: 17, 1: 1, 2: 1, 3: 0, 4: 1}",0
14 | 13,False,"{'a': 0.5, 'e': 0.34999999999999964}",20,5,-29.19961200493926,"{0: 13, 1: 0, 2: 1, 3: 0, 4: 1}",1
15 | 14,False,"{'a': 0.5, 'e': 0.29999999999999966}",20,0,-2.918593524608614,"{0: 16, 1: 3, 2: 1, 3: 0, 4: 0}",0
16 | 15,False,"{'a': 0.5, 'e': 0.24999999999999967}",25,12,21.289498513974756,"{0: 13, 1: 0, 2: 0, 3: 0, 4: 0}",1
17 | 16,False,"{'a': 0.5, 'e': 0.19999999999999968}",20,0,-50.21845680672953,"{0: 18, 1: 0, 2: 0, 3: 0, 4: 2}",0
18 | 17,False,"{'a': 0.5, 'e': 0.1499999999999997}",20,16,10.304394010758255,"{0: 4, 1: 0, 2: 0, 3: 0, 4: 0}",1
19 | 18,False,"{'a': 0.5, 'e': 0.09999999999999969}",20,10,20.80318001076089,"{0: 10, 1: 0, 2: 0, 3: 0, 4: 0}",1
20 | 19,False,"{'a': 0.5, 'e': 0.049999999999999684}",25,7,12.021487368019777,"{0: 17, 1: 0, 2: 0, 3: 1, 4: 0}",1
21 | 20,False,"{'a': 0.5, 'e': -3.191891195797325e-16}",25,13,22.44349872235844,"{0: 12, 1: 0, 2: 0, 3: 0, 4: 0}",1
22 | 1,True,"{'a': 0, 'e': 0}",25,13,23.591695430775612,"{0: 12, 1: 0, 2: 0, 3: 0, 4: 0}",1
23 | 2,True,"{'a': 0, 'e': 0}",30,12,31.876211003456852,"{0: 18, 1: 0, 2: 0, 3: 0, 4: 0}",1
24 | 3,True,"{'a': 0, 'e': 0}",25,0,-25.997937197119185,"{0: 23, 1: 0, 2: 0, 3: 1, 4: 1}",0
25 | 4,True,"{'a': 0, 'e': 0}",20,9,19.688356652330302,"{0: 11, 1: 0, 2: 0, 3: 0, 4: 0}",1
26 | 5,True,"{'a': 0, 'e': 0}",25,12,25.61534719374939,"{0: 13, 1: 0, 2: 0, 3: 0, 4: 0}",1
27 | 6,True,"{'a': 0, 'e': 0}",35,12,44.99543335703007,"{0: 23, 1: 0, 2: 0, 3: 0, 4: 0}",1
28 | 7,True,"{'a': 0, 'e': 0}",25,15,18.10519012016503,"{0: 10, 1: 0, 2: 0, 3: 0, 4: 0}",1
29 | 8,True,"{'a': 0, 'e': 0}",25,9,6.303381484449328,"{0: 15, 1: 0, 2: 0, 3: 1, 4: 0}",1
30 | 9,True,"{'a': 0, 'e': 0}",25,15,22.30262088856527,"{0: 10, 1: 0, 2: 0, 3: 0, 4: 0}",1
31 | 10,True,"{'a': 0, 'e': 0}",25,3,31.601188162330747,"{0: 22, 1: 0, 2: 0, 3: 0, 4: 0}",1
32 |
--------------------------------------------------------------------------------
/smartcab/logs/sim_default-learning.txt:
--------------------------------------------------------------------------------
1 | /-----------------------------------------
2 | | State-action rewards from Q-Learning
3 | \-----------------------------------------
4 |
5 | ('right', 'green', False, False)
6 | -- forward : 0.90
7 | -- None : -3.63
8 | -- right : 1.72
9 | -- left : 0.49
10 |
11 | ('left', 'red', False, False)
12 | -- forward : 0.00
13 | -- None : 1.83
14 | -- right : 1.09
15 | -- left : -10.26
16 |
17 | ('right', 'green', True, False)
18 | -- forward : 0.28
19 | -- None : -3.54
20 | -- right : 1.72
21 | -- left : -9.57
22 |
23 | ('left', 'red', True, False)
24 | -- forward : -8.90
25 | -- None : 1.10
26 | -- right : 0.36
27 | -- left : 0.00
28 |
29 | ('left', 'green', True, True)
30 | -- forward : 0.00
31 | -- None : 0.00
32 | -- right : 0.06
33 | -- left : 0.00
34 |
35 | ('forward', 'green', True, True)
36 | -- forward : 0.00
37 | -- None : 0.00
38 | -- right : 0.15
39 | -- left : 0.00
40 |
41 | ('forward', 'red', False, True)
42 | -- forward : 0.00
43 | -- None : 1.73
44 | -- right : -9.64
45 | -- left : -29.85
46 |
47 | ('right', 'green', False, True)
48 | -- forward : 0.70
49 | -- None : 0.00
50 | -- right : 0.00
51 | -- left : 0.05
52 |
53 | ('right', 'red', True, True)
54 | -- forward : 0.00
55 | -- None : 2.08
56 | -- right : 0.00
57 | -- left : -19.85
58 |
59 | ('left', 'red', False, True)
60 | -- forward : -34.97
61 | -- None : 2.05
62 | -- right : -4.45
63 | -- left : -29.98
64 |
65 | ('left', 'green', False, False)
66 | -- forward : 0.59
67 | -- None : -3.26
68 | -- right : 1.20
69 | -- left : 2.29
70 |
71 | ('left', 'green', True, False)
72 | -- forward : 0.00
73 | -- None : -5.45
74 | -- right : 0.96
75 | -- left : 0.00
76 |
77 | ('forward', 'red', True, False)
78 | -- forward : -5.24
79 | -- None : 1.51
80 | -- right : 0.90
81 | -- left : 0.00
82 |
83 | ('forward', 'red', False, False)
84 | -- forward : -9.56
85 | -- None : 1.76
86 | -- right : 1.21
87 | -- left : -9.67
88 |
89 | ('forward', 'green', False, True)
90 | -- forward : 1.04
91 | -- None : -2.25
92 | -- right : 0.52
93 | -- left : 0.53
94 |
95 | ('right', 'red', False, False)
96 | -- forward : -9.61
97 | -- None : 2.06
98 | -- right : 0.96
99 | -- left : -10.42
100 |
101 | ('right', 'red', True, False)
102 | -- forward : -5.00
103 | -- None : 0.00
104 | -- right : 1.60
105 | -- left : -4.65
106 |
107 | ('right', 'green', True, True)
108 | -- forward : 0.00
109 | -- None : 0.00
110 | -- right : 0.00
111 | -- left : -9.93
112 |
113 | ('left', 'green', False, True)
114 | -- forward : 0.26
115 | -- None : -3.80
116 | -- right : 0.28
117 | -- left : 0.95
118 |
119 | ('left', 'red', True, True)
120 | -- forward : -19.51
121 | -- None : 0.00
122 | -- right : 0.00
123 | -- left : 0.00
124 |
125 | ('forward', 'red', True, True)
126 | -- forward : -19.91
127 | -- None : 0.00
128 | -- right : 0.00
129 | -- left : -30.35
130 |
131 | ('forward', 'green', True, False)
132 | -- forward : 1.63
133 | -- None : -4.39
134 | -- right : 0.00
135 | -- left : -9.75
136 |
137 | ('forward', 'green', False, False)
138 | -- forward : 1.53
139 | -- None : -4.92
140 | -- right : 0.31
141 | -- left : 1.06
142 |
143 | ('right', 'red', False, True)
144 | -- forward : -34.78
145 | -- None : 0.63
146 | -- right : -9.45
147 | -- left : 0.00
148 |
149 |
--------------------------------------------------------------------------------
/smartcab/logs/sim_no-learning.csv:
--------------------------------------------------------------------------------
1 | trial,testing,parameters,initial_deadline,final_deadline,net_reward,actions,success
2 | 1,False,"{'a': 0.5, 'e': 1.0}",20,0,-93.26123797920842,"{0: 12, 1: 2, 2: 4, 3: 1, 4: 1}",0
3 | 2,False,"{'a': 0.5, 'e': 1.0}",20,0,-93.77395441962415,"{0: 13, 1: 4, 2: 1, 3: 0, 4: 2}",0
4 | 3,False,"{'a': 0.5, 'e': 1.0}",20,0,-113.51560440630828,"{0: 14, 1: 0, 2: 3, 3: 1, 4: 2}",0
5 | 4,False,"{'a': 0.5, 'e': 1.0}",25,7,-18.289133567400253,"{0: 15, 1: 1, 2: 1, 3: 1, 4: 0}",1
6 | 5,False,"{'a': 0.5, 'e': 1.0}",20,0,-92.99727611241221,"{0: 12, 1: 2, 2: 4, 3: 1, 4: 1}",0
7 | 6,False,"{'a': 0.5, 'e': 1.0}",30,0,-217.88590661448606,"{0: 14, 1: 3, 2: 10, 3: 0, 4: 3}",0
8 | 7,False,"{'a': 0.5, 'e': 1.0}",25,0,-84.28899817200508,"{0: 16, 1: 3, 2: 5, 3: 0, 4: 1}",0
9 | 8,False,"{'a': 0.5, 'e': 1.0}",20,0,-54.38409527191148,"{0: 11, 1: 5, 2: 4, 3: 0, 4: 0}",0
10 | 9,False,"{'a': 0.5, 'e': 1.0}",30,0,-160.13202618839966,"{0: 20, 1: 3, 2: 3, 3: 1, 4: 3}",0
11 | 10,False,"{'a': 0.5, 'e': 1.0}",25,0,-78.39060580203846,"{0: 20, 1: 1, 2: 1, 3: 2, 4: 1}",0
12 | 11,False,"{'a': 0.5, 'e': 1.0}",20,0,-32.28564143160174,"{0: 15, 1: 1, 2: 4, 3: 0, 4: 0}",0
13 | 12,False,"{'a': 0.5, 'e': 1.0}",30,0,-251.8830454840801,"{0: 18, 1: 1, 2: 4, 3: 3, 4: 4}",0
14 | 13,False,"{'a': 0.5, 'e': 1.0}",20,0,-101.42265402380453,"{0: 11, 1: 1, 2: 7, 3: 0, 4: 1}",0
15 | 14,False,"{'a': 0.5, 'e': 1.0}",30,2,-244.1098376976274,"{0: 14, 1: 2, 2: 7, 3: 1, 4: 4}",1
16 | 15,False,"{'a': 0.5, 'e': 1.0}",20,0,-78.52448372238419,"{0: 14, 1: 1, 2: 3, 3: 1, 4: 1}",0
17 | 16,False,"{'a': 0.5, 'e': 1.0}",25,0,-112.07939409851832,"{0: 17, 1: 1, 2: 5, 3: 0, 4: 2}",0
18 | 17,False,"{'a': 0.5, 'e': 1.0}",20,0,-68.47204830674838,"{0: 9, 1: 8, 2: 2, 3: 1, 4: 0}",0
19 | 18,False,"{'a': 0.5, 'e': 1.0}",25,0,-138.69640533540485,"{0: 16, 1: 0, 2: 6, 3: 1, 4: 2}",0
20 | 19,False,"{'a': 0.5, 'e': 1.0}",25,17,3.0168561933500504,"{0: 7, 1: 0, 2: 1, 3: 0, 4: 0}",1
21 | 20,False,"{'a': 0.5, 'e': 1.0}",25,0,-121.22725948341478,"{0: 17, 1: 2, 2: 3, 3: 1, 4: 2}",0
22 | 1,True,"{'a': 0.5, 'e': 1.0}",20,0,-18.31492248877889,"{0: 15, 1: 3, 2: 2, 3: 0, 4: 0}",0
23 | 2,True,"{'a': 0.5, 'e': 1.0}",20,0,-123.63704697875012,"{0: 9, 1: 2, 2: 8, 3: 0, 4: 1}",0
24 | 3,True,"{'a': 0.5, 'e': 1.0}",20,0,-58.23968261471943,"{0: 14, 1: 2, 2: 3, 3: 0, 4: 1}",0
25 | 4,True,"{'a': 0.5, 'e': 1.0}",20,0,-44.78428957901549,"{0: 16, 1: 2, 2: 1, 3: 0, 4: 1}",0
26 | 5,True,"{'a': 0.5, 'e': 1.0}",25,0,-56.78857487544803,"{0: 18, 1: 0, 2: 6, 3: 1, 4: 0}",0
27 | 6,True,"{'a': 0.5, 'e': 1.0}",30,0,-234.07867260270754,"{0: 16, 1: 3, 2: 6, 3: 1, 4: 4}",0
28 | 7,True,"{'a': 0.5, 'e': 1.0}",25,0,-143.1088671029654,"{0: 13, 1: 4, 2: 6, 3: 0, 4: 2}",0
29 | 8,True,"{'a': 0.5, 'e': 1.0}",20,0,-182.78659171952157,"{0: 9, 1: 3, 2: 4, 3: 1, 4: 3}",0
30 | 9,True,"{'a': 0.5, 'e': 1.0}",25,5,-69.11794536993459,"{0: 14, 1: 0, 2: 5, 3: 0, 4: 1}",1
31 | 10,True,"{'a': 0.5, 'e': 1.0}",20,5,-35.64511240661868,"{0: 13, 1: 0, 2: 1, 3: 0, 4: 1}",1
32 |
--------------------------------------------------------------------------------
/smartcab/project_description.md:
--------------------------------------------------------------------------------
1 | # Content: Reinforcement Learning
2 | ## Project: Train a Smartcab How to Drive
3 |
4 | ## Project Overview
5 |
6 | In this project you will apply reinforcement learning techniques for a self-driving agent in a simplified world to aid it in effectively reaching its destinations in the allotted time. You will first investigate the environment the agent operates in by constructing a very basic driving implementation. Once your agent is successful at operating within the environment, you will then identify each possible state the agent can be in when considering such things as traffic lights and oncoming traffic at each intersection. With states identified, you will then implement a Q-Learning algorithm for the self-driving agent to guide the agent towards its destination within the allotted time. Finally, you will improve upon the Q-Learning algorithm to find the best configuration of learning and exploration factors to ensure the self-driving agent is reaching its destinations with consistently positive results.
7 |
8 | ## Description
9 | In the not-so-distant future, taxicab companies across the United States no longer employ human drivers to operate their fleet of vehicles. Instead, the taxicabs are operated by self-driving agents, known as *smartcabs*, to transport people from one location to another within the cities those companies operate. In major metropolitan areas, such as Chicago, New York City, and San Francisco, an increasing number of people have come to depend on *smartcabs* to get to where they need to go as safely and reliably as possible. Although *smartcabs* have become the transport of choice, concerns have arose that a self-driving agent might not be as safe or reliable as human drivers, particularly when considering city traffic lights and other vehicles. To alleviate these concerns, your task as an employee for a national taxicab company is to use reinforcement learning techniques to construct a demonstration of a *smartcab* operating in real-time to prove that both safety and reliability can be achieved.
10 |
11 | ## Software Requirements
12 | This project uses the following software and Python libraries:
13 |
14 | - [Python 2.7](https://www.python.org/download/releases/2.7/)
15 | - [NumPy](http://www.numpy.org/)
16 | - [pandas](http://pandas.pydata.org/)
17 | - [matplotlib](http://matplotlib.org/)
18 | - [PyGame](http://pygame.org/)
19 |
20 | If you do not have Python installed yet, it is highly recommended that you install the [Anaconda](http://continuum.io/downloads) distribution of Python, which already has the above packages and more included. Make sure that you select the Python 2.7 installer and not the Python 3.x installer. `pygame` can then be installed using one of the following commands:
21 |
22 | Mac: `conda install -c https://conda.anaconda.org/quasiben pygame`
23 | Windows: `conda install -c https://conda.anaconda.org/tlatorre pygame`
24 | Linux: `conda install -c https://conda.anaconda.org/prkrekel pygame`
25 |
26 | ## Fixing Common PyGame Problems
27 |
28 | The PyGame library can in some cases require a bit of troubleshooting to work correctly for this project. While the PyGame aspect of the project is not required for a successful submission (you can complete the project without a visual simulation, although it is more difficult), it is very helpful to have it working! If you encounter an issue with PyGame, first see these helpful links below that are developed by communities of users working with the library:
29 | - [Getting Started](https://www.pygame.org/wiki/GettingStarted)
30 | - [PyGame Information](http://www.pygame.org/wiki/info)
31 | - [Google Group](https://groups.google.com/forum/#!forum/pygame-mirror-on-google-groups)
32 | - [PyGame subreddit](https://www.reddit.com/r/pygame/)
33 |
34 | ### Problems most often reported by students
35 | _"PyGame won't install on my machine; there was an issue with the installation."_
36 | **Solution:** As has been recommended for previous projects, Udacity suggests that you are using the Anaconda distribution of Python, which can then allow you to install PyGame through the `conda`-specific command.
37 |
38 | _"I'm seeing a black screen when running the code; output says that it can't load car images."_
39 | **Solution:** The code will not operate correctly unless it is run from the top-level directory for `smartcab`. The top-level directory is the one that contains the **README** and the project notebook.
40 |
41 | If you continue to have problems with the project code in regards to PyGame, you can also [use the discussion forums](https://discussions.udacity.com/c/nd009-reinforcement-learning) to find posts from students that encountered issues that you may be experiencing. Additionally, you can seek help from a swath of students in the [MLND Student Slack Community](http://mlnd.slack.com).
42 |
43 | ## Starting the Project
44 |
45 | For this assignment, you can find the `smartcab` folder containing the necessary project files on the [Machine Learning projects GitHub](https://github.com/udacity/machine-learning), under the `projects` folder. You may download all of the files for projects we'll use in this Nanodegree program directly from this repo. Please make sure that you use the most recent version of project files when completing a project!
46 |
47 | This project contains three directories:
48 |
49 | - `/logs/`: This folder will contain all log files that are given from the simulation when specific prerequisites are met.
50 | - `/images/`: This folder contains various images of cars to be used in the graphical user interface. You will not need to modify or create any files in this directory.
51 | - `/smartcab/`: This folder contains the Python scripts that create the environment, graphical user interface, the simulation, and the agents. You will not need to modify or create any files in this directory except for `agent.py`.
52 |
53 | It also contains two files:
54 | - `smartcab.ipynb`: This is the main file where you will answer questions and provide an analysis for your work.
55 | -`visuals.py`: This Python script provides supplementary visualizations for the analysis. Do not modify.
56 |
57 | Finally, in `/smartcab/` are the following four files:
58 | - **Modify:**
59 | - `agent.py`: This is the main Python file where you will be performing your work on the project.
60 | - **Do not modify:**
61 | - `environment.py`: This Python file will create the *smartcab* environment.
62 | - `planner.py`: This Python file creates a high-level planner for the agent to follow towards a set goal.
63 | - `simulation.py`: This Python file creates the simulation and graphical user interface.
64 |
65 | ### Running the Code
66 | In a terminal or command window, navigate to the top-level project directory `smartcab/` (that contains the two project directories) and run one of the following commands:
67 |
68 | `python smartcab/agent.py` or
69 | `python -m smartcab.agent`
70 |
71 | This will run the `agent.py` file and execute your implemented agent code into the environment. Additionally, use the command `jupyter notebook smartcab.ipynb` from this same directory to open up a browser window or tab to work with your analysis notebook. Alternatively, you can use the command `jupyter notebook` or `ipython notebook` and navigate to the notebook file in the browser window that opens. Follow the instructions in the notebook and answer each question presented to successfully complete the implementation necessary for your `agent.py` agent file. A **README** file has also been provided with the project files which may contain additional necessary information or instruction for the project.
72 |
73 | ## Definitions
74 |
75 | ### Environment
76 | The *smartcab* operates in an ideal, grid-like city (similar to New York City), with roads going in the North-South and East-West directions. Other vehicles will certainly be present on the road, but there will be no pedestrians to be concerned with. At each intersection there is a traffic light that either allows traffic in the North-South direction or the East-West direction. U.S. Right-of-Way rules apply:
77 | - On a green light, a left turn is permitted if there is no oncoming traffic making a right turn or coming straight through the intersection.
78 | - On a red light, a right turn is permitted if no oncoming traffic is approaching from your left through the intersection.
79 | To understand how to correctly yield to oncoming traffic when turning left, you may refer to [this official drivers? education video](https://www.youtube.com/watch?v=TW0Eq2Q-9Ac), or [this passionate exposition](https://www.youtube.com/watch?v=0EdkxI6NeuA).
80 |
81 | ### Inputs and Outputs
82 | Assume that the *smartcab* is assigned a route plan based on the passengers? starting location and destination. The route is split at each intersection into waypoints, and you may assume that the *smartcab*, at any instant, is at some intersection in the world. Therefore, the next waypoint to the destination, assuming the destination has not already been reached, is one intersection away in one direction (North, South, East, or West). The *smartcab* has only an egocentric view of the intersection it is at: It can determine the state of the traffic light for its direction of movement, and whether there is a vehicle at the intersection for each of the oncoming directions. For each action, the *smartcab* may either idle at the intersection, or drive to the next intersection to the left, right, or ahead of it. Finally, each trip has a time to reach the destination which decreases for each action taken (the passengers want to get there quickly). If the allotted time becomes zero before reaching the destination, the trip has failed.
83 |
84 | ### Rewards and Goal
85 | The *smartcab* will receive positive or negative rewards based on the action it as taken. Expectedly, the *smartcab* will receive a small positive reward when making a good action, and a varying amount of negative reward dependent on the severity of the traffic violation it would have committed. Based on the rewards and penalties the *smartcab* receives, the self-driving agent implementation should learn an optimal policy for driving on the city roads while obeying traffic rules, avoiding accidents, and reaching passengers? destinations in the allotted time.
86 |
87 | ## Submitting the Project
88 |
89 | ### Evaluation
90 | Your project will be reviewed by a Udacity reviewer against the **Train a Smartcab to Drive project rubric**. Be sure to review this rubric thoroughly and self-evaluate your project before submission. All criteria found in the rubric must be *meeting specifications* for you to pass.
91 |
92 | ### Submission Files
93 | When you are ready to submit your project, collect the following files and compress them into a single archive for upload. Alternatively, you may supply the following files on your GitHub Repo in a folder named `smartcab` for ease of access:
94 | - The `agent.py` Python file with all code implemented as required in the instructed tasks.
95 | - The `/logs/` folder which should contain **five** log files that were produced from your simulation and used in the analysis.
96 | - The `smartcab.ipynb` notebook file with all questions answered and all visualization cells executed and displaying results.
97 | - An **HTML** export of the project notebook with the name **report.html**. This file *must* be present for your project to be evaluated.
98 |
99 | Once you have collected these files and reviewed the project rubric, proceed to the project submission page.
100 |
101 | ### I'm Ready!
102 | When you're ready to submit your project, click on the **Submit Project** button at the bottom of the page.
103 |
104 | If you are having any problems submitting your project or wish to check on the status of your submission, please email us at **machine-support@udacity.com** or visit us in the discussion forums.
105 |
106 | ### What's Next?
107 | You will get an email as soon as your reviewer has feedback for you. In the meantime, review your next project and feel free to get started on it or the courses supporting it!
--------------------------------------------------------------------------------
/smartcab/smartcab.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab.zip
--------------------------------------------------------------------------------
/smartcab/smartcab/agent.py:
--------------------------------------------------------------------------------
1 | import random
2 | import math
3 | from environment import Agent, Environment
4 | from planner import RoutePlanner
5 | from simulator import Simulator
6 |
7 | class LearningAgent(Agent):
8 | """ An agent that learns to drive in the Smartcab world.
9 | This is the object you will be modifying. """
10 |
11 | def __init__(self, env, learning=False, epsilon=1.0, alpha=0.5):
12 | super(LearningAgent, self).__init__(env) # Set the agent in the evironment
13 | self.planner = RoutePlanner(self.env, self) # Create a route planner
14 | self.valid_actions = self.env.valid_actions # The set of valid actions
15 |
16 | # Set parameters of the learning agent
17 | self.learning = learning # Whether the agent is expected to learn
18 | self.Q = dict() # Create a Q-table which will be a dictionary of tuples
19 | self.epsilon = epsilon # Random exploration factor
20 | self.alpha = alpha # Learning factor
21 |
22 | ###########
23 | ## TO DO ##
24 | ###########
25 | # Set any additional class parameters as needed
26 |
27 |
28 | def reset(self, destination=None, testing=False):
29 | """ The reset function is called at the beginning of each trial.
30 | 'testing' is set to True if testing trials are being used
31 | once training trials have completed. """
32 |
33 | # Select the destination as the new location to route to
34 | self.planner.route_to(destination)
35 |
36 | ###########
37 | ## TO DO ##
38 | ###########
39 | # Update epsilon using a decay function of your choice
40 | # Update additional class parameters as needed
41 | # If 'testing' is True, set epsilon and alpha to 0
42 |
43 | return None
44 |
45 | def build_state(self):
46 | """ The build_state function is called when the agent requests data from the
47 | environment. The next waypoint, the intersection inputs, and the deadline
48 | are all features available to the agent. """
49 |
50 | # Collect data about the environment
51 | waypoint = self.planner.next_waypoint() # The next waypoint
52 | inputs = self.env.sense(self) # Visual input - intersection light and traffic
53 | deadline = self.env.get_deadline(self) # Remaining deadline
54 |
55 | ###########
56 | ## TO DO ##
57 | ###########
58 | # Set 'state' as a tuple of relevant data for the agent
59 | state = None
60 |
61 | return state
62 |
63 |
64 | def get_maxQ(self, state):
65 | """ The get_max_Q function is called when the agent is asked to find the
66 | maximum Q-value of all actions based on the 'state' the smartcab is in. """
67 |
68 | ###########
69 | ## TO DO ##
70 | ###########
71 | # Calculate the maximum Q-value of all actions for a given state
72 |
73 | maxQ = None
74 |
75 | return maxQ
76 |
77 |
78 | def createQ(self, state):
79 | """ The createQ function is called when a state is generated by the agent. """
80 |
81 | ###########
82 | ## TO DO ##
83 | ###########
84 | # When learning, check if the 'state' is not in the Q-table
85 | # If it is not, create a new dictionary for that state
86 | # Then, for each action available, set the initial Q-value to 0.0
87 |
88 | return
89 |
90 |
91 | def choose_action(self, state):
92 | """ The choose_action function is called when the agent is asked to choose
93 | which action to take, based on the 'state' the smartcab is in. """
94 |
95 | # Set the agent state and default action
96 | self.state = state
97 | self.next_waypoint = self.planner.next_waypoint()
98 | action = None
99 |
100 | ###########
101 | ## TO DO ##
102 | ###########
103 | # When not learning, choose a random action
104 | # When learning, choose a random action with 'epsilon' probability
105 | # Otherwise, choose an action with the highest Q-value for the current state
106 |
107 | return action
108 |
109 |
110 | def learn(self, state, action, reward):
111 | """ The learn function is called after the agent completes an action and
112 | receives an award. This function does not consider future rewards
113 | when conducting learning. """
114 |
115 | ###########
116 | ## TO DO ##
117 | ###########
118 | # When learning, implement the value iteration update rule
119 | # Use only the learning rate 'alpha' (do not use the discount factor 'gamma')
120 |
121 | return
122 |
123 |
124 | def update(self):
125 | """ The update function is called when a time step is completed in the
126 | environment for a given trial. This function will build the agent
127 | state, choose an action, receive a reward, and learn if enabled. """
128 |
129 | state = self.build_state() # Get current state
130 | self.createQ(state) # Create 'state' in Q-table
131 | action = self.choose_action(state) # Choose an action
132 | reward = self.env.act(self, action) # Receive a reward
133 | self.learn(state, action, reward) # Q-learn
134 |
135 | return
136 |
137 |
138 | def run():
139 | """ Driving function for running the simulation.
140 | Press ESC to close the simulation, or [SPACE] to pause the simulation. """
141 |
142 | ##############
143 | # Create the environment
144 | # Flags:
145 | # verbose - set to True to display additional output from the simulation
146 | # num_dummies - discrete number of dummy agents in the environment, default is 100
147 | # grid_size - discrete number of intersections (columns, rows), default is (8, 6)
148 | env = Environment()
149 |
150 | ##############
151 | # Create the driving agent
152 | # Flags:
153 | # learning - set to True to force the driving agent to use Q-learning
154 | # * epsilon - continuous value for the exploration factor, default is 1
155 | # * alpha - continuous value for the learning rate, default is 0.5
156 | agent = env.create_agent(LearningAgent)
157 |
158 | ##############
159 | # Follow the driving agent
160 | # Flags:
161 | # enforce_deadline - set to True to enforce a deadline metric
162 | env.set_primary_agent(agent)
163 |
164 | ##############
165 | # Create the simulation
166 | # Flags:
167 | # update_delay - continuous time (in seconds) between actions, default is 2.0 seconds
168 | # display - set to False to disable the GUI if PyGame is enabled
169 | # log_metrics - set to True to log trial and simulation results to /logs
170 | # optimized - set to True to change the default log file name
171 | sim = Simulator(env)
172 |
173 | ##############
174 | # Run the simulator
175 | # Flags:
176 | # tolerance - epsilon tolerance before beginning testing, default is 0.05
177 | # n_test - discrete number of testing trials to perform, default is 0
178 | sim.run()
179 |
180 |
181 | if __name__ == '__main__':
182 | run()
--------------------------------------------------------------------------------
/smartcab/smartcab/environment.py:
--------------------------------------------------------------------------------
1 | import time
2 | import random
3 | import math
4 | from collections import OrderedDict
5 | from simulator import Simulator
6 |
7 |
8 | class TrafficLight(object):
9 | """A traffic light that switches periodically."""
10 |
11 | valid_states = [True, False] # True = NS open; False = EW open
12 |
13 | def __init__(self, state=None, period=None):
14 | self.state = state if state is not None else random.choice(self.valid_states)
15 | self.period = period if period is not None else random.choice([2, 3, 4, 5])
16 | self.last_updated = 0
17 |
18 | def reset(self):
19 | self.last_updated = 0
20 |
21 | def update(self, t):
22 | if t - self.last_updated >= self.period:
23 | self.state = not self.state # Assuming state is boolean
24 | self.last_updated = t
25 |
26 |
27 | class Environment(object):
28 | """Environment within which all agents operate."""
29 |
30 | valid_actions = [None, 'forward', 'left', 'right']
31 | valid_inputs = {'light': TrafficLight.valid_states, 'oncoming': valid_actions, 'left': valid_actions, 'right': valid_actions}
32 | valid_headings = [(1, 0), (0, -1), (-1, 0), (0, 1)] # E, N, W, S
33 | hard_time_limit = -100 # Set a hard time limit even if deadline is not enforced.
34 |
35 | def __init__(self, verbose=False, num_dummies=100, grid_size = (8, 6)):
36 | self.num_dummies = num_dummies # Number of dummy driver agents in the environment
37 | self.verbose = verbose # If debug output should be given
38 |
39 | # Initialize simulation variables
40 | self.done = False
41 | self.t = 0
42 | self.agent_states = OrderedDict()
43 | self.step_data = {}
44 | self.success = None
45 |
46 | # Road network
47 | self.grid_size = grid_size # (columns, rows)
48 | self.bounds = (1, 2, self.grid_size[0], self.grid_size[1] + 1)
49 | self.block_size = 100
50 | self.hang = 0.6
51 | self.intersections = OrderedDict()
52 | self.roads = []
53 | for x in xrange(self.bounds[0], self.bounds[2] + 1):
54 | for y in xrange(self.bounds[1], self.bounds[3] + 1):
55 | self.intersections[(x, y)] = TrafficLight() # A traffic light at each intersection
56 |
57 | for a in self.intersections:
58 | for b in self.intersections:
59 | if a == b:
60 | continue
61 | if (abs(a[0] - b[0]) + abs(a[1] - b[1])) == 1: # L1 distance = 1
62 | self.roads.append((a, b))
63 |
64 | # Add environment boundaries
65 | for x in xrange(self.bounds[0], self.bounds[2] + 1):
66 | self.roads.append(((x, self.bounds[1] - self.hang), (x, self.bounds[1])))
67 | self.roads.append(((x, self.bounds[3] + self.hang), (x, self.bounds[3])))
68 | for y in xrange(self.bounds[1], self.bounds[3] + 1):
69 | self.roads.append(((self.bounds[0] - self.hang, y), (self.bounds[0], y)))
70 | self.roads.append(((self.bounds[2] + self.hang, y), (self.bounds[2], y)))
71 |
72 | # Create dummy agents
73 | for i in xrange(self.num_dummies):
74 | self.create_agent(DummyAgent)
75 |
76 | # Primary agent and associated parameters
77 | self.primary_agent = None # to be set explicitly
78 | self.enforce_deadline = False
79 |
80 | # Trial data (updated at the end of each trial)
81 | self.trial_data = {
82 | 'testing': False, # if the trial is for testing a learned policy
83 | 'initial_distance': 0, # L1 distance from start to destination
84 | 'initial_deadline': 0, # given deadline (time steps) to start with
85 | 'net_reward': 0.0, # total reward earned in current trial
86 | 'final_deadline': None, # deadline value (time remaining) at the end
87 | 'actions': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}, # violations and accidents
88 | 'success': 0 # whether the agent reached the destination in time
89 | }
90 |
91 | def create_agent(self, agent_class, *args, **kwargs):
92 | """ When called, create_agent creates an agent in the environment. """
93 |
94 | agent = agent_class(self, *args, **kwargs)
95 | self.agent_states[agent] = {'location': random.choice(self.intersections.keys()), 'heading': (0, 1)}
96 | return agent
97 |
98 | def set_primary_agent(self, agent, enforce_deadline=False):
99 | """ When called, set_primary_agent sets 'agent' as the primary agent.
100 | The primary agent is the smartcab that is followed in the environment. """
101 |
102 | self.primary_agent = agent
103 | agent.primary_agent = True
104 | self.enforce_deadline = enforce_deadline
105 |
106 | def reset(self, testing=False):
107 | """ This function is called at the beginning of a new trial. """
108 |
109 | self.done = False
110 | self.t = 0
111 |
112 | # Reset status text
113 | self.step_data = {}
114 |
115 | # Reset traffic lights
116 | for traffic_light in self.intersections.itervalues():
117 | traffic_light.reset()
118 |
119 | # Pick a start and a destination
120 | start = random.choice(self.intersections.keys())
121 | destination = random.choice(self.intersections.keys())
122 |
123 | # Ensure starting location and destination are not too close
124 | while self.compute_dist(start, destination) < 4:
125 | start = random.choice(self.intersections.keys())
126 | destination = random.choice(self.intersections.keys())
127 |
128 | start_heading = random.choice(self.valid_headings)
129 | distance = self.compute_dist(start, destination)
130 | deadline = distance * 5 # 5 time steps per intersection away
131 | if(self.verbose == True): # Debugging
132 | print "Environment.reset(): Trial set up with start = {}, destination = {}, deadline = {}".format(start, destination, deadline)
133 |
134 | # Create a map of all possible initial positions
135 | positions = dict()
136 | for location in self.intersections:
137 | positions[location] = list()
138 | for heading in self.valid_headings:
139 | positions[location].append(heading)
140 |
141 | # Initialize agent(s)
142 | for agent in self.agent_states.iterkeys():
143 |
144 | if agent is self.primary_agent:
145 | self.agent_states[agent] = {
146 | 'location': start,
147 | 'heading': start_heading,
148 | 'destination': destination,
149 | 'deadline': deadline
150 | }
151 | # For dummy agents, make them choose one of the available
152 | # intersections and headings still in 'positions'
153 | else:
154 | intersection = random.choice(positions.keys())
155 | heading = random.choice(positions[intersection])
156 | self.agent_states[agent] = {
157 | 'location': intersection,
158 | 'heading': heading,
159 | 'destination': None,
160 | 'deadline': None
161 | }
162 | # Now delete the taken location and heading from 'positions'
163 | positions[intersection] = list(set(positions[intersection]) - set([heading]))
164 | if positions[intersection] == list(): # No headings available for intersection
165 | del positions[intersection] # Delete the intersection altogether
166 |
167 |
168 | agent.reset(destination=(destination if agent is self.primary_agent else None), testing=testing)
169 | if agent is self.primary_agent:
170 | # Reset metrics for this trial (step data will be set during the step)
171 | self.trial_data['testing'] = testing
172 | self.trial_data['initial_deadline'] = deadline
173 | self.trial_data['final_deadline'] = deadline
174 | self.trial_data['net_reward'] = 0.0
175 | self.trial_data['actions'] = {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}
176 | self.trial_data['parameters'] = {'e': agent.epsilon, 'a': agent.alpha}
177 | self.trial_data['success'] = 0
178 |
179 | def step(self):
180 | """ This function is called when a time step is taken turing a trial. """
181 |
182 | # Pretty print to terminal
183 | print ""
184 | print "/-------------------"
185 | print "| Step {} Results".format(self.t)
186 | print "\-------------------"
187 | print ""
188 |
189 | if(self.verbose == True): # Debugging
190 | print "Environment.step(): t = {}".format(self.t)
191 |
192 | # Update agents, primary first
193 | if self.primary_agent is not None:
194 | self.primary_agent.update()
195 |
196 | for agent in self.agent_states.iterkeys():
197 | if agent is not self.primary_agent:
198 | agent.update()
199 |
200 | # Update traffic lights
201 | for intersection, traffic_light in self.intersections.iteritems():
202 | traffic_light.update(self.t)
203 |
204 | if self.primary_agent is not None:
205 | # Agent has taken an action: reduce the deadline by 1
206 | agent_deadline = self.agent_states[self.primary_agent]['deadline'] - 1
207 | self.agent_states[self.primary_agent]['deadline'] = agent_deadline
208 |
209 | if agent_deadline <= self.hard_time_limit:
210 | self.done = True
211 | self.success = False
212 | if self.verbose: # Debugging
213 | print "Environment.step(): Primary agent hit hard time limit ({})! Trial aborted.".format(self.hard_time_limit)
214 | elif self.enforce_deadline and agent_deadline <= 0:
215 | self.done = True
216 | self.success = False
217 | if self.verbose: # Debugging
218 | print "Environment.step(): Primary agent ran out of time! Trial aborted."
219 |
220 | self.t += 1
221 |
222 | def sense(self, agent):
223 | """ This function is called when information is requested about the sensor
224 | inputs from an 'agent' in the environment. """
225 |
226 | assert agent in self.agent_states, "Unknown agent!"
227 |
228 | state = self.agent_states[agent]
229 | location = state['location']
230 | heading = state['heading']
231 | light = 'green' if (self.intersections[location].state and heading[1] != 0) or ((not self.intersections[location].state) and heading[0] != 0) else 'red'
232 |
233 | # Populate oncoming, left, right
234 | oncoming = None
235 | left = None
236 | right = None
237 | for other_agent, other_state in self.agent_states.iteritems():
238 | if agent == other_agent or location != other_state['location'] or (heading[0] == other_state['heading'][0] and heading[1] == other_state['heading'][1]):
239 | continue
240 | # For dummy agents, ignore the primary agent
241 | # This is because the primary agent is not required to follow the waypoint
242 | if other_agent == self.primary_agent:
243 | continue
244 | other_heading = other_agent.get_next_waypoint()
245 | if (heading[0] * other_state['heading'][0] + heading[1] * other_state['heading'][1]) == -1:
246 | if oncoming != 'left': # we don't want to override oncoming == 'left'
247 | oncoming = other_heading
248 | elif (heading[1] == other_state['heading'][0] and -heading[0] == other_state['heading'][1]):
249 | if right != 'forward' and right != 'left': # we don't want to override right == 'forward or 'left'
250 | right = other_heading
251 | else:
252 | if left != 'forward': # we don't want to override left == 'forward'
253 | left = other_heading
254 |
255 | return {'light': light, 'oncoming': oncoming, 'left': left, 'right': right}
256 |
257 | def get_deadline(self, agent):
258 | """ Returns the deadline remaining for an agent. """
259 |
260 | return self.agent_states[agent]['deadline'] if agent is self.primary_agent else None
261 |
262 | def act(self, agent, action):
263 | """ Consider an action and perform the action if it is legal.
264 | Receive a reward for the agent based on traffic laws. """
265 |
266 | assert agent in self.agent_states, "Unknown agent!"
267 | assert action in self.valid_actions, "Invalid action!"
268 |
269 | state = self.agent_states[agent]
270 | location = state['location']
271 | heading = state['heading']
272 | light = 'green' if (self.intersections[location].state and heading[1] != 0) or ((not self.intersections[location].state) and heading[0] != 0) else 'red'
273 | inputs = self.sense(agent)
274 |
275 | # Assess whether the agent can move based on the action chosen.
276 | # Either the action is okay to perform, or falls under 4 types of violations:
277 | # 0: Action okay
278 | # 1: Minor traffic violation
279 | # 2: Major traffic violation
280 | # 3: Minor traffic violation causing an accident
281 | # 4: Major traffic violation causing an accident
282 | violation = 0
283 |
284 | # Reward scheme
285 | # First initialize reward uniformly random from [-1, 1]
286 | reward = 2 * random.random() - 1
287 |
288 | # Create a penalty factor as a function of remaining deadline
289 | # Scales reward multiplicatively from [0, 1]
290 | fnc = self.t * 1.0 / (self.t + state['deadline']) if agent.primary_agent else 0.0
291 | gradient = 10
292 |
293 | # No penalty given to an agent that has no enforced deadline
294 | penalty = 0
295 |
296 | # If the deadline is enforced, give a penalty based on time remaining
297 | if self.enforce_deadline:
298 | penalty = (math.pow(gradient, fnc) - 1) / (gradient - 1)
299 |
300 | # Agent wants to drive forward:
301 | if action == 'forward':
302 | if light != 'green': # Running red light
303 | violation = 2 # Major violation
304 | if inputs['left'] == 'forward' or inputs['right'] == 'forward': # Cross traffic
305 | violation = 4 # Accident
306 |
307 | # Agent wants to drive left:
308 | elif action == 'left':
309 | if light != 'green': # Running a red light
310 | violation = 2 # Major violation
311 | if inputs['left'] == 'forward' or inputs['right'] == 'forward': # Cross traffic
312 | violation = 4 # Accident
313 | elif inputs['oncoming'] == 'right': # Oncoming car turning right
314 | violation = 4 # Accident
315 | else: # Green light
316 | if inputs['oncoming'] == 'right' or inputs['oncoming'] == 'forward': # Incoming traffic
317 | violation = 3 # Accident
318 | else: # Valid move!
319 | heading = (heading[1], -heading[0])
320 |
321 | # Agent wants to drive right:
322 | elif action == 'right':
323 | if light != 'green' and inputs['left'] == 'forward': # Cross traffic
324 | violation = 3 # Accident
325 | else: # Valid move!
326 | heading = (-heading[1], heading[0])
327 |
328 | # Agent wants to perform no action:
329 | elif action == None:
330 | if light == 'green' and inputs['oncoming'] != 'left': # No oncoming traffic
331 | violation = 1 # Minor violation
332 |
333 |
334 | # Did the agent attempt a valid move?
335 | if violation == 0:
336 | if action == agent.get_next_waypoint(): # Was it the correct action?
337 | reward += 2 - penalty # (2, 1)
338 | elif action == None and light != 'green': # Was the agent stuck at a red light?
339 | reward += 2 - penalty # (2, 1)
340 | else: # Valid but incorrect
341 | reward += 1 - penalty # (1, 0)
342 |
343 | # Move the agent
344 | if action is not None:
345 | location = ((location[0] + heading[0] - self.bounds[0]) % (self.bounds[2] - self.bounds[0] + 1) + self.bounds[0],
346 | (location[1] + heading[1] - self.bounds[1]) % (self.bounds[3] - self.bounds[1] + 1) + self.bounds[1]) # wrap-around
347 | state['location'] = location
348 | state['heading'] = heading
349 | # Agent attempted invalid move
350 | else:
351 | if violation == 1: # Minor violation
352 | reward += -5
353 | elif violation == 2: # Major violation
354 | reward += -10
355 | elif violation == 3: # Minor accident
356 | reward += -20
357 | elif violation == 4: # Major accident
358 | reward += -40
359 |
360 | # Did agent reach the goal after a valid move?
361 | if agent is self.primary_agent:
362 | if state['location'] == state['destination']:
363 | # Did agent get to destination before deadline?
364 | if state['deadline'] >= 0:
365 | self.trial_data['success'] = 1
366 |
367 | # Stop the trial
368 | self.done = True
369 | self.success = True
370 |
371 | if(self.verbose == True): # Debugging
372 | print "Environment.act(): Primary agent has reached destination!"
373 |
374 | if(self.verbose == True): # Debugging
375 | print "Environment.act() [POST]: location: {}, heading: {}, action: {}, reward: {}".format(location, heading, action, reward)
376 |
377 | # Update metrics
378 | self.step_data['t'] = self.t
379 | self.step_data['violation'] = violation
380 | self.step_data['state'] = agent.get_state()
381 | self.step_data['deadline'] = state['deadline']
382 | self.step_data['waypoint'] = agent.get_next_waypoint()
383 | self.step_data['inputs'] = inputs
384 | self.step_data['light'] = light
385 | self.step_data['action'] = action
386 | self.step_data['reward'] = reward
387 |
388 | self.trial_data['final_deadline'] = state['deadline'] - 1
389 | self.trial_data['net_reward'] += reward
390 | self.trial_data['actions'][violation] += 1
391 |
392 | if(self.verbose == True): # Debugging
393 | print "Environment.act(): Step data: {}".format(self.step_data)
394 | return reward
395 |
396 | def compute_dist(self, a, b):
397 | """ Compute the Manhattan (L1) distance of a spherical world. """
398 |
399 | dx1 = abs(b[0] - a[0])
400 | dx2 = abs(self.grid_size[0] - dx1)
401 | dx = dx1 if dx1 < dx2 else dx2
402 |
403 | dy1 = abs(b[1] - a[1])
404 | dy2 = abs(self.
405 | grid_size[1] - dy1)
406 | dy = dy1 if dy1 < dy2 else dy2
407 |
408 | return dx + dy
409 |
410 |
411 | class Agent(object):
412 | """Base class for all agents."""
413 |
414 | def __init__(self, env):
415 | self.env = env
416 | self.state = None
417 | self.next_waypoint = None
418 | self.color = 'white'
419 | self.primary_agent = False
420 |
421 | def reset(self, destination=None, testing=False):
422 | pass
423 |
424 | def update(self):
425 | pass
426 |
427 | def get_state(self):
428 | return self.state
429 |
430 | def get_next_waypoint(self):
431 | return self.next_waypoint
432 |
433 |
434 | class DummyAgent(Agent):
435 | color_choices = ['cyan', 'red', 'blue', 'green', 'orange', 'magenta', 'yellow']
436 |
437 | def __init__(self, env):
438 | super(DummyAgent, self).__init__(env) # sets self.env = env, state = None, next_waypoint = None, and a default color
439 | self.next_waypoint = random.choice(Environment.valid_actions[1:])
440 | self.color = random.choice(self.color_choices)
441 |
442 | def update(self):
443 | """ Update a DummyAgent to move randomly under legal traffic laws. """
444 |
445 | inputs = self.env.sense(self)
446 |
447 | # Check if the chosen waypoint is safe to move to.
448 | action_okay = True
449 | if self.next_waypoint == 'right':
450 | if inputs['light'] == 'red' and inputs['left'] == 'forward':
451 | action_okay = False
452 | elif self.next_waypoint == 'forward':
453 | if inputs['light'] == 'red':
454 | action_okay = False
455 | elif self.next_waypoint == 'left':
456 | if inputs['light'] == 'red' or (inputs['oncoming'] == 'forward' or inputs['oncoming'] == 'right'):
457 | action_okay = False
458 |
459 | # Move to the next waypoint and choose a new one.
460 | action = None
461 | if action_okay:
462 | action = self.next_waypoint
463 | self.next_waypoint = random.choice(Environment.valid_actions[1:])
464 | reward = self.env.act(self, action)
--------------------------------------------------------------------------------
/smartcab/smartcab/environment.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab/environment.pyc
--------------------------------------------------------------------------------
/smartcab/smartcab/planner.py:
--------------------------------------------------------------------------------
1 | import random
2 |
3 | class RoutePlanner(object):
4 | """ Complex route planner that is meant for a perpendicular grid network. """
5 |
6 | def __init__(self, env, agent):
7 | self.env = env
8 | self.agent = agent
9 | self.destination = None
10 |
11 | def route_to(self, destination=None):
12 | """ Select the destination if one is provided, otherwise choose a random intersection. """
13 |
14 | self.destination = destination if destination is not None else random.choice(self.env.intersections.keys())
15 |
16 | def next_waypoint(self):
17 | """ Creates the next waypoint based on current heading, location,
18 | intended destination and L1 distance from destination. """
19 |
20 | # Collect global location details
21 | bounds = self.env.grid_size
22 | location = self.env.agent_states[self.agent]['location']
23 | heading = self.env.agent_states[self.agent]['heading']
24 |
25 | delta_a = (self.destination[0] - location[0], self.destination[1] - location[1])
26 | delta_b = (bounds[0] + delta_a[0] if delta_a[0] <= 0 else delta_a[0] - bounds[0], \
27 | bounds[1] + delta_a[1] if delta_a[1] <= 0 else delta_a[1] - bounds[1])
28 |
29 | # Calculate true difference in location based on world-wrap
30 | # This will pre-determine the need for U-turns from improper headings
31 | dx = delta_a[0] if abs(delta_a[0]) < abs(delta_b[0]) else delta_b[0]
32 | dy = delta_a[1] if abs(delta_a[1]) < abs(delta_b[1]) else delta_b[1]
33 |
34 | # First check if destination is at location
35 | if dx == 0 and dy == 0:
36 | return None
37 |
38 | # Next check if destination is cardinally East or West of location
39 | elif dx != 0:
40 |
41 | if dx * heading[0] > 0: # Heading the correct East or West direction
42 | return 'forward'
43 | elif dx * heading[0] < 0 and heading[0] < 0: # Heading West, destination East
44 | if dy > 0: # Destination also to the South
45 | return 'left'
46 | else:
47 | return 'right'
48 | elif dx * heading[0] < 0 and heading[0] > 0: # Heading East, destination West
49 | if dy < 0: # Destination also to the North
50 | return 'left'
51 | else:
52 | return 'right'
53 | elif dx * heading[1] > 0: # Heading North destination West; Heading South destination East
54 | return 'left'
55 | else:
56 | return 'right'
57 |
58 | # Finally, check if destination is cardinally North or South of location
59 | elif dy != 0:
60 |
61 | if dy * heading[1] > 0: # Heading the correct North or South direction
62 | return 'forward'
63 | elif dy * heading[1] < 0 and heading[1] < 0: # Heading North, destination South
64 | if dx < 0: # Destination also to the West
65 | return 'left'
66 | else:
67 | return 'right'
68 | elif dy * heading[1] < 0 and heading[1] > 0: # Heading South, destination North
69 | if dx > 0: # Destination also to the East
70 | return 'left'
71 | else:
72 | return 'right'
73 | elif dy * heading[0] > 0: # Heading West destination North; Heading East destination South
74 | return 'right'
75 | else:
76 | return 'left'
--------------------------------------------------------------------------------
/smartcab/smartcab/planner.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab/planner.pyc
--------------------------------------------------------------------------------
/smartcab/smartcab/simulator.py:
--------------------------------------------------------------------------------
1 | ###########################################
2 | # Suppress matplotlib user warnings
3 | # Necessary for newer version of matplotlib
4 | import warnings
5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
6 | ###########################################
7 |
8 | import os
9 | import time
10 | import random
11 | import importlib
12 | import csv
13 |
14 | class Simulator(object):
15 | """Simulates agents in a dynamic smartcab environment.
16 |
17 | Uses PyGame to display GUI, if available.
18 | """
19 |
20 | colors = {
21 | 'black' : ( 0, 0, 0),
22 | 'white' : (255, 255, 255),
23 | 'red' : (255, 0, 0),
24 | 'green' : ( 0, 255, 0),
25 | 'dgreen' : ( 0, 228, 0),
26 | 'blue' : ( 0, 0, 255),
27 | 'cyan' : ( 0, 200, 200),
28 | 'magenta' : (200, 0, 200),
29 | 'yellow' : (255, 255, 0),
30 | 'mustard' : (200, 200, 0),
31 | 'orange' : (255, 128, 0),
32 | 'maroon' : (200, 0, 0),
33 | 'crimson' : (128, 0, 0),
34 | 'gray' : (155, 155, 155)
35 | }
36 |
37 | def __init__(self, env, size=None, update_delay=2.0, display=True, log_metrics=False, optimized=False):
38 | self.env = env
39 | self.size = size if size is not None else ((self.env.grid_size[0] + 1) * self.env.block_size, (self.env.grid_size[1] + 2) * self.env.block_size)
40 | self.width, self.height = self.size
41 | self.road_width = 44
42 |
43 | self.bg_color = self.colors['gray']
44 | self.road_color = self.colors['black']
45 | self.line_color = self.colors['mustard']
46 | self.boundary = self.colors['black']
47 | self.stop_color = self.colors['crimson']
48 |
49 | self.quit = False
50 | self.start_time = None
51 | self.current_time = 0.0
52 | self.last_updated = 0.0
53 | self.update_delay = update_delay # duration between each step (in seconds)
54 |
55 | self.display = display
56 | if self.display:
57 | try:
58 | self.pygame = importlib.import_module('pygame')
59 | self.pygame.init()
60 | self.screen = self.pygame.display.set_mode(self.size)
61 | self._logo = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "logo.png")), (self.road_width, self.road_width))
62 |
63 | self._ew = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "east-west.png")), (self.road_width, self.road_width))
64 | self._ns = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "north-south.png")), (self.road_width, self.road_width))
65 |
66 | self.frame_delay = max(1, int(self.update_delay * 1000)) # delay between GUI frames in ms (min: 1)
67 | self.agent_sprite_size = (32, 32)
68 | self.primary_agent_sprite_size = (42, 42)
69 | self.agent_circle_radius = 20 # radius of circle, when using simple representation
70 | for agent in self.env.agent_states:
71 | if agent.color == 'white':
72 | agent._sprite = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "car-{}.png".format(agent.color))), self.primary_agent_sprite_size)
73 | else:
74 | agent._sprite = self.pygame.transform.smoothscale(self.pygame.image.load(os.path.join("images", "car-{}.png".format(agent.color))), self.agent_sprite_size)
75 | agent._sprite_size = (agent._sprite.get_width(), agent._sprite.get_height())
76 |
77 | self.font = self.pygame.font.Font(None, 20)
78 | self.paused = False
79 | except ImportError as e:
80 | self.display = False
81 | print "Simulator.__init__(): Unable to import pygame; display disabled.\n{}: {}".format(e.__class__.__name__, e)
82 | except Exception as e:
83 | self.display = False
84 | print "Simulator.__init__(): Error initializing GUI objects; display disabled.\n{}: {}".format(e.__class__.__name__, e)
85 |
86 | # Setup metrics to report
87 | self.log_metrics = log_metrics
88 | self.optimized = optimized
89 |
90 | if self.log_metrics:
91 | a = self.env.primary_agent
92 |
93 | # Set log files
94 | if a.learning:
95 | if self.optimized: # Whether the user is optimizing the parameters and decay functions
96 | self.log_filename = os.path.join("logs", "sim_improved-learning.csv")
97 | self.table_filename = os.path.join("logs","sim_improved-learning.txt")
98 | else:
99 | self.log_filename = os.path.join("logs", "sim_default-learning.csv")
100 | self.table_filename = os.path.join("logs","sim_default-learning.txt")
101 |
102 | self.table_file = open(self.table_filename, 'wb')
103 | else:
104 | self.log_filename = os.path.join("logs", "sim_no-learning.csv")
105 |
106 | self.log_fields = ['trial', 'testing', 'parameters', 'initial_deadline', 'final_deadline', 'net_reward', 'actions', 'success']
107 | self.log_file = open(self.log_filename, 'wb')
108 | self.log_writer = csv.DictWriter(self.log_file, fieldnames=self.log_fields)
109 | self.log_writer.writeheader()
110 |
111 | def run(self, tolerance=0.05, n_test=0):
112 | """ Run a simulation of the environment.
113 |
114 | 'tolerance' is the minimum epsilon necessary to begin testing (if enabled)
115 | 'n_test' is the number of testing trials simulated
116 |
117 | Note that the minimum number of training trials is always 20. """
118 |
119 | self.quit = False
120 |
121 | # Get the primary agent
122 | a = self.env.primary_agent
123 |
124 | total_trials = 1
125 | testing = False
126 | trial = 1
127 |
128 | while True:
129 |
130 | # Flip testing switch
131 | if not testing:
132 | if total_trials > 20: # Must complete minimum 20 training trials
133 | if a.learning:
134 | if a.epsilon < tolerance: # assumes epsilon decays to 0
135 | testing = True
136 | trial = 1
137 | else:
138 | testing = True
139 | trial = 1
140 |
141 | # Break if we've reached the limit of testing trials
142 | else:
143 | if trial > n_test:
144 | break
145 |
146 | # Pretty print to terminal
147 | print
148 | print "/-------------------------"
149 | if testing:
150 | print "| Testing trial {}".format(trial)
151 | else:
152 | print "| Training trial {}".format(trial)
153 |
154 | print "\-------------------------"
155 | print
156 |
157 | self.env.reset(testing)
158 | self.current_time = 0.0
159 | self.last_updated = 0.0
160 | self.start_time = time.time()
161 | while True:
162 | try:
163 | # Update current time
164 | self.current_time = time.time() - self.start_time
165 |
166 | # Handle GUI events
167 | if self.display:
168 | for event in self.pygame.event.get():
169 | if event.type == self.pygame.QUIT:
170 | self.quit = True
171 | elif event.type == self.pygame.KEYDOWN:
172 | if event.key == 27: # Esc
173 | self.quit = True
174 | elif event.unicode == u' ':
175 | self.paused = True
176 |
177 | if self.paused:
178 | self.pause()
179 |
180 | # Update environment
181 | if self.current_time - self.last_updated >= self.update_delay:
182 | self.env.step()
183 | self.last_updated = self.current_time
184 |
185 | # Render text
186 | self.render_text(trial, testing)
187 |
188 | # Render GUI and sleep
189 | if self.display:
190 | self.render(trial, testing)
191 | self.pygame.time.wait(self.frame_delay)
192 |
193 | except KeyboardInterrupt:
194 | self.quit = True
195 | finally:
196 | if self.quit or self.env.done:
197 | break
198 |
199 | if self.quit:
200 | break
201 |
202 | # Collect metrics from trial
203 | if self.log_metrics:
204 | self.log_writer.writerow({
205 | 'trial': trial,
206 | 'testing': self.env.trial_data['testing'],
207 | 'parameters': self.env.trial_data['parameters'],
208 | 'initial_deadline': self.env.trial_data['initial_deadline'],
209 | 'final_deadline': self.env.trial_data['final_deadline'],
210 | 'net_reward': self.env.trial_data['net_reward'],
211 | 'actions': self.env.trial_data['actions'],
212 | 'success': self.env.trial_data['success']
213 | })
214 |
215 | # Trial finished
216 | if self.env.success == True:
217 | print "\nTrial Completed!"
218 | print "Agent reached the destination."
219 | else:
220 | print "\nTrial Aborted!"
221 | print "Agent did not reach the destination."
222 |
223 | # Increment
224 | total_trials = total_trials + 1
225 | trial = trial + 1
226 |
227 | # Clean up
228 | if self.log_metrics:
229 |
230 | if a.learning:
231 | f = self.table_file
232 |
233 | f.write("/-----------------------------------------\n")
234 | f.write("| State-action rewards from Q-Learning\n")
235 | f.write("\-----------------------------------------\n\n")
236 |
237 | for state in a.Q:
238 | f.write("{}\n".format(state))
239 | for action, reward in a.Q[state].iteritems():
240 | f.write(" -- {} : {:.2f}\n".format(action, reward))
241 | f.write("\n")
242 | self.table_file.close()
243 |
244 | self.log_file.close()
245 |
246 | print "\nSimulation ended. . . "
247 |
248 | # Report final metrics
249 | if self.display:
250 | self.pygame.display.quit() # shut down pygame
251 |
252 | def render_text(self, trial, testing=False):
253 | """ This is the non-GUI render display of the simulation.
254 | Simulated trial data will be rendered in the terminal/command prompt. """
255 |
256 | status = self.env.step_data
257 | if status and status['waypoint'] is not None: # Continuing the trial
258 |
259 | # Previous State
260 | if status['state']:
261 | print "Agent previous state: {}".format(status['state'])
262 | else:
263 | print "!! Agent state not been updated!"
264 |
265 | # Result
266 | if status['violation'] == 0: # Legal
267 | if status['waypoint'] == status['action']: # Followed waypoint
268 | print "Agent followed the waypoint {}. (rewarded {:.2f})".format(status['action'], status['reward'])
269 | elif status['action'] == None:
270 | if status['light'] == 'red': # Stuck at red light
271 | print "Agent properly idled at a red light. (rewarded {:.2f})".format(status['reward'])
272 | else:
273 | print "Agent idled at a green light with oncoming traffic. (rewarded {:.2f})".format(status['reward'])
274 | else: # Did not follow waypoint
275 | print "Agent drove {} instead of {}. (rewarded {:.2f})".format(status['action'], status['waypoint'], status['reward'])
276 | else: # Illegal
277 | if status['violation'] == 1: # Minor violation
278 | print "Agent idled at a green light with no oncoming traffic. (rewarded {:.2f})".format(status['reward'])
279 | elif status['violation'] == 2: # Major violation
280 | print "Agent attempted driving {} through a red light. (rewarded {:.2f})".format(status['action'], status['reward'])
281 | elif status['violation'] == 3: # Minor accident
282 | print "Agent attempted driving {} through traffic and cause a minor accident. (rewarded {:.2f})".format(status['action'], status['reward'])
283 | elif status['violation'] == 4: # Major accident
284 | print "Agent attempted driving {} through a red light with traffic and cause a major accident. (rewarded {:.2f})".format(status['action'], status['reward'])
285 |
286 | # Time Remaining
287 | if self.env.enforce_deadline:
288 | time = (status['deadline'] - 1) * 100.0 / (status['t'] + status['deadline'])
289 | print "{:.0f}% of time remaining to reach destination.".format(time)
290 | else:
291 | print "Agent not enforced to meet deadline."
292 |
293 | # Starting new trial
294 | else:
295 | a = self.env.primary_agent
296 | print "Simulating trial. . . "
297 | if a.learning:
298 | print "epsilon = {:.4f}; alpha = {:.4f}".format(a.epsilon, a.alpha)
299 | else:
300 | print "Agent not set to learn."
301 |
302 |
303 | def render(self, trial, testing=False):
304 | """ This is the GUI render display of the simulation.
305 | Supplementary trial data can be found from render_text. """
306 |
307 | # Reset the screen.
308 | self.screen.fill(self.bg_color)
309 |
310 | # Draw elements
311 | # * Static elements
312 |
313 | # Boundary
314 | self.pygame.draw.rect(self.screen, self.boundary, ((self.env.bounds[0] - self.env.hang)*self.env.block_size, (self.env.bounds[1]-self.env.hang)*self.env.block_size, (self.env.bounds[2] + self.env.hang/3)*self.env.block_size, (self.env.bounds[3] - 1 + self.env.hang/3)*self.env.block_size), 4)
315 |
316 | for road in self.env.roads:
317 | # Road
318 | self.pygame.draw.line(self.screen, self.road_color, (road[0][0] * self.env.block_size, road[0][1] * self.env.block_size), (road[1][0] * self.env.block_size, road[1][1] * self.env.block_size), self.road_width)
319 | # Center line
320 | self.pygame.draw.line(self.screen, self.line_color, (road[0][0] * self.env.block_size, road[0][1] * self.env.block_size), (road[1][0] * self.env.block_size, road[1][1] * self.env.block_size), 2)
321 |
322 | for intersection, traffic_light in self.env.intersections.iteritems():
323 | self.pygame.draw.circle(self.screen, self.road_color, (intersection[0] * self.env.block_size, intersection[1] * self.env.block_size), self.road_width/2)
324 |
325 | if traffic_light.state: # North-South is open
326 | self.screen.blit(self._ns,
327 | self.pygame.rect.Rect(intersection[0]*self.env.block_size - self.road_width/2, intersection[1]*self.env.block_size - self.road_width/2, intersection[0]*self.env.block_size + self.road_width, intersection[1]*self.env.block_size + self.road_width/2))
328 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size - self.road_width/2), (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size + self.road_width/2), 2)
329 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size + self.road_width/2 + 1, intersection[1] * self.env.block_size - self.road_width/2), (intersection[0] * self.env.block_size + self.road_width/2 + 1, intersection[1] * self.env.block_size + self.road_width/2), 2)
330 | else:
331 | self.screen.blit(self._ew,
332 | self.pygame.rect.Rect(intersection[0]*self.env.block_size - self.road_width/2, intersection[1]*self.env.block_size - self.road_width/2, intersection[0]*self.env.block_size + self.road_width, intersection[1]*self.env.block_size + self.road_width/2))
333 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size - self.road_width/2), (intersection[0] * self.env.block_size + self.road_width/2, intersection[1] * self.env.block_size - self.road_width/2), 2)
334 | self.pygame.draw.line(self.screen, self.stop_color, (intersection[0] * self.env.block_size + self.road_width/2, intersection[1] * self.env.block_size + self.road_width/2 + 1), (intersection[0] * self.env.block_size - self.road_width/2, intersection[1] * self.env.block_size + self.road_width/2 + 1), 2)
335 |
336 | # * Dynamic elements
337 | self.font = self.pygame.font.Font(None, 20)
338 | for agent, state in self.env.agent_states.iteritems():
339 | # Compute precise agent location here (back from the intersection some)
340 | agent_offset = (2 * state['heading'][0] * self.agent_circle_radius + self.agent_circle_radius * state['heading'][1] * 0.5, \
341 | 2 * state['heading'][1] * self.agent_circle_radius - self.agent_circle_radius * state['heading'][0] * 0.5)
342 |
343 |
344 | agent_pos = (state['location'][0] * self.env.block_size - agent_offset[0], state['location'][1] * self.env.block_size - agent_offset[1])
345 | agent_color = self.colors[agent.color]
346 |
347 | if hasattr(agent, '_sprite') and agent._sprite is not None:
348 | # Draw agent sprite (image), properly rotated
349 | rotated_sprite = agent._sprite if state['heading'] == (1, 0) else self.pygame.transform.rotate(agent._sprite, 180 if state['heading'][0] == -1 else state['heading'][1] * -90)
350 | self.screen.blit(rotated_sprite,
351 | self.pygame.rect.Rect(agent_pos[0] - agent._sprite_size[0] / 2, agent_pos[1] - agent._sprite_size[1] / 2,
352 | agent._sprite_size[0], agent._sprite_size[1]))
353 | else:
354 | # Draw simple agent (circle with a short line segment poking out to indicate heading)
355 | self.pygame.draw.circle(self.screen, agent_color, agent_pos, self.agent_circle_radius)
356 | self.pygame.draw.line(self.screen, agent_color, agent_pos, state['location'], self.road_width)
357 |
358 |
359 | if state['destination'] is not None:
360 | self.screen.blit(self._logo,
361 | self.pygame.rect.Rect(state['destination'][0] * self.env.block_size - self.road_width/2, \
362 | state['destination'][1]*self.env.block_size - self.road_width/2, \
363 | state['destination'][0]*self.env.block_size + self.road_width/2, \
364 | state['destination'][1]*self.env.block_size + self.road_width/2))
365 |
366 | # * Overlays
367 | self.font = self.pygame.font.Font(None, 50)
368 | if testing:
369 | self.screen.blit(self.font.render("Testing Trial %s"%(trial), True, self.colors['black'], self.bg_color), (10, 10))
370 | else:
371 | self.screen.blit(self.font.render("Training Trial %s"%(trial), True, self.colors['black'], self.bg_color), (10, 10))
372 |
373 | self.font = self.pygame.font.Font(None, 30)
374 |
375 | # Status text about each step
376 | status = self.env.step_data
377 | if status:
378 |
379 | # Previous State
380 | if status['state']:
381 | self.screen.blit(self.font.render("Previous State: {}".format(status['state']), True, self.colors['white'], self.bg_color), (350, 10))
382 | if not status['state']:
383 | self.screen.blit(self.font.render("!! Agent state not updated!", True, self.colors['maroon'], self.bg_color), (350, 10))
384 |
385 | # Action
386 | if status['violation'] == 0: # Legal
387 | if status['action'] == None:
388 | self.screen.blit(self.font.render("No action taken. (rewarded {:.2f})".format(status['reward']), True, self.colors['dgreen'], self.bg_color), (350, 40))
389 | else:
390 | self.screen.blit(self.font.render("Agent drove {}. (rewarded {:.2f})".format(status['action'], status['reward']), True, self.colors['dgreen'], self.bg_color), (350, 40))
391 | else: # Illegal
392 | if status['action'] == None:
393 | self.screen.blit(self.font.render("No action taken. (rewarded {:.2f})".format(status['reward']), True, self.colors['maroon'], self.bg_color), (350, 40))
394 | else:
395 | self.screen.blit(self.font.render("{} attempted (rewarded {:.2f})".format(status['action'], status['reward']), True, self.colors['maroon'], self.bg_color), (350, 40))
396 |
397 | # Result
398 | if status['violation'] == 0: # Legal
399 | if status['waypoint'] == status['action']: # Followed waypoint
400 | self.screen.blit(self.font.render("Agent followed the waypoint!", True, self.colors['dgreen'], self.bg_color), (350, 70))
401 | elif status['action'] == None:
402 | if status['light'] == 'red': # Stuck at a red light
403 | self.screen.blit(self.font.render("Agent idled at a red light!", True, self.colors['dgreen'], self.bg_color), (350, 70))
404 | else:
405 | self.screen.blit(self.font.render("Agent idled at a green light with oncoming traffic.", True, self.colors['mustard'], self.bg_color), (350, 70))
406 | else: # Did not follow waypoint
407 | self.screen.blit(self.font.render("Agent did not follow the waypoint.", True, self.colors['mustard'], self.bg_color), (350, 70))
408 | else: # Illegal
409 | if status['violation'] == 1: # Minor violation
410 | self.screen.blit(self.font.render("There was a green light with no oncoming traffic.", True, self.colors['maroon'], self.bg_color), (350, 70))
411 | elif status['violation'] == 2: # Major violation
412 | self.screen.blit(self.font.render("There was a red light with no traffic.", True, self.colors['maroon'], self.bg_color), (350, 70))
413 | elif status['violation'] == 3: # Minor accident
414 | self.screen.blit(self.font.render("There was traffic with right-of-way.", True, self.colors['maroon'], self.bg_color), (350, 70))
415 | elif status['violation'] == 4: # Major accident
416 | self.screen.blit(self.font.render("There was a red light with traffic.", True, self.colors['maroon'], self.bg_color), (350, 70))
417 |
418 | # Time Remaining
419 | if self.env.enforce_deadline:
420 | time = (status['deadline'] - 1) * 100.0 / (status['t'] + status['deadline'])
421 | self.screen.blit(self.font.render("{:.0f}% of time remaining to reach destination.".format(time), True, self.colors['black'], self.bg_color), (350, 100))
422 | else:
423 | self.screen.blit(self.font.render("Agent not enforced to meet deadline.", True, self.colors['black'], self.bg_color), (350, 100))
424 |
425 | # Denote whether a trial was a success or failure
426 | if (state['destination'] != state['location'] and state['deadline'] > 0) or (self.env.enforce_deadline is not True and state['destination'] != state['location']):
427 | self.font = self.pygame.font.Font(None, 40)
428 | if self.env.success == True:
429 | self.screen.blit(self.font.render("Previous Trial: Success", True, self.colors['dgreen'], self.bg_color), (10, 50))
430 | if self.env.success == False:
431 | self.screen.blit(self.font.render("Previous Trial: Failure", True, self.colors['maroon'], self.bg_color), (10, 50))
432 |
433 | if self.env.primary_agent.learning:
434 | self.font = self.pygame.font.Font(None, 22)
435 | self.screen.blit(self.font.render("epsilon = {:.4f}".format(self.env.primary_agent.epsilon), True, self.colors['black'], self.bg_color), (10, 80))
436 | self.screen.blit(self.font.render("alpha = {:.4f}".format(self.env.primary_agent.alpha), True, self.colors['black'], self.bg_color), (10, 95))
437 |
438 | # Reset status text
439 | else:
440 | self.pygame.rect.Rect(350, 10, self.width, 200)
441 | self.font = self.pygame.font.Font(None, 40)
442 | self.screen.blit(self.font.render("Simulating trial. . .", True, self.colors['white'], self.bg_color), (400, 60))
443 |
444 |
445 | # Flip buffers
446 | self.pygame.display.flip()
447 |
448 | def pause(self):
449 | """ When the GUI is enabled, this function will pause the simulation. """
450 |
451 | abs_pause_time = time.time()
452 | self.font = self.pygame.font.Font(None, 30)
453 | pause_text = "Simulation Paused. Press any key to continue. . ."
454 | self.screen.blit(self.font.render(pause_text, True, self.colors['red'], self.bg_color), (400, self.height - 30))
455 | self.pygame.display.flip()
456 | print pause_text
457 | while self.paused:
458 | for event in self.pygame.event.get():
459 | if event.type == self.pygame.KEYDOWN:
460 | self.paused = False
461 | self.pygame.time.wait(self.frame_delay)
462 | self.screen.blit(self.font.render(pause_text, True, self.bg_color, self.bg_color), (400, self.height - 30))
463 | self.start_time += (time.time() - abs_pause_time)
464 |
--------------------------------------------------------------------------------
/smartcab/smartcab/simulator.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab/simulator.pyc
--------------------------------------------------------------------------------
/smartcab/smartcab2.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab2.zip
--------------------------------------------------------------------------------
/smartcab/smartcab3.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/smartcab/smartcab3.zip
--------------------------------------------------------------------------------
/smartcab/visuals.py:
--------------------------------------------------------------------------------
1 | ###########################################
2 | # Suppress matplotlib user warnings
3 | # Necessary for newer version of matplotlib
4 | import warnings
5 | warnings.filterwarnings("ignore", category = UserWarning, module = "matplotlib")
6 | ###########################################
7 | #
8 | # Display inline matplotlib plots with IPython
9 | from IPython import get_ipython
10 | get_ipython().run_line_magic('matplotlib', 'inline')
11 | ###########################################
12 |
13 | import matplotlib.pyplot as plt
14 | import numpy as np
15 | import pandas as pd
16 | import os
17 | import ast
18 |
19 |
20 | def calculate_safety(data):
21 | """ Calculates the safety rating of the smartcab during testing. """
22 |
23 | good_ratio = data['good_actions'].sum() * 1.0 / \
24 | (data['initial_deadline'] - data['final_deadline']).sum()
25 |
26 | if good_ratio == 1: # Perfect driving
27 | return ("A+", "green")
28 | else: # Imperfect driving
29 | if data['actions'].apply(lambda x: ast.literal_eval(x)[4]).sum() > 0: # Major accident
30 | return ("F", "red")
31 | elif data['actions'].apply(lambda x: ast.literal_eval(x)[3]).sum() > 0: # Minor accident
32 | return ("D", "#EEC700")
33 | elif data['actions'].apply(lambda x: ast.literal_eval(x)[2]).sum() > 0: # Major violation
34 | return ("C", "#EEC700")
35 | else: # Minor violation
36 | minor = data['actions'].apply(lambda x: ast.literal_eval(x)[1]).sum()
37 | if minor >= len(data)/2: # Minor violation in at least half of the trials
38 | return ("B", "green")
39 | else:
40 | return ("A", "green")
41 |
42 |
43 | def calculate_reliability(data):
44 | """ Calculates the reliability rating of the smartcab during testing. """
45 |
46 | success_ratio = data['success'].sum() * 1.0 / len(data)
47 |
48 | if success_ratio == 1: # Always meets deadline
49 | return ("A+", "green")
50 | else:
51 | if success_ratio >= 0.90:
52 | return ("A", "green")
53 | elif success_ratio >= 0.80:
54 | return ("B", "green")
55 | elif success_ratio >= 0.70:
56 | return ("C", "#EEC700")
57 | elif success_ratio >= 0.60:
58 | return ("D", "#EEC700")
59 | else:
60 | return ("F", "red")
61 |
62 |
63 | def plot_trials(csv):
64 | """ Plots the data from logged metrics during a simulation."""
65 |
66 | data = pd.read_csv(os.path.join("logs", csv))
67 |
68 | if len(data) < 10:
69 | print "Not enough data collected to create a visualization."
70 | print "At least 20 trials are required."
71 | return
72 |
73 | # Create additional features
74 | data['average_reward'] = (data['net_reward'] / (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
75 | data['reliability_rate'] = (data['success']*100).rolling(window=10, center=False).mean() # compute avg. net reward with window=10
76 | data['good_actions'] = data['actions'].apply(lambda x: ast.literal_eval(x)[0])
77 | data['good'] = (data['good_actions'] * 1.0 / \
78 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
79 | data['minor'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[1]) * 1.0 / \
80 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
81 | data['major'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[2]) * 1.0 / \
82 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
83 | data['minor_acc'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[3]) * 1.0 / \
84 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
85 | data['major_acc'] = (data['actions'].apply(lambda x: ast.literal_eval(x)[4]) * 1.0 / \
86 | (data['initial_deadline'] - data['final_deadline'])).rolling(window=10, center=False).mean()
87 | data['epsilon'] = data['parameters'].apply(lambda x: ast.literal_eval(x)['e'])
88 | data['alpha'] = data['parameters'].apply(lambda x: ast.literal_eval(x)['a'])
89 |
90 |
91 | # Create training and testing subsets
92 | training_data = data[data['testing'] == False]
93 | testing_data = data[data['testing'] == True]
94 |
95 | plt.figure(figsize=(12,8))
96 |
97 |
98 | ###############
99 | ### Average step reward plot
100 | ###############
101 |
102 | ax = plt.subplot2grid((6,6), (0,3), colspan=3, rowspan=2)
103 | ax.set_title("10-Trial Rolling Average Reward per Action")
104 | ax.set_ylabel("Reward per Action")
105 | ax.set_xlabel("Trial Number")
106 | ax.set_xlim((10, len(training_data)))
107 |
108 | # Create plot-specific data
109 | step = training_data[['trial','average_reward']].dropna()
110 |
111 | ax.axhline(xmin = 0, xmax = 1, y = 0, color = 'black', linestyle = 'dashed')
112 | ax.plot(step['trial'], step['average_reward'])
113 |
114 |
115 | ###############
116 | ### Parameters Plot
117 | ###############
118 |
119 | ax = plt.subplot2grid((6,6), (2,3), colspan=3, rowspan=2)
120 |
121 | # Check whether the agent was expected to learn
122 | if csv != 'sim_no-learning.csv':
123 | ax.set_ylabel("Parameter Value")
124 | ax.set_xlabel("Trial Number")
125 | ax.set_xlim((1, len(training_data)))
126 | ax.set_ylim((0, 1.05))
127 |
128 | ax.plot(training_data['trial'], training_data['epsilon'], color='blue', label='Exploration factor')
129 | ax.plot(training_data['trial'], training_data['alpha'], color='green', label='Learning factor')
130 |
131 | ax.legend(bbox_to_anchor=(0.5,1.19), fancybox=True, ncol=2, loc='upper center', fontsize=10)
132 |
133 | else:
134 | ax.axis('off')
135 | ax.text(0.52, 0.30, "Simulation completed\nwith learning disabled.", fontsize=24, ha='center', style='italic')
136 |
137 |
138 | ###############
139 | ### Bad Actions Plot
140 | ###############
141 |
142 | actions = training_data[['trial','good', 'minor','major','minor_acc','major_acc']].dropna()
143 | maximum = (1 - actions['good']).values.max()
144 |
145 | ax = plt.subplot2grid((6,6), (0,0), colspan=3, rowspan=4)
146 | ax.set_title("10-Trial Rolling Relative Frequency of Bad Actions")
147 | ax.set_ylabel("Relative Frequency")
148 | ax.set_xlabel("Trial Number")
149 |
150 | ax.set_ylim((0, maximum + 0.01))
151 | ax.set_xlim((10, len(training_data)))
152 |
153 | ax.set_yticks(np.linspace(0, maximum+0.01, 10))
154 |
155 | ax.plot(actions['trial'], (1 - actions['good']), color='black', label='Total Bad Actions', linestyle='dotted', linewidth=3)
156 | ax.plot(actions['trial'], actions['minor'], color='orange', label='Minor Violation', linestyle='dashed')
157 | ax.plot(actions['trial'], actions['major'], color='orange', label='Major Violation', linewidth=2)
158 | ax.plot(actions['trial'], actions['minor_acc'], color='red', label='Minor Accident', linestyle='dashed')
159 | ax.plot(actions['trial'], actions['major_acc'], color='red', label='Major Accident', linewidth=2)
160 |
161 | ax.legend(loc='upper right', fancybox=True, fontsize=10)
162 |
163 |
164 | ###############
165 | ### Rolling Success-Rate plot
166 | ###############
167 |
168 | ax = plt.subplot2grid((6,6), (4,0), colspan=4, rowspan=2)
169 | ax.set_title("10-Trial Rolling Rate of Reliability")
170 | ax.set_ylabel("Rate of Reliability")
171 | ax.set_xlabel("Trial Number")
172 | ax.set_xlim((10, len(training_data)))
173 | ax.set_ylim((-5, 105))
174 | ax.set_yticks(np.arange(0, 101, 20))
175 | ax.set_yticklabels(['0%', '20%', '40%', '60%', '80%', '100%'])
176 |
177 | # Create plot-specific data
178 | trial = training_data.dropna()['trial']
179 | rate = training_data.dropna()['reliability_rate']
180 |
181 | # Rolling success rate
182 | ax.plot(trial, rate, label="Reliability Rate", color='blue')
183 |
184 |
185 | ###############
186 | ### Test results
187 | ###############
188 |
189 | ax = plt.subplot2grid((6,6), (4,4), colspan=2, rowspan=2)
190 | ax.axis('off')
191 |
192 | if len(testing_data) > 0:
193 | safety_rating, safety_color = calculate_safety(testing_data)
194 | reliability_rating, reliability_color = calculate_reliability(testing_data)
195 |
196 | # Write success rate
197 | ax.text(0.40, .9, "{} testing trials simulated.".format(len(testing_data)), fontsize=14, ha='center')
198 | ax.text(0.40, 0.7, "Safety Rating:", fontsize=16, ha='center')
199 | ax.text(0.40, 0.42, "{}".format(safety_rating), fontsize=40, ha='center', color=safety_color)
200 | ax.text(0.40, 0.27, "Reliability Rating:", fontsize=16, ha='center')
201 | ax.text(0.40, 0, "{}".format(reliability_rating), fontsize=40, ha='center', color=reliability_color)
202 |
203 | else:
204 | ax.text(0.36, 0.30, "Simulation completed\nwith testing disabled.", fontsize=20, ha='center', style='italic')
205 |
206 | plt.tight_layout()
207 | plt.show()
208 |
--------------------------------------------------------------------------------
/student_intervention/README.md:
--------------------------------------------------------------------------------
1 | # 项目 2: 监督学习
2 | ## 搭建一个学生干预系统
3 |
4 | ### 安装
5 |
6 | 这个项目要求使用 **Python 2.7** 并且需要安装下面这些python包:
7 |
8 | - [NumPy](http://www.numpy.org/)
9 | - [pandas](http://pandas.pydata.org)
10 | - [scikit-learn](http://scikit-learn.org/stable/)
11 |
12 | 你同样需要安装好相应软件使之能够运行[Jupyter Notebook](http://jupyter.org/)
13 |
14 | 优达学城推荐学生安装[Anaconda](https://www.continuum.io/downloads), 这是一个已经打包好的python发行版,它包含了我们这个项目需要的所有的库和软件。
15 |
16 |
17 | ### 代码
18 |
19 | 初始代码包含在 `student_intervention.ipynb` 这个notebook文件中。这里面有一些代码已经实现好来帮助你开始项目,但是为了完成项目,你还需要实现附加的功能。
20 |
21 | ### 运行
22 |
23 | 在命令行中,确保当前目录为 `student_intervention/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令:
24 |
25 | ```jupyter notebook student_intervention.ipynb```
26 |
27 | 这会启动 Jupyter Notebook 并把项目文件打开在你的浏览器中。
28 |
29 | ## 数据
30 |
31 | 这个项目的数据包含在 `student-data.csv` 文件中。这个数据集包含以下属性:
32 |
33 | - `school` : 学生的学校(二元特征:值为“GP”或者是“MS”)
34 | - `sex` : 学生的性别(二元特征:“F”表示女性 或者是 “M”表示男性)
35 | - `age` : 学生的年龄(数值特征:从15到22)
36 | - `address`: 学生的家庭住址类型(二元特征:“U”表示城市 或者是 “R”表示农村)
37 | - `famsize`: 家庭大小(二元特征:“LE3”表示小于等于3 或者 “GT3”表示大于3)
38 | - `Pstatus`: 父母共同生活状态(二元特征:“T”表示共同生活 或者是 “A”表示分居)
39 | - `Medu`: 母亲的教育程度 (数值特征:0 - 未受教育, 1 - 小学教育(4年级), 2 - 5年级到9年级, 3 - 中学教育 或者 4 - 更高等级教育)
40 | - `Fedu`: 父亲的教育程度 (数值特征:0 - 未受教育, 1 - 小学教育(4年级), 2 - 5年级到9年级, 3 - 中学教育 或者 4 - 更高等级教育)
41 | - `Mjob` : 母亲的工作 (常量特征: "teacher", "health" 表示和健康看护相关的工作, "services" 表示公务员(比如:行政人员或者警察), "at_home"表示在家, "other"表示其他)
42 | - `Fjob` : 父亲的工作 (常量特征: "teacher", "health" 表示和健康看护相关的工作, "services" 表示公务员(比如:行政人员或者警察), "at_home"表示在家, "other"表示其他)
43 | - `reason` : 选择这所学校的原因 (常量特征:"home"表示离家近, "reputation"表示学校声誉, "course"表示课程偏好 或者 "other"表示其他)
44 | - `guardian` : 学生的监护人 (常量特征:"mother"表示母亲, "father"表示父亲 或者 "other"表示其他)
45 | - `traveltime` : 到学校需要的时间 (数值特征: 1 - 小于15分钟., 2 - 15到30分钟., 3 - 30分钟到1小时, 4 - 大于1小时)
46 | - `studytime`: 每周学习时间 (数值特征: 1 - 小于2个小时, 2 - 2到5个小时, 3 - 5到10个小时, 4 - 大于10个小时)
47 | - `failures`:过去考试失败的次数 (数值特征: n 如果 1<=n<3, 其他 4)
48 | - `schoolsup` : 额外的教育支持 (二元特征: yes 或者 no)
49 | - `famsup` : 家庭教育支持 (二元特征: yes 或者 no)
50 | - `paid` : 和课程有关的其他付费课堂 (数学或者葡萄牙语) (二值特征: yes 或者 no)
51 | - `activities` : 课外活动 (二元特征: yes 或者 no)
52 | - `nursery` : 参加托儿所 (二元特征: yes 或者 no)
53 | - `higher` : 希望得到高等教育(二元特征: yes 或者 no)
54 | - `internet` : 在家是否能够访问网络 (二元特征: yes 或者 no)
55 | - `romantic` : 有没有谈恋爱 (二元特征: yes 或者 no)
56 | - `famrel` : 与家人关系的好坏 (数值特征: 从 1 - 非常差 到 5 - 非常好)
57 | - `freetime` : 放学后的空闲时间(数值特征: 从 1 - 非常少 到 5 - 非常多)
58 | - `goout` : 和朋友出去(数值特征: 从 1 - 非常少 到 5 - 非常多)
59 | - `Dalc` : 工作日饮酒量(数值特征:从 1 - 非常少 到 5 - 非常多)
60 | - `Walc` : 周末饮酒量(数值特征:从 1 - 非常少 到 5 - 非常多)
61 | - `health` : 当前健康状况 (数值特征: 从 1 - 非常差 到 5 - 非常好)
62 | - `absences` :在学校的缺席次数 (数值特征: 从 0 到 93)
63 | - `passed` : 学生是否通过最终的考试 (二元特征: yes 或者 no)
64 |
--------------------------------------------------------------------------------
/student_intervention/student_intervention.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# 机器学习工程师纳米学位\n",
8 | "## 监督学习\n",
9 | "## 项目 2: 搭建一个学生干预系统"
10 | ]
11 | },
12 | {
13 | "cell_type": "markdown",
14 | "metadata": {},
15 | "source": [
16 | "欢迎来到机器学习工程师纳米学位的第二个项目!在此文件中,有些示例代码已经提供给你,但你还需要实现更多的功能让项目成功运行。除非有明确要求,你无须修改任何已给出的代码。以**'练习'**开始的标题表示接下来的代码部分中有你必须要实现的功能。每一部分都会有详细的指导,需要实现的部分也会在注释中以**'TODO'**标出。请仔细阅读所有的提示!\n",
17 | "\n",
18 | "除了实现代码外,你还**必须**回答一些与项目和你的实现有关的问题。每一个需要你回答的问题都会以**'问题 X'**为标题。请仔细阅读每个问题,并且在问题后的**'回答'**文字框中写出完整的答案。我们将根据你对问题的回答和撰写代码所实现的功能来对你提交的项目进行评分。\n",
19 | "\n",
20 | ">**提示:**Code 和 Markdown 区域可通过 **Shift + Enter** 快捷键运行。此外,Markdown可以通过双击进入编辑模式。"
21 | ]
22 | },
23 | {
24 | "cell_type": "markdown",
25 | "metadata": {},
26 | "source": [
27 | "### 问题 1 - 分类 vs. 回归\n",
28 | "*在这个项目中你的任务是找出那些如果不给予帮助,最终可能无法毕业的学生。你觉得这个问题是哪种类型的监督学习问题,是分类问题还是回归问题?为什么?*"
29 | ]
30 | },
31 | {
32 | "cell_type": "markdown",
33 | "metadata": {},
34 | "source": [
35 | "**答案: **分类问题。因为该项目的目标是找到可能无法毕业的学生,输出的数据是学生能否毕业,也就是把学生分类为毕业的和不能毕业的,结果是不连续的,所以属于分类问题。"
36 | ]
37 | },
38 | {
39 | "cell_type": "markdown",
40 | "metadata": {},
41 | "source": [
42 | "## 分析数据\n",
43 | "运行下面区域的代码以载入学生数据集,以及一些此项目所需的Python库。注意数据集的最后一列`'passed'`是我们的预测的目标(表示学生是毕业了还是没有毕业),其他的列是每个学生的属性。"
44 | ]
45 | },
46 | {
47 | "cell_type": "code",
48 | "execution_count": 4,
49 | "metadata": {
50 | "collapsed": false
51 | },
52 | "outputs": [
53 | {
54 | "name": "stdout",
55 | "output_type": "stream",
56 | "text": [
57 | "Student data read successfully!\n"
58 | ]
59 | }
60 | ],
61 | "source": [
62 | "# 载入所需要的库\n",
63 | "import numpy as np\n",
64 | "import pandas as pd\n",
65 | "from time import time\n",
66 | "from sklearn.metrics import f1_score\n",
67 | "\n",
68 | "# 载入学生数据集\n",
69 | "student_data = pd.read_csv(\"student-data.csv\")\n",
70 | "print \"Student data read successfully!\""
71 | ]
72 | },
73 | {
74 | "cell_type": "markdown",
75 | "metadata": {},
76 | "source": [
77 | "### 练习: 分析数据\n",
78 | "我们首先通过调查数据,以确定有多少学生的信息,并了解这些学生的毕业率。在下面的代码单元中,你需要完成如下的运算:\n",
79 | "- 学生的总数, `n_students`。\n",
80 | "- 每个学生的特征总数, `n_features`。\n",
81 | "- 毕业的学生的数量, `n_passed`。\n",
82 | "- 未毕业的学生的数量, `n_failed`。\n",
83 | "- 班级的毕业率, `grad_rate`, 用百分数表示(%)。\n"
84 | ]
85 | },
86 | {
87 | "cell_type": "code",
88 | "execution_count": 33,
89 | "metadata": {
90 | "collapsed": false
91 | },
92 | "outputs": [
93 | {
94 | "name": "stdout",
95 | "output_type": "stream",
96 | "text": [
97 | "Total number of students: 395\n",
98 | "Number of features: 30\n",
99 | "Number of students who passed: 265\n",
100 | "Number of students who failed: 130\n",
101 | "Graduation rate of the class: 67.09%\n"
102 | ]
103 | }
104 | ],
105 | "source": [
106 | "# TODO: 计算学生的数量\n",
107 | "n_students = len(student_data)\n",
108 | "\n",
109 | "# TODO: 计算特征数量\n",
110 | "n_features = len(student_data.columns)-1\n",
111 | "\n",
112 | "# TODO: 计算通过的学生数\n",
113 | "n_passed = sum(student_data['passed']=='yes')\n",
114 | " \n",
115 | "# TODO: 计算未通过的学生数\n",
116 | "n_failed = n_students-n_passed\n",
117 | "\n",
118 | "# TODO: 计算通过率\n",
119 | "grad_rate = float(n_passed/n_students)*100\n",
120 | "\n",
121 | "# 输出结果\n",
122 | "print \"Total number of students: {}\".format(n_students)\n",
123 | "print \"Number of features: {}\".format(n_features)\n",
124 | "print \"Number of students who passed: {}\".format(n_passed)\n",
125 | "print \"Number of students who failed: {}\".format(n_failed)\n",
126 | "print \"Graduation rate of the class: {:.2f}%\".format(grad_rate)"
127 | ]
128 | },
129 | {
130 | "cell_type": "markdown",
131 | "metadata": {},
132 | "source": [
133 | "## 数据准备\n",
134 | "在这个部分中,我们将要为建模、训练和测试准备数据\n",
135 | "### 识别特征和目标列\n",
136 | "你获取的数据中通常都会包含一些非数字的特征,这会导致一些问题,因为大多数的机器学习算法都会期望输入数字特征进行计算。\n",
137 | "\n",
138 | "运行下面的代码单元将学生数据分成特征和目标列看一看他们中是否有非数字特征。"
139 | ]
140 | },
141 | {
142 | "cell_type": "code",
143 | "execution_count": 25,
144 | "metadata": {
145 | "collapsed": false
146 | },
147 | "outputs": [
148 | {
149 | "name": "stdout",
150 | "output_type": "stream",
151 | "text": [
152 | "Feature columns:\n",
153 | "['school', 'sex', 'age', 'address', 'famsize', 'Pstatus', 'Medu', 'Fedu', 'Mjob', 'Fjob', 'reason', 'guardian', 'traveltime', 'studytime', 'failures', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic', 'famrel', 'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences']\n",
154 | "\n",
155 | "Target column: passed\n",
156 | "\n",
157 | "Feature values:\n",
158 | " school sex age address famsize Pstatus Medu Fedu Mjob Fjob \\\n",
159 | "0 GP F 18 U GT3 A 4 4 at_home teacher \n",
160 | "1 GP F 17 U GT3 T 1 1 at_home other \n",
161 | "2 GP F 15 U LE3 T 1 1 at_home other \n",
162 | "3 GP F 15 U GT3 T 4 2 health services \n",
163 | "4 GP F 16 U GT3 T 3 3 other other \n",
164 | "\n",
165 | " ... higher internet romantic famrel freetime goout Dalc Walc health \\\n",
166 | "0 ... yes no no 4 3 4 1 1 3 \n",
167 | "1 ... yes yes no 5 3 3 1 1 3 \n",
168 | "2 ... yes yes no 4 3 2 2 3 3 \n",
169 | "3 ... yes yes yes 3 2 2 1 1 5 \n",
170 | "4 ... yes no no 4 3 2 1 2 5 \n",
171 | "\n",
172 | " absences \n",
173 | "0 6 \n",
174 | "1 4 \n",
175 | "2 10 \n",
176 | "3 2 \n",
177 | "4 4 \n",
178 | "\n",
179 | "[5 rows x 30 columns]\n"
180 | ]
181 | }
182 | ],
183 | "source": [
184 | "# 提取特征列\n",
185 | "feature_cols = list(student_data.columns[:-1])\n",
186 | "\n",
187 | "# 提取目标列 ‘passed’\n",
188 | "target_col = student_data.columns[-1] \n",
189 | "\n",
190 | "# 显示列的列表\n",
191 | "print \"Feature columns:\\n{}\".format(feature_cols)\n",
192 | "print \"\\nTarget column: {}\".format(target_col)\n",
193 | "\n",
194 | "# 将数据分割成特征数据和目标数据(即X_all 和 y_all)\n",
195 | "X_all = student_data[feature_cols]\n",
196 | "y_all = student_data[target_col]\n",
197 | "\n",
198 | "# 通过打印前5行显示特征信息\n",
199 | "print \"\\nFeature values:\"\n",
200 | "print X_all.head()"
201 | ]
202 | },
203 | {
204 | "cell_type": "markdown",
205 | "metadata": {},
206 | "source": [
207 | "### 预处理特征列\n",
208 | "\n",
209 | "正如你所见,我们这里有几个非数值的列需要做一定的转换!它们中很多是简单的`yes`/`no`,比如`internet`。这些可以合理地转化为`1`/`0`(二元值,binary)值。\n",
210 | "\n",
211 | "其他的列,如`Mjob`和`Fjob`,有两个以上的值,被称为_分类变量(categorical variables)_。处理这样的列的推荐方法是创建和可能值一样多的列(如:`Fjob_teacher`,`Fjob_other`,`Fjob_services`等),然后将其中一个的值设为`1`另外的设为`0`。\n",
212 | "\n",
213 | "这些创建的列有时候叫做 _虚拟变量(dummy variables)_,我们将用[`pandas.get_dummies()`](http://pandas.pydata.org/pandas-docs/stable/generated/pandas.get_dummies.html?highlight=get_dummies#pandas.get_dummies)函数来完成这个转换。运行下面代码单元的代码来完成这里讨论的预处理步骤。"
214 | ]
215 | },
216 | {
217 | "cell_type": "code",
218 | "execution_count": 26,
219 | "metadata": {
220 | "collapsed": false
221 | },
222 | "outputs": [
223 | {
224 | "name": "stdout",
225 | "output_type": "stream",
226 | "text": [
227 | "Processed feature columns (48 total features):\n",
228 | "['school_GP', 'school_MS', 'sex_F', 'sex_M', 'age', 'address_R', 'address_U', 'famsize_GT3', 'famsize_LE3', 'Pstatus_A', 'Pstatus_T', 'Medu', 'Fedu', 'Mjob_at_home', 'Mjob_health', 'Mjob_other', 'Mjob_services', 'Mjob_teacher', 'Fjob_at_home', 'Fjob_health', 'Fjob_other', 'Fjob_services', 'Fjob_teacher', 'reason_course', 'reason_home', 'reason_other', 'reason_reputation', 'guardian_father', 'guardian_mother', 'guardian_other', 'traveltime', 'studytime', 'failures', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic', 'famrel', 'freetime', 'goout', 'Dalc', 'Walc', 'health', 'absences']\n"
229 | ]
230 | }
231 | ],
232 | "source": [
233 | "def preprocess_features(X):\n",
234 | " ''' 预处理学生数据,将非数字的二元特征转化成二元值(0或1),将分类的变量转换成虚拟变量\n",
235 | " '''\n",
236 | " \n",
237 | " # 初始化一个用于输出的DataFrame\n",
238 | " output = pd.DataFrame(index = X.index)\n",
239 | "\n",
240 | " # 查看数据的每一个特征列\n",
241 | " for col, col_data in X.iteritems():\n",
242 | " \n",
243 | " # 如果数据是非数字类型,将所有的yes/no替换成1/0\n",
244 | " if col_data.dtype == object:\n",
245 | " col_data = col_data.replace(['yes', 'no'], [1, 0])\n",
246 | "\n",
247 | " # 如果数据类型是类别的(categorical),将它转换成虚拟变量\n",
248 | " if col_data.dtype == object:\n",
249 | " # 例子: 'school' => 'school_GP' and 'school_MS'\n",
250 | " col_data = pd.get_dummies(col_data, prefix = col) \n",
251 | " \n",
252 | " # 收集转换后的列\n",
253 | " output = output.join(col_data)\n",
254 | " \n",
255 | " return output\n",
256 | "\n",
257 | "X_all = preprocess_features(X_all)\n",
258 | "print \"Processed feature columns ({} total features):\\n{}\".format(len(X_all.columns), list(X_all.columns))"
259 | ]
260 | },
261 | {
262 | "cell_type": "markdown",
263 | "metadata": {},
264 | "source": [
265 | "### 实现: 将数据分成训练集和测试集\n",
266 | "现在我们已经将所有的 _分类的(categorical)_ 特征转换成数值了。下一步我们将把数据(包括特征和对应的标签数据)分割成训练集和测试集。在下面的代码单元中,你需要完成下列功能:\n",
267 | "- 随机混洗切分数据(`X_all`, `y_all`) 为训练子集和测试子集。\n",
268 | " - 使用300个数据点作为训练集(约76%),使用95个数据点作为测试集(约24%)。\n",
269 | " - 如果可能的话,为你使用的函数设置一个`random_state`。\n",
270 | " - 将结果存储在`X_train`, `X_test`, `y_train`和 `y_test`中。"
271 | ]
272 | },
273 | {
274 | "cell_type": "code",
275 | "execution_count": 44,
276 | "metadata": {
277 | "collapsed": false
278 | },
279 | "outputs": [
280 | {
281 | "name": "stdout",
282 | "output_type": "stream",
283 | "text": [
284 | "Training set has 300 samples.\n",
285 | "Testing set has 95 samples.\n"
286 | ]
287 | }
288 | ],
289 | "source": [
290 | "from sklearn.cross_validation import train_test_split\n",
291 | "\n",
292 | "# TODO:设置训练集的数量\n",
293 | "num_train = 300\n",
294 | "\n",
295 | "# TODO:设置测试集的数量\n",
296 | "num_test = X_all.shape[0] - num_train\n",
297 | "\n",
298 | "# TODO:把数据集混洗和分割成上面定义的训练集和测试集\n",
299 | "X_train, X_test, y_train, y_test = train_test_split(X_all, y_all,test_size=0.24, random_state=42)\n",
300 | "\n",
301 | "# 显示分割的结果\n",
302 | "print \"Training set has {} samples.\".format(X_train.shape[0])\n",
303 | "print \"Testing set has {} samples.\".format(X_test.shape[0])"
304 | ]
305 | },
306 | {
307 | "cell_type": "markdown",
308 | "metadata": {},
309 | "source": [
310 | "## 训练和评价模型\n",
311 | "在这个部分,你将选择3个适合这个问题并且在`scikit-learn`中已有的监督学习的模型。首先你需要说明你选择这三个模型的原因,包括这些数据集有哪些特点,每个模型的优点和缺点各是什么。然后,你需要将这些模型用不同大小的训练集(100个数据点,200个数据点,300个数据点)进行训练,并用F1的值来衡量。你需要制作三个表,每个表要显示训练集大小,训练时间,预测时间,训练集上的F1值和测试集上的F1值(每个模型一个表)。\n",
312 | "\n",
313 | "**这是目前** [`scikit-learn`](http://scikit-learn.org/stable/supervised_learning.html) **里有的监督学习模型,你可以从中选择:**\n",
314 | "- Gaussian Naive Bayes (GaussianNB) 朴素贝叶斯\n",
315 | "- Decision Trees 决策树\n",
316 | "- Ensemble Methods (Bagging, AdaBoost, Random Forest, Gradient Boosting)\n",
317 | "- K-Nearest Neighbors (KNeighbors)\n",
318 | "- Stochastic Gradient Descent (SGDC)\n",
319 | "- Support Vector Machines (SVM) 向量模型机\n",
320 | "- Logistic Regression 逻辑回归"
321 | ]
322 | },
323 | {
324 | "cell_type": "markdown",
325 | "metadata": {},
326 | "source": [
327 | "### 问题 2 - 应用模型\n",
328 | "*列出三个适合这个问题的监督学习算法模型。每一个你选择的模型:*\n",
329 | "\n",
330 | "- 描述一个该模型在真实世界的一个应用场景。(你需要为此做点研究,并给出你的引用出处)\n",
331 | "- 这个模型的优势是什么?他什么情况下表现最好?\n",
332 | "- 这个模型的缺点是什么?什么条件下它表现很差?\n",
333 | "- 根据我们当前数据集的特点,为什么这个模型适合这个问题。"
334 | ]
335 | },
336 | {
337 | "cell_type": "markdown",
338 | "metadata": {},
339 | "source": [
340 | "**回答: **我选择的是GaussianNB, Decision Trees, AdaBoost这三种算法。\n",
341 | "- GaussianNB: 真实应用场景:过滤垃圾邮件(来自Machine Learning in Action | by Peter Harrington 内容4.6) \n",
342 | "- 优势是在数据较少的情况下仍然有效,可以处理多类别问题。对于文档分类它的表现最好\n",
343 | "- 缺点是对于输入数据的准备方式较为敏感。当输入数据是由多个单词组成且意义明显不同的短语时,它的表现很差。\n",
344 | "- 对于我们的数据集,它的特征数很多(30个),而GaussianNB适合处理大量特征数的数据。\n",
345 | "\n",
346 | "\n",
347 | "- Decision Tree: 真实应用场景:预测隐形眼镜类型(来自Machine Learning in Action | by Peter Harrington 内容3.4) \n",
348 | "- 优势是计算复杂度不高,输出结果易于理解,对中间值的确是不敏感,可以处理不相关特征数据。对于简单的布尔型数据它的表现最好\n",
349 | "- 缺点是可能出现过拟合问题。当过于依赖数据或参数设置不好时,它的表现很差。\n",
350 | "- 我们的数据中有大量布尔型特征,适合用Decision Tree,而且它的一些特征(如nursery)对于我们的目标(passed)可能相关程度并不高,而Decision Tree适合处理不相关特征数的数据。\n",
351 | "\n",
352 | "\n",
353 | "- AdaBoost: 真实应用场景:预测患有疝病的马是否存活(来自Machine Learning in Action | by Peter Harrington 内容7.6) \n",
354 | "- 优势是泛化错误低,易编码,可以应用在大部分分类器上,无参数调整。对于基于错误提升分类器性能它的表现最好\n",
355 | "- 缺点是对离群点敏感。当输入数据有不少极端值时,它的表现很差。\n",
356 | "- 我们的数据集特征很多,较为复杂,在后续迭代中,出现错误的数据权重可能增大,而针对这种错误的调节能力正是AdaBoost的长处。"
357 | ]
358 | },
359 | {
360 | "cell_type": "markdown",
361 | "metadata": {},
362 | "source": [
363 | "### 准备\n",
364 | "运行下面的代码单元以初始化三个帮助函数,这三个函数将能够帮你训练和测试你上面所选择的三个监督学习算法。这些函数是:\n",
365 | "- `train_classifier` - 输入一个分类器和训练集,用数据来训练这个分类器。\n",
366 | "- `predict_labels` - 输入一个训练好的分类器、特征以及一个目标标签,这个函数将帮你做预测并给出F1的值.\n",
367 | "- `train_predict` - 输入一个分类器以及训练集和测试集,它可以运行`train_clasifier`和`predict_labels`.\n",
368 | " - 这个函数将分别输出训练集的F1值和测试集的F1值"
369 | ]
370 | },
371 | {
372 | "cell_type": "code",
373 | "execution_count": 45,
374 | "metadata": {
375 | "collapsed": false
376 | },
377 | "outputs": [],
378 | "source": [
379 | "def train_classifier(clf, X_train, y_train):\n",
380 | " ''' 用训练集训练分类器 '''\n",
381 | " \n",
382 | " # 开始计时,训练分类器,然后停止计时\n",
383 | " start = time()\n",
384 | " clf.fit(X_train, y_train)\n",
385 | " end = time()\n",
386 | " \n",
387 | " # Print the results\n",
388 | " print \"Trained model in {:.4f} seconds\".format(end - start)\n",
389 | "\n",
390 | " \n",
391 | "def predict_labels(clf, features, target):\n",
392 | " ''' 用训练好的分类器做预测并输出F1值'''\n",
393 | " \n",
394 | " # 开始计时,作出预测,然后停止计时\n",
395 | " start = time()\n",
396 | " y_pred = clf.predict(features)\n",
397 | " end = time()\n",
398 | " \n",
399 | " # 输出并返回结果\n",
400 | " print \"Made predictions in {:.4f} seconds.\".format(end - start)\n",
401 | " return f1_score(target.values, y_pred, pos_label='yes')\n",
402 | "\n",
403 | "\n",
404 | "def train_predict(clf, X_train, y_train, X_test, y_test):\n",
405 | " ''' 用一个分类器训练和预测,并输出F1值 '''\n",
406 | " \n",
407 | " # 输出分类器名称和训练集大小\n",
408 | " print \"Training a {} using a training set size of {}. . .\".format(clf.__class__.__name__, len(X_train))\n",
409 | " \n",
410 | " # 训练一个分类器\n",
411 | " train_classifier(clf, X_train, y_train)\n",
412 | " \n",
413 | " # 输出训练和测试的预测结果\n",
414 | " print \"F1 score for training set: {:.4f}.\".format(predict_labels(clf, X_train, y_train))\n",
415 | " print \"F1 score for test set: {:.4f}.\".format(predict_labels(clf, X_test, y_test))"
416 | ]
417 | },
418 | {
419 | "cell_type": "markdown",
420 | "metadata": {},
421 | "source": [
422 | "### 练习: 模型评价指标\n",
423 | "借助于上面定义的函数,你现在需要导入三个你选择的监督学习模型,然后为每一个模型运行`train_predict`函数。请记住,对于每一个模型你需要在不同大小的训练集(100,200和300)上进行训练和测试。所以,你在下面应该会有9个不同的输出(每个模型都有训练集大小不同的三个输出)。在接下来的代码单元中,你将需要实现以下功能:\n",
424 | "- 引入三个你在上面讨论过的监督式学习算法模型。\n",
425 | "- 初始化三个模型并将它们存储在`clf_A`, `clf_B` 和 `clf_C`中。\n",
426 | " - 如果可能对每一个模型都设置一个`random_state`。\n",
427 | " - **注意:** 这里先使用每一个模型的默认参数,在接下来的部分中你将需要对某一个模型的参数进行调整。\n",
428 | "- 创建不同大小的训练集用来训练每一个模型。\n",
429 | " - *不要再混洗和再分割数据!新的训练集要取自`X_train`和`y_train`.*\n",
430 | "- 对于每一个模型要用不同大小的训练集来训练它,然后在测试集上做测试(总共需要9次训练测试) \n",
431 | "**注意:** 在下面的代码单元后面我们提供了三个表用来存储你的结果。"
432 | ]
433 | },
434 | {
435 | "cell_type": "code",
436 | "execution_count": 46,
437 | "metadata": {
438 | "collapsed": false
439 | },
440 | "outputs": [
441 | {
442 | "name": "stdout",
443 | "output_type": "stream",
444 | "text": [
445 | "Training a GaussianNB using a training set size of 100. . .\n",
446 | "Trained model in 0.0020 seconds\n",
447 | "Made predictions in 0.0000 seconds.\n",
448 | "F1 score for training set: 0.8467.\n",
449 | "Made predictions in 0.0010 seconds.\n",
450 | "F1 score for test set: 0.8029.\n",
451 | "Training a GaussianNB using a training set size of 200. . .\n",
452 | "Trained model in 0.0020 seconds\n",
453 | "Made predictions in 0.0010 seconds.\n",
454 | "F1 score for training set: 0.8406.\n",
455 | "Made predictions in 0.0010 seconds.\n",
456 | "F1 score for test set: 0.7244.\n",
457 | "Training a GaussianNB using a training set size of 300. . .\n",
458 | "Trained model in 0.0030 seconds\n",
459 | "Made predictions in 0.0020 seconds.\n",
460 | "F1 score for training set: 0.8038.\n",
461 | "Made predictions in 0.0010 seconds.\n",
462 | "F1 score for test set: 0.7634.\n",
463 | "Training a DecisionTreeClassifier using a training set size of 100. . .\n",
464 | "Trained model in 0.0020 seconds\n",
465 | "Made predictions in 0.0000 seconds.\n",
466 | "F1 score for training set: 1.0000.\n",
467 | "Made predictions in 0.0010 seconds.\n",
468 | "F1 score for test set: 0.6552.\n",
469 | "Training a DecisionTreeClassifier using a training set size of 200. . .\n",
470 | "Trained model in 0.0030 seconds\n",
471 | "Made predictions in 0.0000 seconds.\n",
472 | "F1 score for training set: 1.0000.\n",
473 | "Made predictions in 0.0000 seconds.\n",
474 | "F1 score for test set: 0.7500.\n",
475 | "Training a DecisionTreeClassifier using a training set size of 300. . .\n",
476 | "Trained model in 0.0040 seconds\n",
477 | "Made predictions in 0.0000 seconds.\n",
478 | "F1 score for training set: 1.0000.\n",
479 | "Made predictions in 0.0000 seconds.\n",
480 | "F1 score for test set: 0.6613.\n",
481 | "Training a AdaBoostClassifier using a training set size of 100. . .\n",
482 | "Trained model in 0.1430 seconds\n",
483 | "Made predictions in 0.0310 seconds.\n",
484 | "F1 score for training set: 0.9481.\n",
485 | "Made predictions in 0.0160 seconds.\n",
486 | "F1 score for test set: 0.7669.\n",
487 | "Training a AdaBoostClassifier using a training set size of 200. . .\n",
488 | "Trained model in 0.2020 seconds\n",
489 | "Made predictions in 0.0080 seconds.\n",
490 | "F1 score for training set: 0.8927.\n",
491 | "Made predictions in 0.0070 seconds.\n",
492 | "F1 score for test set: 0.8281.\n",
493 | "Training a AdaBoostClassifier using a training set size of 300. . .\n",
494 | "Trained model in 0.1660 seconds\n",
495 | "Made predictions in 0.0000 seconds.\n",
496 | "F1 score for training set: 0.8637.\n",
497 | "Made predictions in 0.0150 seconds.\n",
498 | "F1 score for test set: 0.7820.\n"
499 | ]
500 | }
501 | ],
502 | "source": [
503 | "# TODO:从sklearn中引入三个监督学习模型\n",
504 | "from sklearn.naive_bayes import GaussianNB\n",
505 | "from sklearn.tree import DecisionTreeClassifier\n",
506 | "from sklearn.ensemble import AdaBoostClassifier\n",
507 | "\n",
508 | "# TODO:初始化三个模型\n",
509 | "clf_A = GaussianNB()\n",
510 | "clf_B = DecisionTreeClassifier(random_state=42)\n",
511 | "clf_C = AdaBoostClassifier(random_state=42)\n",
512 | "\n",
513 | "# TODO:设置训练集大小\n",
514 | "X_train_100 = X_train[0:100]\n",
515 | "y_train_100 = y_train[0:100]\n",
516 | "\n",
517 | "X_train_200 = X_train[100:300]\n",
518 | "y_train_200 = y_train[100:300]\n",
519 | "\n",
520 | "X_train_300 = X_train\n",
521 | "y_train_300 = y_train\n",
522 | "\n",
523 | "# TODO:对每一个分类器和每一个训练集大小运行'train_predict' \n",
524 | "for clf in [clf_A, clf_B, clf_C]:\n",
525 | " for Size in [100, 200, 300]:\n",
526 | " train_predict(clf, X_train[:Size], y_train[:Size], X_test, y_test)"
527 | ]
528 | },
529 | {
530 | "cell_type": "markdown",
531 | "metadata": {},
532 | "source": [
533 | "### 结果表格\n",
534 | "编辑下面的表格看看在[Markdown](https://github.com/adam-p/markdown-here/wiki/Markdown-Cheatsheet#tables)中如何设计一个表格。你需要把上面的结果记录在表格中。"
535 | ]
536 | },
537 | {
538 | "cell_type": "markdown",
539 | "metadata": {},
540 | "source": [
541 | "** 分类器 1 - GaussianNB** \n",
542 | "\n",
543 | "| 训练集大小 | 训练时间 | 预测时间 (测试) | F1值 (训练) | F1值 (测试) |\n",
544 | "| :---------------: | :---------------------: | :--------------------: | :--------------: | :-------------: |\n",
545 | "| 100 | 0.0010 | 0.0010 | 0.8467 | 0.8029 |\n",
546 | "| 200 | 0.0020 | 0.0010 | 0.8102 | 0.7258 |\n",
547 | "| 300 | 0.0020 | 0.0010 | 0.8038 | 0.7634 |\n",
548 | "\n",
549 | "** 分类器 2 - DecisionTree** \n",
550 | "\n",
551 | "| 训练集大小 | 训练时间 | 预测时间 (测试) | F1值 (训练) | F1值 (测试) |\n",
552 | "| :---------------: | :---------------------: | :--------------------: | :--------------: | :-------------: |\n",
553 | "| 100 | 0.0010 | 0.0010 | 1 | 0.6552 |\n",
554 | "| 200 | 0.0030 | 0.0010 | 1 | 0.7031 |\n",
555 | "| 300 | 0.0470 | 0.0010 | 1 | 0.6631 |\n",
556 | "\n",
557 | "** 分类器 3 - AdaBoost** \n",
558 | "\n",
559 | "| 训练集大小 | 训练时间 | 预测时间 (测试) | F1值 (训练) | F1值 (测试) |\n",
560 | "| :---------------: | :---------------------: | :--------------------: | :--------------: | :-------------: |\n",
561 | "| 100 | 0.2270 | 0.0080 | 0.9481 | 0.7669 |\n",
562 | "| 200 | 0.1970 | 0.0230 | 0.8836 | 0.7344 |\n",
563 | "| 300 | 0.2610 | 0.0100 | 0.8637 | 0.7820 |"
564 | ]
565 | },
566 | {
567 | "cell_type": "markdown",
568 | "metadata": {},
569 | "source": [
570 | "## 选择最佳模型\n",
571 | "在最后这一部分中,你将从三个监督学习模型中选择一个用在学生数据上的最佳模型。然后你将在最佳模型上用全部的训练集(`X_train`和`y_train`)运行一个网格搜索算法,在这个过程中,你要至少调整一个参数以提高模型的F1值(相比于没有调参的模型的分值有所提高)。 "
572 | ]
573 | },
574 | {
575 | "cell_type": "markdown",
576 | "metadata": {},
577 | "source": [
578 | "### 问题 3 - 选择最佳模型\n",
579 | "*给予你上面做的实验,用一到两段话,向(学校)监事会解释你将选择哪个模型作为最佳的模型。哪个模型在现有的数据,有限的资源、开支和模型表现综合来看是最好的选择?*"
580 | ]
581 | },
582 | {
583 | "cell_type": "markdown",
584 | "metadata": {},
585 | "source": [
586 | "**回答: **你好,既然你们让我来构建学生干预系统,那我就要对你们负责,所以我选择了3个模型来进行预测,分别是GaussianNB, DecisionTree, Adaboost.基于我对模型的评估,我认为AdaBoost是最好的模型。\n",
587 | "\n",
588 | "我的理由是AdaBoost的F1得分在三个模型中是最好的(虽然DecisionTree训练得分为1,可测试得分太低,出现过拟合),虽然它的训练时间较长,但考虑到它的预测时间短,也就是查询时间短,我们一旦把模型训练出来(考虑到数据集大小也花不了多少时间),之后的主要任务就只有查询了,并不会对过多消耗资源和开支,所以我还是使用AdaBoost"
589 | ]
590 | },
591 | {
592 | "cell_type": "markdown",
593 | "metadata": {},
594 | "source": [
595 | "### 问题 4 - 用通俗的语言解释模型\n",
596 | "*用一到两段话,向(学校)监事会用外行也听得懂的话来解释最终模型是如何工作的。你需要解释所选模型的主要特点。例如,这个模型是怎样被训练的,它又是如何做出预测的。避免使用高级的数学或技术术语,不要使用公式或特定的算法名词。*"
597 | ]
598 | },
599 | {
600 | "cell_type": "markdown",
601 | "metadata": {},
602 | "source": [
603 | "让我们以自己的数据为例,根据我们的数据集,让我们假设我们有10名老师,他们根据学生的相关情况,比如性别,年龄,家庭情况,来判断学生能否通过考试,第一次预测,10位老师每人一票投票他们是否能通过,然后第二次预测的时候,我们考虑到第一次投票时每位老师的准确度,可能一些老师侧重于家庭情况,一些老师觉得男生肯定比女生通过的多,然后对第一次预测时准确度高的老师,我们给他更多的票数,给他5张,而对准确度低的老师我们就不让他再参与投票了,这个就叫加权,这样反复进行多次预测,最终得到最好的预测组合,这就是AdaBoost了"
604 | ]
605 | },
606 | {
607 | "cell_type": "markdown",
608 | "metadata": {},
609 | "source": [
610 | "### 练习: 模型调参\n",
611 | "细调选择的模型的参数。使用网格搜索(`GridSearchCV`)来至少调整模型的重要参数(至少调整一个),这个参数至少需给出并尝试3个不同的值。你要使用整个训练集来完成这个过程。在接下来的代码单元中,你需要实现以下功能:\n",
612 | "- 导入 [`sklearn.grid_search.gridSearchCV`](http://scikit-learn.org/stable/modules/generated/sklearn.grid_search.GridSearchCV.html) 和 [`sklearn.metrics.make_scorer`](http://scikit-learn.org/stable/modules/generated/sklearn.metrics.make_scorer.html).\n",
613 | "- 创建一个对于这个模型你希望调整参数的字典。\n",
614 | " - 例如: `parameters = {'parameter' : [list of values]}`。\n",
615 | "- 初始化你选择的分类器,并将其存储在`clf`中。\n",
616 | "- 使用`make_scorer` 创建F1评分函数并将其存储在`f1_scorer`中。\n",
617 | " - 需正确设定参数`pos_label`的值!\n",
618 | "- 在分类器`clf`上用`f1_scorer` 作为评价函数运行网格搜索,并将结果存储在`grid_obj`中。\n",
619 | "- 用训练集(`X_train`, `y_train`)训练grid search object,并将结果存储在`grid_obj`中。"
620 | ]
621 | },
622 | {
623 | "cell_type": "code",
624 | "execution_count": 47,
625 | "metadata": {
626 | "collapsed": false
627 | },
628 | "outputs": [
629 | {
630 | "name": "stdout",
631 | "output_type": "stream",
632 | "text": [
633 | "Made predictions in 0.0000 seconds.\n",
634 | "Tuned model has a training F1 score of 0.8299.\n",
635 | "Made predictions in 0.0000 seconds.\n",
636 | "Tuned model has a testing F1 score of 0.8000.\n"
637 | ]
638 | }
639 | ],
640 | "source": [
641 | "# TODO: 导入 'GridSearchCV' 和 'make_scorer'\n",
642 | "\n",
643 | "from sklearn.grid_search import GridSearchCV\n",
644 | "from sklearn.metrics import make_scorer\n",
645 | "\n",
646 | "\n",
647 | "# TODO:创建你希望调整的参数列表\n",
648 | "parameters ={'n_estimators':[1,25,50,75,100]}\n",
649 | "\n",
650 | "# TODO:初始化分类器\n",
651 | "clf = AdaBoostClassifier()\n",
652 | "\n",
653 | "# TODO:用'make_scorer'创建一个f1评分函数\n",
654 | "f1_scorer = make_scorer(f1_score,pos_label='yes')\n",
655 | "\n",
656 | "# TODO:在分类器上使用f1_scorer作为评分函数运行网格搜索\n",
657 | "grid_obj = GridSearchCV(clf,parameters,f1_scorer)\n",
658 | "\n",
659 | "# TODO: Fit the grid search object to the training data and find the optimal parameters\n",
660 | "# TODO:用训练集训练grid search object来寻找最佳参数\n",
661 | "grid_obj = grid_obj.fit(X_train,y_train)\n",
662 | "\n",
663 | "# Get the estimator\n",
664 | "# 得到预测的结果\n",
665 | "clf = grid_obj.best_estimator_\n",
666 | "\n",
667 | "# Report the final F1 score for training and testing after parameter tuning\n",
668 | "# 输出经过调参之后的训练集和测试集的F1值\n",
669 | "print \"Tuned model has a training F1 score of {:.4f}.\".format(predict_labels(clf, X_train, y_train))\n",
670 | "print \"Tuned model has a testing F1 score of {:.4f}.\".format(predict_labels(clf, X_test, y_test))"
671 | ]
672 | },
673 | {
674 | "cell_type": "markdown",
675 | "metadata": {},
676 | "source": [
677 | "### 问题 5 - 最终的 F1 值\n",
678 | "*最终模型的训练和测试的F1值是多少?这个值相比于没有调整过参数的模型怎么样?*"
679 | ]
680 | },
681 | {
682 | "cell_type": "markdown",
683 | "metadata": {},
684 | "source": [
685 | "**回答: **最终模型的训练和测试的f1得分为0.83和0.80,相比之前的0.86和0.78,虽然训练得分下降了。但测试得分提升了,表现还是可以的"
686 | ]
687 | },
688 | {
689 | "cell_type": "markdown",
690 | "metadata": {},
691 | "source": [
692 | "> **注意**: 当你写完了所有的代码,并且回答了所有的问题。你就可以把你的 iPython Notebook 导出成 HTML 文件。你可以在菜单栏,这样导出**File -> Download as -> HTML (.html)**把这个 HTML 和这个 iPython notebook 一起做为你的作业提交。 "
693 | ]
694 | }
695 | ],
696 | "metadata": {
697 | "anaconda-cloud": {},
698 | "kernelspec": {
699 | "display_name": "Python [default]",
700 | "language": "python",
701 | "name": "python2"
702 | },
703 | "language_info": {
704 | "codemirror_mode": {
705 | "name": "ipython",
706 | "version": 2
707 | },
708 | "file_extension": ".py",
709 | "mimetype": "text/x-python",
710 | "name": "python",
711 | "nbconvert_exporter": "python",
712 | "pygments_lexer": "ipython2",
713 | "version": "2.7.12"
714 | }
715 | },
716 | "nbformat": 4,
717 | "nbformat_minor": 0
718 | }
719 |
--------------------------------------------------------------------------------
/student_intervention/student_intervention.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/student_intervention/student_intervention.zip
--------------------------------------------------------------------------------
/student_intervention/student_intervention2.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/student_intervention/student_intervention2.zip
--------------------------------------------------------------------------------
/titanic_survival_exploration/README.md:
--------------------------------------------------------------------------------
1 | # 项目 0: 入门与基础
2 | ## 预测泰坦尼克号乘客幸存率
3 |
4 | ### 安装要求
5 | 这个项目要求使用 **Python 2.7** 以及安装下列python库
6 |
7 | - [NumPy](http://www.numpy.org/)
8 | - [Pandas](http://pandas.pydata.org)
9 | - [matplotlib](http://matplotlib.org/)
10 | - [scikit-learn](http://scikit-learn.org/stable/)
11 |
12 |
13 | 你还需要安装和运行 [Jupyter Notebook](http://jupyter.readthedocs.io/en/latest/install.html#optional-for-experienced-python-developers-installing-jupyter-with-pip)。
14 |
15 |
16 | 优达学城推荐学生安装 [Anaconda](https://www.continuum.io/downloads),一个包含了项目需要的所有库和软件的 Python 发行版本。[这里](https://classroom.udacity.com/nanodegrees/nd002/parts/0021345403/modules/317671873575460/lessons/5430778793/concepts/54140889150923)介绍了如何安装Anaconda。
17 |
18 | 如果你使用macOS系统并且对命令行比较熟悉,可以安装[homebrew](http://brew.sh/),以及brew版python
19 |
20 | ```bash
21 | $ brew install python
22 | ```
23 |
24 | 再用下列命令安装所需要的python库
25 |
26 | ```bash
27 | $ pip install numpy pandas matplotlib scikit-learn scipy jupyter
28 | ```
29 |
30 | ### 代码
31 |
32 | 事例代码在 `titanic_survival_exploration_cn.ipynb` 文件中,辅助代码在 `titanic_visualizations.py` 文件中。尽管已经提供了一些代码帮助你上手,你还是需要补充些代码使得项目要求的功能能够成功实现。
33 |
34 | ### 运行
35 |
36 | 在命令行中,确保当前目录为 `titanic_survival_exploration/` 文件夹的最顶层(目录包含本 README 文件),运行下列命令:
37 |
38 | ```bash
39 | $ jupyter notebook titanic_survival_exploration.ipynb
40 | ```
41 |
42 | 这会启动 Jupyter Notebook 把项目文件打开在你的浏览器中。
43 |
44 | 对jupyter不熟悉的同学可以看一下这两个链接:
45 |
46 | - [Jupyter使用视频教程](http://cn-static.udacity.com/mlnd/how_to_use_jupyter.mp4)
47 | - [为什么使用jupyter?](https://www.zhihu.com/question/37490497)
48 |
49 |
50 |
51 |
52 |
53 |
54 |
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 | ### 数据
64 |
65 | 这个项目的数据包含在 `titanic_data.csv` 文件中。文件包含下列特征:
66 |
67 | - **Survived**:是否存活(0代表否,1代表是)
68 | - **Pclass**:社会阶级(1代表上层阶级,2代表中层阶级,3代表底层阶级)
69 | - **Name**:船上乘客的名字
70 | - **Sex**:船上乘客的性别
71 | - **Age**:船上乘客的年龄(可能存在 `NaN`)
72 | - **SibSp**:乘客在船上的兄弟姐妹和配偶的数量
73 | - **Parch**:乘客在船上的父母以及小孩的数量
74 | - **Ticket**:乘客船票的编号
75 | - **Fare**:乘客为船票支付的费用
76 | - **Cabin**:乘客所在船舱的编号(可能存在 `NaN`)
77 | - **Embarked**:乘客上船的港口(C 代表从 Cherbourg 登船,Q 代表从 Queenstown 登船,S 代表从 Southampton 登船)
78 |
--------------------------------------------------------------------------------
/titanic_survival_exploration/titanic_survival_exploration.zip:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/ciozhang/machinelearning-deeplearning-project/d6e8d481990efcfb2ff9861f04c161a29c42cb1b/titanic_survival_exploration/titanic_survival_exploration.zip
--------------------------------------------------------------------------------
/titanic_survival_exploration/titanic_visualizations.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import pandas as pd
3 | import matplotlib.pyplot as plt
4 |
5 | def filter_data(data, condition):
6 | """
7 | Remove elements that do not match the condition provided.
8 | Takes a data list as input and returns a filtered list.
9 | Conditions should be a list of strings of the following format:
10 | ' '
11 | where the following operations are valid: >, <, >=, <=, ==, !=
12 |
13 | Example: ["Sex == 'male'", 'Age < 18']
14 | """
15 |
16 | field, op, value = condition.split(" ")
17 |
18 | # convert value into number or strip excess quotes if string
19 | try:
20 | value = float(value)
21 | except:
22 | value = value.strip("\'\"")
23 |
24 | # get booleans for filtering
25 | if op == ">":
26 | matches = data[field] > value
27 | elif op == "<":
28 | matches = data[field] < value
29 | elif op == ">=":
30 | matches = data[field] >= value
31 | elif op == "<=":
32 | matches = data[field] <= value
33 | elif op == "==":
34 | matches = data[field] == value
35 | elif op == "!=":
36 | matches = data[field] != value
37 | else: # catch invalid operation codes
38 | raise Exception("Invalid comparison operator. Only >, <, >=, <=, ==, != allowed.")
39 |
40 | # filter data and outcomes
41 | data = data[matches].reset_index(drop = True)
42 | return data
43 |
44 | def survival_stats(data, outcomes, key, filters = []):
45 | """
46 | Print out selected statistics regarding survival, given a feature of
47 | interest and any number of filters (including no filters)
48 | """
49 |
50 | # Check that the key exists
51 | if key not in data.columns.values :
52 | print "'{}' is not a feature of the Titanic data. Did you spell something wrong?".format(key)
53 | return False
54 |
55 | # Return the function before visualizing if 'Cabin' or 'Ticket'
56 | # is selected: too many unique categories to display
57 | if(key == 'Cabin' or key == 'PassengerId' or key == 'Ticket'):
58 | print "'{}' has too many unique categories to display! Try a different feature.".format(key)
59 | return False
60 |
61 | # Merge data and outcomes into single dataframe
62 | all_data = pd.concat([data, outcomes], axis = 1)
63 |
64 | # Apply filters to data
65 | for condition in filters:
66 | all_data = filter_data(all_data, condition)
67 |
68 | # Create outcomes DataFrame
69 | all_data = all_data[[key, 'Survived']]
70 |
71 | # Create plotting figure
72 | plt.figure(figsize=(8,6))
73 |
74 | # 'Numerical' features
75 | if(key == 'Age' or key == 'Fare'):
76 |
77 | # Remove NaN values from Age data
78 | all_data = all_data[~np.isnan(all_data[key])]
79 |
80 | # Divide the range of data into bins and count survival rates
81 | min_value = all_data[key].min()
82 | max_value = all_data[key].max()
83 | value_range = max_value - min_value
84 |
85 | # 'Fares' has larger range of values than 'Age' so create more bins
86 | if(key == 'Fare'):
87 | bins = np.arange(0, all_data['Fare'].max() + 20, 20)
88 | if(key == 'Age'):
89 | bins = np.arange(0, all_data['Age'].max() + 10, 10)
90 |
91 | # Overlay each bin's survival rates
92 | nonsurv_vals = all_data[all_data['Survived'] == 0][key].reset_index(drop = True)
93 | surv_vals = all_data[all_data['Survived'] == 1][key].reset_index(drop = True)
94 | plt.hist(nonsurv_vals, bins = bins, alpha = 0.6,
95 | color = 'red', label = 'Did not survive')
96 | plt.hist(surv_vals, bins = bins, alpha = 0.6,
97 | color = 'green', label = 'Survived')
98 |
99 | # Add legend to plot
100 | plt.xlim(0, bins.max())
101 | plt.legend(framealpha = 0.8)
102 |
103 | # 'Categorical' features
104 | else:
105 |
106 | # Set the various categories
107 | if(key == 'Pclass'):
108 | values = np.arange(1,4)
109 | if(key == 'Parch' or key == 'SibSp'):
110 | values = np.arange(0,np.max(data[key]) + 1)
111 | if(key == 'Embarked'):
112 | values = ['C', 'Q', 'S']
113 | if(key == 'Sex'):
114 | values = ['male', 'female']
115 |
116 | # Create DataFrame containing categories and count of each
117 | frame = pd.DataFrame(index = np.arange(len(values)), columns=(key,'Survived','NSurvived'))
118 | for i, value in enumerate(values):
119 | frame.loc[i] = [value, \
120 | len(all_data[(all_data['Survived'] == 1) & (all_data[key] == value)]), \
121 | len(all_data[(all_data['Survived'] == 0) & (all_data[key] == value)])]
122 |
123 | # Set the width of each bar
124 | bar_width = 0.4
125 |
126 | # Display each category's survival rates
127 | for i in np.arange(len(frame)):
128 | nonsurv_bar = plt.bar(i-bar_width, frame.loc[i]['NSurvived'], width = bar_width, color = 'r')
129 | surv_bar = plt.bar(i, frame.loc[i]['Survived'], width = bar_width, color = 'g')
130 |
131 | plt.xticks(np.arange(len(frame)), values)
132 | plt.legend((nonsurv_bar[0], surv_bar[0]),('Did not survive', 'Survived'), framealpha = 0.8)
133 |
134 | # Common attributes for plot formatting
135 | plt.xlabel(key)
136 | plt.ylabel('Number of Passengers')
137 | plt.title('Passenger Survival Statistics With \'%s\' Feature'%(key))
138 | plt.show()
139 |
140 | # Report number of passengers with missing values
141 | if sum(pd.isnull(all_data[key])):
142 | nan_outcomes = all_data[pd.isnull(all_data[key])]['Survived']
143 | print "Passengers with missing '{}' values: {} ({} survived, {} did not survive)".format( \
144 | key, len(nan_outcomes), sum(nan_outcomes == 1), sum(nan_outcomes == 0))
145 |
146 |
--------------------------------------------------------------------------------