├── .gitignore
├── README.md
├── algorithm
├── DecisionTree.py
└── __init__.py
├── data
├── Iris.csv
└── watermelon.csv
├── example
├── __init__.py
├── decision_tree_testes.py
└── sample_testes.py
├── fetch
├── __init__.py
└── fetch_house.py
├── logistic
├── __init__.py
└── logistic.py
├── match
├── .ipynb_checkpoints
│ ├── cheetsheet-checkpoint.ipynb
│ ├── explore_product-checkpoint.ipynb
│ ├── fraud_detection-Copy1-checkpoint.ipynb
│ ├── fraud_detection-checkpoint.ipynb
│ ├── instacart-checkpoint.ipynb
│ ├── instacart_ffm-checkpoint.ipynb
│ ├── instacart_gbm-checkpoint.ipynb
│ ├── instacart_lr-checkpoint.ipynb
│ ├── instacart_match-checkpoint.ipynb
│ ├── instacart_running-checkpoint.ipynb
│ ├── instcart_predict-checkpoint.ipynb
│ └── titanic-checkpoint.ipynb
├── __init__.py
├── cheetsheet.ipynb
├── dstool
│ └── __init__.py
├── explore_product.ipynb
├── fraud_detection.ipynb
├── instacart.ipynb
├── instacart
│ ├── .ipynb_checkpoints
│ │ └── tunning-checkpoint.ipynb
│ ├── __init__.py
│ ├── const.py
│ ├── data
│ │ └── __init__.py
│ ├── eda.py
│ ├── generate.py
│ ├── run.py
│ ├── sample.py
│ ├── tunning.ipynb
│ ├── tunning.py
│ └── utils.py
├── instacart_ffm.ipynb
├── instacart_gbm.ipynb
├── instacart_lr.ipynb
├── instacart_match.ipynb
├── instacart_running.ipynb
├── instcart_predict.ipynb
└── titanic.ipynb
├── network
└── __init__.py
├── preproccessing
├── StandardScaler.py
└── __init__.py
├── profile
├── __init__.py
├── objgraph.py
└── profile.py
├── py_lightgbm
├── __init__.py
├── application
│ ├── __init__.py
│ └── classifier.py
├── boosting
│ ├── __init__.py
│ ├── boosting.py
│ └── gbdt.py
├── config
│ ├── __init__.py
│ └── tree_config.py
├── io
│ ├── __init__.py
│ ├── bin.py
│ └── dataset.py
├── logmanager
│ ├── __init__.py
│ └── logger.py
├── metric
│ └── __init__.py
├── objective
│ ├── __init__.py
│ └── objective_function.py
├── testes
│ ├── .ipynb_checkpoints
│ │ └── test_xgboost-checkpoint.ipynb
│ ├── __init__.py
│ ├── test_dataset.py
│ ├── test_lightgbm.py
│ ├── test_py_lightgbm.py
│ └── test_xgboost.ipynb
├── tree
│ ├── __init__.py
│ ├── data_partition.py
│ ├── feature_histogram.py
│ ├── leaf_splits.py
│ ├── split_info.py
│ ├── tree.py
│ └── tree_learner.py
└── utils
│ ├── __init__.py
│ ├── conf.py
│ └── const.py
├── testes
├── __init__.py
├── decision_tree_test.py
└── logistic_test.py
├── tree
├── DecisionTreeClassifier.py
└── __init__.py
└── utils
├── __init__.py
├── cmd_table.py
├── const.py
├── file_utils.py
├── formula.py
├── logger.py
└── sample.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea/*
2 | *.xml
3 | .DS_Store
4 | *.pyc
5 | *.dot
6 | *.pdf
7 | data/*
8 | data/mnist_test.csv
9 | data/mnist_train.csv
10 | *.txt
11 | *.doc
12 | match/fraud/creditcard.csv
13 | match/instacart/data/*
14 | *.prof
15 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ### README
2 |
3 | #### data format
4 |
5 | Support both np.array and list, while saving space.
6 | target we can store `target` using one C-like array
7 |
8 | data: choose size=(num, num_feature)
9 | target: choose size=(1, num)
10 |
11 | #### Complexity Analysis
--------------------------------------------------------------------------------
/algorithm/DecisionTree.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: DecisionTreeClassifier.py
8 | @time: 2017/3/18 12:57
9 | @change_time:
10 | 1.2017/3/18 12:57 Make a simple decision tree
11 | """
12 | from utils import logger
13 | from collections import deque
14 | from graphviz import Digraph
15 | from collections import defaultdict
16 | from utils import formula
17 | import random
18 | from collections import Counter
19 |
20 |
21 | class DecisionNode(object):
22 | """
23 | """
24 |
25 | def __init__(self, parent=None):
26 | self.parent = parent
27 | self._children = []
28 | self._decide_info = []
29 | return
30 |
31 | def is_leaf(self):
32 | return False
33 |
34 | def is_decide(self):
35 | return False
36 |
37 | def get_decide_func(self):
38 | if not self.parent:
39 | print "not parent"
40 | return None
41 |
42 | return self.parent.get_child_decide_func(self)
43 |
44 | def index(self, child):
45 | return self.children.index(child)
46 |
47 | @property
48 | def children(self):
49 | return self._children
50 |
51 | @property
52 | def decide_info(self):
53 | return self._decide_info
54 |
55 | def add_node(self, node, func):
56 | self.children.append(node)
57 | self._decide_info.append(func)
58 | return
59 |
60 |
61 | class DecideNode(DecisionNode):
62 | def __init__(self, parent, decide_index=None):
63 | super(DecideNode, self).__init__(parent=parent)
64 | self._decide_index = decide_index
65 | return
66 |
67 | def decide(self, item):
68 | feature = item[self.decide_index]
69 |
70 | for index, child in enumerate(self.children):
71 | func = self.decide_info[index]
72 | if func(feature):
73 | if child.is_leaf():
74 | return child.category
75 | else:
76 | return child.decide(item)
77 | return None
78 |
79 | @property
80 | def decide_index(self):
81 | return self._decide_index
82 |
83 | @decide_index.setter
84 | def decide_index(self, value):
85 | self._decide_index = value
86 | return
87 |
88 | def is_decide(self):
89 | return True
90 |
91 | def get_child_decide_func(self, child):
92 | func_index = self.index(child)
93 | if func_index is None:
94 | return None
95 | return self._decide_info[func_index]
96 |
97 | def __repr__(self):
98 | if not self.parent:
99 | return "R index({0})".format(self.decide_index)
100 |
101 | return "D index({0})".format(self.decide_index)
102 |
103 |
104 | class LeafNode(DecisionNode):
105 |
106 | def __init__(self, parent, category=None):
107 | super(LeafNode, self).__init__(parent)
108 | self._category = category
109 | self._num = 1
110 | return
111 |
112 | def is_leaf(self):
113 | return True
114 |
115 | @property
116 | def num(self):
117 | return self._num
118 |
119 | @property
120 | def category(self):
121 | return self._category
122 |
123 | @category.setter
124 | def category(self, value):
125 | self._category = value
126 | return
127 |
128 | @num.setter
129 | def num(self, value):
130 | self._num = value
131 | return
132 |
133 | def __repr__(self):
134 | return "L category({0}), num({1})".format(self.category, self.num)
135 |
136 |
137 | def choose_random(data, target, data_index_list, index_left_list):
138 | sample = random.sample(index_left_list, 1)
139 | if not sample:
140 | logger.error("failed to choose sample for index_left_list({0})".format(index_left_list))
141 | return None
142 | return sample[0]
143 |
144 |
145 | def choose_information_entropy(data, target, data_index_list, index_left_list):
146 | """
147 | choose an index that minimize the information entropy
148 | """
149 | min_entropy = float("inf")
150 | result_index = None
151 |
152 | for index in index_left_list:
153 | target_dict = defaultdict(list)
154 | for index_list in data_index_list:
155 | target_dict[data[index_list][index]].append(index_list)
156 |
157 | entropy = 0
158 | for value, index_list in target_dict.iteritems():
159 | entropy_part = formula.calculate_entropy(target, index_list)
160 | entropy += len(index_list) * entropy_part / float(len(data_index_list))
161 |
162 | if entropy < min_entropy:
163 | result_index = index
164 | min_entropy = entropy
165 |
166 | return result_index
167 |
168 |
169 | def choose_gini(data, target, data_index_list, index_left_list):
170 | """
171 | choose an index that minimize the information entropy
172 | """
173 | min_entropy = float("inf")
174 | result_index = None
175 |
176 | for index in index_left_list:
177 | target_dict = defaultdict(list)
178 | for index_list in data_index_list:
179 | target_dict[data[index_list][index]].append(index_list)
180 |
181 | entropy = 0
182 | for value, index_list in target_dict.iteritems():
183 | entropy_part = formula.calculate_gini(target, index_list)
184 | entropy += len(index_list) * entropy_part / float(len(data_index_list))
185 |
186 | if entropy < min_entropy:
187 | result_index = index
188 | min_entropy = entropy
189 |
190 | return result_index
191 |
192 |
193 | def choose_gain_ratio(data, target, data_index_list, index_left_list):
194 | max_entropy = 0
195 | result_index = None
196 |
197 | cur_entropy_part = formula.calculate_entropy(target, data_index_list)
198 |
199 | for index in index_left_list:
200 | target_dict = defaultdict(list)
201 | for index_list in data_index_list:
202 | target_dict[data[index_list][index]].append(index_list)
203 |
204 | entropy = 0
205 | iv = formula.calculate_iv(target_dict)
206 | for value, index_list in target_dict.iteritems():
207 | entropy_part = formula.calculate_entropy(target, index_list)
208 | entropy += len(index_list) * entropy_part / float(len(data_index_list))
209 |
210 | result = (cur_entropy_part - entropy) / float(iv)
211 |
212 | if result > max_entropy:
213 | result_index = index
214 | max_entropy = result
215 |
216 | return result_index
217 |
218 | CHOOSE_RANDOM = 1 # random choose index
219 | CHOOSE_INFO_ENTROPY = 2 # information entropy
220 | CHOOSE_GAIN_RATIO = 3 # gain ratio
221 | CHOOSE_GINI = 4 # gini
222 |
223 | CHOOSE_FUNC_DICT = {
224 | CHOOSE_RANDOM: choose_random,
225 | CHOOSE_INFO_ENTROPY: choose_information_entropy,
226 | CHOOSE_GAIN_RATIO: choose_gain_ratio,
227 | CHOOSE_GINI: choose_gini,
228 | }
229 |
230 |
231 | def get_choose_func(key):
232 | return CHOOSE_FUNC_DICT.get(key, None)
233 |
234 |
235 | class DecisionTree(object):
236 |
237 | def __init__(self, depth=10):
238 | self._root = DecideNode(None, 0)
239 | self._max_depth = depth
240 | return
241 |
242 | def decide(self, item):
243 | result = self.root.decide(item)
244 | return result
245 |
246 | @property
247 | def max_depth(self):
248 | return self._max_depth
249 |
250 | def make_tree(self, data, target, choose_func=CHOOSE_INFO_ENTROPY):
251 | index_left_list = range(len(data[0]))
252 | data_index_list = range(len(data))
253 | depth = 1
254 | self.make_tree_recursive(data, target, data_index_list, index_left_list, self.root, depth, choose_func)
255 | return
256 |
257 | def choose_index(self, choose_func, data, target, data_index_list, index_left_list):
258 | choose_func = get_choose_func(choose_func)
259 |
260 | if not choose_func:
261 | logger.error("failed to find choose_func for key({0})".format(choose_func))
262 | return None
263 |
264 | index = choose_func(data, target, data_index_list, index_left_list)
265 | if index is None:
266 | logger.error("failed to find index with data_index_list({0}), index_left_list({1}), choose_func({2})"
267 | .format(data_index_list, index_left_list, choose_func))
268 | return None
269 | return index
270 |
271 | def check_finish(self, data, target, data_index_list, index_left_list, depth):
272 |
273 | if len(data_index_list) <= 1: # left one data
274 | return True
275 |
276 | cnt = Counter()
277 | for index in data_index_list:
278 | cnt[target[index]] += 1
279 |
280 | if len(cnt) == 1: # left one category
281 | return True
282 |
283 | if len(index_left_list) == 0: # no attribute left
284 | return True
285 |
286 | if depth == self.max_depth: # reach the max depth
287 | return True
288 |
289 | return False
290 |
291 | def get_majority_category(self, target, data_index_list):
292 | cnt = Counter()
293 | for index in data_index_list:
294 | cnt[target[index]] += 1
295 |
296 | most_common_list = cnt.most_common(1)
297 | if not most_common_list:
298 | logger.error("can't not find most_common_list for data({0})".format(data_index_list))
299 | return None
300 |
301 | category = most_common_list[0][0]
302 | return category
303 |
304 | def make_tree_recursive(self, data, target, data_index_list, index_left_list, root_node, depth, choose_func):
305 | if self.check_finish(data, target, data_index_list, index_left_list, depth):
306 | leaf = LeafNode(root_node)
307 | category = self.get_majority_category(target, data_index_list)
308 | leaf.num = len(data_index_list)
309 | leaf.category = category
310 | root_node.add_node(leaf, None)
311 | return
312 |
313 | index = self.choose_index(choose_func, data, target, data_index_list, index_left_list)
314 | if index is None:
315 | return
316 |
317 | new_index_left_list = [x for x in index_left_list]
318 | new_index_left_list.remove(index)
319 | root_node.decide_index = index
320 |
321 | types_dict = defaultdict(list)
322 | for idx in data_index_list:
323 | types_dict[data[idx][index]].append(idx)
324 |
325 | for key, new_index_list in types_dict.iteritems():
326 | new_depth = depth + 1
327 | if self.check_finish(data, target, new_index_list, new_index_left_list, new_depth): # 当然还有其他终止条件
328 | leaf = LeafNode(root_node)
329 | leaf.num = len(new_index_list)
330 | category = self.get_majority_category(target, new_index_list)
331 | leaf.category = category
332 | root_node.add_node(leaf, formula.equal(key))
333 | continue
334 |
335 | decide_node = DecideNode(root_node)
336 | root_node.add_node(decide_node, formula.equal(key))
337 | self.make_tree_recursive(data, target, new_index_list, new_index_left_list, decide_node, new_depth, choose_func)
338 | return
339 |
340 | @property
341 | def root(self):
342 | return self._root
343 |
344 | def get_node_queue_with_level(self):
345 | level = 1
346 | node_queue = deque()
347 | result_queue = deque()
348 | node_queue.append(self._root)
349 | node_queue.append(level)
350 | result_queue.append(self._root)
351 | result_queue.append(level)
352 |
353 | while len(node_queue):
354 | node = node_queue.popleft()
355 | level = node_queue.popleft()
356 | for child in node.children:
357 | node_queue.append(child)
358 | node_queue.append(level + 1)
359 | result_queue.append(child)
360 | result_queue.append(level + 1)
361 |
362 | return result_queue
363 |
364 | def get_node_queue(self):
365 | node_queue = deque()
366 | result_queue = deque()
367 | node_queue.append(self._root)
368 | result_queue.append(self._root)
369 |
370 | while len(node_queue):
371 | node = node_queue.popleft()
372 | for child in node.children:
373 | node_queue.append(child)
374 | result_queue.append(child)
375 |
376 | return result_queue
377 |
378 | def show(self):
379 | """use queue to traversal"""
380 | node_queue = self.get_node_queue_with_level()
381 |
382 | cur_level = 1
383 | cur_list = []
384 | while len(node_queue):
385 | node = node_queue.popleft()
386 | level = node_queue.popleft()
387 |
388 | if cur_level != level:
389 | print "\t".join(cur_list)
390 | cur_level = level
391 | cur_list = []
392 |
393 | cur_list.append(repr(node))
394 |
395 | if cur_list:
396 | print "\t".join(cur_list)
397 | return
398 |
399 | def save(self, filename):
400 | dot_tree = Digraph(comment="decision_tree_test")
401 | node_queue = self.get_node_queue()
402 |
403 | for item in node_queue:
404 | key_item = str(id(item))
405 | dot_tree.node(key_item, repr(item))
406 | for index, child in enumerate(item.children):
407 | key_child = str(id(child))
408 | decide_func = child.get_decide_func()
409 | dot_tree.edge(key_item, key_child, label=repr(decide_func))
410 |
411 | dot_tree.render(filename, view=True)
412 | return dot_tree
413 |
--------------------------------------------------------------------------------
/algorithm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/18 12:56
11 | @change_time:
12 | 1.2017/3/18 12:56
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/data/Iris.csv:
--------------------------------------------------------------------------------
1 | Id,SepalLengthCm,SepalWidthCm,PetalLengthCm,PetalWidthCm,Species
2 | 51,7.0,3.2,4.7,1.4,Iris-versicolor
3 | 52,6.4,3.2,4.5,1.5,Iris-versicolor
4 | 53,6.9,3.1,4.9,1.5,Iris-versicolor
5 | 54,5.5,2.3,4.0,1.3,Iris-versicolor
6 | 55,6.5,2.8,4.6,1.5,Iris-versicolor
7 | 56,5.7,2.8,4.5,1.3,Iris-versicolor
8 | 57,6.3,3.3,4.7,1.6,Iris-versicolor
9 | 58,4.9,2.4,3.3,1.0,Iris-versicolor
10 | 59,6.6,2.9,4.6,1.3,Iris-versicolor
11 | 60,5.2,2.7,3.9,1.4,Iris-versicolor
12 | 61,5.0,2.0,3.5,1.0,Iris-versicolor
13 | 62,5.9,3.0,4.2,1.5,Iris-versicolor
14 | 63,6.0,2.2,4.0,1.0,Iris-versicolor
15 | 64,6.1,2.9,4.7,1.4,Iris-versicolor
16 | 65,5.6,2.9,3.6,1.3,Iris-versicolor
17 | 66,6.7,3.1,4.4,1.4,Iris-versicolor
18 | 67,5.6,3.0,4.5,1.5,Iris-versicolor
19 | 68,5.8,2.7,4.1,1.0,Iris-versicolor
20 | 69,6.2,2.2,4.5,1.5,Iris-versicolor
21 | 70,5.6,2.5,3.9,1.1,Iris-versicolor
22 | 71,5.9,3.2,4.8,1.8,Iris-versicolor
23 | 72,6.1,2.8,4.0,1.3,Iris-versicolor
24 | 73,6.3,2.5,4.9,1.5,Iris-versicolor
25 | 74,6.1,2.8,4.7,1.2,Iris-versicolor
26 | 75,6.4,2.9,4.3,1.3,Iris-versicolor
27 | 76,6.6,3.0,4.4,1.4,Iris-versicolor
28 | 77,6.8,2.8,4.8,1.4,Iris-versicolor
29 | 78,6.7,3.0,5.0,1.7,Iris-versicolor
30 | 79,6.0,2.9,4.5,1.5,Iris-versicolor
31 | 80,5.7,2.6,3.5,1.0,Iris-versicolor
32 | 81,5.5,2.4,3.8,1.1,Iris-versicolor
33 | 82,5.5,2.4,3.7,1.0,Iris-versicolor
34 | 83,5.8,2.7,3.9,1.2,Iris-versicolor
35 | 84,6.0,2.7,5.1,1.6,Iris-versicolor
36 | 85,5.4,3.0,4.5,1.5,Iris-versicolor
37 | 86,6.0,3.4,4.5,1.6,Iris-versicolor
38 | 87,6.7,3.1,4.7,1.5,Iris-versicolor
39 | 88,6.3,2.3,4.4,1.3,Iris-versicolor
40 | 89,5.6,3.0,4.1,1.3,Iris-versicolor
41 | 90,5.5,2.5,4.0,1.3,Iris-versicolor
42 | 91,5.5,2.6,4.4,1.2,Iris-versicolor
43 | 92,6.1,3.0,4.6,1.4,Iris-versicolor
44 | 93,5.8,2.6,4.0,1.2,Iris-versicolor
45 | 94,5.0,2.3,3.3,1.0,Iris-versicolor
46 | 95,5.6,2.7,4.2,1.3,Iris-versicolor
47 | 96,5.7,3.0,4.2,1.2,Iris-versicolor
48 | 97,5.7,2.9,4.2,1.3,Iris-versicolor
49 | 98,6.2,2.9,4.3,1.3,Iris-versicolor
50 | 99,5.1,2.5,3.0,1.1,Iris-versicolor
51 | 100,5.7,2.8,4.1,1.3,Iris-versicolor
52 | 101,6.3,3.3,6.0,2.5,Iris-virginica
53 | 102,5.8,2.7,5.1,1.9,Iris-virginica
54 | 103,7.1,3.0,5.9,2.1,Iris-virginica
55 | 104,6.3,2.9,5.6,1.8,Iris-virginica
56 | 105,6.5,3.0,5.8,2.2,Iris-virginica
57 | 106,7.6,3.0,6.6,2.1,Iris-virginica
58 | 107,4.9,2.5,4.5,1.7,Iris-virginica
59 | 108,7.3,2.9,6.3,1.8,Iris-virginica
60 | 109,6.7,2.5,5.8,1.8,Iris-virginica
61 | 110,7.2,3.6,6.1,2.5,Iris-virginica
62 | 111,6.5,3.2,5.1,2.0,Iris-virginica
63 | 112,6.4,2.7,5.3,1.9,Iris-virginica
64 | 113,6.8,3.0,5.5,2.1,Iris-virginica
65 | 114,5.7,2.5,5.0,2.0,Iris-virginica
66 | 115,5.8,2.8,5.1,2.4,Iris-virginica
67 | 116,6.4,3.2,5.3,2.3,Iris-virginica
68 | 117,6.5,3.0,5.5,1.8,Iris-virginica
69 | 118,7.7,3.8,6.7,2.2,Iris-virginica
70 | 119,7.7,2.6,6.9,2.3,Iris-virginica
71 | 120,6.0,2.2,5.0,1.5,Iris-virginica
72 | 121,6.9,3.2,5.7,2.3,Iris-virginica
73 | 122,5.6,2.8,4.9,2.0,Iris-virginica
74 | 123,7.7,2.8,6.7,2.0,Iris-virginica
75 | 124,6.3,2.7,4.9,1.8,Iris-virginica
76 | 125,6.7,3.3,5.7,2.1,Iris-virginica
77 | 126,7.2,3.2,6.0,1.8,Iris-virginica
78 | 127,6.2,2.8,4.8,1.8,Iris-virginica
79 | 128,6.1,3.0,4.9,1.8,Iris-virginica
80 | 129,6.4,2.8,5.6,2.1,Iris-virginica
81 | 130,7.2,3.0,5.8,1.6,Iris-virginica
82 | 131,7.4,2.8,6.1,1.9,Iris-virginica
83 | 132,7.9,3.8,6.4,2.0,Iris-virginica
84 | 133,6.4,2.8,5.6,2.2,Iris-virginica
85 | 134,6.3,2.8,5.1,1.5,Iris-virginica
86 | 135,6.1,2.6,5.6,1.4,Iris-virginica
87 | 136,7.7,3.0,6.1,2.3,Iris-virginica
88 | 137,6.3,3.4,5.6,2.4,Iris-virginica
89 | 138,6.4,3.1,5.5,1.8,Iris-virginica
90 | 139,6.0,3.0,4.8,1.8,Iris-virginica
91 | 140,6.9,3.1,5.4,2.1,Iris-virginica
92 | 141,6.7,3.1,5.6,2.4,Iris-virginica
93 | 142,6.9,3.1,5.1,2.3,Iris-virginica
94 | 143,5.8,2.7,5.1,1.9,Iris-virginica
95 | 144,6.8,3.2,5.9,2.3,Iris-virginica
96 | 145,6.7,3.3,5.7,2.5,Iris-virginica
97 | 146,6.7,3.0,5.2,2.3,Iris-virginica
98 | 147,6.3,2.5,5.0,1.9,Iris-virginica
99 | 148,6.5,3.0,5.2,2.0,Iris-virginica
100 | 149,6.2,3.4,5.4,2.3,Iris-virginica
101 | 150,5.9,3.0,5.1,1.8,Iris-virginica
102 |
--------------------------------------------------------------------------------
/data/watermelon.csv:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/xiaoqiangkx/toyplay/c78ab9a67739412797f9311b2ddd70531d2f2684/data/watermelon.csv
--------------------------------------------------------------------------------
/example/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/15 21:39
11 | @change_time:
12 | 1.2017/3/15 21:39
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/example/decision_tree_testes.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: decision_tree_testes.py
8 | @time: 2017/3/19 15:39
9 | @change_time:
10 | 1.2017/3/19 15:39
11 | """
12 | from utils import file_utils
13 | from utils import sample
14 | from algorithm import DecisionTree as DT
15 | from tree import DecisionTreeClassifier as DTC
16 | from utils import logger
17 | from utils import formula
18 |
19 |
20 | if __name__ == '__main__':
21 | train_filename = "../data/mnist_train.csv"
22 | test_filename = "../data/mnist_test.csv"
23 | training_data, training_target, _, _ = file_utils.load_mnist_data(train_filename, test_filename)
24 | # training_data = training_data[0:100, :]
25 | # training_target = training_target[0:100]
26 | train_index, cv_index, test_index = sample.sample_target_data(training_target)
27 |
28 | train_data = training_data[train_index, :]
29 | train_target = training_target[train_index]
30 | cv_data = training_data[cv_index, :]
31 | cv_target = training_target[cv_index]
32 | test_data = training_data[test_index, :]
33 | test_target = training_target[test_index]
34 |
35 | tree = DTC.DecisionTreeClassifier(depth=3)
36 | logger.info("start making tree")
37 | tree.fit(train_data, train_target, choose_func=DT.CHOOSE_GAIN_RATIO)
38 | logger.info("finish making tree")
39 | # tree.show()
40 |
41 | logger.info("start predicting")
42 | cv_result = tree.predict(cv_data)
43 | logger.info("calculate precision")
44 | formula.cal_new(cv_result, cv_target)
45 |
--------------------------------------------------------------------------------
/example/sample_testes.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: sample_testes.py
10 | @time: 2017/3/15 20:59
11 | @change_time:
12 | 1.2017/3/15 20:59
13 | """
14 | from logistic.logistic import LogisticRegression
15 | from utils import file_utils as FU
16 | from utils import formula as FORMULA
17 | from utils import sample as SAMPLE
18 | from matplotlib import pyplot as plt
19 | from preproccessing.StandardScaler import StandardScaler
20 | from utils import logger as LOGGER
21 |
22 |
23 | def model_test(X, Y, method):
24 | amount = 0
25 | times = 0
26 | for cv_data, cv_target, test_data, test_target in SAMPLE.iter_sample_data(X, Y, method):
27 | times += 1
28 | model = LogisticRegression(delta=0.01, alpha=0.01)
29 | model.fit(cv_data, cv_target)
30 | predict_y = model.predict(test_data)
31 | amount += FORMULA.cal(predict_y, test_target)
32 |
33 | return float(amount) / times
34 |
35 | if __name__ == '__main__':
36 | filename = "../data/iris.csv"
37 | X, Y = FU.load_iris_data(filename)
38 | X = StandardScaler().fit_transform(X)
39 | X = FORMULA.plus_one(X)
40 |
41 | LOGGER.setLevel(LOGGER.LEVEL_NORMAL)
42 | print u"10折交叉法:", model_test(X, Y, 10)
43 | print u"留一法:", model_test(X, Y, 1)
44 |
45 | # model.draw_data(X, Y)
46 | # model.draw_line(X)
47 | # model.draw_loss()
48 | # plt.show()
49 |
--------------------------------------------------------------------------------
/fetch/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/3/30 21:52
9 | @change_time:
10 | 1.2017/3/30 21:52
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/fetch/fetch_house.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: fetch_house.py
8 | @time: 2017/3/30 21:53
9 | @change_time:
10 | 1.2017/3/30 21:53
11 | """
12 | import os
13 | import tarfile
14 | import requests
15 |
16 |
17 | EXTRACT_HOUSING_PATH = "../data/housing"
18 | HOUSING_URL = "https://github.com/ageron/handson-ml/raw/master/datasets/housing/housing.tgz"
19 |
20 |
21 | def download_file(url):
22 | print url
23 |
24 | local_filename = url.split('/')[-1]
25 | # NOTE the stream=True parameter
26 | r = requests.get(url, stream=True)
27 | with open(local_filename, 'wb') as f:
28 | for chunk in r.iter_content(chunk_size=1024):
29 | if chunk: # filter out keep-alive new chunks
30 | f.write(chunk)
31 |
32 | return local_filename
33 |
34 |
35 | def fetch_housing_data(housing_url=HOUSING_URL, extract_path=EXTRACT_HOUSING_PATH):
36 | if not os.path.isdir(extract_path):
37 | os.makedirs(extract_path)
38 |
39 | filename = download_file(housing_url)
40 | housing_tgz = tarfile.open(filename)
41 | housing_tgz.extractall(path=extract_path)
42 | housing_tgz.close()
43 | os.remove(filename)
44 |
45 |
46 | if __name__ == '__main__':
47 | fetch_housing_data()
48 |
--------------------------------------------------------------------------------
/logistic/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/12 23:46
11 | @change_time:
12 | 1.2017/3/12 23:46
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/logistic/logistic.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: logistic.py
10 | @time: 2017/3/12 23:47
11 | @change_time:
12 | 1.2017/3/12 23:47
13 | """
14 | import numpy as np
15 | from utils import file_utils as FU
16 | from utils import formula
17 | import matplotlib.pyplot as plt
18 |
19 |
20 | class LogisticRegression(object):
21 | def __init__(self, delta=0.0001, alpha=0.1):
22 | """
23 | init hyper-parameters
24 | """
25 | self.w = None
26 | self.size = None
27 | self.delta = delta
28 | self.theshold = 0.5
29 | self.loss_data = []
30 | self.alpha = alpha
31 | return
32 |
33 | def init_params(self, item_size):
34 | self.w = np.random.rand(*item_size)
35 | self.size = item_size
36 | self.loss_data = []
37 | return
38 |
39 | def fit(self, X, Y):
40 | """
41 | Train model using training data and hyper-parameters
42 | """
43 | item_size = (X.shape[0], 1)
44 | self.init_params(item_size)
45 |
46 | num = 0
47 | last_loss = np.inf
48 | step = 0
49 | while True:
50 | self.w -= self.cal_delta(X, Y).astype(float)
51 | num += 1
52 | loss = self.cal_loss(X, Y)
53 | if np.abs(last_loss - loss) <= self.delta:
54 | break
55 | else:
56 | step += 1
57 | last_loss = loss
58 | self.loss_data.append(loss)
59 | # print "step:", step, "loss:", loss
60 |
61 | return
62 |
63 | def predict(self, X):
64 | data = formula.sigmoid(X, self.w) >= self.theshold
65 | y = np.zeros((1, X.shape[1]))
66 | for index, value in enumerate(data[0, :]):
67 | if value:
68 | y[0, index] = 1
69 | return y
70 |
71 | def cal_delta(self, X, Y):
72 | y = formula.sigmoid(X, self.w)
73 | delta_w_1 = - np.dot(X, (Y - y).T)
74 |
75 | delta_w_2 = np.zeros((self.size[0], self.size[0]))
76 | for i in xrange(X.shape[1]):
77 | delta_w_2 += (np.dot(X[:, i:i+1], X[:, i:i+1].T) * np.dot(y, (1 - y).T)).astype(float)
78 | return np.dot(np.linalg.inv(delta_w_2), delta_w_1)
79 | # return self.alpha * delta_w_1
80 |
81 | def cal_loss(self, X, Y):
82 | loss = 0
83 | for i in xrange(X.shape[1]):
84 | temp = np.dot(self.w.T, X[:, i])[[0]]
85 | loss += -Y[0, i] * temp + np.log(1 + np.exp(temp.astype(float)))
86 | return loss
87 |
88 | def draw_data(self, X, Y):
89 | np_index1 = [index for index, x in enumerate(Y[0]) if x == 0]
90 | np_index2 = [index for index, x in enumerate(Y[0]) if x == 1]
91 |
92 | plt.scatter(X[0, np_index1], X[1, np_index1], c='y', marker='o')
93 | plt.scatter(X[0, np_index2], X[1, np_index2], c='b', marker='v')
94 | return
95 |
96 | def draw_line(self, X):
97 | min1 = np.min(X[0, :])
98 | max1 = np.max(X[0, :])
99 |
100 | min2 = - (self.w[2] + self.w[0] * min1) / self.w[1]
101 | max2 = - (self.w[2] + self.w[0] * max1) / self.w[1]
102 |
103 | plt.plot([min1, max1], [min2, max2], 'r')
104 | return
105 |
106 | def draw_loss(self):
107 | plt.plot(self.loss_data)
108 | return
109 |
110 |
111 | if __name__ == '__main__':
112 | filename = u"../data/watermelon.csv"
113 | X, Y = FU.load_water_melon_data(filename)
114 | X = formula.plus_one(X)
115 | model = LogisticRegression()
116 | model.fit(X, Y)
117 |
118 | # model.draw_data(X, Y)
119 | # model.draw_line(X)
120 | model.draw_loss()
121 | plt.show()
122 |
123 |
--------------------------------------------------------------------------------
/match/.ipynb_checkpoints/cheetsheet-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 0
6 | }
7 |
--------------------------------------------------------------------------------
/match/.ipynb_checkpoints/instacart_match-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 2
6 | }
7 |
--------------------------------------------------------------------------------
/match/.ipynb_checkpoints/instacart_running-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [],
3 | "metadata": {},
4 | "nbformat": 4,
5 | "nbformat_minor": 2
6 | }
7 |
--------------------------------------------------------------------------------
/match/.ipynb_checkpoints/instcart_predict-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%matplotlib inline\n",
12 | "import matplotlib\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import pandas as pd\n",
15 | "import numpy as np\n",
16 | "import seaborn as sns\n",
17 | "import scipy.sparse\n",
18 | "pd.set_option(\"display.max_columns\",101)\n",
19 | "RANDOM_STATE = 42\n",
20 | "DATA_PATH = \"../data/instacart/\""
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {
27 | "collapsed": true
28 | },
29 | "outputs": [],
30 | "source": [
31 | "positive_X = pd.read_csv(DATA_PATH + \"positive_X.csv\")\n",
32 | "negative_X = pd.read_csv(DATA_PATH + \"negative_X.csv\")\n",
33 | "positive_X.drop(['Unnamed: 0', ], axis=1, inplace=True)\n",
34 | "negative_X.drop(['Unnamed: 0', ], axis=1, inplace=True)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 3,
40 | "metadata": {
41 | "collapsed": true
42 | },
43 | "outputs": [],
44 | "source": [
45 | "positive_X.fillna(0, inplace=True)\n",
46 | "negative_X.fillna(0, inplace=True)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 4,
52 | "metadata": {
53 | "collapsed": true
54 | },
55 | "outputs": [],
56 | "source": [
57 | "from sklearn.preprocessing import StandardScaler\n",
58 | "pX = StandardScaler().fit_transform(positive_X)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 5,
64 | "metadata": {
65 | "collapsed": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "nX = StandardScaler().fit_transform(negative_X)"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 6,
75 | "metadata": {
76 | "collapsed": true
77 | },
78 | "outputs": [],
79 | "source": [
80 | "py = np.ones(positive_X.shape[0])\n",
81 | "ny = np.zeros(negative_X.shape[0])"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 7,
87 | "metadata": {
88 | "collapsed": true
89 | },
90 | "outputs": [],
91 | "source": [
92 | "m, n = nX.shape"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 8,
98 | "metadata": {
99 | "collapsed": false
100 | },
101 | "outputs": [
102 | {
103 | "name": "stdout",
104 | "output_type": "stream",
105 | "text": [
106 | "1384617 73\n"
107 | ]
108 | }
109 | ],
110 | "source": [
111 | "print m, n"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 25,
117 | "metadata": {
118 | "collapsed": false
119 | },
120 | "outputs": [],
121 | "source": [
122 | "percentage = 1\n",
123 | "sample_m = int(m * percentage)\n",
124 | "X = np.concatenate([pX[:sample_m, :], nX[:sample_m, :]])\n",
125 | "y = np.concatenate([py[:sample_m], ny[:sample_m]])"
126 | ]
127 | },
128 | {
129 | "cell_type": "code",
130 | "execution_count": null,
131 | "metadata": {
132 | "collapsed": true
133 | },
134 | "outputs": [],
135 | "source": []
136 | },
137 | {
138 | "cell_type": "code",
139 | "execution_count": 10,
140 | "metadata": {
141 | "collapsed": false,
142 | "scrolled": false
143 | },
144 | "outputs": [
145 | {
146 | "data": {
147 | "text/plain": [
148 | "LogisticRegression(C=1.0, class_weight=None, dual=False, fit_intercept=True,\n",
149 | " intercept_scaling=1, max_iter=100, multi_class='ovr', n_jobs=1,\n",
150 | " penalty='l1', random_state=None, solver='liblinear', tol=0.0001,\n",
151 | " verbose=0, warm_start=False)"
152 | ]
153 | },
154 | "execution_count": 10,
155 | "metadata": {},
156 | "output_type": "execute_result"
157 | }
158 | ],
159 | "source": [
160 | "from sklearn.linear_model import LogisticRegression\n",
161 | "lr = LogisticRegression(penalty='l1')\n",
162 | "lr.fit(X, y)"
163 | ]
164 | },
165 | {
166 | "cell_type": "code",
167 | "execution_count": 23,
168 | "metadata": {
169 | "collapsed": true
170 | },
171 | "outputs": [],
172 | "source": [
173 | "test_X = np.concatenate([pX[sample_m:2*sample_m, :], nX[sample_m:2*sample_m, :]])\n",
174 | "test_y = np.concatenate([py[sample_m:2*sample_m], ny[sample_m:2*sample_m]])"
175 | ]
176 | },
177 | {
178 | "cell_type": "code",
179 | "execution_count": 12,
180 | "metadata": {
181 | "collapsed": false
182 | },
183 | "outputs": [
184 | {
185 | "data": {
186 | "text/plain": [
187 | "0.50556117290192115"
188 | ]
189 | },
190 | "execution_count": 12,
191 | "metadata": {},
192 | "output_type": "execute_result"
193 | }
194 | ],
195 | "source": [
196 | "lr.score(test_X, test_y)"
197 | ]
198 | },
199 | {
200 | "cell_type": "code",
201 | "execution_count": 13,
202 | "metadata": {
203 | "collapsed": false
204 | },
205 | "outputs": [
206 | {
207 | "data": {
208 | "text/plain": [
209 | "0.51686407626751407"
210 | ]
211 | },
212 | "execution_count": 13,
213 | "metadata": {},
214 | "output_type": "execute_result"
215 | }
216 | ],
217 | "source": [
218 | "lr.score(X, y)"
219 | ]
220 | },
221 | {
222 | "cell_type": "markdown",
223 | "metadata": {},
224 | "source": [
225 | "#### The model is Low bias"
226 | ]
227 | },
228 | {
229 | "cell_type": "code",
230 | "execution_count": 21,
231 | "metadata": {
232 | "collapsed": false
233 | },
234 | "outputs": [
235 | {
236 | "data": {
237 | "text/plain": [
238 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=None,\n",
239 | " max_features=None, max_leaf_nodes=None,\n",
240 | " min_impurity_split=1e-07, min_samples_leaf=1,\n",
241 | " min_samples_split=2, min_weight_fraction_leaf=0.0,\n",
242 | " presort=False, random_state=None, splitter='best')"
243 | ]
244 | },
245 | "execution_count": 21,
246 | "metadata": {},
247 | "output_type": "execute_result"
248 | }
249 | ],
250 | "source": [
251 | "from sklearn.tree import DecisionTreeClassifier\n",
252 | "clf = DecisionTreeClassifier()\n",
253 | "clf.fit(X, y)"
254 | ]
255 | },
256 | {
257 | "cell_type": "code",
258 | "execution_count": 22,
259 | "metadata": {
260 | "collapsed": false
261 | },
262 | "outputs": [
263 | {
264 | "data": {
265 | "text/plain": [
266 | "1.0"
267 | ]
268 | },
269 | "execution_count": 22,
270 | "metadata": {},
271 | "output_type": "execute_result"
272 | }
273 | ],
274 | "source": [
275 | "clf.score(X, y)"
276 | ]
277 | },
278 | {
279 | "cell_type": "code",
280 | "execution_count": 24,
281 | "metadata": {
282 | "collapsed": false
283 | },
284 | "outputs": [
285 | {
286 | "data": {
287 | "text/plain": [
288 | "0.99962083185879058"
289 | ]
290 | },
291 | "execution_count": 24,
292 | "metadata": {},
293 | "output_type": "execute_result"
294 | }
295 | ],
296 | "source": [
297 | "clf.score(test_X, test_y)"
298 | ]
299 | },
300 | {
301 | "cell_type": "code",
302 | "execution_count": null,
303 | "metadata": {
304 | "collapsed": true
305 | },
306 | "outputs": [],
307 | "source": []
308 | },
309 | {
310 | "cell_type": "markdown",
311 | "metadata": {},
312 | "source": [
313 | "#### Try add test item into data"
314 | ]
315 | },
316 | {
317 | "cell_type": "code",
318 | "execution_count": 17,
319 | "metadata": {
320 | "collapsed": true
321 | },
322 | "outputs": [],
323 | "source": [
324 | "orders = pd.read_csv(DATA_PATH + \"orders.csv\")"
325 | ]
326 | },
327 | {
328 | "cell_type": "code",
329 | "execution_count": 18,
330 | "metadata": {
331 | "collapsed": true
332 | },
333 | "outputs": [],
334 | "source": [
335 | "test_order = orders.loc[orders.eval_set == 'test']"
336 | ]
337 | },
338 | {
339 | "cell_type": "code",
340 | "execution_count": 22,
341 | "metadata": {
342 | "collapsed": false
343 | },
344 | "outputs": [
345 | {
346 | "data": {
347 | "text/plain": [
348 | "(75000, 7)"
349 | ]
350 | },
351 | "execution_count": 22,
352 | "metadata": {},
353 | "output_type": "execute_result"
354 | }
355 | ],
356 | "source": [
357 | "test_order.shape"
358 | ]
359 | },
360 | {
361 | "cell_type": "code",
362 | "execution_count": 23,
363 | "metadata": {
364 | "collapsed": false
365 | },
366 | "outputs": [
367 | {
368 | "data": {
369 | "text/html": [
370 | "
\n",
371 | "
\n",
372 | " \n",
373 | " \n",
374 | " | \n",
375 | " order_id | \n",
376 | " user_id | \n",
377 | " eval_set | \n",
378 | " order_number | \n",
379 | " order_dow | \n",
380 | " order_hour_of_day | \n",
381 | " days_since_prior_order | \n",
382 | "
\n",
383 | " \n",
384 | " \n",
385 | " \n",
386 | " 38 | \n",
387 | " 2774568 | \n",
388 | " 3 | \n",
389 | " test | \n",
390 | " 13 | \n",
391 | " 5 | \n",
392 | " 15 | \n",
393 | " 11.0 | \n",
394 | "
\n",
395 | " \n",
396 | " 44 | \n",
397 | " 329954 | \n",
398 | " 4 | \n",
399 | " test | \n",
400 | " 6 | \n",
401 | " 3 | \n",
402 | " 12 | \n",
403 | " 30.0 | \n",
404 | "
\n",
405 | " \n",
406 | " 53 | \n",
407 | " 1528013 | \n",
408 | " 6 | \n",
409 | " test | \n",
410 | " 4 | \n",
411 | " 3 | \n",
412 | " 16 | \n",
413 | " 22.0 | \n",
414 | "
\n",
415 | " \n",
416 | " 96 | \n",
417 | " 1376945 | \n",
418 | " 11 | \n",
419 | " test | \n",
420 | " 8 | \n",
421 | " 6 | \n",
422 | " 11 | \n",
423 | " 8.0 | \n",
424 | "
\n",
425 | " \n",
426 | " 102 | \n",
427 | " 1356845 | \n",
428 | " 12 | \n",
429 | " test | \n",
430 | " 6 | \n",
431 | " 1 | \n",
432 | " 20 | \n",
433 | " 30.0 | \n",
434 | "
\n",
435 | " \n",
436 | "
\n",
437 | "
"
438 | ],
439 | "text/plain": [
440 | " order_id user_id eval_set order_number order_dow order_hour_of_day \\\n",
441 | "38 2774568 3 test 13 5 15 \n",
442 | "44 329954 4 test 6 3 12 \n",
443 | "53 1528013 6 test 4 3 16 \n",
444 | "96 1376945 11 test 8 6 11 \n",
445 | "102 1356845 12 test 6 1 20 \n",
446 | "\n",
447 | " days_since_prior_order \n",
448 | "38 11.0 \n",
449 | "44 30.0 \n",
450 | "53 22.0 \n",
451 | "96 8.0 \n",
452 | "102 30.0 "
453 | ]
454 | },
455 | "execution_count": 23,
456 | "metadata": {},
457 | "output_type": "execute_result"
458 | }
459 | ],
460 | "source": [
461 | "test_order.head()"
462 | ]
463 | },
464 | {
465 | "cell_type": "markdown",
466 | "metadata": {},
467 | "source": [
468 | "#### product_list of user_id "
469 | ]
470 | },
471 | {
472 | "cell_type": "code",
473 | "execution_count": 19,
474 | "metadata": {
475 | "collapsed": true
476 | },
477 | "outputs": [],
478 | "source": [
479 | "test_user_id_list = np.unique(test_order['user_id'])\n",
480 | "test_order_ix = test_order.ix\n",
481 | "test_order_index = list(test_order.index)"
482 | ]
483 | },
484 | {
485 | "cell_type": "code",
486 | "execution_count": 57,
487 | "metadata": {
488 | "collapsed": false
489 | },
490 | "outputs": [
491 | {
492 | "data": {
493 | "text/plain": [
494 | "order_id 2774568\n",
495 | "user_id 3\n",
496 | "eval_set test\n",
497 | "order_number 13\n",
498 | "order_dow 5\n",
499 | "order_hour_of_day 15\n",
500 | "days_since_prior_order 11\n",
501 | "Name: 38, dtype: object"
502 | ]
503 | },
504 | "execution_count": 57,
505 | "metadata": {},
506 | "output_type": "execute_result"
507 | }
508 | ],
509 | "source": [
510 | "test_order_ix[38]"
511 | ]
512 | },
513 | {
514 | "cell_type": "code",
515 | "execution_count": 20,
516 | "metadata": {
517 | "collapsed": true
518 | },
519 | "outputs": [],
520 | "source": [
521 | "test_X = pd.read_csv(DATA_PATH + \"test_X.csv\")"
522 | ]
523 | },
524 | {
525 | "cell_type": "code",
526 | "execution_count": 21,
527 | "metadata": {
528 | "collapsed": true
529 | },
530 | "outputs": [],
531 | "source": [
532 | "test_X.drop(['Unnamed: 0', ], axis=1, inplace=True)\n",
533 | "test_X.fillna(0, inplace=True)\n",
534 | "from sklearn.preprocessing import StandardScaler\n",
535 | "testX = StandardScaler().fit_transform(test_X)"
536 | ]
537 | },
538 | {
539 | "cell_type": "code",
540 | "execution_count": 50,
541 | "metadata": {
542 | "collapsed": false
543 | },
544 | "outputs": [],
545 | "source": [
546 | "predict_y = lr.predict(test_X)"
547 | ]
548 | },
549 | {
550 | "cell_type": "code",
551 | "execution_count": 51,
552 | "metadata": {
553 | "collapsed": false
554 | },
555 | "outputs": [
556 | {
557 | "data": {
558 | "text/plain": [
559 | "(4833292,)"
560 | ]
561 | },
562 | "execution_count": 51,
563 | "metadata": {},
564 | "output_type": "execute_result"
565 | }
566 | ],
567 | "source": [
568 | "predict_y.shape"
569 | ]
570 | },
571 | {
572 | "cell_type": "code",
573 | "execution_count": 55,
574 | "metadata": {
575 | "collapsed": true
576 | },
577 | "outputs": [],
578 | "source": [
579 | "prior = pd.read_csv(DATA_PATH + \"order_products__prior.csv\")\n",
580 | "prior_order = pd.merge(prior, orders, on='order_id')\n",
581 | "item_user_reordered = pd.DataFrame(prior_order.groupby(['user_id', 'product_id']).agg({'reordered': np.sum}))"
582 | ]
583 | },
584 | {
585 | "cell_type": "code",
586 | "execution_count": 60,
587 | "metadata": {
588 | "collapsed": true
589 | },
590 | "outputs": [],
591 | "source": [
592 | "item_user_reordered_idx = item_user_reordered.ix"
593 | ]
594 | },
595 | {
596 | "cell_type": "code",
597 | "execution_count": 64,
598 | "metadata": {
599 | "collapsed": false
600 | },
601 | "outputs": [
602 | {
603 | "data": {
604 | "text/plain": [
605 | "array([ 1., 1., 0., ..., 1., 0., 1.])"
606 | ]
607 | },
608 | "execution_count": 64,
609 | "metadata": {},
610 | "output_type": "execute_result"
611 | }
612 | ],
613 | "source": [
614 | "predict_y"
615 | ]
616 | },
617 | {
618 | "cell_type": "code",
619 | "execution_count": 62,
620 | "metadata": {
621 | "collapsed": false,
622 | "scrolled": true
623 | },
624 | "outputs": [
625 | {
626 | "name": "stdout",
627 | "output_type": "stream",
628 | "text": [
629 | "1\n",
630 | "10001\n",
631 | "20001\n",
632 | "30001\n",
633 | "40001\n",
634 | "50001\n",
635 | "60001\n",
636 | "70001\n",
637 | "80001\n",
638 | "90001\n",
639 | "100001\n",
640 | "110001\n",
641 | "120001\n",
642 | "130001\n",
643 | "140001\n",
644 | "150001\n",
645 | "160001\n",
646 | "170001\n",
647 | "180001\n",
648 | "190001\n",
649 | "200001\n",
650 | "210001\n",
651 | "220001\n",
652 | "230001\n",
653 | "240001\n",
654 | "250001\n",
655 | "260001\n",
656 | "270001\n",
657 | "280001\n",
658 | "290001\n",
659 | "300001\n",
660 | "310001\n",
661 | "320001\n",
662 | "330001\n",
663 | "340001\n",
664 | "350001\n",
665 | "360001\n",
666 | "370001\n",
667 | "380001\n",
668 | "390001\n",
669 | "400001\n",
670 | "410001\n",
671 | "420001\n",
672 | "430001\n",
673 | "440001\n",
674 | "450001\n",
675 | "460001\n",
676 | "470001\n",
677 | "480001\n",
678 | "490001\n",
679 | "500001\n",
680 | "510001\n",
681 | "520001\n",
682 | "530001\n",
683 | "540001\n",
684 | "550001\n",
685 | "560001\n",
686 | "570001\n",
687 | "580001\n",
688 | "590001\n",
689 | "600001\n",
690 | "610001\n",
691 | "620001\n",
692 | "630001\n",
693 | "640001\n",
694 | "650001\n",
695 | "660001\n",
696 | "670001\n",
697 | "680001\n",
698 | "690001\n",
699 | "700001\n",
700 | "710001\n",
701 | "720001\n",
702 | "730001\n",
703 | "740001\n",
704 | "750001\n",
705 | "760001\n",
706 | "770001\n",
707 | "780001\n",
708 | "790001\n",
709 | "800001\n",
710 | "810001\n",
711 | "820001\n",
712 | "830001\n",
713 | "840001\n",
714 | "850001\n",
715 | "860001\n",
716 | "870001\n",
717 | "880001\n",
718 | "890001\n",
719 | "900001\n",
720 | "910001\n",
721 | "920001\n",
722 | "930001\n",
723 | "940001\n",
724 | "950001\n",
725 | "960001\n",
726 | "970001\n",
727 | "980001\n",
728 | "990001\n",
729 | "1000001\n",
730 | "1010001\n",
731 | "1020001\n",
732 | "1030001\n",
733 | "1040001\n",
734 | "1050001\n",
735 | "1060001\n",
736 | "1070001\n",
737 | "1080001\n",
738 | "1090001\n",
739 | "1100001\n",
740 | "1110001\n",
741 | "1120001\n",
742 | "1130001\n",
743 | "1140001\n",
744 | "1150001\n",
745 | "1160001\n",
746 | "1170001\n",
747 | "1180001\n",
748 | "1190001\n",
749 | "1200001\n",
750 | "1210001\n",
751 | "1220001\n",
752 | "1230001\n",
753 | "1240001\n",
754 | "1250001\n",
755 | "1260001\n",
756 | "1270001\n",
757 | "1280001\n",
758 | "1290001\n",
759 | "1300001\n",
760 | "1310001\n",
761 | "1320001\n",
762 | "1330001\n",
763 | "1340001\n",
764 | "1350001\n",
765 | "1360001\n",
766 | "1370001\n",
767 | "1380001\n",
768 | "1390001\n",
769 | "1400001\n",
770 | "1410001\n",
771 | "1420001\n",
772 | "1430001\n",
773 | "1440001\n",
774 | "1450001\n",
775 | "1460001\n",
776 | "1470001\n",
777 | "1480001\n",
778 | "1490001\n",
779 | "1500001\n",
780 | "1510001\n",
781 | "1520001\n",
782 | "1530001\n",
783 | "1540001\n",
784 | "1550001\n",
785 | "1560001\n",
786 | "1570001\n",
787 | "1580001\n",
788 | "1590001\n",
789 | "1600001\n",
790 | "1610001\n",
791 | "1620001\n",
792 | "1630001\n",
793 | "1640001\n",
794 | "1650001\n",
795 | "1660001\n",
796 | "1670001\n",
797 | "1680001\n",
798 | "1690001\n",
799 | "1700001\n",
800 | "1710001\n",
801 | "1720001\n",
802 | "1730001\n",
803 | "1740001\n",
804 | "1750001\n",
805 | "1760001\n",
806 | "1770001\n",
807 | "1780001\n",
808 | "1790001\n",
809 | "1800001\n",
810 | "1810001\n",
811 | "1820001\n",
812 | "1830001\n",
813 | "1840001\n",
814 | "1850001\n",
815 | "1860001\n",
816 | "1870001\n",
817 | "1880001\n",
818 | "1890001\n",
819 | "1900001\n",
820 | "1910001\n",
821 | "1920001\n",
822 | "1930001\n",
823 | "1940001\n",
824 | "1950001\n",
825 | "1960001\n",
826 | "1970001\n",
827 | "1980001\n",
828 | "1990001\n",
829 | "2000001\n",
830 | "2010001\n",
831 | "2020001\n",
832 | "2030001\n",
833 | "2040001\n",
834 | "2050001\n",
835 | "2060001\n",
836 | "2070001\n",
837 | "2080001\n",
838 | "2090001\n",
839 | "2100001\n",
840 | "2110001\n",
841 | "2120001\n",
842 | "2130001\n",
843 | "2140001\n",
844 | "2150001\n",
845 | "2160001\n",
846 | "2170001\n",
847 | "2180001\n",
848 | "2190001\n",
849 | "2200001\n",
850 | "2210001\n",
851 | "2220001\n",
852 | "2230001\n",
853 | "2240001\n",
854 | "2250001\n",
855 | "2260001\n",
856 | "2270001\n",
857 | "2280001\n",
858 | "2290001\n",
859 | "2300001\n",
860 | "2310001\n",
861 | "2320001\n",
862 | "2330001\n",
863 | "2340001\n",
864 | "2350001\n",
865 | "2360001\n",
866 | "2370001\n",
867 | "2380001\n",
868 | "2390001\n",
869 | "2400001\n",
870 | "2410001\n",
871 | "2420001\n",
872 | "2430001\n",
873 | "2440001\n",
874 | "2450001\n",
875 | "2460001\n",
876 | "2470001\n",
877 | "2480001\n",
878 | "2490001\n",
879 | "2500001\n",
880 | "2510001\n",
881 | "2520001\n",
882 | "2530001\n",
883 | "2540001\n",
884 | "2550001\n",
885 | "2560001\n",
886 | "2570001\n",
887 | "2580001\n",
888 | "2590001\n",
889 | "2600001\n",
890 | "2610001\n",
891 | "2620001\n",
892 | "2630001\n",
893 | "2640001\n",
894 | "2650001\n",
895 | "2660001\n",
896 | "2670001\n",
897 | "2680001\n",
898 | "2690001\n",
899 | "2700001\n",
900 | "2710001\n",
901 | "2720001\n",
902 | "2730001\n",
903 | "2740001\n",
904 | "2750001\n",
905 | "2760001\n",
906 | "2770001\n",
907 | "2780001\n",
908 | "2790001\n",
909 | "2800001\n",
910 | "2810001\n",
911 | "2820001\n",
912 | "2830001\n",
913 | "2840001\n",
914 | "2850001\n",
915 | "2860001\n",
916 | "2870001\n",
917 | "2880001\n",
918 | "2890001\n",
919 | "2900001\n",
920 | "2910001\n",
921 | "2920001\n",
922 | "2930001\n",
923 | "2940001\n",
924 | "2950001\n",
925 | "2960001\n",
926 | "2970001\n",
927 | "2980001\n",
928 | "2990001\n",
929 | "3000001\n",
930 | "3010001\n",
931 | "3020001\n",
932 | "3030001\n",
933 | "3040001\n",
934 | "3050001\n",
935 | "3060001\n",
936 | "3070001\n",
937 | "3080001\n",
938 | "3090001\n",
939 | "3100001\n",
940 | "3110001\n",
941 | "3120001\n",
942 | "3130001\n",
943 | "3140001\n",
944 | "3150001\n",
945 | "3160001\n",
946 | "3170001\n",
947 | "3180001\n",
948 | "3190001\n",
949 | "3200001\n",
950 | "3210001\n",
951 | "3220001\n",
952 | "3230001\n",
953 | "3240001\n",
954 | "3250001\n",
955 | "3260001\n",
956 | "3270001\n",
957 | "3280001\n",
958 | "3290001\n",
959 | "3300001\n",
960 | "3310001\n",
961 | "3320001\n",
962 | "3330001\n",
963 | "3340001\n",
964 | "3350001\n",
965 | "3360001\n",
966 | "3370001\n",
967 | "3380001\n",
968 | "3390001\n",
969 | "3400001\n",
970 | "3410001\n",
971 | "3420001\n",
972 | "3430001\n",
973 | "3440001\n",
974 | "3450001\n",
975 | "3460001\n",
976 | "3470001\n",
977 | "3480001\n",
978 | "3490001\n",
979 | "3500001\n",
980 | "3510001\n",
981 | "3520001\n",
982 | "3530001\n",
983 | "3540001\n",
984 | "3550001\n",
985 | "3560001\n",
986 | "3570001\n",
987 | "3580001\n",
988 | "3590001\n",
989 | "3600001\n",
990 | "3610001\n",
991 | "3620001\n",
992 | "3630001\n",
993 | "3640001\n",
994 | "3650001\n",
995 | "3660001\n",
996 | "3670001\n",
997 | "3680001\n",
998 | "3690001\n",
999 | "3700001\n",
1000 | "3710001\n",
1001 | "3720001\n",
1002 | "3730001\n",
1003 | "3740001\n",
1004 | "3750001\n",
1005 | "3760001\n",
1006 | "3770001\n",
1007 | "3780001\n",
1008 | "3790001\n",
1009 | "3800001\n",
1010 | "3810001\n",
1011 | "3820001\n",
1012 | "3830001\n",
1013 | "3840001\n",
1014 | "3850001\n",
1015 | "3860001\n",
1016 | "3870001\n",
1017 | "3880001\n",
1018 | "3890001\n",
1019 | "3900001\n",
1020 | "3910001\n",
1021 | "3920001\n",
1022 | "3930001\n",
1023 | "3940001\n",
1024 | "3950001\n",
1025 | "3960001\n",
1026 | "3970001\n",
1027 | "3980001\n",
1028 | "3990001\n",
1029 | "4000001\n",
1030 | "4010001\n",
1031 | "4020001\n",
1032 | "4030001\n",
1033 | "4040001\n",
1034 | "4050001\n",
1035 | "4060001\n",
1036 | "4070001\n",
1037 | "4080001\n",
1038 | "4090001\n",
1039 | "4100001\n",
1040 | "4110001\n",
1041 | "4120001\n",
1042 | "4130001\n",
1043 | "4140001\n",
1044 | "4150001\n",
1045 | "4160001\n",
1046 | "4170001\n",
1047 | "4180001\n",
1048 | "4190001\n",
1049 | "4200001\n",
1050 | "4210001\n",
1051 | "4220001\n",
1052 | "4230001\n",
1053 | "4240001\n",
1054 | "4250001\n",
1055 | "4260001\n",
1056 | "4270001\n",
1057 | "4280001\n",
1058 | "4290001\n",
1059 | "4300001\n",
1060 | "4310001\n",
1061 | "4320001\n",
1062 | "4330001\n",
1063 | "4340001\n",
1064 | "4350001\n",
1065 | "4360001\n",
1066 | "4370001\n",
1067 | "4380001\n",
1068 | "4390001\n",
1069 | "4400001\n",
1070 | "4410001\n",
1071 | "4420001\n",
1072 | "4430001\n",
1073 | "4440001\n",
1074 | "4450001\n",
1075 | "4460001\n",
1076 | "4470001\n",
1077 | "4480001\n",
1078 | "4490001\n",
1079 | "4500001\n",
1080 | "4510001\n",
1081 | "4520001\n",
1082 | "4530001\n",
1083 | "4540001\n",
1084 | "4550001\n",
1085 | "4560001\n",
1086 | "4570001\n",
1087 | "4580001\n",
1088 | "4590001\n",
1089 | "4600001\n",
1090 | "4610001\n",
1091 | "4620001\n",
1092 | "4630001\n",
1093 | "4640001\n",
1094 | "4650001\n",
1095 | "4660001\n",
1096 | "4670001\n",
1097 | "4680001\n",
1098 | "4690001\n",
1099 | "4700001\n",
1100 | "4710001\n",
1101 | "4720001\n",
1102 | "4730001\n",
1103 | "4740001\n",
1104 | "4750001\n",
1105 | "4760001\n",
1106 | "4770001\n",
1107 | "4780001\n",
1108 | "4790001\n",
1109 | "4800001\n",
1110 | "4810001\n",
1111 | "4820001\n",
1112 | "4830001\n"
1113 | ]
1114 | }
1115 | ],
1116 | "source": [
1117 | "result_file = DATA_PATH + \"submission.csv\"\n",
1118 | "with open(result_file, \"w\") as f:\n",
1119 | " f.write(\"order_id,products\\n\")\n",
1120 | " idx = 0\n",
1121 | " for line_id in test_order_index:\n",
1122 | " test_order = test_order_ix[line_id]\n",
1123 | " order_id = test_order['order_id']\n",
1124 | " user_id = test_order['user_id']\n",
1125 | " product_id_list = list(item_user_reordered_idx[user_id].index)\n",
1126 | " \n",
1127 | " final_order_list = []\n",
1128 | " for product_id in product_id_list:\n",
1129 | " if predict_y[idx]:\n",
1130 | " final_order_list.append(str(product_id))\n",
1131 | " idx += 1\n",
1132 | " \n",
1133 | " if idx % 10000 == 1:\n",
1134 | " print idx\n",
1135 | " if final_order_list:\n",
1136 | " f.write(\"%d,%s\\n\" % (order_id, \" \".join(final_order_list)))\n",
1137 | " else:\n",
1138 | " f.write(\"%d, None\\n\" % order_id)"
1139 | ]
1140 | }
1141 | ],
1142 | "metadata": {
1143 | "kernelspec": {
1144 | "display_name": "Python 2",
1145 | "language": "python",
1146 | "name": "python2"
1147 | },
1148 | "language_info": {
1149 | "codemirror_mode": {
1150 | "name": "ipython",
1151 | "version": 2
1152 | },
1153 | "file_extension": ".py",
1154 | "mimetype": "text/x-python",
1155 | "name": "python",
1156 | "nbconvert_exporter": "python",
1157 | "pygments_lexer": "ipython2",
1158 | "version": "2.7.13"
1159 | }
1160 | },
1161 | "nbformat": 4,
1162 | "nbformat_minor": 2
1163 | }
1164 |
--------------------------------------------------------------------------------
/match/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/14 07:16
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/14 07:16
12 | """
13 |
14 | if __name__ == '__main__':
15 | pass
16 |
--------------------------------------------------------------------------------
/match/cheetsheet.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 2,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%matplotlib inline\n",
12 | "import matplotlib\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import pandas as pd\n",
15 | "import numpy as np\n",
16 | "import seaborn as sns\n",
17 | "pd.set_option(\"display.max_columns\",101)\n",
18 | "RANDOM_STATE = 42"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": null,
24 | "metadata": {
25 | "collapsed": true
26 | },
27 | "outputs": [],
28 | "source": [
29 | ""
30 | ]
31 | }
32 | ],
33 | "metadata": {
34 | "kernelspec": {
35 | "display_name": "Python 2",
36 | "language": "python",
37 | "name": "python2"
38 | },
39 | "language_info": {
40 | "codemirror_mode": {
41 | "name": "ipython",
42 | "version": 2.0
43 | },
44 | "file_extension": ".py",
45 | "mimetype": "text/x-python",
46 | "name": "python",
47 | "nbconvert_exporter": "python",
48 | "pygments_lexer": "ipython2",
49 | "version": "2.7.11"
50 | }
51 | },
52 | "nbformat": 4,
53 | "nbformat_minor": 0
54 | }
--------------------------------------------------------------------------------
/match/dstool/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/13 20:47
9 | @change_time:
10 | 1.2017/7/13 20:47
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/match/instacart/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/11 20:26
9 | @change_time:
10 | 1.2017/7/11 20:26
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/match/instacart/const.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: const.py
8 | @time: 2017/7/14 07:16
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/14 07:16
12 | """
13 |
14 | RANDOM_STATE = 42
15 |
16 | SAMPLE_TRAIN_PATH = "data/sample_train.csv"
17 | EXTEND_PRIOR_PATH = "data/extend_prior_order.csv"
18 | EXTEND_TRAIN_PATH = "data/extend_train_order.csv"
19 | EXTEND_NEGATIVE_TRAIN_PATH = "data/extend_negative_train.csv"
20 |
21 | NEGATIVE_TRAIN_DATA = "data/negative_train.csv"
22 | TOTAL_TRAIN_DATA = "data/total_train.csv"
23 |
24 |
25 | TRAIN_ORDERS_PATH = "../../data/instacart/order_products__train.csv"
26 | AISLES_PATH = "../../data/instacart/aisles.csv"
27 | DEPARTMENTS = "../../data/instacart/departments.csv"
28 | PRIOR_ORDERS_PATH = "../../data/instacart/order_products__prior.csv"
29 | ORDERS_PATH = "../../data/instacart/orders.csv"
30 | PRODUCTS_PATH = "../../data/instacart/products.csv"
31 |
32 |
33 | SAMPLE_TRAIN_FC_PATH = "data/sample_train_fc.csv"
34 |
35 |
36 | UID = 'user_id'
37 | PID = 'product_id'
38 | OID = 'order_id'
39 | AID = 'aisle_id'
40 | DID = 'department_id'
41 | ORDER_NUM = "order_number"
42 | ADD_TO_CART_NUM = "add_to_cart_order"
43 | REORDER = "reordered"
44 | ORDER_DOW = "order_dow"
45 | ORDER_HOUR_OF_DAY = "order_hour_of_day"
46 | DAYS_SINCE_PRIOR_ORDER = "days_since_prior_order"
47 | LABEL = "label"
48 |
--------------------------------------------------------------------------------
/match/instacart/data/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/17 21:53
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/17 21:53
12 | """
13 |
14 | if __name__ == '__main__':
15 | pass
16 |
--------------------------------------------------------------------------------
/match/instacart/eda.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: eda.py
8 | @time: 2017/7/13 22:46
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/13 22:46 EDA分析常用方法
12 | """
13 |
14 |
15 | def unique_analyse():
16 |
17 | return
18 |
19 |
20 | if __name__ == '__main__':
21 | pass
22 |
--------------------------------------------------------------------------------
/match/instacart/generate.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: generate.py
8 | @time: 2017/7/14 07:16
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/14 07:16 初始化
12 | 2.2017/7/18 15:48 构造feature_engineering
13 | """
14 | import sample
15 | import const
16 | from utils import dec_timer
17 | import pandas as pd
18 | import numpy as np
19 |
20 |
21 | def load_prior_data():
22 | df = pd.read_csv(const.EXTEND_PRIOR_PATH)
23 | return df
24 |
25 |
26 | @dec_timer
27 | def ratio(df):
28 |
29 | return df
30 |
31 |
32 | @dec_timer
33 | def diff(df):
34 | # 距离上一次购买这个商品过了多久时间
35 | df['order_number-user_product_order_number_max'] = df['order_number'] - df['user_product_order_number_max']
36 |
37 | # TODO 可以关心距离这个aisle和department过了多久时间
38 | return df
39 |
40 |
41 | @dec_timer
42 | def build_field_feature(prior, group_id):
43 | # TODO: order_dow, order_hour_of_day如何处理??
44 | field_key = {
45 | const.PID: "product",
46 | const.UID: "user",
47 | const.AID: "aisle",
48 | const.DID: "departments"
49 | }[group_id]
50 |
51 | df = prior.groupby(group_id).agg({
52 | const.UID: {"%s_bought_time" % field_key: np.size},
53 | const.PID: {"%s_item_unique_number" % field_key: pd.Series.nunique},
54 | const.OID: {"%s_order_number" % field_key: pd.Series.nunique},
55 | const.ORDER_NUM: {"%s_order_number_max" % field_key: np.max},
56 | const.ADD_TO_CART_NUM: {
57 | "%s_item_add_to_cart_mean" % field_key: np.mean,
58 | },
59 | const.REORDER: {
60 | "%s_reordered_times" % field_key: np.sum,
61 | "%s_reordered_ratio" % field_key: np.mean,
62 | },
63 | const.DAYS_SINCE_PRIOR_ORDER: {
64 | "%s_days_prior_order_mean" % field_key: np.mean,
65 | "%s_active_days" % field_key: np.sum, # 表征用户活跃的整个时间跨度
66 | }
67 |
68 | }).reset_index()
69 |
70 | df.columns = [group_id] + list(df.columns.droplevel(0))[1:]
71 | return df
72 |
73 |
74 | @dec_timer
75 | def build_interactive_feature(prior_data, key_1, key_2):
76 |
77 | group_id = [key_1, key_2]
78 | field_key = {
79 | (const.UID, const.PID): "user_product",
80 | (const.UID, const.AID): "user_aisle",
81 | (const.UID, const.DID): "user_department"
82 | }[(key_1, key_2)]
83 |
84 | df = prior_data.groupby(group_id).agg({
85 | const.OID: {"%s_bought_times" % field_key: pd.Series.nunique},
86 | const.ORDER_NUM: {
87 | "%s_order_number_max" % field_key: np.max,
88 | "%s_order_number_min" % field_key: np.min,
89 | },
90 | const.ADD_TO_CART_NUM: {
91 | "%s_item_add_to_cart_mean" % field_key: np.mean,
92 | "%s_item_add_to_cart_min" % field_key: np.min,
93 | "%s_item_add_to_cart_max" % field_key: np.max,
94 | },
95 | const.REORDER: {
96 | "%s_reordered_times" % field_key: np.sum,
97 | "%s_reordered_ratio" % field_key: np.mean,
98 | },
99 | const.DAYS_SINCE_PRIOR_ORDER: {
100 | "%s_days_prior_order_mean" % field_key: np.mean,
101 | "%s_days_prior_order_min" % field_key: np.min,
102 | "%s_days_prior_order_max" % field_key: np.max,
103 | }
104 |
105 | }).reset_index()
106 |
107 | df.columns = group_id + list(df.columns.droplevel(0))[2:]
108 | return df
109 |
110 |
111 | def feature_creation(df):
112 | """
113 | 构造特征数据:
114 | """
115 | prior_data = load_prior_data()
116 |
117 | # ------------ 1. field feature -----------------------
118 | # 1.1 user feature
119 | user_feature = build_field_feature(prior_data, const.UID)
120 | df = df.merge(user_feature, how='left', on=[const.UID])
121 |
122 | # 1.2 product feature
123 | product_feature = build_field_feature(prior_data, const.PID)
124 | df = df.merge(product_feature, how='left', on=[const.PID])
125 |
126 | # 1.3 order 包括df中order_dow, order_hour_of_day等字段, 已包含
127 | aisle_feature = build_field_feature(prior_data, const.AID)
128 | df = df.merge(aisle_feature, how='left', on=[const.AID])
129 |
130 | department_feature = build_field_feature(prior_data, const.DID)
131 | df = df.merge(department_feature, how='left', on=[const.DID])
132 |
133 | # ------------ 2. interactive feature -----------------
134 | # 2.1 user-product feature
135 | user_product_feature = build_interactive_feature(prior_data, const.UID, const.PID)
136 | user_product_feature.set_index([const.UID, const.PID])
137 | df.set_index([const.UID, const.PID])
138 |
139 | df = df.merge(user_product_feature, how='left', on=[const.UID, const.PID])
140 |
141 | # 2.2 user-order feature
142 | user_aisle_feature = build_interactive_feature(prior_data, const.UID, const.AID)
143 | user_aisle_feature.set_index([const.UID, const.AID])
144 | df.set_index([const.UID, const.AID])
145 |
146 | df = df.merge(user_aisle_feature, how='left', on=[const.UID, const.AID])
147 |
148 | # 2.3 product-order feature
149 | user_department_feature = build_interactive_feature(prior_data, const.UID, const.DID)
150 | user_department_feature.set_index([const.UID, const.DID])
151 | df.set_index([const.UID, const.DID])
152 |
153 | df = df.merge(user_department_feature, how='left', on=[const.UID, const.DID])
154 |
155 | # ------------ 3. experience feature ------------------
156 |
157 | # ratio
158 | df = ratio(df)
159 |
160 | # diff
161 | df = diff(df)
162 |
163 | # rank
164 |
165 | # rule
166 |
167 | del prior_data # 删除数据
168 | return df
169 |
170 |
171 | if __name__ == '__main__':
172 | raw_train_df = sample.read_sample_train(const.SAMPLE_TRAIN_PATH)
173 | raw_train_df.set_index([const.OID, const.UID, const.PID])
174 | train_feature_df = feature_creation(raw_train_df)
175 | train_feature_df.to_csv(const.SAMPLE_TRAIN_FC_PATH, index=False)
176 |
--------------------------------------------------------------------------------
/match/instacart/run.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: run.py
8 | @time: 2017/7/13 20:52
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/13 20:52
12 | """
13 |
14 | if __name__ == '__main__':
15 | pass
16 |
--------------------------------------------------------------------------------
/match/instacart/sample.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: sample.py
8 | @time: 2017/7/13 20:51
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/13 20:51 sample test data
12 | """
13 | import pandas as pd
14 | import const
15 | import random
16 | from utils import dec_timer
17 | import numpy as np
18 |
19 | random.seed(const.RANDOM_STATE)
20 |
21 |
22 | def sample_train(from_path, sample_path, key, percentage=0.1):
23 | """
24 | sample a dataset according to keys
25 | """
26 | from_df = pd.read_csv(from_path)
27 | order_id_list = list(from_df[key].value_counts().index)
28 | sample_order_id_set = set(random.sample(order_id_list, int(len(order_id_list) * percentage)))
29 |
30 | order_id_data_list = from_df[key].values
31 | sample_index_flag = range(from_df.shape[0])
32 | for idx, value in enumerate(order_id_data_list):
33 | if value in sample_order_id_set:
34 | sample_index_flag[idx] = True
35 | else:
36 | sample_index_flag[idx] = False
37 |
38 | sample_df = from_df[sample_index_flag]
39 | sample_df.to_csv(sample_path, index=False)
40 | return
41 |
42 |
43 | @dec_timer
44 | def extend_data(path_a, save_path, path_b_list, key_list):
45 | a_df = pd.read_csv(path_a)
46 | for idx, path_b in enumerate(path_b_list):
47 | key = key_list[idx]
48 |
49 | b_df = pd.read_csv(path_b)
50 | a_df = pd.merge(a_df, b_df, on=key, how='left')
51 | a_df.to_csv(save_path, index=False)
52 | return
53 |
54 |
55 | def read_sample_train(from_path, nrows=None):
56 | return pd.read_csv(from_path, nrows=nrows)
57 |
58 |
59 | def make_user_product_list():
60 | prior_order = pd.read_csv(const.EXTEND_PRIOR_PATH)
61 | prior_user_product_list_df = prior_order[prior_order.eval_set == 'prior'].groupby(const.UID)[const.PID].apply(
62 | set).reset_index().set_index(const.UID)
63 |
64 | del prior_order
65 |
66 | train_df = pd.read_csv(const.EXTEND_TRAIN_PATH)
67 | train_user_product_list_df = train_df[train_df.eval_set == 'train'].groupby([const.OID, const.UID])[
68 | const.PID].apply(
69 | set).reset_index().set_index([const.OID, const.UID])
70 |
71 | del train_df
72 |
73 | order_user_product_list = {}
74 | cnt = 0
75 | for order_user_id, train_user_product_list in train_user_product_list_df.iterrows():
76 | order_id, user_id = order_user_id
77 | if cnt % 10000 == 1:
78 | print cnt
79 | prior_order_product_list = prior_user_product_list_df.ix[user_id][const.PID]
80 |
81 | not_in_clude_in_train_product_list = prior_order_product_list - train_user_product_list[const.PID]
82 | order_user_product_list[order_user_id] = not_in_clude_in_train_product_list
83 |
84 | cnt += 1
85 |
86 | total_len = sum([len(value) for value in order_user_product_list.itervalues()])
87 | output_data = np.zeros((total_len, 2))
88 | cnt = 0
89 | for order_user_id, product_list in order_user_product_list.iteritems():
90 | order_id, user_id = order_user_id
91 | for product_id in product_list:
92 | output_data[cnt, 0] = order_id
93 | output_data[cnt, 1] = product_id
94 | cnt += 1
95 |
96 | new_dataframe = pd.DataFrame(output_data, columns=[const.OID, const.PID], dtype="int")
97 | new_dataframe.to_csv(const.NEGATIVE_TRAIN_DATA, index=False)
98 | return
99 |
100 |
101 | def merge_train_negative_data(train_path, negative_path, total_train):
102 | train_df = pd.read_csv(train_path)
103 | train_df[const.LABEL] = 1
104 | negative_df = pd.read_csv(negative_path)
105 | negative_df[const.LABEL] = 0
106 | final_train_data = pd.concat([train_df, negative_df])
107 | final_train_data.to_csv(total_train, index=False)
108 | pass
109 |
110 |
111 | if __name__ == '__main__':
112 | # extend_data(const.PRIOR_ORDERS_PATH, const.EXTEND_PRIOR_PATH, [const.ORDERS_PATH, const.PRODUCTS_PATH], [const.OID, const.PID])
113 | # extend_data(const.TRAIN_ORDERS_PATH, const.EXTEND_TRAIN_PATH, [const.ORDERS_PATH, const.PRODUCTS_PATH], [const.OID, const.PID])
114 |
115 | #
116 | # # 构造用户过去购买过的所有物品,构造反例数据,生成sample_train_failure_path
117 | # make_user_product_list()
118 | # extend_data(const.NEGATIVE_TRAIN_DATA, const.EXTEND_NEGATIVE_TRAIN_PATH, [const.ORDERS_PATH, const.PRODUCTS_PATH], [const.OID, const.PID])
119 |
120 | # merge train and negative data
121 | # merge_train_negative_data(const.EXTEND_TRAIN_PATH, const.EXTEND_NEGATIVE_TRAIN_PATH, const.TOTAL_TRAIN_DATA)
122 |
123 | raw_train_orders_path = const.TOTAL_TRAIN_DATA
124 | sample_path = const.SAMPLE_TRAIN_PATH
125 | sample_train(raw_train_orders_path, sample_path, const.OID, percentage=0.1)
126 |
--------------------------------------------------------------------------------
/match/instacart/tunning.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: tunning.py
8 | @time: 2017/7/14 07:16
9 | @contact: bywangqiang@foxmail.com
10 | @change_time:
11 | 1.2017/7/14 07:16
12 | 2.2017/7/18 22:11 使用LightGBM的坐标下降法来进行调参
13 | """
14 | import lightgbm as lgb
15 | import pandas as pd
16 | import const
17 | import numpy as np
18 |
19 | import random
20 | random.seed(const.RANDOM_STATE)
21 | from sklearn import metrics
22 |
23 |
24 | def split_train_test(data, key, percentage=0.2):
25 | """
26 | 根据key值来划分数据为train和test数据
27 | """
28 | key_id_list = list(data[key].value_counts().index)
29 | sample_key_set = set(random.sample(key_id_list, int(len(key_id_list) * percentage)))
30 |
31 | key_id_data_list = data[key].values
32 | sample_index_flag = range(data.shape[0])
33 | for idx, value in enumerate(key_id_data_list):
34 | if value in sample_key_set:
35 | sample_index_flag[idx] = True
36 | else:
37 | sample_index_flag[idx] = False
38 |
39 | test_df = data[sample_index_flag]
40 | train_df = data[[not x for x in sample_index_flag]]
41 | return train_df, test_df
42 |
43 |
44 | def training_model(train_data, key):
45 | params = {
46 | 'boosting_type': 'gbdt',
47 | 'objective': 'binary',
48 | 'num_leaves': 72,
49 | 'max_depth': 10,
50 | 'feature_fraction': 0.85,
51 | 'bagging_fraction': 0.95,
52 | 'bagging_freq': 5,
53 | 'learning_rate': 0.1,
54 | 'min_child_samples': 50,
55 | 'reg_lambda': 0.7,
56 | 'n_estimators': 50,
57 | 'silent': True,
58 | 'metric': ['auc', ]
59 | }
60 |
61 | train_df = train_data.drop([key], axis=1)
62 | train_labels = train_data[key]
63 | print('light GBM train :-)')
64 | clf_list = []
65 | clf = lgb.LGBMClassifier(**params)
66 | categorical_feature = ['order_dow', 'order_hour_of_day', ]
67 | clf.fit(
68 | train_df,
69 | train_labels,
70 | categorical_feature=categorical_feature,
71 | )
72 | clf_list.append(clf)
73 | return clf_list
74 |
75 |
76 | def mean_f1_score(order_id_index, predict_y):
77 | f_score_total = 0
78 | n = 0
79 | idx = 0
80 | for _, value in order_id_index.iterrows():
81 | label_list = value['label']
82 | predict_label_list = predict_y[idx:idx + len(label_list)]
83 | if not np.any(label_list) and not np.any(predict_label_list):
84 | f_score = 1
85 | else:
86 | f_score = metrics.f1_score(predict_label_list, label_list)
87 | f_score_total += f_score
88 | idx += len(label_list)
89 | n += 1
90 |
91 | return f_score_total / n
92 |
93 |
94 | def cal_score(clf, train_validate_data, key):
95 | validate_df = train_validate_data.drop([key], axis=1)
96 | validate_labels = train_validate_data['label']
97 | order_id_index = n_test.groupby(const.OID)[key].apply(list).reset_index()
98 | predict_y = clf.predict_proba(validate_df)[:, 1]
99 |
100 | best_score = 0
101 | best_margin = 0
102 |
103 | for margin in range(1, 100):
104 | threshold = margin / float(100)
105 | predict_result = predict_y >= threshold
106 | # score = metrics.f1_score(validate_labels.values, predict_result)
107 | score = mean_f1_score(order_id_index, predict_result)
108 | if score > best_score:
109 | best_score = score
110 | best_margin = threshold
111 |
112 | return best_score, best_margin
113 |
114 |
115 | def get_sample_negative_data(data, sample_percentage):
116 | order_id_index = data.groupby(const.OID)[const.PID].apply(list).reset_index()
117 | result = data.set_index([const.OID, const.PID])
118 | sample_index_list = []
119 | for order_id, value in order_id_index.iterrows():
120 | order_id = value['order_id']
121 | product_id_list = value['product_id']
122 | num = int(len(product_id_list) * sample_percentage)
123 | if num <= 0:
124 | continue
125 |
126 | product_sample = random.sample(product_id_list, num)
127 | sample_index_list.extend([(order_id, product_id) for product_id in product_sample])
128 |
129 | new_data = result.ix[sample_index_list]
130 | new_data.reset_index(inplace=True)
131 | return new_data
132 |
133 |
134 | if __name__ == '__main__':
135 | total_train = pd.read_csv(const.SAMPLE_TRAIN_FC_PATH)
136 | num_data = total_train.shape[0]
137 |
138 | n_train, n_test = split_train_test(total_train, const.OID, 0.2)
139 | negative_data = n_train[n_train.label == 0]
140 | positive_data = n_train[n_train.label == 1]
141 |
142 | num_positive = positive_data.shape[0]
143 | num_negative = negative_data.shape[0]
144 |
145 | target_num = 2 * num_positive
146 | sample_negative_data = get_sample_negative_data(negative_data, target_num / float(num_negative))
147 | train_data = pd.concat([sample_negative_data, positive_data])
148 |
149 | drop_list = ['order_id', 'product_id', 'reordered', 'add_to_cart_order', 'user_id', 'eval_set', 'order_number', 'P']
150 | train_data.drop(drop_list, axis=1, inplace=True)
151 | train_validate_data = n_test.drop(drop_list, axis=1)
152 |
153 | clf_list = training_model(train_data, 'label')
154 |
155 | for clf in clf_list:
156 | result = cal_score(clf, train_validate_data, 'label')
157 | print "result", result
158 |
--------------------------------------------------------------------------------
/match/instacart/utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: utils.py
8 | @time: 2017/7/13 20:48
9 | @change_time:
10 | 1.2017/7/13 20:48
11 | """
12 | import datetime
13 | import time
14 |
15 |
16 | def dec_timer(func):
17 | def _wrapper(*args, **kwargs):
18 | start_time = datetime.datetime.now()
19 | start = time.time()
20 | ret = func(*args, **kwargs)
21 |
22 | end_time = datetime.datetime.now()
23 | print '_'*70
24 | print'{} takes {:2f}s, from {} to {}'.format(func.__name__, time.time() - start, start_time, end_time)
25 | print '_'*70
26 | # print 'end at', datetime.datetime.now()
27 | return ret
28 |
29 | return _wrapper
--------------------------------------------------------------------------------
/match/instcart_predict.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%matplotlib inline\n",
12 | "import matplotlib\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import pandas as pd\n",
15 | "import numpy as np\n",
16 | "import seaborn as sns\n",
17 | "import scipy.sparse\n",
18 | "pd.set_option(\"display.max_columns\",101)\n",
19 | "RANDOM_STATE = 42\n",
20 | "DATA_PATH = \"../data/instacart/\""
21 | ]
22 | },
23 | {
24 | "cell_type": "code",
25 | "execution_count": 2,
26 | "metadata": {
27 | "collapsed": true
28 | },
29 | "outputs": [],
30 | "source": [
31 | "positive_X = pd.read_csv(DATA_PATH + \"positive_X.csv\")\n",
32 | "negative_X = pd.read_csv(DATA_PATH + \"negative_X.csv\")\n",
33 | "positive_X.drop(['Unnamed: 0', ], axis=1, inplace=True)\n",
34 | "negative_X.drop(['Unnamed: 0', ], axis=1, inplace=True)"
35 | ]
36 | },
37 | {
38 | "cell_type": "code",
39 | "execution_count": 3,
40 | "metadata": {
41 | "collapsed": true
42 | },
43 | "outputs": [],
44 | "source": [
45 | "positive_X.fillna(0, inplace=True)\n",
46 | "negative_X.fillna(0, inplace=True)"
47 | ]
48 | },
49 | {
50 | "cell_type": "code",
51 | "execution_count": 4,
52 | "metadata": {
53 | "collapsed": true
54 | },
55 | "outputs": [],
56 | "source": [
57 | "from sklearn.preprocessing import StandardScaler\n",
58 | "pX = StandardScaler().fit_transform(positive_X)"
59 | ]
60 | },
61 | {
62 | "cell_type": "code",
63 | "execution_count": 5,
64 | "metadata": {
65 | "collapsed": true
66 | },
67 | "outputs": [],
68 | "source": [
69 | "nX = StandardScaler().fit_transform(negative_X)"
70 | ]
71 | },
72 | {
73 | "cell_type": "code",
74 | "execution_count": 6,
75 | "metadata": {
76 | "collapsed": true
77 | },
78 | "outputs": [],
79 | "source": [
80 | "py = np.ones(positive_X.shape[0])\n",
81 | "ny = np.zeros(negative_X.shape[0])"
82 | ]
83 | },
84 | {
85 | "cell_type": "code",
86 | "execution_count": 7,
87 | "metadata": {
88 | "collapsed": true
89 | },
90 | "outputs": [],
91 | "source": [
92 | "m, n = nX.shape"
93 | ]
94 | },
95 | {
96 | "cell_type": "code",
97 | "execution_count": 8,
98 | "metadata": {
99 | "collapsed": false
100 | },
101 | "outputs": [
102 | {
103 | "name": "stdout",
104 | "output_type": "stream",
105 | "text": [
106 | "1384617 73\n"
107 | ]
108 | }
109 | ],
110 | "source": [
111 | "print m, n"
112 | ]
113 | },
114 | {
115 | "cell_type": "code",
116 | "execution_count": 115,
117 | "metadata": {
118 | "collapsed": false
119 | },
120 | "outputs": [],
121 | "source": [
122 | "import random\n",
123 | "percentage = 0.1\n",
124 | "sample_m = int(m * percentage)\n",
125 | "sample_list = random.sample(range(m), sample_m)\n",
126 | "X = np.concatenate([pX[sample_list, :], nX[sample_list, :]])\n",
127 | "y = np.concatenate([py[sample_list], ny[sample_list]])"
128 | ]
129 | },
130 | {
131 | "cell_type": "code",
132 | "execution_count": 118,
133 | "metadata": {
134 | "collapsed": true
135 | },
136 | "outputs": [],
137 | "source": [
138 | "test_sample_list = random.sample(range(m), sample_m)\n",
139 | "test_X = np.concatenate([pX[test_sample_list, :], nX[test_sample_list, :]])\n",
140 | "test_y = np.concatenate([py[test_sample_list], ny[test_sample_list]])"
141 | ]
142 | },
143 | {
144 | "cell_type": "markdown",
145 | "metadata": {},
146 | "source": [
147 | "#### The model is Low bias"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 121,
153 | "metadata": {
154 | "collapsed": false,
155 | "scrolled": true
156 | },
157 | "outputs": [
158 | {
159 | "data": {
160 | "text/plain": [
161 | "DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=10,\n",
162 | " max_features=None, max_leaf_nodes=None,\n",
163 | " min_impurity_split=1e-07, min_samples_leaf=10,\n",
164 | " min_samples_split=10, min_weight_fraction_leaf=0.0,\n",
165 | " presort=False, random_state=None, splitter='best')"
166 | ]
167 | },
168 | "execution_count": 121,
169 | "metadata": {},
170 | "output_type": "execute_result"
171 | }
172 | ],
173 | "source": [
174 | "from sklearn.tree import DecisionTreeClassifier\n",
175 | "clf = DecisionTreeClassifier(max_depth=10, min_samples_leaf=10, min_samples_split=10)\n",
176 | "clf.fit(X, y)"
177 | ]
178 | },
179 | {
180 | "cell_type": "code",
181 | "execution_count": 122,
182 | "metadata": {
183 | "collapsed": false
184 | },
185 | "outputs": [
186 | {
187 | "data": {
188 | "text/plain": [
189 | "0.99760221289749462"
190 | ]
191 | },
192 | "execution_count": 122,
193 | "metadata": {},
194 | "output_type": "execute_result"
195 | }
196 | ],
197 | "source": [
198 | "clf.score(X, y)"
199 | ]
200 | },
201 | {
202 | "cell_type": "code",
203 | "execution_count": 123,
204 | "metadata": {
205 | "collapsed": false
206 | },
207 | "outputs": [
208 | {
209 | "data": {
210 | "text/plain": [
211 | "0.99758415727172267"
212 | ]
213 | },
214 | "execution_count": 123,
215 | "metadata": {},
216 | "output_type": "execute_result"
217 | }
218 | ],
219 | "source": [
220 | "clf.score(test_X, test_y)"
221 | ]
222 | },
223 | {
224 | "cell_type": "markdown",
225 | "metadata": {},
226 | "source": [
227 | "#### Try add test item into data"
228 | ]
229 | },
230 | {
231 | "cell_type": "code",
232 | "execution_count": 27,
233 | "metadata": {
234 | "collapsed": true
235 | },
236 | "outputs": [],
237 | "source": [
238 | "orders = pd.read_csv(DATA_PATH + \"orders.csv\")"
239 | ]
240 | },
241 | {
242 | "cell_type": "code",
243 | "execution_count": 28,
244 | "metadata": {
245 | "collapsed": true
246 | },
247 | "outputs": [],
248 | "source": [
249 | "test_order = orders.loc[orders.eval_set == 'test']"
250 | ]
251 | },
252 | {
253 | "cell_type": "code",
254 | "execution_count": 22,
255 | "metadata": {
256 | "collapsed": false
257 | },
258 | "outputs": [
259 | {
260 | "data": {
261 | "text/plain": [
262 | "(75000, 7)"
263 | ]
264 | },
265 | "execution_count": 22,
266 | "metadata": {},
267 | "output_type": "execute_result"
268 | }
269 | ],
270 | "source": [
271 | "test_order.shape"
272 | ]
273 | },
274 | {
275 | "cell_type": "code",
276 | "execution_count": 23,
277 | "metadata": {
278 | "collapsed": false
279 | },
280 | "outputs": [
281 | {
282 | "data": {
283 | "text/html": [
284 | "\n",
285 | "
\n",
286 | " \n",
287 | " \n",
288 | " | \n",
289 | " order_id | \n",
290 | " user_id | \n",
291 | " eval_set | \n",
292 | " order_number | \n",
293 | " order_dow | \n",
294 | " order_hour_of_day | \n",
295 | " days_since_prior_order | \n",
296 | "
\n",
297 | " \n",
298 | " \n",
299 | " \n",
300 | " 38 | \n",
301 | " 2774568 | \n",
302 | " 3 | \n",
303 | " test | \n",
304 | " 13 | \n",
305 | " 5 | \n",
306 | " 15 | \n",
307 | " 11.0 | \n",
308 | "
\n",
309 | " \n",
310 | " 44 | \n",
311 | " 329954 | \n",
312 | " 4 | \n",
313 | " test | \n",
314 | " 6 | \n",
315 | " 3 | \n",
316 | " 12 | \n",
317 | " 30.0 | \n",
318 | "
\n",
319 | " \n",
320 | " 53 | \n",
321 | " 1528013 | \n",
322 | " 6 | \n",
323 | " test | \n",
324 | " 4 | \n",
325 | " 3 | \n",
326 | " 16 | \n",
327 | " 22.0 | \n",
328 | "
\n",
329 | " \n",
330 | " 96 | \n",
331 | " 1376945 | \n",
332 | " 11 | \n",
333 | " test | \n",
334 | " 8 | \n",
335 | " 6 | \n",
336 | " 11 | \n",
337 | " 8.0 | \n",
338 | "
\n",
339 | " \n",
340 | " 102 | \n",
341 | " 1356845 | \n",
342 | " 12 | \n",
343 | " test | \n",
344 | " 6 | \n",
345 | " 1 | \n",
346 | " 20 | \n",
347 | " 30.0 | \n",
348 | "
\n",
349 | " \n",
350 | "
\n",
351 | "
"
352 | ],
353 | "text/plain": [
354 | " order_id user_id eval_set order_number order_dow order_hour_of_day \\\n",
355 | "38 2774568 3 test 13 5 15 \n",
356 | "44 329954 4 test 6 3 12 \n",
357 | "53 1528013 6 test 4 3 16 \n",
358 | "96 1376945 11 test 8 6 11 \n",
359 | "102 1356845 12 test 6 1 20 \n",
360 | "\n",
361 | " days_since_prior_order \n",
362 | "38 11.0 \n",
363 | "44 30.0 \n",
364 | "53 22.0 \n",
365 | "96 8.0 \n",
366 | "102 30.0 "
367 | ]
368 | },
369 | "execution_count": 23,
370 | "metadata": {},
371 | "output_type": "execute_result"
372 | }
373 | ],
374 | "source": [
375 | "test_order.head()"
376 | ]
377 | },
378 | {
379 | "cell_type": "markdown",
380 | "metadata": {},
381 | "source": [
382 | "#### product_list of user_id "
383 | ]
384 | },
385 | {
386 | "cell_type": "code",
387 | "execution_count": 29,
388 | "metadata": {
389 | "collapsed": true
390 | },
391 | "outputs": [],
392 | "source": [
393 | "test_user_id_list = np.unique(test_order['user_id'])\n",
394 | "test_order_ix = test_order.ix\n",
395 | "test_order_index = list(test_order.index)"
396 | ]
397 | },
398 | {
399 | "cell_type": "code",
400 | "execution_count": 57,
401 | "metadata": {
402 | "collapsed": false,
403 | "scrolled": true
404 | },
405 | "outputs": [
406 | {
407 | "data": {
408 | "text/plain": [
409 | "order_id 2774568\n",
410 | "user_id 3\n",
411 | "eval_set test\n",
412 | "order_number 13\n",
413 | "order_dow 5\n",
414 | "order_hour_of_day 15\n",
415 | "days_since_prior_order 11\n",
416 | "Name: 38, dtype: object"
417 | ]
418 | },
419 | "execution_count": 57,
420 | "metadata": {},
421 | "output_type": "execute_result"
422 | }
423 | ],
424 | "source": [
425 | "test_order_ix[38]"
426 | ]
427 | },
428 | {
429 | "cell_type": "code",
430 | "execution_count": 60,
431 | "metadata": {
432 | "collapsed": true
433 | },
434 | "outputs": [],
435 | "source": [
436 | "predict_X = pd.read_csv(DATA_PATH + \"test_X.csv\")"
437 | ]
438 | },
439 | {
440 | "cell_type": "code",
441 | "execution_count": 61,
442 | "metadata": {
443 | "collapsed": true
444 | },
445 | "outputs": [],
446 | "source": [
447 | "predict_X.drop(['Unnamed: 0', ], axis=1, inplace=True)\n",
448 | "predict_X.fillna(0, inplace=True)\n",
449 | "from sklearn.preprocessing import StandardScaler\n",
450 | "std_predict_X = StandardScaler().fit_transform(predict_X)"
451 | ]
452 | },
453 | {
454 | "cell_type": "code",
455 | "execution_count": 124,
456 | "metadata": {
457 | "collapsed": false
458 | },
459 | "outputs": [],
460 | "source": [
461 | "predict_y = clf.predict(std_predict_X)"
462 | ]
463 | },
464 | {
465 | "cell_type": "code",
466 | "execution_count": 63,
467 | "metadata": {
468 | "collapsed": false
469 | },
470 | "outputs": [
471 | {
472 | "data": {
473 | "text/plain": [
474 | "(4833292,)"
475 | ]
476 | },
477 | "execution_count": 63,
478 | "metadata": {},
479 | "output_type": "execute_result"
480 | }
481 | ],
482 | "source": [
483 | "predict_y.shape"
484 | ]
485 | },
486 | {
487 | "cell_type": "code",
488 | "execution_count": 125,
489 | "metadata": {
490 | "collapsed": false,
491 | "scrolled": true
492 | },
493 | "outputs": [
494 | {
495 | "data": {
496 | "text/plain": [
497 | "Counter({0.0: 4741879, 1.0: 91413})"
498 | ]
499 | },
500 | "execution_count": 125,
501 | "metadata": {},
502 | "output_type": "execute_result"
503 | }
504 | ],
505 | "source": [
506 | "Counter(predict_y)"
507 | ]
508 | },
509 | {
510 | "cell_type": "code",
511 | "execution_count": null,
512 | "metadata": {
513 | "collapsed": true
514 | },
515 | "outputs": [],
516 | "source": []
517 | },
518 | {
519 | "cell_type": "code",
520 | "execution_count": 72,
521 | "metadata": {
522 | "collapsed": false
523 | },
524 | "outputs": [
525 | {
526 | "data": {
527 | "text/plain": [
528 | "Counter({0.0: 4751, 1.0: 249})"
529 | ]
530 | },
531 | "execution_count": 72,
532 | "metadata": {},
533 | "output_type": "execute_result"
534 | }
535 | ],
536 | "source": [
537 | "Counter(predict_y[0:5000])"
538 | ]
539 | },
540 | {
541 | "cell_type": "code",
542 | "execution_count": 34,
543 | "metadata": {
544 | "collapsed": true
545 | },
546 | "outputs": [],
547 | "source": [
548 | "prior = pd.read_csv(DATA_PATH + \"order_products__prior.csv\")\n",
549 | "prior_order = pd.merge(prior, orders, on='order_id')\n",
550 | "item_user_reordered = pd.DataFrame(prior_order.groupby(['user_id', 'product_id']).agg({'reordered': np.sum}))"
551 | ]
552 | },
553 | {
554 | "cell_type": "code",
555 | "execution_count": 35,
556 | "metadata": {
557 | "collapsed": true
558 | },
559 | "outputs": [],
560 | "source": [
561 | "item_user_reordered_idx = item_user_reordered.ix"
562 | ]
563 | },
564 | {
565 | "cell_type": "code",
566 | "execution_count": 126,
567 | "metadata": {
568 | "collapsed": false,
569 | "scrolled": true
570 | },
571 | "outputs": [
572 | {
573 | "name": "stdout",
574 | "output_type": "stream",
575 | "text": [
576 | "1\n",
577 | "10001\n",
578 | "20001\n",
579 | "30001\n",
580 | "40001\n",
581 | "50001\n",
582 | "60001\n",
583 | "70001\n",
584 | "80001\n",
585 | "90001\n",
586 | "100001\n",
587 | "110001\n",
588 | "120001\n",
589 | "130001\n",
590 | "140001\n",
591 | "150001\n",
592 | "160001\n",
593 | "170001\n",
594 | "180001\n",
595 | "190001\n",
596 | "200001\n",
597 | "210001\n",
598 | "220001\n",
599 | "230001\n",
600 | "240001\n",
601 | "250001\n",
602 | "260001\n",
603 | "270001\n",
604 | "280001\n",
605 | "290001\n",
606 | "300001\n",
607 | "310001\n",
608 | "320001\n",
609 | "330001\n",
610 | "340001\n",
611 | "350001\n",
612 | "360001\n",
613 | "370001\n",
614 | "380001\n",
615 | "390001\n",
616 | "400001\n",
617 | "410001\n",
618 | "420001\n",
619 | "430001\n",
620 | "440001\n",
621 | "450001\n",
622 | "460001\n",
623 | "470001\n",
624 | "480001\n",
625 | "490001\n",
626 | "500001\n",
627 | "510001\n",
628 | "520001\n",
629 | "530001\n",
630 | "540001\n",
631 | "550001\n",
632 | "560001\n",
633 | "570001\n",
634 | "580001\n",
635 | "590001\n",
636 | "600001\n",
637 | "610001\n",
638 | "620001\n",
639 | "630001\n",
640 | "640001\n",
641 | "650001\n",
642 | "660001\n",
643 | "670001\n",
644 | "680001\n",
645 | "690001\n",
646 | "700001\n",
647 | "710001\n",
648 | "720001\n",
649 | "730001\n",
650 | "740001\n",
651 | "750001\n",
652 | "760001\n",
653 | "770001\n",
654 | "780001\n",
655 | "790001\n",
656 | "800001\n",
657 | "810001\n",
658 | "820001\n",
659 | "830001\n",
660 | "840001\n",
661 | "850001\n",
662 | "860001\n",
663 | "870001\n",
664 | "880001\n",
665 | "890001\n",
666 | "900001\n",
667 | "910001\n",
668 | "920001\n",
669 | "930001\n",
670 | "940001\n",
671 | "950001\n",
672 | "960001\n",
673 | "970001\n",
674 | "980001\n",
675 | "990001\n",
676 | "1000001\n",
677 | "1010001\n",
678 | "1020001\n",
679 | "1030001\n",
680 | "1040001\n",
681 | "1050001\n",
682 | "1060001\n",
683 | "1070001\n",
684 | "1080001\n",
685 | "1090001\n",
686 | "1100001\n",
687 | "1110001\n",
688 | "1120001\n",
689 | "1130001\n",
690 | "1140001\n",
691 | "1150001\n",
692 | "1160001\n",
693 | "1170001\n",
694 | "1180001\n",
695 | "1190001\n",
696 | "1200001\n",
697 | "1210001\n",
698 | "1220001\n",
699 | "1230001\n",
700 | "1240001\n",
701 | "1250001\n",
702 | "1260001\n",
703 | "1270001\n",
704 | "1280001\n",
705 | "1290001\n",
706 | "1300001\n",
707 | "1310001\n",
708 | "1320001\n",
709 | "1330001\n",
710 | "1340001\n",
711 | "1350001\n",
712 | "1360001\n",
713 | "1370001\n",
714 | "1380001\n",
715 | "1390001\n",
716 | "1400001\n",
717 | "1410001\n",
718 | "1420001\n",
719 | "1430001\n",
720 | "1440001\n",
721 | "1450001\n",
722 | "1460001\n",
723 | "1470001\n",
724 | "1480001\n",
725 | "1490001\n",
726 | "1500001\n",
727 | "1510001\n",
728 | "1520001\n",
729 | "1530001\n",
730 | "1540001\n",
731 | "1550001\n",
732 | "1560001\n",
733 | "1570001\n",
734 | "1580001\n",
735 | "1590001\n",
736 | "1600001\n",
737 | "1610001\n",
738 | "1620001\n",
739 | "1630001\n",
740 | "1640001\n",
741 | "1650001\n",
742 | "1660001\n",
743 | "1670001\n",
744 | "1680001\n",
745 | "1690001\n",
746 | "1700001\n",
747 | "1710001\n",
748 | "1720001\n",
749 | "1730001\n",
750 | "1740001\n",
751 | "1750001\n",
752 | "1760001\n",
753 | "1770001\n",
754 | "1780001\n",
755 | "1790001\n",
756 | "1800001\n",
757 | "1810001\n",
758 | "1820001\n",
759 | "1830001\n",
760 | "1840001\n",
761 | "1850001\n",
762 | "1860001\n",
763 | "1870001\n",
764 | "1880001\n",
765 | "1890001\n",
766 | "1900001\n",
767 | "1910001\n",
768 | "1920001\n",
769 | "1930001\n",
770 | "1940001\n",
771 | "1950001\n",
772 | "1960001\n",
773 | "1970001\n",
774 | "1980001\n",
775 | "1990001\n",
776 | "2000001\n",
777 | "2010001\n",
778 | "2020001\n",
779 | "2030001\n",
780 | "2040001\n",
781 | "2050001\n",
782 | "2060001\n",
783 | "2070001\n",
784 | "2080001\n",
785 | "2090001\n",
786 | "2100001\n",
787 | "2110001\n",
788 | "2120001\n",
789 | "2130001\n",
790 | "2140001\n",
791 | "2150001\n",
792 | "2160001\n",
793 | "2170001\n",
794 | "2180001\n",
795 | "2190001\n",
796 | "2200001\n",
797 | "2210001\n",
798 | "2220001\n",
799 | "2230001\n",
800 | "2240001\n",
801 | "2250001\n",
802 | "2260001\n",
803 | "2270001\n",
804 | "2280001\n",
805 | "2290001\n",
806 | "2300001\n",
807 | "2310001\n",
808 | "2320001\n",
809 | "2330001\n",
810 | "2340001\n",
811 | "2350001\n",
812 | "2360001\n",
813 | "2370001\n",
814 | "2380001\n",
815 | "2390001\n",
816 | "2400001\n",
817 | "2410001\n",
818 | "2420001\n",
819 | "2430001\n",
820 | "2440001\n",
821 | "2450001\n",
822 | "2460001\n",
823 | "2470001\n",
824 | "2480001\n",
825 | "2490001\n",
826 | "2500001\n",
827 | "2510001\n",
828 | "2520001\n",
829 | "2530001\n",
830 | "2540001\n",
831 | "2550001\n",
832 | "2560001\n",
833 | "2570001\n",
834 | "2580001\n",
835 | "2590001\n",
836 | "2600001\n",
837 | "2610001\n",
838 | "2620001\n",
839 | "2630001\n",
840 | "2640001\n",
841 | "2650001\n",
842 | "2660001\n",
843 | "2670001\n",
844 | "2680001\n",
845 | "2690001\n",
846 | "2700001\n",
847 | "2710001\n",
848 | "2720001\n",
849 | "2730001\n",
850 | "2740001\n",
851 | "2750001\n",
852 | "2760001\n",
853 | "2770001\n",
854 | "2780001\n",
855 | "2790001\n",
856 | "2800001\n",
857 | "2810001\n",
858 | "2820001\n",
859 | "2830001\n",
860 | "2840001\n",
861 | "2850001\n",
862 | "2860001\n",
863 | "2870001\n",
864 | "2880001\n",
865 | "2890001\n",
866 | "2900001\n",
867 | "2910001\n",
868 | "2920001\n",
869 | "2930001\n",
870 | "2940001\n",
871 | "2950001\n",
872 | "2960001\n",
873 | "2970001\n",
874 | "2980001\n",
875 | "2990001\n",
876 | "3000001\n",
877 | "3010001\n",
878 | "3020001\n",
879 | "3030001\n",
880 | "3040001\n",
881 | "3050001\n",
882 | "3060001\n",
883 | "3070001\n",
884 | "3080001\n",
885 | "3090001\n",
886 | "3100001\n",
887 | "3110001\n",
888 | "3120001\n",
889 | "3130001\n",
890 | "3140001\n",
891 | "3150001\n",
892 | "3160001\n",
893 | "3170001\n",
894 | "3180001\n",
895 | "3190001\n",
896 | "3200001\n",
897 | "3210001\n",
898 | "3220001\n",
899 | "3230001\n",
900 | "3240001\n",
901 | "3250001\n",
902 | "3260001\n",
903 | "3270001\n",
904 | "3280001\n",
905 | "3290001\n",
906 | "3300001\n",
907 | "3310001\n",
908 | "3320001\n",
909 | "3330001\n",
910 | "3340001\n",
911 | "3350001\n",
912 | "3360001\n",
913 | "3370001\n",
914 | "3380001\n",
915 | "3390001\n",
916 | "3400001\n",
917 | "3410001\n",
918 | "3420001\n",
919 | "3430001\n",
920 | "3440001\n",
921 | "3450001\n",
922 | "3460001\n",
923 | "3470001\n",
924 | "3480001\n",
925 | "3490001\n",
926 | "3500001\n",
927 | "3510001\n",
928 | "3520001\n",
929 | "3530001\n",
930 | "3540001\n",
931 | "3550001\n",
932 | "3560001\n",
933 | "3570001\n",
934 | "3580001\n",
935 | "3590001\n",
936 | "3600001\n",
937 | "3610001\n",
938 | "3620001\n",
939 | "3630001\n",
940 | "3640001\n",
941 | "3650001\n",
942 | "3660001\n",
943 | "3670001\n",
944 | "3680001\n",
945 | "3690001\n",
946 | "3700001\n",
947 | "3710001\n",
948 | "3720001\n",
949 | "3730001\n",
950 | "3740001\n",
951 | "3750001\n",
952 | "3760001\n",
953 | "3770001\n",
954 | "3780001\n",
955 | "3790001\n",
956 | "3800001\n",
957 | "3810001\n",
958 | "3820001\n",
959 | "3830001\n",
960 | "3840001\n",
961 | "3850001\n",
962 | "3860001\n",
963 | "3870001\n",
964 | "3880001\n",
965 | "3890001\n",
966 | "3900001\n",
967 | "3910001\n",
968 | "3920001\n",
969 | "3930001\n",
970 | "3940001\n",
971 | "3950001\n",
972 | "3960001\n",
973 | "3970001\n",
974 | "3980001\n",
975 | "3990001\n",
976 | "4000001\n",
977 | "4010001\n",
978 | "4020001\n",
979 | "4030001\n",
980 | "4040001\n",
981 | "4050001\n",
982 | "4060001\n",
983 | "4070001\n",
984 | "4080001\n",
985 | "4090001\n",
986 | "4100001\n",
987 | "4110001\n",
988 | "4120001\n",
989 | "4130001\n",
990 | "4140001\n",
991 | "4150001\n",
992 | "4160001\n",
993 | "4170001\n",
994 | "4180001\n",
995 | "4190001\n",
996 | "4200001\n",
997 | "4210001\n",
998 | "4220001\n",
999 | "4230001\n",
1000 | "4240001\n",
1001 | "4250001\n",
1002 | "4260001\n",
1003 | "4270001\n",
1004 | "4280001\n",
1005 | "4290001\n",
1006 | "4300001\n",
1007 | "4310001\n",
1008 | "4320001\n",
1009 | "4330001\n",
1010 | "4340001\n",
1011 | "4350001\n",
1012 | "4360001\n",
1013 | "4370001\n",
1014 | "4380001\n",
1015 | "4390001\n",
1016 | "4400001\n",
1017 | "4410001\n",
1018 | "4420001\n",
1019 | "4430001\n",
1020 | "4440001\n",
1021 | "4450001\n",
1022 | "4460001\n",
1023 | "4470001\n",
1024 | "4480001\n",
1025 | "4490001\n",
1026 | "4500001\n",
1027 | "4510001\n",
1028 | "4520001\n",
1029 | "4530001\n",
1030 | "4540001\n",
1031 | "4550001\n",
1032 | "4560001\n",
1033 | "4570001\n",
1034 | "4580001\n",
1035 | "4590001\n",
1036 | "4600001\n",
1037 | "4610001\n",
1038 | "4620001\n",
1039 | "4630001\n",
1040 | "4640001\n",
1041 | "4650001\n",
1042 | "4660001\n",
1043 | "4670001\n",
1044 | "4680001\n",
1045 | "4690001\n",
1046 | "4700001\n",
1047 | "4710001\n",
1048 | "4720001\n",
1049 | "4730001\n",
1050 | "4740001\n",
1051 | "4750001\n",
1052 | "4760001\n",
1053 | "4770001\n",
1054 | "4780001\n",
1055 | "4790001\n",
1056 | "4800001\n",
1057 | "4810001\n",
1058 | "4820001\n",
1059 | "4830001\n"
1060 | ]
1061 | }
1062 | ],
1063 | "source": [
1064 | "result_file = DATA_PATH + \"submission.csv\"\n",
1065 | "with open(result_file, \"w\") as f:\n",
1066 | " f.write(\"order_id,products\\n\")\n",
1067 | " idx = 0\n",
1068 | " for line_id in test_order_index:\n",
1069 | " test_order = test_order_ix[line_id]\n",
1070 | " order_id = test_order['order_id']\n",
1071 | " user_id = test_order['user_id']\n",
1072 | " product_id_list = list(item_user_reordered_idx[user_id].index)\n",
1073 | " \n",
1074 | " final_order_list = []\n",
1075 | " for product_id in product_id_list:\n",
1076 | " if predict_y[idx]:\n",
1077 | " final_order_list.append(str(product_id))\n",
1078 | " idx += 1\n",
1079 | " \n",
1080 | " if idx % 10000 == 1:\n",
1081 | " print idx\n",
1082 | " if final_order_list:\n",
1083 | " f.write(\"%d,%s\\n\" % (order_id, \" \".join(final_order_list)))\n",
1084 | " else:\n",
1085 | " f.write(\"%d, None\\n\" % order_id)"
1086 | ]
1087 | },
1088 | {
1089 | "cell_type": "code",
1090 | "execution_count": null,
1091 | "metadata": {
1092 | "collapsed": true
1093 | },
1094 | "outputs": [],
1095 | "source": []
1096 | }
1097 | ],
1098 | "metadata": {
1099 | "kernelspec": {
1100 | "display_name": "Python 2",
1101 | "language": "python",
1102 | "name": "python2"
1103 | },
1104 | "language_info": {
1105 | "codemirror_mode": {
1106 | "name": "ipython",
1107 | "version": 2
1108 | },
1109 | "file_extension": ".py",
1110 | "mimetype": "text/x-python",
1111 | "name": "python",
1112 | "nbconvert_exporter": "python",
1113 | "pygments_lexer": "ipython2",
1114 | "version": "2.7.13"
1115 | }
1116 | },
1117 | "nbformat": 4,
1118 | "nbformat_minor": 2
1119 | }
1120 |
--------------------------------------------------------------------------------
/network/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/3/19 17:24
9 | @change_time:
10 | 1.2017/3/19 17:24
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/preproccessing/StandardScaler.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: StandardScaler.py
10 | @time: 2017/3/15 22:38
11 | @change_time:
12 | 1.2017/3/15 22:38
13 | """
14 | import numpy as np
15 | import math
16 |
17 |
18 | class StandardScaler(object):
19 | def __init__(self):
20 | return
21 |
22 | def fit_transform(self, X):
23 | """
24 | 按照行进行数据的转换
25 | """
26 | rows, cols = X.shape
27 | result = np.zeros(X.shape)
28 | for row in xrange(rows):
29 | x_mean = np.mean(X[row, :])
30 | x_var = np.var(X[row, :])
31 | result[row, :] = (X[row, :] - x_mean) / math.sqrt(x_var)
32 | return result
33 |
34 |
35 | if __name__ == '__main__':
36 | data = np.array([[1, 2, 3], [2, 2, 3]])
37 | result = StandardScaler().fit_transform(data)
38 | print result
--------------------------------------------------------------------------------
/preproccessing/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/15 22:36
11 | @change_time:
12 | 1.2017/3/15 22:36
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/profile/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 |
5 | """
6 | @version: 1.0
7 | @author: xiaoqiangkx
8 | @file: __init__.py.py
9 | @time: 2017/7/14 15:46
10 | @change_time:
11 | 1.2017/7/14 15:46
12 | """
13 |
14 | from profile import (
15 | enable_profile,
16 | close_profile,
17 | )
--------------------------------------------------------------------------------
/profile/profile.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 |
5 | """
6 | @version: 1.0
7 | @author: xiaoqiangkx
8 | @file: profile.py
9 | @time: 2017/7/14 15:58
10 | @change_time:
11 | 1.2017/7/14 15:58
12 | """
13 | PROFILE = None
14 | import time
15 |
16 |
17 | def enable_profile():
18 | global PROFILE
19 | import cProfile
20 | PROFILE = cProfile.Profile()
21 | PROFILE.enable()
22 | return
23 |
24 |
25 | def close_profile():
26 | PROFILE.disable()
27 | PROFILE.dump_stats('profile_%d.prof' % time.time())
28 | return
--------------------------------------------------------------------------------
/py_lightgbm/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/6/29 21:44
9 | @change_time:
10 | 1.2017/6/29 21:44
11 | """
12 |
13 |
14 | from application.classifier import LGBMClassifier
--------------------------------------------------------------------------------
/py_lightgbm/application/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/1 11:28
9 | @change_time:
10 | 1.2017/7/1 11:28
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/application/classifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: classifier.py
8 | @time: 2017/7/9 12:03
9 | @change_time:
10 | 1.2017/7/9 12:03 add category support
11 | """
12 | from py_lightgbm.boosting import boosting
13 | from py_lightgbm.io.dataset import Dataset
14 | from py_lightgbm.logmanager import logger
15 | from py_lightgbm.utils import const
16 | from py_lightgbm.config.tree_config import TreeConfig
17 |
18 |
19 | _LOGGER = logger.get_logger("LGBMClassifier")
20 |
21 |
22 | class LGBMClassifier(object):
23 |
24 | def __init__(self, boosting_type='gbdt', num_leaves=31, max_depth=-1,
25 | learning_rate=0.1, n_estimators=10, max_bin=const.DEFAULT_MAX_BIN,
26 | subsample_for_bin=50000, objective=None, min_split_gain=0,
27 | min_child_weight=5, min_child_samples=10, subsample=1, subsample_freq=1,
28 | colsample_bytree=1, reg_alpha=0, reg_lambda=0,
29 | seed=0, nthread=-1, silent=True, **kwargs):
30 |
31 | self._boosting_type = boosting_type # 仅支持gdbt
32 | self._boosting = boosting.Booster()
33 | self._train_data = None
34 | self._silent = silent
35 |
36 | self._tree_config = TreeConfig()
37 | self._tree_config.learning_rate = learning_rate
38 | self._tree_config.n_estimators = n_estimators
39 |
40 | self._tree_config.max_bin = max_bin
41 | self._tree_config.num_leaves = num_leaves
42 | self._tree_config.max_depth = max_depth
43 |
44 | self._tree_config.min_child_samples = min_child_samples
45 | self._tree_config.min_split_gain = min_split_gain
46 |
47 | self._tree_config.reg_alpha = reg_alpha
48 | self._tree_config.reg_lambda = reg_lambda
49 |
50 | self._bin_mappers = None
51 | self._categorical_feature = []
52 | return
53 |
54 | def fit(self, X, y, sample_weight=None, init_score=None,
55 | group=None, eval_set=None, eval_names=None, eval_sample_weight=None,
56 | eval_init_score=None, eval_group=None, eval_metric=None, early_stopping_rounds=None,
57 | verbose=True, feature_name=None, categorical_feature=None, callbacks=None):
58 | self._train_data = Dataset(X, y, feature_name, categorical_feature, self._tree_config)
59 | self._bin_mappers = self._train_data.create_bin_mapper(self._tree_config.max_bin)
60 | self._train_data.construct(self._bin_mappers)
61 | self._boosting.init(self._train_data, self._tree_config)
62 |
63 | for i in xrange(self._tree_config.n_estimators):
64 | _LOGGER.info("iteration-------{0}".format(i + 1))
65 | self._boosting.train_one_iter(self._train_data)
66 | return
67 |
68 | def show(self):
69 | self._boosting.show()
70 | return
71 |
72 | def print_bin_mappers(self):
73 | for bin_mapper in self._bin_mappers:
74 | print ""
75 | print "bins", bin_mapper._num_bins
76 | print "bin_upper_bound", bin_mapper._bin_upper_bound
77 | print "min", bin_mapper._min_value
78 | print "max", bin_mapper._max_value
79 | print ""
80 | return
81 |
82 | def predict(self, X, raw_score=False, num_iteration=0):
83 | return
84 |
85 | def predict_proba(self, X, raw_score=False, num_iteration=0):
86 | return self._boosting.predict_proba(X)
87 |
88 | def score(self, X, y):
89 |
90 | return
91 |
--------------------------------------------------------------------------------
/py_lightgbm/boosting/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/1 11:28
9 | @change_time:
10 | 1.2017/7/1 11:28
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/boosting/boosting.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: boosting.py
8 | @time: 2017/7/1 11:29
9 | @change_time:
10 | 1.2017/7/1 11:29
11 | """
12 | from py_lightgbm.boosting.gbdt import Gbdt
13 |
14 |
15 | class Booster(object):
16 | """
17 | 各种实现方式的封装类
18 | """
19 | def __init__(self):
20 | self._booster = Gbdt()
21 | return
22 |
23 | def init(self, train_data, tree_config):
24 | self._booster.init(train_data, tree_config)
25 | return
26 |
27 | def train_one_iter(self, train_data, gradients=None, hessians=None):
28 | self._booster.train_one_iter(train_data)
29 | return
30 |
31 | def show(self):
32 | self._booster.show()
33 | return
34 |
35 | def predict_proba(self, X):
36 | return self._booster.predict_proba(X)
37 |
38 |
39 | if __name__ == '__main__':
40 | pass
41 |
--------------------------------------------------------------------------------
/py_lightgbm/boosting/gbdt.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: gbdt.py
8 | @time: 2017/7/1 11:25
9 | @change_time:
10 | 1.2017/7/1 11:25
11 | """
12 | from py_lightgbm.tree.tree_learner import TreeLearner
13 | from py_lightgbm.objective import BinaryObjective
14 | import numpy as np
15 | from py_lightgbm.utils import const
16 |
17 |
18 | class Gbdt(object):
19 | def __init__(self):
20 | self._train_data = None
21 | self._scores = None
22 | self._object_function = None
23 | self._gradients = None
24 | self._hessians = None
25 |
26 | self._scores = None
27 | self._tree_config = None
28 |
29 | self._models = []
30 | self._coefs = []
31 | self._tree_list = [] # 记录所有的树
32 | self._tree_coef = [] # 记录所有的因子
33 | return
34 |
35 | def init(self, train_data, tree_config):
36 | self._train_data = train_data
37 | self._scores = train_data.init_score
38 | self._tree_config = tree_config
39 | self._object_function = BinaryObjective(self._train_data.labels)
40 | return
41 |
42 | def train_one_iter(self, train_data, gradients=None, hessians=None):
43 | """
44 | What's the value of gradient and hessian
45 | """
46 |
47 | if gradients is None or hessians is None:
48 | self._boosting()
49 | gradients = self._gradients
50 | hessians = self._hessians
51 |
52 | # only use one tree in one iterations
53 | tree_learner = TreeLearner(self._tree_config, self._train_data)
54 | tree = tree_learner.train(gradients, hessians)
55 | self._tree_list.append(tree)
56 |
57 | self._update_scores(tree, tree_learner)
58 | return
59 |
60 | def _update_scores(self, tree, tree_learner):
61 | """
62 | 更新循环结束以后所有值的score
63 | """
64 | for i in xrange(self._tree_config.num_leaves):
65 | output = tree.output_of_leaf(i)
66 |
67 | indices = tree_learner.get_indices_of_leaf(i)
68 | for index in indices:
69 | self._scores[index] += self._tree_config.learning_rate * output
70 | return
71 |
72 | def _boosting(self):
73 | if not self._object_function:
74 | return
75 |
76 | self._gradients, self._hessians = self._object_function.get_gradients(self._scores)
77 | return
78 |
79 | def predict_proba(self, X):
80 | """
81 | predict result according to tree_list
82 | """
83 | result = np.zeros((X.shape[0], ))
84 | for idx, tree in enumerate(self._tree_list):
85 | predict_y = tree.predict_prob(X)
86 | result += predict_y
87 |
88 | convert_result = self._object_function.convert_output(result)
89 | return self._train_data.convert_labels(convert_result)
90 |
91 | def show(self):
92 | """
93 | 展示所有的tree信息
94 | """
95 | for index, tree in enumerate(self._tree_list):
96 | print "\n======================show tree {0}".format(index + 1)
97 | tree.show()
98 | return
99 |
100 | if __name__ == '__main__':
101 | pass
102 |
--------------------------------------------------------------------------------
/py_lightgbm/config/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py
8 | @time: 2017/7/8 10:57
9 | @change_time:
10 | 1.2017/7/8 10:57
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/config/tree_config.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: tree_config.py
8 | @time: 2017/7/2 22:38
9 | @change_time:
10 | 1.2017/7/2 22:38
11 | """
12 | from py_lightgbm.utils import const
13 |
14 |
15 | class TreeConfig(object):
16 | def __init__(self):
17 | # 性能参数
18 | self.num_leaves = const.DEFAULT_NUM_LEAVES
19 | self.max_depth = const.DEFAULT_MAX_DEPTH
20 | self.learning_rate = const.DEFAULT_LEARNING_RATE
21 | self.n_estimators = const.DEFAULT_NUM_ESTIMATORS
22 | self.max_bin = const.DEFAULT_MAX_BIN
23 | self.min_split_gain = const.DEFAULT_MIN_SPLIT_GAIN
24 | self.reg_alpha = const.DEFAULT_REG_ALPHA
25 | self.reg_lambda = const.DEFAULT_REG_LAMBDA
26 | self.min_child_samples = const.DEFAULT_MIN_CHILD_SAMPLES
27 | return
28 |
29 |
30 | if __name__ == '__main__':
31 | pass
32 |
--------------------------------------------------------------------------------
/py_lightgbm/io/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/8 10:57
9 | @change_time:
10 | 1.2017/7/8 10:57
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/io/bin.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: bin.py
8 | @time: 2017/7/8 11:58
9 | @change_time:
10 | 1.2017/7/8 11:58
11 | """
12 | from collections import Counter
13 | from py_lightgbm.utils import const
14 |
15 | import math
16 |
17 |
18 | class HistogramBinEntry(object):
19 |
20 | """
21 | store data for one histogram bin
22 | """
23 | def __init__(self):
24 | self._sum_gradients = 0
25 | self._sum_hessians = 0
26 | self._cnt = 0
27 | return
28 |
29 |
30 | class BinMapper(object):
31 | """
32 | Convert feature values into bin and store some meta information for bin
33 | Every feature will get a BinMapper
34 | """
35 |
36 | def __init__(self):
37 | self._num_bins = 0 # number of bins
38 | self._bin_upper_bound = []
39 |
40 | self._bin_type = const.TYPE_NUMERICAL
41 |
42 | self._category2bin = {} # int to unsigned int
43 | self._bin2category = {}
44 |
45 | self._is_trivial = False
46 |
47 | self._min_value = 0
48 | self._max_value = 0
49 | self._default_bin = 0
50 | return
51 |
52 | def __repr__(self):
53 | return self.__str__()
54 |
55 | def __str__(self):
56 | repr_str = "nb:{0}, bound:{1}".format(self._num_bins, self._bin_upper_bound)
57 | return repr_str
58 |
59 | def __len__(self):
60 | return len(self._bin_upper_bound)
61 |
62 | def upper_at(self, bin_index):
63 | return self._bin_upper_bound[bin_index]
64 |
65 | def find_lower_bound(self, threshold):
66 | print self._bin_upper_bound
67 | index = self._bin_upper_bound.index(threshold)
68 |
69 | if index == 0:
70 | return -float("inf")
71 |
72 | return self._bin_upper_bound[index - 1]
73 |
74 | def find_bin(self, values, max_bin, bin_type=const.TYPE_NUMERICAL, min_data_in_bin=0, min_split_data=0):
75 | """
76 | Construct feature values to binMapper
77 | """
78 | self._bin_type = bin_type
79 |
80 | if bin_type is const.TYPE_NUMERICAL:
81 | distinct_values = sorted(list(set(values))) # set操作会默认进行排序操作
82 | self._min_value = distinct_values[0]
83 | self._max_value = distinct_values[-1]
84 | counts = Counter(values)
85 | else: # TYPE_NUMERICAL, fill category2bin and bin2category
86 | category_counts = Counter(values)
87 | bin_num = 1
88 | counts = {}
89 | max_bin = min(max_bin, int(math.ceil(len(category_counts) * 0.98)))
90 |
91 | for key, cnt in category_counts.most_common(max_bin):
92 | self._category2bin[key] = bin_num
93 | self._bin2category[bin_num] = key
94 | counts[bin_num] = cnt
95 | bin_num += 1
96 |
97 | distinct_values = self._category2bin.values()
98 |
99 | self._bin_upper_bound = self.greedy_find_bin(
100 | distinct_values, counts, max_bin,
101 | min_data_in_bin=min_data_in_bin,
102 | )
103 |
104 | self._num_bins = len(self._bin_upper_bound)
105 | return
106 |
107 | def greedy_find_bin(self, distinct_values, counts, max_bin, min_data_in_bin=0):
108 | # update upper_bound
109 | num_total_values = len(distinct_values)
110 | bin_upper_bound = []
111 |
112 | mean_data_in_bin = num_total_values / max_bin
113 | mean_data_in_bin = max(min_data_in_bin, mean_data_in_bin)
114 | mean_data_in_bin = max(1, mean_data_in_bin)
115 |
116 | cnt = 0
117 | for value in distinct_values:
118 | num_value = counts[value]
119 | cnt += num_value
120 |
121 | if cnt >= mean_data_in_bin:
122 | bin_upper_bound.append(value)
123 | cnt = 0
124 |
125 | bin_upper_bound.append(float("inf"))
126 | return bin_upper_bound
127 |
128 | def find_bin_idx(self, value):
129 | # find the bin for value, use bi_search
130 | if self._bin_type is const.TYPE_CATEGORY:
131 | value = self._category2bin.get(value, float("inf"))
132 |
133 | st = 0
134 | data = self._bin_upper_bound
135 | ed = len(data) - 1
136 |
137 | mid = -1
138 | while st <= ed:
139 | mid = (st + ed) / 2
140 | if data[mid] == value:
141 | break
142 | elif data[mid] < value:
143 | st = mid + 1
144 | else:
145 | ed = mid - 1
146 |
147 | return mid
148 |
149 |
150 | class OrderedBin(object):
151 | pass
152 |
153 |
154 | if __name__ == '__main__':
155 | bin_mapper = BinMapper()
156 | bin_mapper._bin_upper_bound = [0, 1, 3, 4, 5, float("inf")]
157 | result = bin_mapper.find_bin_idx(3.5)
158 | print result
159 |
--------------------------------------------------------------------------------
/py_lightgbm/io/dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: dataset.py
8 | @time: 2017/7/8 10:59
9 | @change_time:
10 | 1.2017/7/8 10:59 init dataset, add BinMapper
11 | """
12 | from py_lightgbm.io.bin import BinMapper
13 | from py_lightgbm.tree.feature_histogram import FeatureHistogram
14 | from py_lightgbm.utils import const
15 | from collections import Counter
16 |
17 | import numpy as np
18 |
19 |
20 | MIN_DATA_IN_BIN = 10
21 |
22 |
23 | class MetaData(object):
24 | """
25 | store some meta data used for trainning data
26 | """
27 | def __init__(self):
28 | self._labels = []
29 | self._init_scores = []
30 | return
31 |
32 |
33 | class Dataset(object):
34 | """
35 | The main class of dataset
36 | """
37 | def __init__(self, X, y, feature_name, categorical_feature, tree_config):
38 | self._train_X = X
39 | self._feature_names = feature_name if feature_name else []
40 | self._categorical_feature = categorical_feature if categorical_feature else []
41 | self._num_data = self._train_X.shape[0]
42 | self._num_features = self._train_X.shape[1]
43 |
44 | self._labels = None
45 | self._label2real = {}
46 | self._real2label = {}
47 | self._init_score = None
48 | self._bin_mappers = None
49 | self._tree_config = tree_config
50 | self.create_label(y)
51 | return
52 |
53 | @property
54 | def train_X(self):
55 | return self._train_X
56 |
57 | @property
58 | def num_data(self):
59 | return self._num_data
60 |
61 | @property
62 | def num_features(self):
63 | return self._num_features
64 |
65 | @property
66 | def labels(self):
67 | return self._labels
68 |
69 | @property
70 | def init_score(self):
71 | return self._init_score
72 |
73 | def split(self, feature_index, threshold_bin, indices, begin, end):
74 | left_indices = []
75 | right_indices = []
76 |
77 | bin_mapper = self._bin_mappers[feature_index]
78 | low_bound = -float("inf")
79 | if threshold_bin > 0:
80 | low_bound = bin_mapper.upper_at(threshold_bin - 1)
81 |
82 | upper_bound = bin_mapper.upper_at(threshold_bin)
83 |
84 | for i in range(begin, end):
85 | index = indices[i]
86 | value = self._train_X[index, feature_index]
87 | if upper_bound >= value > low_bound:
88 | left_indices.append(index)
89 | else:
90 | right_indices.append(index)
91 |
92 | return left_indices, right_indices
93 |
94 | def convert_labels(self, y, threshold=0.5):
95 | result = np.zeros(y.shape)
96 | for i in xrange(y.shape[0]):
97 | value = 1 if y[i] >= threshold else -1
98 | result[i] = self._label2real[value]
99 | return result
100 |
101 | def create_label(self, y):
102 | labels = np.zeros(y.shape)
103 | self._init_score = np.zeros(y.shape)
104 |
105 | raw_labels = Counter(y).keys()
106 | self._label2real = {
107 | 1: raw_labels[0],
108 | -1: raw_labels[1],
109 | }
110 |
111 | self._real2label = {
112 | raw_labels[0]: 1,
113 | raw_labels[1]: -1,
114 | }
115 |
116 | for i in xrange(y.shape[0]):
117 | labels[i] = self._real2label[y[i]]
118 |
119 | mean_y = np.mean(labels)
120 |
121 | self._init_score[:] = 1.0 / 2 * np.log((1 + mean_y) / (1 - mean_y))
122 |
123 | self._labels = labels
124 | return
125 |
126 | def construct_histograms(self, is_feature_used, data_indices, leaf_idx, gradients, hessians):
127 | if not data_indices:
128 | return []
129 |
130 | feature_histograms = []
131 | # 为每一个feature建立一个Bin数据,Bin数据用于之后的划分
132 | for feature_index in xrange(self._num_features):
133 | feature_histogram = FeatureHistogram(feature_index, self._bin_mappers[feature_index])
134 | feature_histogram.init(self._train_X, data_indices, gradients, hessians, self._tree_config)
135 | feature_histograms.append(feature_histogram)
136 |
137 | return feature_histograms
138 |
139 | def fix_histogram(self, feature_idx, sum_gradients, sum_hessian, num_data, histogram_data):
140 |
141 | return
142 |
143 | def construct(self, bin_mappers):
144 | self._bin_mappers = bin_mappers
145 | return
146 |
147 | def create_bin_mapper(self, max_bin):
148 | bin_mappers = []
149 | for i in xrange(self._num_features):
150 | bin_mapper = BinMapper()
151 | values = self._train_X[:, i]
152 | bin_type = const.TYPE_CATEGORY if i in self._categorical_feature else const.TYPE_NUMERICAL
153 | bin_mapper.find_bin(values, max_bin, min_data_in_bin=MIN_DATA_IN_BIN, bin_type=bin_type)
154 | bin_mappers.append(bin_mapper)
155 | return bin_mappers
156 |
157 |
158 | if __name__ == '__main__':
159 | pass
160 |
--------------------------------------------------------------------------------
/py_lightgbm/logmanager/__init__.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 |
5 | """
6 | @version: 1.0
7 | @author: xiaoqiangkx
8 | @file: __init__.py.py
9 | @time: 2017/7/12 20:21
10 | @change_time:
11 | 1.2017/7/12 20:21
12 | """
13 |
14 | if __name__ == '__main__':
15 | pass
--------------------------------------------------------------------------------
/py_lightgbm/logmanager/logger.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 |
5 | """
6 | @version: 1.0
7 | @author: xiaoqiangkx
8 | @file: logger.py
9 | @time: 2017/7/12 20:21
10 | @change_time:
11 | 1.2017/7/12 20:21
12 | """
13 | import logging
14 |
15 |
16 | LOG_LEVEL = logging.INFO
17 | logging.basicConfig()
18 |
19 |
20 | def get_logger(name):
21 | logger = logging.getLogger(name)
22 | logger.setLevel(LOG_LEVEL)
23 | return logger
--------------------------------------------------------------------------------
/py_lightgbm/metric/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/1 11:28
9 | @change_time:
10 | 1.2017/7/1 11:28
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/objective/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/9 17:40
9 | @change_time:
10 | 1.2017/7/9 17:40
11 | """
12 |
13 | from py_lightgbm.objective.objective_function import BinaryObjective
14 |
--------------------------------------------------------------------------------
/py_lightgbm/objective/objective_function.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: objective_function.py
8 | @time: 2017/7/9 17:41
9 | @change_time:
10 | 1.2017/7/9 17:41
11 | """
12 | import numpy as np
13 |
14 |
15 | class BinaryObjective(object):
16 | def __init__(self, labels):
17 | self._sigmoid = 2.0
18 | self._labels = labels
19 | return
20 |
21 | def get_gradients(self, scores):
22 | gradients = np.zeros(scores.shape)
23 | hessians = np.zeros(scores.shape)
24 |
25 | for i in xrange(self._labels.shape[0]):
26 | label = self._labels[i]
27 | response = -label * self._sigmoid / (1 + np.exp(label * self._sigmoid * scores[i]))
28 | abs_response = np.abs(response)
29 | gradients[i] = response
30 | hessians[i] = abs_response * (self._sigmoid - abs_response)
31 |
32 | return gradients, hessians
33 |
34 | def convert_output(self, value):
35 | return 1.0 / (1.0 + np.exp(-self._sigmoid * value))
36 |
37 |
38 | if __name__ == '__main__':
39 | labels = np.array([[1], [1], [-1], [1], [-1]])
40 | scores = np.array([[0.5], [-1], [-1], [1], [0.5]])
41 | bo = BinaryObjective(labels)
42 | gradients, hessians = bo.get_gradients(scores)
43 | print gradients
44 | print hessians
--------------------------------------------------------------------------------
/py_lightgbm/testes/.ipynb_checkpoints/test_xgboost-checkpoint.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "code",
5 | "execution_count": 1,
6 | "metadata": {
7 | "collapsed": true
8 | },
9 | "outputs": [],
10 | "source": [
11 | "%matplotlib inline\n",
12 | "import matplotlib\n",
13 | "import matplotlib.pyplot as plt\n",
14 | "import pandas as pd\n",
15 | "import numpy as np\n",
16 | "import seaborn as sns\n",
17 | "pd.set_option(\"display.max_columns\",101)\n",
18 | "RANDOM_STATE = 42"
19 | ]
20 | },
21 | {
22 | "cell_type": "code",
23 | "execution_count": 2,
24 | "metadata": {
25 | "collapsed": true
26 | },
27 | "outputs": [],
28 | "source": [
29 | "DATA_PATH = \"../../data/mnist_train.csv\"\n",
30 | "mnist_train = pd.read_csv(DATA_PATH)"
31 | ]
32 | },
33 | {
34 | "cell_type": "code",
35 | "execution_count": 3,
36 | "metadata": {
37 | "collapsed": false
38 | },
39 | "outputs": [
40 | {
41 | "data": {
42 | "text/html": [
43 | "\n",
44 | "
\n",
45 | " \n",
46 | " \n",
47 | " | \n",
48 | " label | \n",
49 | " pixel0 | \n",
50 | " pixel1 | \n",
51 | " pixel2 | \n",
52 | " pixel3 | \n",
53 | " pixel4 | \n",
54 | " pixel5 | \n",
55 | " pixel6 | \n",
56 | " pixel7 | \n",
57 | " pixel8 | \n",
58 | " pixel9 | \n",
59 | " pixel10 | \n",
60 | " pixel11 | \n",
61 | " pixel12 | \n",
62 | " pixel13 | \n",
63 | " pixel14 | \n",
64 | " pixel15 | \n",
65 | " pixel16 | \n",
66 | " pixel17 | \n",
67 | " pixel18 | \n",
68 | " pixel19 | \n",
69 | " pixel20 | \n",
70 | " pixel21 | \n",
71 | " pixel22 | \n",
72 | " pixel23 | \n",
73 | " pixel24 | \n",
74 | " pixel25 | \n",
75 | " pixel26 | \n",
76 | " pixel27 | \n",
77 | " pixel28 | \n",
78 | " pixel29 | \n",
79 | " pixel30 | \n",
80 | " pixel31 | \n",
81 | " pixel32 | \n",
82 | " pixel33 | \n",
83 | " pixel34 | \n",
84 | " pixel35 | \n",
85 | " pixel36 | \n",
86 | " pixel37 | \n",
87 | " pixel38 | \n",
88 | " pixel39 | \n",
89 | " pixel40 | \n",
90 | " pixel41 | \n",
91 | " pixel42 | \n",
92 | " pixel43 | \n",
93 | " pixel44 | \n",
94 | " pixel45 | \n",
95 | " pixel46 | \n",
96 | " pixel47 | \n",
97 | " pixel48 | \n",
98 | " ... | \n",
99 | " pixel734 | \n",
100 | " pixel735 | \n",
101 | " pixel736 | \n",
102 | " pixel737 | \n",
103 | " pixel738 | \n",
104 | " pixel739 | \n",
105 | " pixel740 | \n",
106 | " pixel741 | \n",
107 | " pixel742 | \n",
108 | " pixel743 | \n",
109 | " pixel744 | \n",
110 | " pixel745 | \n",
111 | " pixel746 | \n",
112 | " pixel747 | \n",
113 | " pixel748 | \n",
114 | " pixel749 | \n",
115 | " pixel750 | \n",
116 | " pixel751 | \n",
117 | " pixel752 | \n",
118 | " pixel753 | \n",
119 | " pixel754 | \n",
120 | " pixel755 | \n",
121 | " pixel756 | \n",
122 | " pixel757 | \n",
123 | " pixel758 | \n",
124 | " pixel759 | \n",
125 | " pixel760 | \n",
126 | " pixel761 | \n",
127 | " pixel762 | \n",
128 | " pixel763 | \n",
129 | " pixel764 | \n",
130 | " pixel765 | \n",
131 | " pixel766 | \n",
132 | " pixel767 | \n",
133 | " pixel768 | \n",
134 | " pixel769 | \n",
135 | " pixel770 | \n",
136 | " pixel771 | \n",
137 | " pixel772 | \n",
138 | " pixel773 | \n",
139 | " pixel774 | \n",
140 | " pixel775 | \n",
141 | " pixel776 | \n",
142 | " pixel777 | \n",
143 | " pixel778 | \n",
144 | " pixel779 | \n",
145 | " pixel780 | \n",
146 | " pixel781 | \n",
147 | " pixel782 | \n",
148 | " pixel783 | \n",
149 | "
\n",
150 | " \n",
151 | " \n",
152 | " \n",
153 | " 0 | \n",
154 | " 1 | \n",
155 | " 0 | \n",
156 | " 0 | \n",
157 | " 0 | \n",
158 | " 0 | \n",
159 | " 0 | \n",
160 | " 0 | \n",
161 | " 0 | \n",
162 | " 0 | \n",
163 | " 0 | \n",
164 | " 0 | \n",
165 | " 0 | \n",
166 | " 0 | \n",
167 | " 0 | \n",
168 | " 0 | \n",
169 | " 0 | \n",
170 | " 0 | \n",
171 | " 0 | \n",
172 | " 0 | \n",
173 | " 0 | \n",
174 | " 0 | \n",
175 | " 0 | \n",
176 | " 0 | \n",
177 | " 0 | \n",
178 | " 0 | \n",
179 | " 0 | \n",
180 | " 0 | \n",
181 | " 0 | \n",
182 | " 0 | \n",
183 | " 0 | \n",
184 | " 0 | \n",
185 | " 0 | \n",
186 | " 0 | \n",
187 | " 0 | \n",
188 | " 0 | \n",
189 | " 0 | \n",
190 | " 0 | \n",
191 | " 0 | \n",
192 | " 0 | \n",
193 | " 0 | \n",
194 | " 0 | \n",
195 | " 0 | \n",
196 | " 0 | \n",
197 | " 0 | \n",
198 | " 0 | \n",
199 | " 0 | \n",
200 | " 0 | \n",
201 | " 0 | \n",
202 | " 0 | \n",
203 | " 0 | \n",
204 | " ... | \n",
205 | " 0 | \n",
206 | " 0 | \n",
207 | " 0 | \n",
208 | " 0 | \n",
209 | " 0 | \n",
210 | " 0 | \n",
211 | " 0 | \n",
212 | " 0 | \n",
213 | " 0 | \n",
214 | " 0 | \n",
215 | " 0 | \n",
216 | " 0 | \n",
217 | " 0 | \n",
218 | " 0 | \n",
219 | " 0 | \n",
220 | " 0 | \n",
221 | " 0 | \n",
222 | " 0 | \n",
223 | " 0 | \n",
224 | " 0 | \n",
225 | " 0 | \n",
226 | " 0 | \n",
227 | " 0 | \n",
228 | " 0 | \n",
229 | " 0 | \n",
230 | " 0 | \n",
231 | " 0 | \n",
232 | " 0 | \n",
233 | " 0 | \n",
234 | " 0 | \n",
235 | " 0 | \n",
236 | " 0 | \n",
237 | " 0 | \n",
238 | " 0 | \n",
239 | " 0 | \n",
240 | " 0 | \n",
241 | " 0 | \n",
242 | " 0 | \n",
243 | " 0 | \n",
244 | " 0 | \n",
245 | " 0 | \n",
246 | " 0 | \n",
247 | " 0 | \n",
248 | " 0 | \n",
249 | " 0 | \n",
250 | " 0 | \n",
251 | " 0 | \n",
252 | " 0 | \n",
253 | " 0 | \n",
254 | " 0 | \n",
255 | "
\n",
256 | " \n",
257 | " 1 | \n",
258 | " 0 | \n",
259 | " 0 | \n",
260 | " 0 | \n",
261 | " 0 | \n",
262 | " 0 | \n",
263 | " 0 | \n",
264 | " 0 | \n",
265 | " 0 | \n",
266 | " 0 | \n",
267 | " 0 | \n",
268 | " 0 | \n",
269 | " 0 | \n",
270 | " 0 | \n",
271 | " 0 | \n",
272 | " 0 | \n",
273 | " 0 | \n",
274 | " 0 | \n",
275 | " 0 | \n",
276 | " 0 | \n",
277 | " 0 | \n",
278 | " 0 | \n",
279 | " 0 | \n",
280 | " 0 | \n",
281 | " 0 | \n",
282 | " 0 | \n",
283 | " 0 | \n",
284 | " 0 | \n",
285 | " 0 | \n",
286 | " 0 | \n",
287 | " 0 | \n",
288 | " 0 | \n",
289 | " 0 | \n",
290 | " 0 | \n",
291 | " 0 | \n",
292 | " 0 | \n",
293 | " 0 | \n",
294 | " 0 | \n",
295 | " 0 | \n",
296 | " 0 | \n",
297 | " 0 | \n",
298 | " 0 | \n",
299 | " 0 | \n",
300 | " 0 | \n",
301 | " 0 | \n",
302 | " 0 | \n",
303 | " 0 | \n",
304 | " 0 | \n",
305 | " 0 | \n",
306 | " 0 | \n",
307 | " 0 | \n",
308 | " ... | \n",
309 | " 0 | \n",
310 | " 0 | \n",
311 | " 0 | \n",
312 | " 0 | \n",
313 | " 0 | \n",
314 | " 0 | \n",
315 | " 0 | \n",
316 | " 0 | \n",
317 | " 0 | \n",
318 | " 0 | \n",
319 | " 0 | \n",
320 | " 0 | \n",
321 | " 0 | \n",
322 | " 0 | \n",
323 | " 0 | \n",
324 | " 0 | \n",
325 | " 0 | \n",
326 | " 0 | \n",
327 | " 0 | \n",
328 | " 0 | \n",
329 | " 0 | \n",
330 | " 0 | \n",
331 | " 0 | \n",
332 | " 0 | \n",
333 | " 0 | \n",
334 | " 0 | \n",
335 | " 0 | \n",
336 | " 0 | \n",
337 | " 0 | \n",
338 | " 0 | \n",
339 | " 0 | \n",
340 | " 0 | \n",
341 | " 0 | \n",
342 | " 0 | \n",
343 | " 0 | \n",
344 | " 0 | \n",
345 | " 0 | \n",
346 | " 0 | \n",
347 | " 0 | \n",
348 | " 0 | \n",
349 | " 0 | \n",
350 | " 0 | \n",
351 | " 0 | \n",
352 | " 0 | \n",
353 | " 0 | \n",
354 | " 0 | \n",
355 | " 0 | \n",
356 | " 0 | \n",
357 | " 0 | \n",
358 | " 0 | \n",
359 | "
\n",
360 | " \n",
361 | " 2 | \n",
362 | " 1 | \n",
363 | " 0 | \n",
364 | " 0 | \n",
365 | " 0 | \n",
366 | " 0 | \n",
367 | " 0 | \n",
368 | " 0 | \n",
369 | " 0 | \n",
370 | " 0 | \n",
371 | " 0 | \n",
372 | " 0 | \n",
373 | " 0 | \n",
374 | " 0 | \n",
375 | " 0 | \n",
376 | " 0 | \n",
377 | " 0 | \n",
378 | " 0 | \n",
379 | " 0 | \n",
380 | " 0 | \n",
381 | " 0 | \n",
382 | " 0 | \n",
383 | " 0 | \n",
384 | " 0 | \n",
385 | " 0 | \n",
386 | " 0 | \n",
387 | " 0 | \n",
388 | " 0 | \n",
389 | " 0 | \n",
390 | " 0 | \n",
391 | " 0 | \n",
392 | " 0 | \n",
393 | " 0 | \n",
394 | " 0 | \n",
395 | " 0 | \n",
396 | " 0 | \n",
397 | " 0 | \n",
398 | " 0 | \n",
399 | " 0 | \n",
400 | " 0 | \n",
401 | " 0 | \n",
402 | " 0 | \n",
403 | " 0 | \n",
404 | " 0 | \n",
405 | " 0 | \n",
406 | " 0 | \n",
407 | " 0 | \n",
408 | " 0 | \n",
409 | " 0 | \n",
410 | " 0 | \n",
411 | " 0 | \n",
412 | " ... | \n",
413 | " 0 | \n",
414 | " 0 | \n",
415 | " 0 | \n",
416 | " 0 | \n",
417 | " 0 | \n",
418 | " 0 | \n",
419 | " 0 | \n",
420 | " 0 | \n",
421 | " 0 | \n",
422 | " 0 | \n",
423 | " 0 | \n",
424 | " 0 | \n",
425 | " 0 | \n",
426 | " 0 | \n",
427 | " 0 | \n",
428 | " 0 | \n",
429 | " 0 | \n",
430 | " 0 | \n",
431 | " 0 | \n",
432 | " 0 | \n",
433 | " 0 | \n",
434 | " 0 | \n",
435 | " 0 | \n",
436 | " 0 | \n",
437 | " 0 | \n",
438 | " 0 | \n",
439 | " 0 | \n",
440 | " 0 | \n",
441 | " 0 | \n",
442 | " 0 | \n",
443 | " 0 | \n",
444 | " 0 | \n",
445 | " 0 | \n",
446 | " 0 | \n",
447 | " 0 | \n",
448 | " 0 | \n",
449 | " 0 | \n",
450 | " 0 | \n",
451 | " 0 | \n",
452 | " 0 | \n",
453 | " 0 | \n",
454 | " 0 | \n",
455 | " 0 | \n",
456 | " 0 | \n",
457 | " 0 | \n",
458 | " 0 | \n",
459 | " 0 | \n",
460 | " 0 | \n",
461 | " 0 | \n",
462 | " 0 | \n",
463 | "
\n",
464 | " \n",
465 | " 3 | \n",
466 | " 4 | \n",
467 | " 0 | \n",
468 | " 0 | \n",
469 | " 0 | \n",
470 | " 0 | \n",
471 | " 0 | \n",
472 | " 0 | \n",
473 | " 0 | \n",
474 | " 0 | \n",
475 | " 0 | \n",
476 | " 0 | \n",
477 | " 0 | \n",
478 | " 0 | \n",
479 | " 0 | \n",
480 | " 0 | \n",
481 | " 0 | \n",
482 | " 0 | \n",
483 | " 0 | \n",
484 | " 0 | \n",
485 | " 0 | \n",
486 | " 0 | \n",
487 | " 0 | \n",
488 | " 0 | \n",
489 | " 0 | \n",
490 | " 0 | \n",
491 | " 0 | \n",
492 | " 0 | \n",
493 | " 0 | \n",
494 | " 0 | \n",
495 | " 0 | \n",
496 | " 0 | \n",
497 | " 0 | \n",
498 | " 0 | \n",
499 | " 0 | \n",
500 | " 0 | \n",
501 | " 0 | \n",
502 | " 0 | \n",
503 | " 0 | \n",
504 | " 0 | \n",
505 | " 0 | \n",
506 | " 0 | \n",
507 | " 0 | \n",
508 | " 0 | \n",
509 | " 0 | \n",
510 | " 0 | \n",
511 | " 0 | \n",
512 | " 0 | \n",
513 | " 0 | \n",
514 | " 0 | \n",
515 | " 0 | \n",
516 | " ... | \n",
517 | " 0 | \n",
518 | " 0 | \n",
519 | " 0 | \n",
520 | " 0 | \n",
521 | " 0 | \n",
522 | " 0 | \n",
523 | " 0 | \n",
524 | " 0 | \n",
525 | " 0 | \n",
526 | " 0 | \n",
527 | " 0 | \n",
528 | " 0 | \n",
529 | " 0 | \n",
530 | " 0 | \n",
531 | " 0 | \n",
532 | " 0 | \n",
533 | " 0 | \n",
534 | " 0 | \n",
535 | " 0 | \n",
536 | " 0 | \n",
537 | " 0 | \n",
538 | " 0 | \n",
539 | " 0 | \n",
540 | " 0 | \n",
541 | " 0 | \n",
542 | " 0 | \n",
543 | " 0 | \n",
544 | " 0 | \n",
545 | " 0 | \n",
546 | " 0 | \n",
547 | " 0 | \n",
548 | " 0 | \n",
549 | " 0 | \n",
550 | " 0 | \n",
551 | " 0 | \n",
552 | " 0 | \n",
553 | " 0 | \n",
554 | " 0 | \n",
555 | " 0 | \n",
556 | " 0 | \n",
557 | " 0 | \n",
558 | " 0 | \n",
559 | " 0 | \n",
560 | " 0 | \n",
561 | " 0 | \n",
562 | " 0 | \n",
563 | " 0 | \n",
564 | " 0 | \n",
565 | " 0 | \n",
566 | " 0 | \n",
567 | "
\n",
568 | " \n",
569 | " 4 | \n",
570 | " 0 | \n",
571 | " 0 | \n",
572 | " 0 | \n",
573 | " 0 | \n",
574 | " 0 | \n",
575 | " 0 | \n",
576 | " 0 | \n",
577 | " 0 | \n",
578 | " 0 | \n",
579 | " 0 | \n",
580 | " 0 | \n",
581 | " 0 | \n",
582 | " 0 | \n",
583 | " 0 | \n",
584 | " 0 | \n",
585 | " 0 | \n",
586 | " 0 | \n",
587 | " 0 | \n",
588 | " 0 | \n",
589 | " 0 | \n",
590 | " 0 | \n",
591 | " 0 | \n",
592 | " 0 | \n",
593 | " 0 | \n",
594 | " 0 | \n",
595 | " 0 | \n",
596 | " 0 | \n",
597 | " 0 | \n",
598 | " 0 | \n",
599 | " 0 | \n",
600 | " 0 | \n",
601 | " 0 | \n",
602 | " 0 | \n",
603 | " 0 | \n",
604 | " 0 | \n",
605 | " 0 | \n",
606 | " 0 | \n",
607 | " 0 | \n",
608 | " 0 | \n",
609 | " 0 | \n",
610 | " 0 | \n",
611 | " 0 | \n",
612 | " 0 | \n",
613 | " 0 | \n",
614 | " 0 | \n",
615 | " 0 | \n",
616 | " 0 | \n",
617 | " 0 | \n",
618 | " 0 | \n",
619 | " 0 | \n",
620 | " ... | \n",
621 | " 0 | \n",
622 | " 0 | \n",
623 | " 0 | \n",
624 | " 0 | \n",
625 | " 0 | \n",
626 | " 0 | \n",
627 | " 0 | \n",
628 | " 0 | \n",
629 | " 0 | \n",
630 | " 0 | \n",
631 | " 0 | \n",
632 | " 0 | \n",
633 | " 0 | \n",
634 | " 0 | \n",
635 | " 0 | \n",
636 | " 0 | \n",
637 | " 0 | \n",
638 | " 0 | \n",
639 | " 0 | \n",
640 | " 0 | \n",
641 | " 0 | \n",
642 | " 0 | \n",
643 | " 0 | \n",
644 | " 0 | \n",
645 | " 0 | \n",
646 | " 0 | \n",
647 | " 0 | \n",
648 | " 0 | \n",
649 | " 0 | \n",
650 | " 0 | \n",
651 | " 0 | \n",
652 | " 0 | \n",
653 | " 0 | \n",
654 | " 0 | \n",
655 | " 0 | \n",
656 | " 0 | \n",
657 | " 0 | \n",
658 | " 0 | \n",
659 | " 0 | \n",
660 | " 0 | \n",
661 | " 0 | \n",
662 | " 0 | \n",
663 | " 0 | \n",
664 | " 0 | \n",
665 | " 0 | \n",
666 | " 0 | \n",
667 | " 0 | \n",
668 | " 0 | \n",
669 | " 0 | \n",
670 | " 0 | \n",
671 | "
\n",
672 | " \n",
673 | "
\n",
674 | "
5 rows × 785 columns
\n",
675 | "
"
676 | ],
677 | "text/plain": [
678 | " label pixel0 pixel1 pixel2 pixel3 pixel4 pixel5 pixel6 pixel7 \\\n",
679 | "0 1 0 0 0 0 0 0 0 0 \n",
680 | "1 0 0 0 0 0 0 0 0 0 \n",
681 | "2 1 0 0 0 0 0 0 0 0 \n",
682 | "3 4 0 0 0 0 0 0 0 0 \n",
683 | "4 0 0 0 0 0 0 0 0 0 \n",
684 | "\n",
685 | " pixel8 pixel9 pixel10 pixel11 pixel12 pixel13 pixel14 pixel15 \\\n",
686 | "0 0 0 0 0 0 0 0 0 \n",
687 | "1 0 0 0 0 0 0 0 0 \n",
688 | "2 0 0 0 0 0 0 0 0 \n",
689 | "3 0 0 0 0 0 0 0 0 \n",
690 | "4 0 0 0 0 0 0 0 0 \n",
691 | "\n",
692 | " pixel16 pixel17 pixel18 pixel19 pixel20 pixel21 pixel22 pixel23 \\\n",
693 | "0 0 0 0 0 0 0 0 0 \n",
694 | "1 0 0 0 0 0 0 0 0 \n",
695 | "2 0 0 0 0 0 0 0 0 \n",
696 | "3 0 0 0 0 0 0 0 0 \n",
697 | "4 0 0 0 0 0 0 0 0 \n",
698 | "\n",
699 | " pixel24 pixel25 pixel26 pixel27 pixel28 pixel29 pixel30 pixel31 \\\n",
700 | "0 0 0 0 0 0 0 0 0 \n",
701 | "1 0 0 0 0 0 0 0 0 \n",
702 | "2 0 0 0 0 0 0 0 0 \n",
703 | "3 0 0 0 0 0 0 0 0 \n",
704 | "4 0 0 0 0 0 0 0 0 \n",
705 | "\n",
706 | " pixel32 pixel33 pixel34 pixel35 pixel36 pixel37 pixel38 pixel39 \\\n",
707 | "0 0 0 0 0 0 0 0 0 \n",
708 | "1 0 0 0 0 0 0 0 0 \n",
709 | "2 0 0 0 0 0 0 0 0 \n",
710 | "3 0 0 0 0 0 0 0 0 \n",
711 | "4 0 0 0 0 0 0 0 0 \n",
712 | "\n",
713 | " pixel40 pixel41 pixel42 pixel43 pixel44 pixel45 pixel46 pixel47 \\\n",
714 | "0 0 0 0 0 0 0 0 0 \n",
715 | "1 0 0 0 0 0 0 0 0 \n",
716 | "2 0 0 0 0 0 0 0 0 \n",
717 | "3 0 0 0 0 0 0 0 0 \n",
718 | "4 0 0 0 0 0 0 0 0 \n",
719 | "\n",
720 | " pixel48 ... pixel734 pixel735 pixel736 pixel737 pixel738 \\\n",
721 | "0 0 ... 0 0 0 0 0 \n",
722 | "1 0 ... 0 0 0 0 0 \n",
723 | "2 0 ... 0 0 0 0 0 \n",
724 | "3 0 ... 0 0 0 0 0 \n",
725 | "4 0 ... 0 0 0 0 0 \n",
726 | "\n",
727 | " pixel739 pixel740 pixel741 pixel742 pixel743 pixel744 pixel745 \\\n",
728 | "0 0 0 0 0 0 0 0 \n",
729 | "1 0 0 0 0 0 0 0 \n",
730 | "2 0 0 0 0 0 0 0 \n",
731 | "3 0 0 0 0 0 0 0 \n",
732 | "4 0 0 0 0 0 0 0 \n",
733 | "\n",
734 | " pixel746 pixel747 pixel748 pixel749 pixel750 pixel751 pixel752 \\\n",
735 | "0 0 0 0 0 0 0 0 \n",
736 | "1 0 0 0 0 0 0 0 \n",
737 | "2 0 0 0 0 0 0 0 \n",
738 | "3 0 0 0 0 0 0 0 \n",
739 | "4 0 0 0 0 0 0 0 \n",
740 | "\n",
741 | " pixel753 pixel754 pixel755 pixel756 pixel757 pixel758 pixel759 \\\n",
742 | "0 0 0 0 0 0 0 0 \n",
743 | "1 0 0 0 0 0 0 0 \n",
744 | "2 0 0 0 0 0 0 0 \n",
745 | "3 0 0 0 0 0 0 0 \n",
746 | "4 0 0 0 0 0 0 0 \n",
747 | "\n",
748 | " pixel760 pixel761 pixel762 pixel763 pixel764 pixel765 pixel766 \\\n",
749 | "0 0 0 0 0 0 0 0 \n",
750 | "1 0 0 0 0 0 0 0 \n",
751 | "2 0 0 0 0 0 0 0 \n",
752 | "3 0 0 0 0 0 0 0 \n",
753 | "4 0 0 0 0 0 0 0 \n",
754 | "\n",
755 | " pixel767 pixel768 pixel769 pixel770 pixel771 pixel772 pixel773 \\\n",
756 | "0 0 0 0 0 0 0 0 \n",
757 | "1 0 0 0 0 0 0 0 \n",
758 | "2 0 0 0 0 0 0 0 \n",
759 | "3 0 0 0 0 0 0 0 \n",
760 | "4 0 0 0 0 0 0 0 \n",
761 | "\n",
762 | " pixel774 pixel775 pixel776 pixel777 pixel778 pixel779 pixel780 \\\n",
763 | "0 0 0 0 0 0 0 0 \n",
764 | "1 0 0 0 0 0 0 0 \n",
765 | "2 0 0 0 0 0 0 0 \n",
766 | "3 0 0 0 0 0 0 0 \n",
767 | "4 0 0 0 0 0 0 0 \n",
768 | "\n",
769 | " pixel781 pixel782 pixel783 \n",
770 | "0 0 0 0 \n",
771 | "1 0 0 0 \n",
772 | "2 0 0 0 \n",
773 | "3 0 0 0 \n",
774 | "4 0 0 0 \n",
775 | "\n",
776 | "[5 rows x 785 columns]"
777 | ]
778 | },
779 | "execution_count": 3,
780 | "metadata": {},
781 | "output_type": "execute_result"
782 | }
783 | ],
784 | "source": [
785 | "mnist_train.head()"
786 | ]
787 | },
788 | {
789 | "cell_type": "code",
790 | "execution_count": null,
791 | "metadata": {
792 | "collapsed": true
793 | },
794 | "outputs": [],
795 | "source": [
796 | "X = mnist_train.ix[mnist_train.label == 1 or mnist_train.label == 0]\n"
797 | ]
798 | }
799 | ],
800 | "metadata": {
801 | "kernelspec": {
802 | "display_name": "Python 2",
803 | "language": "python",
804 | "name": "python2"
805 | },
806 | "language_info": {
807 | "codemirror_mode": {
808 | "name": "ipython",
809 | "version": 2
810 | },
811 | "file_extension": ".py",
812 | "mimetype": "text/x-python",
813 | "name": "python",
814 | "nbconvert_exporter": "python",
815 | "pygments_lexer": "ipython2",
816 | "version": "2.7.13"
817 | }
818 | },
819 | "nbformat": 4,
820 | "nbformat_minor": 2
821 | }
822 |
--------------------------------------------------------------------------------
/py_lightgbm/testes/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/9 09:33
9 | @change_time:
10 | 1.2017/7/9 09:33
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/testes/test_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: test_dataset.py
8 | @time: 2017/7/9 16:14
9 | @change_time:
10 | 1.2017/7/9 16:14
11 | """
12 | import py_lightgbm as lgb
13 | import numpy as np
14 | import random
15 |
16 |
17 | def main():
18 | params = {
19 | "max_bin": 10,
20 | }
21 | clf = lgb.LGBMClassifier(**params)
22 | X_train = np.zeros((100, 2))
23 | y_train = np.zeros((100, 1))
24 | for i in xrange(100):
25 | X_train[i, 0] = random.randint(0, 10)
26 | X_train[i, 1] = random.randint(0, 20)
27 |
28 | clf.fit(X_train, y_train)
29 | clf.print_bin_mappers()
30 | return
31 |
32 |
33 | if __name__ == "__main__":
34 | main()
35 |
--------------------------------------------------------------------------------
/py_lightgbm/testes/test_lightgbm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: test_lightgbm.py
8 | @time: 2017/7/9 09:33
9 | @change_time:
10 | 1.2017/7/9 09:33
11 | """
12 | import pandas as pd
13 | import numpy as np
14 | import lightgbm as lgb
15 | from sklearn import model_selection
16 |
17 | DATA_PATH = "../../data/mnist_train.csv"
18 |
19 |
20 | def main():
21 | mnist_train = pd.read_csv(DATA_PATH)
22 | X = mnist_train[(mnist_train.label == 6) | (mnist_train.label == 8)]
23 | y = X['label'].values
24 | X = X.drop('label', axis=1).values
25 |
26 | X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.33, random_state=42)
27 |
28 | params = {
29 | 'boosting_type': 'gbdt',
30 | 'objective': 'binary',
31 | 'num_leaves': 50,
32 | 'max_depth': 10,
33 | 'learning_rate': 0.1,
34 | # 'reg_lambda': 0.7,
35 | 'n_estimators': 100,
36 | # 'silent': True
37 | }
38 |
39 | print('light GBM train :-)')
40 | clf = lgb.LGBMClassifier(**params)
41 | result_base = clf.fit(X_train, y_train)
42 | score = result_base.score(X_test, y_test)
43 | print "score:", score
44 |
45 |
46 | if __name__ == '__main__':
47 | main()
48 |
--------------------------------------------------------------------------------
/py_lightgbm/testes/test_py_lightgbm.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: test_py_lightgbm.py
8 | @time: 2017/7/9 10:34
9 | @change_time:
10 | 1.2017/7/9 10:34
11 | """
12 |
13 | import pandas as pd
14 | import py_lightgbm as lgb
15 | from sklearn import model_selection
16 | from collections import Counter
17 | from sklearn import metrics
18 | import profile
19 | import time
20 | from py_lightgbm.logmanager import logger
21 |
22 | DATA_PATH = "../../data/mnist_train.csv"
23 |
24 | _LOGGER = logger.get_logger("Test")
25 |
26 |
27 | def main():
28 | mnist_train = pd.read_csv(DATA_PATH, nrows=1000) # train data with 1000 rows, it cost 1.6s, all data will cost 70s
29 | X = mnist_train[(mnist_train.label == 6) | (mnist_train.label == 8)]
30 | y = X['label'].values
31 | X = X.drop('label', axis=1).values
32 |
33 | X_train, X_test, y_train, y_test = model_selection.train_test_split(X, y, test_size=0.33, random_state=42)
34 |
35 | params = {
36 | 'boosting_type': 'gbdt',
37 | 'objective': 'binary',
38 | 'num_leaves': 50,
39 | 'max_depth': 10,
40 | 'learning_rate': 0.1,
41 | 'reg_lambda': 1.0,
42 | 'reg_alpha': 1.0,
43 | 'n_estimators': 2,
44 | 'min_child_samples': 20,
45 | # 'silent': True
46 | }
47 |
48 | _LOGGER.critical('light GBM train :-)')
49 | clf = lgb.LGBMClassifier(**params)
50 | _LOGGER.critical("data_shape: {0}, {1}".format(X_train.shape, Counter(y_train)))
51 | _LOGGER.info("y_test:{0}".format(y_test))
52 |
53 | profile.enable_profile()
54 | timestamp_start = time.time()
55 | _LOGGER.critical("starting profile")
56 | clf.fit(X_train, y_train)
57 | # clf.fit(X_train, y_train, categorical_feature=range(X_train.shape[1]))
58 | # clf.show()
59 | y_predict = clf.predict_proba(X_test)
60 | _LOGGER.info(y_predict)
61 | score = metrics.accuracy_score(y_test, y_predict)
62 | _LOGGER.critical("score:{0}".format(score))
63 | profile.close_profile()
64 | timestamp_end = time.time()
65 | _LOGGER.critical("finish profile:{0}".format(timestamp_end - timestamp_start))
66 |
67 |
68 | if __name__ == '__main__':
69 | main()
70 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/1 11:28
9 | @change_time:
10 | 1.2017/7/1 11:28
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/data_partition.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: data_partition.py
8 | @time: 2017/7/10 21:08
9 | @change_time:
10 | 1.2017/7/10 21:08
11 | """
12 |
13 |
14 | class DataPartition(object):
15 |
16 | def __init__(self, num_data, num_leaves):
17 | self.num_data = num_data
18 | self.num_leaves = num_leaves
19 |
20 | # the three list below used for data_partition
21 | self._leaf_begin = None
22 | self._leaf_count = None
23 | self._indices = None
24 |
25 | self._temp_left_indices = None
26 | self._temp_right_indices = None
27 | return
28 |
29 | def init(self):
30 | self._leaf_begin = [0] * self.num_leaves
31 | self._leaf_count = [0] * self.num_leaves
32 | self._indices = range(self.num_data)
33 |
34 | self._temp_left_indices = [0] * self.num_data
35 | self._temp_right_indices = [0] * self.num_data
36 | self._leaf_count[0] = self.num_data
37 | return
38 |
39 | def __str__(self):
40 | leaf_desc = "leaf_begin:{lb}\nleaf_count:{lc}\nindices:{ind}\n".format(
41 | lb=self._leaf_begin,
42 | lc=self._leaf_count,
43 | ind=self._indices,
44 | )
45 | return leaf_desc
46 |
47 | def counts_of_leaf(self, leaf):
48 | return self._leaf_count[leaf]
49 |
50 | def get_indices_of_leaf(self, leaf):
51 | begin = self._leaf_begin[leaf]
52 | cnt = self._leaf_count[leaf]
53 |
54 | return self._indices[begin:begin+cnt]
55 |
56 | def split(self, left_leaf, train_data, feature_index, threshold_bin, right_leaf):
57 |
58 | begin = self._leaf_begin[left_leaf]
59 | cnt = self._leaf_count[left_leaf]
60 |
61 | # 将这个leaf划分为两份
62 | left_indices, right_indices = train_data.split(feature_index, threshold_bin, self._indices, begin, begin + cnt)
63 | left_cnt = len(left_indices)
64 |
65 | self._indices[begin:begin+left_cnt] = left_indices
66 | self._indices[begin+left_cnt:begin+cnt] = right_indices
67 | # 增加新的leaf的split信息
68 |
69 | self._leaf_count[left_leaf] = left_cnt
70 |
71 | self._leaf_begin[right_leaf] = begin + left_cnt
72 | self._leaf_count[right_leaf] = cnt - left_cnt
73 |
74 | return
75 |
76 |
77 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/feature_histogram.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: feature_histogram.py
8 | @time: 2017/7/2 21:49
9 | @change_time:
10 | 1.2017/7/2 21:49 构造featureHistogram结构
11 | """
12 | from py_lightgbm.tree.split_info import SplitInfo
13 | from py_lightgbm.utils import const
14 |
15 |
16 | class FeatureEntryMeta(object):
17 | def __init__(self):
18 | self.sum_hessians = 0
19 | self.sum_gradients = 0
20 | self.cnt = 0
21 | return
22 |
23 | def __str__(self):
24 | repr_str = "cnt:{0}, h:{1}, g:{2}".format(
25 | self.cnt,
26 | self.sum_hessians,
27 | self.sum_gradients,
28 | )
29 | return repr_str
30 |
31 | def __repr__(self):
32 | return self.__str__()
33 |
34 |
35 | class FeatureHistogram(object):
36 | """
37 | used to construct and store a histogram for a feature.
38 | """
39 |
40 | def __init__(self, feature_index, bin_mapper):
41 | self._meta = None
42 | self._is_splittable = True
43 | self.find_best_threshold_fun_ = self.find_best_threshold_sequence
44 |
45 | self._feature_index = feature_index
46 | self._bin_mapper = bin_mapper
47 |
48 | self._bin_entry = [FeatureEntryMeta() for x in xrange(len(self._bin_mapper))]
49 |
50 | self._min_gain_split = 0.01 # TODO: set min_gain_split
51 |
52 | self._tree_config = None
53 | return
54 |
55 | def __str__(self):
56 | repr_str = "index:{0}, bin_mapper:{1}".format(self._feature_index, self._bin_mapper)
57 | return repr_str
58 |
59 | def __repr__(self):
60 | return self.__str__()
61 |
62 | def init(self, train_X, data_indices, ordered_gradients, ordered_hessians, tree_config):
63 | # build feature histogram
64 | for data_index in data_indices:
65 | value = train_X[data_index, self._feature_index]
66 | bin = self._bin_mapper.find_bin_idx(value)
67 | if bin < 0:
68 | continue
69 |
70 | self._bin_entry[bin].sum_gradients += ordered_gradients[data_index]
71 | self._bin_entry[bin].sum_hessians += ordered_hessians[data_index]
72 | self._bin_entry[bin].cnt += 1
73 |
74 | self._tree_config = tree_config
75 | return
76 |
77 | def __sub__(self, other):
78 | return
79 |
80 | def find_best_threshold(self, sum_gradient, sum_hessian, num_data):
81 | # 根据best_threshold找到最佳值
82 | split_info = SplitInfo()
83 |
84 | best_sum_left_gradients = 0
85 | best_sum_left_hessians = const.Epsion
86 | best_left_count = 0
87 | best_threshold_bin = 0
88 | best_threshold = float("inf")
89 |
90 | sum_left_gradients = 0
91 | sum_left_hessians = const.Epsion
92 | left_count = 0
93 |
94 | min_gain_shift = self._tree_config.min_split_gain + self.get_leaf_split_gain(sum_left_gradients, sum_left_hessians, self._tree_config.reg_alpha, self._tree_config.reg_lambda) +\
95 | self.get_leaf_split_gain(sum_gradient, sum_hessian, self._tree_config.reg_alpha, self._tree_config.reg_lambda)
96 | best_gain = min_gain_shift
97 |
98 | for bin in xrange(len(self._bin_entry)):
99 | sum_left_gradients += self._bin_entry[bin].sum_gradients
100 | sum_left_hessians += self._bin_entry[bin].sum_hessians
101 | left_count += self._bin_entry[bin].cnt
102 |
103 | sum_right_gradients = sum_gradient - sum_left_gradients
104 | sum_right_hessians = sum_hessian - sum_left_hessians
105 |
106 | current_gain = self.get_leaf_split_gain(sum_left_gradients, sum_left_hessians, self._tree_config.reg_alpha, self._tree_config.reg_lambda) +\
107 | self.get_leaf_split_gain(sum_right_gradients, sum_right_hessians, self._tree_config.reg_alpha, self._tree_config.reg_lambda)
108 |
109 | if current_gain > best_gain and left_count > self._tree_config.min_child_samples:
110 | best_sum_left_gradients = sum_left_gradients
111 | best_sum_left_hessians = sum_left_hessians
112 | best_gain = current_gain
113 | best_left_count = left_count
114 | best_threshold_bin = bin
115 | best_threshold = self._bin_mapper.upper_at(bin)
116 |
117 | split_info.threshold_bin = best_threshold_bin
118 | split_info.threshold = best_threshold
119 | split_info.feature_index = self._feature_index
120 | split_info.gain = best_gain
121 |
122 | split_info.left_count = best_left_count
123 | split_info.right_count = num_data - best_left_count
124 |
125 | split_info.left_sum_gradients = best_sum_left_gradients
126 | split_info.left_sum_hessians = best_sum_left_hessians
127 | split_info.right_sum_gradients = sum_gradient - best_sum_left_gradients
128 | split_info.right_sum_hessians = sum_hessian - best_sum_left_hessians
129 |
130 | split_info.left_output = self.get_splitted_leaf_output(
131 | best_sum_left_gradients,
132 | best_sum_left_hessians,
133 | self._tree_config.reg_alpha,
134 | self._tree_config.reg_lambda
135 | )
136 | split_info.right_output = self.get_splitted_leaf_output(
137 | sum_gradient - best_sum_left_gradients,
138 | sum_hessian - best_sum_left_hessians,
139 | self._tree_config.reg_alpha,
140 | self._tree_config.reg_lambda,
141 | )
142 |
143 | return split_info
144 |
145 | def find_best_threshold_numerical(self, sum_gradient, sum_hessian, num_data):
146 | return
147 |
148 | def find_best_threshold_categorical(self, sum_gradient, sum_hessian, num_data):
149 | return
150 |
151 | def get_leaf_split_gain(self, sum_gradient, sum_hessian, l1, l2):
152 | abs_sum_gradients = abs(sum_gradient)
153 | reg_abs_sum_gradients = max(0.0, abs_sum_gradients - l1)
154 | return (reg_abs_sum_gradients * reg_abs_sum_gradients) / (sum_hessian + l2)
155 |
156 | def get_splitted_leaf_output(self, sum_gradient, sum_hessian, l1, l2):
157 | abs_sum_gradients = abs(sum_gradient)
158 | reg_abs_sum_gradients = max(0.0, abs_sum_gradients - l1)
159 | if sum_gradient > 0:
160 | return -reg_abs_sum_gradients / (sum_hessian + l2)
161 | else:
162 | return reg_abs_sum_gradients / (sum_hessian + l2)
163 |
164 | def find_best_threshold_sequence(self, sum_gradient, sum_hessian, num_data, min_gain_shift):
165 | pass
166 |
167 |
168 | class HistogramPool(object):
169 | def __init__(self):
170 | self._pool = []
171 | self._data = []
172 | self._feature_metas = []
173 |
174 | self._cache_size = 0
175 | self._total_size = 0
176 | self._is_enough = False
177 |
178 | self._mapper = []
179 | self._inverse_mapper = []
180 | self._last_used_time = []
181 | self._cur_time = 0
182 | return
183 |
184 | def move(self):
185 | return
186 |
187 | def get(self):
188 | return
189 |
190 | def reset_config(self):
191 | return
192 |
193 | def dynamic_change_size(self, train_data, tree_config, cache_size, total_size):
194 | return
195 |
196 | def reset_map(self):
197 | return
198 |
199 | def reset(self):
200 | return
201 |
202 |
203 | if __name__ == '__main__':
204 | pass
205 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/leaf_splits.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: leaf_splits.py
8 | @time: 2017/7/10 21:30
9 | @change_time:
10 | 1.2017/7/10 21:30
11 | """
12 |
13 |
14 | class LeafSplits(object):
15 | def __init__(self, num_data):
16 | self.leaf_index = -1
17 | self.num_data_in_leaf = num_data
18 | self.num_data = num_data
19 | self.sum_gradients = 0
20 | self.sum_hessians = 0
21 | self.data_indices = range(num_data)
22 | return
23 |
24 | def __str__(self):
25 | repr_str = "leaf_index:{0}, num_data:{1}, sum_gradients:{2}, sum_hessians:{3}\nindices{4}".format(
26 | self.leaf_index,
27 | self.num_data_in_leaf,
28 | self.sum_gradients,
29 | self.sum_hessians,
30 | self.data_indices,
31 | )
32 | return repr_str
33 |
34 | def init(self, gradients, hessians):
35 | self.leaf_index = 0
36 | self.sum_gradients = sum(gradients)
37 | self.sum_hessians = sum(hessians)
38 | return
39 |
40 | def init_with_data_partition(self, leaf, data_partition, gradients, hessians):
41 | self.leaf_index = leaf
42 |
43 | self.data_indices = data_partition.get_indices_of_leaf(leaf)
44 | self.num_data_in_leaf = len(self.data_indices)
45 |
46 | self.sum_gradients = 0
47 | self.sum_hessians = 0
48 |
49 | for index in self.data_indices:
50 | self.sum_gradients += gradients[index]
51 | self.sum_hessians += hessians[index]
52 | return
53 |
54 | def reset(self):
55 | self.leaf_index = -1
56 | self.sum_gradients = 0
57 | self.sum_hessians = 0
58 | self.data_indices = []
59 | return
60 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/split_info.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: split_info.py
8 | @time: 2017/7/9 22:04
9 | @change_time:
10 | 1.2017/7/9 22:04
11 | """
12 |
13 |
14 | class SplitInfo(object):
15 | def __init__(self):
16 | self.feature_index = -1
17 | self.threshold_bin = -1
18 | self.threshold = float("inf")
19 |
20 | self.left_output = 0
21 | self.right_output = 0
22 | self.gain = 0
23 |
24 | self.left_count = 0
25 | self.right_count = 0
26 |
27 | self.left_sum_gradients = 0
28 | self.left_sum_hessians = 0
29 | self.right_sum_gradients = 0
30 | self.right_sum_hessians = 0
31 | return
32 |
33 | def __str__(self):
34 | repr_str = "index:{0};bin:{1};gain:{2};lc:{3};rc:{4}, thres:{5}".format(
35 | self.feature_index,
36 | self.threshold_bin,
37 | self.gain,
38 | self.left_count,
39 | self.right_count,
40 | self.threshold,
41 | )
42 | return repr_str
43 |
44 | def __repr__(self):
45 | return self.__str__()
46 |
47 | def reset(self):
48 | self.feature_index = -1
49 | self.threshold_bin = -1
50 | self.threshold = float("inf")
51 |
52 | self.left_output = 0
53 | self.right_output = 0
54 | self.gain = 0
55 |
56 | self.left_count = 0
57 | self.right_count = 0
58 |
59 | self.left_sum_gradients = 0
60 | self.left_sum_hessians = 0
61 | self.right_sum_gradients = 0
62 | self.right_sum_hessians = 0
63 | return
64 |
65 |
66 | if __name__ == '__main__':
67 | pass
68 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/tree.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: tree.py
8 | @time: 2017/7/2 15:56
9 | @change_time:
10 | 1.2017/7/2 15:56
11 | """
12 | import numpy as np
13 |
14 |
15 | class Tree(object):
16 |
17 | def __init__(self, max_num_leaves):
18 | self.max_leaves = max_num_leaves
19 |
20 | self.num_leaves = 0
21 |
22 | # used fo non-leaf node
23 | self.left_child = [-1] * (self.max_leaves - 1)
24 | self.right_child = [-1] * (self.max_leaves - 1)
25 | self.split_feature_index = [-1] * (self.max_leaves - 1)
26 | self.threshold_in_bin = [-1] * (self.max_leaves - 1) # threshold in bin
27 | self.threshold = [-1] * (self.max_leaves - 1)
28 | self.split_gain = [0] * (self.max_leaves - 1)
29 |
30 | self.internal_values = [0] * (self.max_leaves - 1)
31 | self.internal_counts = [0] * (self.max_leaves - 1)
32 |
33 | # used for leaf node
34 | self.leaf_parent = [-1] * self.max_leaves
35 | self.leaf_values = [0] * self.max_leaves
36 | self.leaf_counts = [0] * self.max_leaves
37 |
38 | self.leaf_depth = [0] * self.max_leaves # depth for leaves
39 |
40 | # root is in the depth 0
41 | self.leaf_depth[0] = 0
42 | self.num_leaves = 1
43 | self.leaf_parent[0] = -1
44 | return
45 |
46 | def counts_of_leaf(self, leaf):
47 | return self.leaf_counts[leaf]
48 |
49 | def output_of_leaf(self, leaf):
50 | return self.leaf_values[leaf]
51 |
52 | def split(self, leaf, split_info):
53 | """
54 | :param best_leaf: the index of best leaf
55 | :param best_split_info: split info
56 | :return:
57 | """
58 | new_node_idx = self.num_leaves - 1
59 |
60 | parent = self.leaf_parent[leaf]
61 | if parent >= 0:
62 | if self.left_child[parent] == -leaf - 1:
63 | self.left_child[parent] = new_node_idx
64 | else:
65 | self.right_child[parent] = new_node_idx
66 |
67 | self.split_feature_index[new_node_idx] = split_info.feature_index
68 | self.threshold_in_bin[new_node_idx] = split_info.threshold_bin
69 | self.threshold[new_node_idx] = split_info.threshold
70 | self.split_gain[new_node_idx] = split_info.gain
71 |
72 | self.left_child[new_node_idx] = -leaf - 1
73 | self.right_child[new_node_idx] = -self.num_leaves - 1
74 |
75 | self.leaf_parent[leaf] = new_node_idx
76 | self.leaf_parent[self.num_leaves] = new_node_idx
77 |
78 | self.internal_values[new_node_idx] = self.leaf_values[leaf]
79 | self.internal_counts[new_node_idx] = self.leaf_counts[leaf]
80 |
81 | self.leaf_values[leaf] = split_info.left_output
82 | self.leaf_values[self.num_leaves] = split_info.right_output
83 |
84 | self.leaf_counts[leaf] = split_info.left_count
85 | self.leaf_counts[self.num_leaves] = split_info.right_count
86 |
87 | self.leaf_depth[self.num_leaves] = self.leaf_depth[leaf] + 1
88 | self.leaf_depth[leaf] += 1
89 |
90 | self.num_leaves += 1
91 | return self.num_leaves - 1
92 |
93 | def depth_of_leaf(self, leaf):
94 | return self.leaf_depth[leaf]
95 |
96 | def predict_prob(self, X):
97 | num_data, num_feature = X.shape
98 |
99 | predict_y = np.zeros((num_data,))
100 |
101 | for num_idx in xrange(num_data):
102 | predict_y[num_idx] = self.predict(X[num_idx, :])
103 |
104 | return predict_y
105 |
106 | def predict(self, feature_values):
107 | """
108 | 根据当前的prediction对应到相应的子节点,根据子节点中数据的比例,得到相应的结果
109 | 遍历树对应到相应的节点, 根据self.split_feature_index和self.threshold来划分到他的左右子节点,直到其中的节点是负数,即叶子节点为止
110 | :return:
111 | """
112 | score = 0.0
113 | current_node = 0
114 |
115 | while current_node >= 0:
116 | left_child_node = self.left_child[current_node]
117 | right_child_node = self.right_child[current_node]
118 |
119 | feature_index = self.split_feature_index[current_node]
120 | if feature_index < 0:
121 | break
122 |
123 | feature_value = feature_values[feature_index]
124 | threshold = self.threshold[current_node]
125 |
126 | if feature_value <= threshold:
127 | current_node = left_child_node
128 | else:
129 | current_node = right_child_node
130 |
131 | if current_node < 0:
132 | score = self.leaf_values[~current_node]
133 | break
134 |
135 | return score
136 |
137 | def show(self):
138 | print "start-------------------"
139 | print "left_child", self.left_child
140 | print "right_child", self.right_child
141 | print "leaf_parent", self.leaf_parent
142 | print "split_feature_index", self.split_feature_index
143 | print "threshold_in_bin", self.threshold_in_bin
144 | print "threshold", self.threshold
145 | print "split_gain", self.split_gain
146 | print "interval_values", self.internal_values
147 | print "internal_counts", self.internal_counts
148 | print "leaf_counts", self.leaf_counts
149 | print "leaf_values", self.leaf_values
150 | print "leaf_depth", self.leaf_depth
151 | print "end----------------------"
152 | return
153 |
154 |
155 | if __name__ == '__main__':
156 | pass
157 |
--------------------------------------------------------------------------------
/py_lightgbm/tree/tree_learner.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: tree.py
8 | @time: 2017/7/1 11:26
9 | @change_time:
10 | 1.2017/7/1 11:26
11 | """
12 |
13 | import copy
14 |
15 |
16 | from py_lightgbm.tree.tree import Tree
17 | from py_lightgbm.tree.split_info import SplitInfo
18 | from py_lightgbm.tree.leaf_splits import LeafSplits
19 | from py_lightgbm.tree.data_partition import DataPartition
20 | from py_lightgbm.utils import const
21 |
22 | from py_lightgbm.logmanager import logger
23 | from py_lightgbm.utils import conf
24 |
25 |
26 | _LOGGER = logger.get_logger("TreeLearner")
27 |
28 |
29 | class TreeLearner(object):
30 | def __init__(self, tree_config, train_data):
31 | self._gradients = None
32 | self._hessians = None
33 | self._histogram_pool = None
34 | self._train_data = train_data
35 | self._max_cache_size = None
36 | self._tree_config = tree_config
37 | self._num_leaves = tree_config.num_leaves
38 | self._num_features = self._train_data.num_features
39 | self._num_data = self._train_data.num_data
40 |
41 | self._smaller_leaf_histogram_array = []
42 | self._larger_leaf_histogram_array = []
43 | self._best_split_per_leaf = None # store all the split info
44 |
45 | self._smaller_leaf_split = LeafSplits(self._num_data) # store the best splits for this leaf at smaller leaf
46 | self._larger_leaf_split = LeafSplits(self._num_data)
47 | self._data_partition = None
48 | self.init()
49 | return
50 |
51 | def init(self):
52 | # self._histogram_pool.DynamicChangeSize(self._train_data, self._tree_config,
53 | # self._max_cache_size, self._tree_config.num_leaves)
54 |
55 | self._best_split_per_leaf = [SplitInfo() for _ in xrange(self._num_leaves)]
56 | self._data_partition = DataPartition(self._num_data, self._num_leaves)
57 | return
58 |
59 | def train(self, gradients, hessians):
60 | self._gradients = gradients
61 | self._hessians = hessians
62 |
63 | self.before_train()
64 |
65 | new_tree = Tree(self._num_leaves)
66 |
67 | left_leaf = 0
68 | right_leaf = -1
69 | cur_depth = 1
70 |
71 | # 增加重要日志信息
72 | for split in xrange(self._num_leaves - 1):
73 | # print "current split num_leave:", split
74 |
75 | if self.before_find_best_leave(new_tree, left_leaf, right_leaf): # 检查数据
76 | self.log_before_split()
77 | self.find_best_splits()
78 |
79 | best_leaf = self.get_max_gain()
80 | if best_leaf is None:
81 | break
82 |
83 | self.log_split()
84 | left_leaf, right_leaf = self.split(new_tree, best_leaf)
85 | self.log_after_split()
86 |
87 | cur_depth = max(cur_depth, new_tree.depth_of_leaf(left_leaf))
88 | return new_tree
89 |
90 | def get_indices_of_leaf(self, idx):
91 | return self._data_partition.get_indices_of_leaf(idx)
92 |
93 | def get_max_gain(self):
94 | best_leaf = None
95 | current_gain = 0.0
96 | for leaf, split_info in enumerate(self._best_split_per_leaf):
97 | if split_info.gain > current_gain:
98 | best_leaf = leaf
99 | current_gain = split_info.gain
100 |
101 | return best_leaf
102 |
103 | def before_train(self):
104 | # self._histogram_pool.resetMap()
105 |
106 | # data_partition等元素
107 | self._data_partition.init()
108 |
109 | # 初始化smaller_leaf_split等
110 | for i in xrange(self._num_leaves):
111 | self._best_split_per_leaf[i].reset()
112 |
113 | self._smaller_leaf_split.init(self._gradients, self._hessians)
114 | self._larger_leaf_split.reset()
115 |
116 | return
117 |
118 | def find_best_splits(self):
119 | # 根据当前的feature发现最佳的切割点
120 | is_feature_used = [True] * self._num_features
121 |
122 | self.construct_histograms(is_feature_used)
123 | self.find_best_split_from_histograms(is_feature_used)
124 | return
125 |
126 | def construct_histograms(self, is_feature_used):
127 | # construct smaller leaf
128 | self._smaller_leaf_histogram_array = self._train_data.construct_histograms(
129 | is_feature_used,
130 | self._smaller_leaf_split.data_indices,
131 | self._smaller_leaf_split.leaf_index,
132 | self._gradients,
133 | self._hessians,
134 | )
135 |
136 | # construct larger leaf
137 | self._larger_leaf_histogram_array = self._train_data.construct_histograms(
138 | is_feature_used,
139 | self._larger_leaf_split.data_indices,
140 | self._larger_leaf_split.leaf_index,
141 | self._gradients,
142 | self._hessians,
143 | )
144 | return
145 |
146 | def find_best_split_from_histograms(self, is_feature_used):
147 |
148 | smaller_best = SplitInfo()
149 | larger_best = SplitInfo()
150 |
151 | for feature_index in xrange(self._num_features):
152 | if not is_feature_used[feature_index]:
153 | continue
154 |
155 | # self._train_data.fix_histograms()
156 | if self._smaller_leaf_histogram_array:
157 | smaller_split = self._smaller_leaf_histogram_array[feature_index].find_best_threshold(
158 | self._smaller_leaf_split.sum_gradients,
159 | self._smaller_leaf_split.sum_hessians,
160 | self._smaller_leaf_split.num_data_in_leaf,
161 | )
162 |
163 | if smaller_split.gain > smaller_best.gain:
164 | smaller_best = copy.deepcopy(smaller_split)
165 |
166 | if self._larger_leaf_histogram_array:
167 | larger_split = self._larger_leaf_histogram_array[feature_index].find_best_threshold(
168 | self._larger_leaf_split.sum_gradients,
169 | self._larger_leaf_split.sum_hessians,
170 | self._larger_leaf_split.num_data_in_leaf,
171 | )
172 |
173 | if larger_split.gain > larger_best.gain:
174 | larger_best = copy.deepcopy(larger_split)
175 |
176 | if self._smaller_leaf_split.leaf_index >= 0:
177 | leaf = self._smaller_leaf_split.leaf_index
178 | self._best_split_per_leaf[leaf] = smaller_best
179 |
180 | if self._larger_leaf_split.leaf_index >= 0:
181 | leaf = self._larger_leaf_split.leaf_index
182 | self._best_split_per_leaf[leaf] = larger_best
183 | return
184 |
185 | def before_find_best_leave(self, new_tree, left_leaf, right_leaf):
186 | # max_depth
187 | if new_tree.depth_of_leaf(left_leaf) >= self._tree_config.max_depth:
188 | self._best_split_per_leaf[left_leaf].gain = const.MIN_SCORE
189 | if right_leaf >= 0:
190 | self._best_split_per_leaf[right_leaf].gain = const.MIN_SCORE
191 |
192 | return False
193 |
194 | # min_child_samples
195 | if self._data_partition.counts_of_leaf(left_leaf) < self._tree_config.min_child_samples:
196 | self._best_split_per_leaf[left_leaf].gain = const.MIN_SCORE
197 | if right_leaf >= 0:
198 | self._best_split_per_leaf[right_leaf].gain = const.MIN_SCORE
199 |
200 | return False
201 |
202 | # TODO: histogram pool
203 |
204 | return True
205 |
206 | def split(self, new_tree, best_leaf):
207 | left_leaf = best_leaf
208 | best_split_info = self._best_split_per_leaf[best_leaf]
209 |
210 | right_leaf = new_tree.split(
211 | best_leaf,
212 | best_split_info,
213 | )
214 |
215 | # data_partition
216 | self._data_partition.split(
217 | left_leaf,
218 | self._train_data,
219 | best_split_info.feature_index,
220 | best_split_info.threshold_bin,
221 | right_leaf
222 | )
223 |
224 | # init the leaves that used on next iteration: smaller_leaf_splits_ used for split
225 | if best_split_info.left_count > best_split_info.right_count:
226 | self._smaller_leaf_split.init_with_data_partition(
227 | left_leaf,
228 | self._data_partition,
229 | self._gradients,
230 | self._hessians,
231 | )
232 |
233 | self._larger_leaf_split.init_with_data_partition(
234 | right_leaf,
235 | self._data_partition,
236 | self._gradients,
237 | self._hessians,
238 | )
239 | else:
240 | self._larger_leaf_split.init_with_data_partition(
241 | left_leaf,
242 | self._data_partition,
243 | self._gradients,
244 | self._hessians,
245 | )
246 |
247 | self._smaller_leaf_split.init_with_data_partition(
248 | right_leaf,
249 | self._data_partition,
250 | self._gradients,
251 | self._hessians,
252 | )
253 | return left_leaf, right_leaf
254 |
255 | def can_log(self):
256 | return not conf.CONFIG_DEV
257 |
258 | def log_before_split(self):
259 | """
260 | 记录分割前情况
261 | """
262 | if not self.can_log():
263 | return
264 |
265 | # 1. 数据划分情况
266 | _LOGGER.info("log_before_split---------------------------------------------")
267 | _LOGGER.info("_data_partition:{0}".format(self._data_partition))
268 |
269 |
270 | # 2. 划分数据Histogram
271 | _LOGGER.info("smaller_leaf_split:{0}\nlarger_leaf_split:{1}\n".format(
272 | self._smaller_leaf_split,
273 | self._larger_leaf_split,
274 | ))
275 | # 3.
276 | _LOGGER.info("best_split:{0}".format(self._best_split_per_leaf))
277 |
278 | return
279 |
280 | def log_split(self):
281 | """
282 | 记录分割情况
283 | """
284 | if not self.can_log():
285 | return
286 |
287 | _LOGGER.info("log_split---------------------------------------------")
288 | # 2. 划分数据Histogram
289 | _LOGGER.info("smaller_leaf_split:{0}\nlarger_leaf_split{1}\n".format(
290 | self._smaller_leaf_split,
291 | self._larger_leaf_split
292 | ))
293 |
294 | # 4. histogram
295 | _LOGGER.info("smaller leaf histogram:{0}".format(self._smaller_leaf_histogram_array))
296 | _LOGGER.info("larger leaf histogram:{0}".format(self._larger_leaf_histogram_array))
297 |
298 | _LOGGER.info("best_split:{0}".format(self._best_split_per_leaf))
299 | _LOGGER.info("---------------------------------------------")
300 | return
301 |
302 | def log_after_split(self):
303 | """
304 | 记录分割后情况
305 | """
306 | if not self.can_log():
307 | return
308 |
309 | _LOGGER.info("log_after_split---------------------------------------------")
310 | _LOGGER.info("{0}".format(self._data_partition))
311 |
312 | _LOGGER.info("---------------------------------------------")
313 | return
314 |
315 |
316 | if __name__ == '__main__':
317 | pass
318 |
--------------------------------------------------------------------------------
/py_lightgbm/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: __init__.py.py
8 | @time: 2017/7/10 01:04
9 | @change_time:
10 | 1.2017/7/10 01:04
11 | """
12 |
13 | if __name__ == '__main__':
14 | pass
15 |
--------------------------------------------------------------------------------
/py_lightgbm/utils/conf.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 | # encoding: utf-8
3 |
4 |
5 | """
6 | @version: 1.0
7 | @author: xiaoqiangkx
8 | @file: conf.py
9 | @time: 2017/7/14 16:24
10 | @change_time:
11 | 1.2017/7/14 16:24
12 | """
13 |
14 | CONFIG_DEV = True # 开发环境
--------------------------------------------------------------------------------
/py_lightgbm/utils/const.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: const.py
8 | @time: 2017/7/10 01:04
9 | @change_time:
10 | 1.2017/7/10 01:04
11 | """
12 |
13 | MIN_SCORE = -float("inf")
14 |
15 | Epsion = 1e-15
16 |
17 | TYPE_CATEGORY = 1
18 | TYPE_NUMERICAL = 2
19 |
20 | CATEGORY_PERCENTAGE = 0.98 # 取98%的比例
21 |
22 | FEATURE_NAME_DEFAULT = "auto"
23 |
24 | DEFAULT_NUM_LEAVES = 50
25 | DEFAULT_MAX_DEPTH = 10
26 | DEFAULT_LEARNING_RATE = 0.1
27 | DEFAULT_NUM_ESTIMATORS = 10
28 | DEFAULT_MAX_BIN = 50
29 | DEFAULT_MIN_SPLIT_GAIN = 1e-3
30 | DEFAULT_REG_ALPHA = 0
31 | DEFAULT_REG_LAMBDA = 0
32 | DEFAULT_MIN_CHILD_SAMPLES = 20
33 |
34 |
35 | THREAD_NUM = 10
36 |
--------------------------------------------------------------------------------
/testes/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/13 21:53
11 | @change_time:
12 | 1.2017/3/13 21:53
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/testes/decision_tree_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: decision_tree_test.py
8 | @time: 2017/3/18 16:22
9 | @change_time:
10 | 1.2017/3/18 16:22
11 | """
12 | from algorithm import DecisionTree as DT
13 | import random
14 | import numpy as np
15 |
16 |
17 | def make_dataset(num, num_feature, category=3):
18 | """
19 | 1. make a dataset with the amount of num
20 | and the amount of num_feature features
21 | 2. every sample is a different category
22 | """
23 | data = np.random.randint(1, 5, size=(num, num_feature))
24 | target = np.random.randint(1, category, size=(1, num))[0]
25 | return data, target
26 |
27 |
28 | if __name__ == '__main__':
29 | # Make a simple Decision Tree and plot it
30 | num = 100
31 | num_feature = 10
32 | data, target = make_dataset(num, num_feature, category=10)
33 | # print target
34 | decision_tree = DT.DecisionTree()
35 | decision_tree.make_tree(data, target, choose_func=DT.CHOOSE_INFO_ENTROPY)
36 | decision_tree.show()
37 | dot_tree = decision_tree.save("test.dot")
38 | # print dot_tree.source
39 |
--------------------------------------------------------------------------------
/testes/logistic_test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: logistic_test.py
10 | @time: 2017/3/13 21:54
11 | @change_time:
12 | 1.2017/3/13 21:54
13 | """
14 | import unittest
15 | from utils import formula
16 | import numpy as np
17 |
18 |
19 | class TestCase(unittest.TestCase):
20 | def test_sigmoid(self):
21 | result = formula.sigmoid(np.array([[0, 0, 0]]).T, np.array([[0, 0, 0]]).T)
22 | self.assertAlmostEqual(result[[0]], 0.5)
23 | return
24 |
25 |
26 | if __name__ == '__main__':
27 | unittest.main()
28 |
--------------------------------------------------------------------------------
/tree/DecisionTreeClassifier.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: DecisionTreeClassifier.py
8 | @time: 2017/3/18 10:51
9 | @change_time:
10 | 1.2017/3/18 10:51
11 | """
12 | from algorithm import DecisionTree as DT
13 | import numpy as np
14 |
15 |
16 | class DecisionTreeClassifier(object):
17 |
18 | def __init__(self, depth=5):
19 | self.tree = DT.DecisionTree(depth=depth)
20 | return
21 |
22 | def fit(self, data, target, choose_func=DT.CHOOSE_INFO_ENTROPY):
23 | self.tree.make_tree(data, target, choose_func)
24 | return
25 |
26 | def show(self):
27 | self.tree.show()
28 | return
29 |
30 | def save(self, filename):
31 | self.tree.save(filename)
32 | return
33 |
34 | def predict(self, data):
35 | m, n = data.shape
36 | target = np.zeros(m)
37 | for idx, item in enumerate(data):
38 | result = self.tree.decide(item)
39 | target[idx] = result
40 |
41 | return target
42 |
--------------------------------------------------------------------------------
/tree/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/18 10:50
11 | @change_time:
12 | 1.2017/3/18 10:50
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: __init__.py.py
10 | @time: 2017/3/13 21:55
11 | @change_time:
12 | 1.2017/3/13 21:55
13 | """
14 |
15 | if __name__ == '__main__':
16 | pass
17 |
--------------------------------------------------------------------------------
/utils/cmd_table.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 |
4 | """
5 | @version: 1.0
6 | @author: clark
7 | @file: cmd_table.py
8 | @time: 2017/4/30 13:12
9 | @change_time:
10 | 1.2017/4/30 13:12
11 | """
12 |
13 | data = """
14 | PassengerId Survived Pclass Name Sex Age SibSp Parch Ticket Fare Cabin Embarked
15 | 0 1 0 3 Braund,Mr. Owen Harris male 22.0 1 0 A/5 21171 7.2500 NaN S
16 | 1 2 1 1 Cumings,Mrs. John Bradley (Florence Briggs Th...) female 38.0 1 0 PC 17599 71.2833 C85 C
17 | """
18 |
19 |
20 | if __name__ == '__main__':
21 | pass
22 |
--------------------------------------------------------------------------------
/utils/const.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: const.py
10 | @time: 2017/3/13 22:46
11 | @change_time:
12 | 1.2017/3/13 22:46
13 | """
14 |
15 | AXIS_COLUMN = 0 # 第一维: 矩阵列
16 | AXIS_ROW = 1 # 第二维: 矩阵行
17 |
--------------------------------------------------------------------------------
/utils/file_utils.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: file_utils.py
10 | @time: 2017/3/13 22:00
11 | @change_time:
12 | 1.2017/3/13 22:00
13 | """
14 |
15 | import pandas as pd
16 | import numpy as np
17 |
18 |
19 | def load_water_melon_data(filename, header='infer'):
20 | df = pd.read_csv(filename, header=header)
21 | m, n = df.shape
22 | X = df.values[:, :n-1]
23 | Y = df.values[:, n-1:n]
24 | return X.T, Y.T
25 |
26 |
27 | def load_iris_data(filename, header='infer'):
28 | df = pd.read_csv(filename, header=header)
29 | m, n = df.shape
30 | X = df.values[:, 1:n-1]
31 |
32 | Y = df.values[:, n-1:]
33 | for i in xrange(m):
34 | if Y[i, 0] == 'Iris-versicolor':
35 | Y[i, 0] = 1
36 | elif Y[i, 0] == 'Iris-virginica':
37 | Y[i, 0] = 0
38 | return X.T, Y.T
39 |
40 |
41 | def load_mnist_data(train_filename, test_filename, header="infer"):
42 | train_df = pd.read_csv(train_filename, header=header)
43 | test_df = pd.read_csv(test_filename, header=header)
44 |
45 | m, n = train_df.shape
46 | training_data = train_df.values[:, 1:n]
47 | training_target = train_df.values[:, 0].T
48 |
49 | m, n = test_df.shape
50 | test_data = test_df.values[:, 1:n]
51 | test_target = test_df.values[:, 0].T
52 |
53 | return training_data, training_target, test_data, test_target
54 |
55 |
56 | if __name__ == '__main__':
57 |
58 | filename = "../data/iris.csv"
59 | X, Y = load_iris_data(filename)
60 | print X, Y
61 |
--------------------------------------------------------------------------------
/utils/formula.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: formula.py
10 | @time: 2017/3/13 21:55
11 | @change_time:
12 | 1.2017/3/13 21:55
13 | """
14 | import numpy as np
15 | from utils import logger
16 | import operator
17 | from collections import Counter
18 |
19 |
20 | def sigmoid(x, w):
21 | return 1.0 / (1.0 + np.exp(np.dot(-w.T, x).astype(float)))
22 |
23 |
24 | def cal(result, test_target_data):
25 | data = (result == test_target_data)
26 | count = 0
27 | for elem in data[0]:
28 | if elem:
29 | count += 1
30 | amount = len(test_target_data[0])
31 | precision = float(count) / amount
32 | logger.info("count({0})/num({1}), precision:{2}".format(count, amount, precision))
33 | return precision
34 |
35 |
36 | def cal_new(result, test_target_data):
37 | data = (result == test_target_data)
38 | count = 0
39 | for elem in data:
40 | if elem:
41 | count += 1
42 | amount = len(test_target_data)
43 | precision = float(count) / amount
44 | logger.info("count({0})/num({1}), precision:{2}".format(count, amount, precision))
45 | return precision
46 |
47 |
48 | def plus_one(X):
49 | m, n = X.shape
50 | X = np.row_stack((X, np.ones(n)))
51 | return X
52 |
53 |
54 | def calculate_entropy(target, index_list):
55 | cnt = Counter()
56 | for index in index_list:
57 | cnt[target[index]] += 1
58 |
59 | result = 0
60 | for num in cnt.itervalues():
61 | prob = float(num)/len(index_list)
62 | result += prob * np.log(prob)
63 | return -result
64 |
65 | def calculate_gini(target, index_list):
66 | cnt = Counter()
67 | for index in index_list:
68 | cnt[target[index]] += 1
69 |
70 | result = 0
71 | for num in cnt.itervalues():
72 | prob = float(num)/len(index_list)
73 | result += prob * prob
74 | return 1 - result
75 |
76 |
77 | def calculate_iv(target_dict):
78 | result = 0
79 | total = sum([len(x) for x in target_dict.itervalues()])
80 | for index_list in target_dict.itervalues():
81 | prob = len(index_list) / float(total)
82 | result += prob * np.log(prob)
83 |
84 | return max(1, -result)
85 |
86 |
87 | # 仿照STL封装一些常用的operator方法,方便以后扩展 #
88 | class MyOerator(object):
89 | def __init__(self):
90 | pass
91 |
92 | class lte(MyOerator):
93 | def __init__(self, a):
94 | super(lte, self).__init__()
95 | self.a = a
96 | return
97 |
98 | def __call__(self, *args, **kwargs):
99 | arg_list = list(args)
100 | arg_list.append(self.a)
101 | return operator.le(*arg_list, **kwargs)
102 |
103 | def __repr__(self):
104 | return "lte({0})".format(self.a)
105 |
106 |
107 | class equal(MyOerator):
108 | def __init__(self, a):
109 | super(equal, self).__init__()
110 | self.a = a
111 | return
112 |
113 | def __call__(self, *args, **kwargs):
114 | arg_list = list(args)
115 | arg_list.append(self.a)
116 | return operator.eq(*arg_list, **kwargs)
117 |
118 | def __repr__(self):
119 | return "equal({0})".format(self.a)
120 |
121 |
122 | if __name__ == '__main__':
123 | a = {1: [2, 3], 2:[3, 4]}
124 | result = calculate_iv(a)
125 | print result
126 |
--------------------------------------------------------------------------------
/utils/logger.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: logger.py
10 | @time: 2017/3/15 23:22
11 | @change_time:
12 | 1.2017/3/15 23:22
13 | """
14 |
15 | import logging
16 | logging.basicConfig()
17 |
18 | _LOGGER = logging.getLogger("toyplay")
19 |
20 | LEVEL_NORMAL = logging.WARNING
21 | LEVEL_DEBUG = logging.INFO
22 |
23 | _LOGGER.setLevel(LEVEL_DEBUG)
24 |
25 |
26 | def info(msg):
27 | _LOGGER.info(msg)
28 | return
29 |
30 |
31 | def error(msg):
32 | _LOGGER.error(msg)
33 | return
34 |
35 |
36 | def setLevel(level):
37 | _LOGGER.setLevel(level)
--------------------------------------------------------------------------------
/utils/sample.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | #!/usr/bin/env python
3 | # encoding: utf-8
4 |
5 |
6 | """
7 | @version: 1.0
8 | @author: xiaoqiangkx
9 | @file: sample.py
10 | @time: 2017/3/15 21:32
11 | @change_time:
12 | 1.2017/3/15 21:32
13 | """
14 | import random
15 | from collections import defaultdict
16 |
17 |
18 | SAMPLE_AS_TEN = 10 # 10折交叉法
19 | SAMPLE_LEAVE_ONE = 1 # 留一法
20 |
21 |
22 | def iter_sample_data(X, Y, method=SAMPLE_AS_TEN):
23 | """
24 | 默认10折交叉法和留一法
25 | """
26 |
27 | m, n = X.shape
28 | col_index = range(n) # n列数据
29 | random.shuffle(col_index)
30 |
31 | for i in xrange(0, n, method):
32 | train_index = col_index[0: i] + col_index[i+method:]
33 | test_index = col_index[i:i + method]
34 | yield X[:, train_index], Y[:, train_index], X[:, test_index], Y[:, test_index]
35 | return
36 |
37 |
38 | def transform_target(target, train, cv, test):
39 | from collections import defaultdict
40 | import random
41 | import math
42 | classify = defaultdict(list)
43 | for index, cls in enumerate(target):
44 | classify[cls].append(index)
45 |
46 | train_index = []
47 | cv_index = []
48 | test_index = []
49 |
50 | for cls, index_list in classify.iteritems():
51 | length = len(index_list)
52 | if length <= 3:
53 | train_index.extend(index_list)
54 | continue
55 |
56 | train_num = max(1, int(math.floor(train * length)))
57 | cv_num = max(1, int(math.floor(cv * length)))
58 | test_num = length - train_num - cv_num
59 |
60 | random.shuffle(index_list)
61 | train_index.extend(index_list[0:train_num])
62 | cv_index.extend(index_list[train_num:train_num + cv_num])
63 | test_index.extend(index_list[train_num+cv_num:])
64 |
65 | return train_index, cv_index, test_index
66 |
67 |
68 | def sample_target_data(target, train=0.6, cv=0.2, test=0.2):
69 | """return index of feature"""
70 | sample_index_list = transform_target(target, train, cv, test)
71 | return sample_index_list
72 |
73 |
74 | if __name__ == '__main__':
75 | from utils import file_utils as FU
76 | filename = "../data/iris.csv"
77 | X, Y = FU.load_iris_data(filename)
78 | for train_x, train_y, test_x, test_y in iter_sample_data(X, Y, 50):
79 | print train_x, train_y, test_x, test_y
80 |
--------------------------------------------------------------------------------