├── README.md ├── algo ├── LogisticCircuit.py ├── LogisticRegression.py └── __init__.py ├── balanced.vtree ├── learn.py ├── pretrained_models └── mnist.circuit ├── structure ├── AndGate.py ├── CircuitNode.py ├── Vtree.py └── __init__.py └── util ├── DataSet.py ├── __init__.py ├── generate_balanced_vtree.py └── mnist_data.py /README.md: -------------------------------------------------------------------------------- 1 | # LogisticCircuit 2 | 3 | This repo contains the code to run experiments reported in the paper "Learning Logistic Circuits", published in AAAI 2019. 4 | 5 | This implementation has one subtle difference with the desciprtion in the paper: the weight of a leaf node. In the paper, we define the weight of a leaf as 0, whereas here we define it as the leaf node's parameter. In terms of representation power, the two definitions are equivalent. One can just add an OR gate on top of every original leaf node. Effectively, the original leaf nodes' parameters are pushed up to the wires between the new OR gates and the leaf nodes. By doing so, the leaf nodes' weights are cast to 0. 6 | 7 | We kindly include a help function in learn.py. In other words, to query what arguments are required and the detailed description what each argument is for, please execute "python3 learn.py --help" in your terminal. 8 | 9 | Note the default balanced.vtree is for MNIST and Fashion-MNIST. To run experiments on other datasets, a different vtree is necessary. As requested by some users, we include a small script (generate_balanced_vtree.py) in "util/" to generate balanced vtrees. The generated vtrees from this script are not optimized, and thus do not guarantee optimal performance. 10 | 11 | We now also support direct multi-class classification through one single circuit instead of resorting to multiple one-vs-all circuits. We parameterize the same circuit structure with n sets of parameters, each corresponding to one one-vs-all binary classification. 12 | 13 | For better reproducibility, we include the trained circuit that achieves the performance reported in our AAAI paper in the folder "pretrained_models". To achieve the optimal classification accuracy, after the structure learning process is finished, we re-learn the parameters with high l2 regularization, and pick the best set of learned parameters according to its result on the validation set. After loading the pretrained circuit, if one keeps running more iterations of parameter learning, he/she may observe a drop of classification accuracy. 14 | 15 | For questions, please don't hesitate to send us an email at yliang@cs.ucla.edu 16 | -------------------------------------------------------------------------------- /algo/LogisticCircuit.py: -------------------------------------------------------------------------------- 1 | import copy 2 | import gc 3 | from collections import deque 4 | 5 | import numpy as np 6 | 7 | from algo.LogisticRegression import LogisticRegression 8 | from structure.AndGate import AndGate 9 | from structure.CircuitNode import CircuitNode, OrGate, CircuitTerminal 10 | from structure.CircuitNode import LITERAL_IS_TRUE, LITERAL_IS_FALSE 11 | from structure.Vtree import Vtree 12 | 13 | FORMAT = """c variables (from inputs) start from 1 14 | c ids of logistic circuit nodes start from 0 15 | c nodes appear bottom-up, children before parents 16 | c the last line of the file records the bias parameter 17 | c three types of nodes: 18 | c T (terminal nodes that correspond to true literals) 19 | c F (terminal nodes that correspond to false literals) 20 | c D (OR gates) 21 | c 22 | c file syntax: 23 | c Logisitic Circuit 24 | c T id-of-true-literal-node id-of-vtree variable parameters 25 | c F id-of-false-literal-node id-of-vtree variable parameters 26 | c D id-of-or-gate id-of-vtree number-of-elements (id-of-prime id-of-sub parameters)s 27 | c B bias-parameters 28 | c 29 | """ 30 | 31 | 32 | class LogisticCircuit(object): 33 | def __init__(self, vtree, num_classes, circuit_file=None): 34 | self._vtree = vtree 35 | self._num_classes = num_classes 36 | self._largest_index = -1 37 | self._num_variables = vtree.var_count 38 | 39 | self._terminal_nodes = [None] * 2 * self._num_variables 40 | self._decision_nodes = None 41 | self._elements = None 42 | self._parameters = None 43 | self._bias = np.random.random_sample(size=(num_classes,)) 44 | 45 | if circuit_file is None: 46 | self._generate_all_terminal_nodes(vtree) 47 | self._root = self._new_logistic_psdd(vtree) 48 | else: 49 | self._root = self.load(circuit_file) 50 | 51 | self._serialize() 52 | 53 | @property 54 | def vtree(self): 55 | return self._vtree 56 | 57 | @property 58 | def num_parameters(self): 59 | return self._parameters.size 60 | 61 | @property 62 | def parameters(self): 63 | return self._parameters 64 | 65 | def _generate_all_terminal_nodes(self, vtree: Vtree): 66 | if vtree.is_leaf(): 67 | var_index = vtree.var 68 | self._largest_index += 1 69 | self._terminal_nodes[var_index - 1] = CircuitTerminal( 70 | self._largest_index, vtree, var_index, LITERAL_IS_TRUE, np.random.random_sample(size=(self._num_classes,)) 71 | ) 72 | self._largest_index += 1 73 | self._terminal_nodes[self._num_variables + var_index - 1] = CircuitTerminal( 74 | self._largest_index, vtree, var_index, LITERAL_IS_FALSE, np.random.random_sample(size=(self._num_classes,)) 75 | ) 76 | else: 77 | self._generate_all_terminal_nodes(vtree.left) 78 | self._generate_all_terminal_nodes(vtree.right) 79 | 80 | def _new_logistic_psdd(self, vtree) -> CircuitNode: 81 | left_vtree = vtree.left 82 | right_vtree = vtree.right 83 | prime_variable = left_vtree.var 84 | sub_variable = right_vtree.var 85 | elements = list() 86 | if left_vtree.is_leaf() and right_vtree.is_leaf(): 87 | elements.append( 88 | AndGate( 89 | self._terminal_nodes[prime_variable - 1], 90 | self._terminal_nodes[sub_variable - 1], 91 | np.random.random_sample(size=(self._num_classes,)), 92 | ) 93 | ) 94 | elements.append( 95 | AndGate( 96 | self._terminal_nodes[prime_variable - 1], 97 | self._terminal_nodes[self._num_variables + sub_variable - 1], 98 | np.random.random_sample(size=(self._num_classes,)), 99 | ) 100 | ) 101 | elements.append( 102 | AndGate( 103 | self._terminal_nodes[self._num_variables + prime_variable - 1], 104 | self._terminal_nodes[sub_variable - 1], 105 | np.random.random_sample(size=(self._num_classes,)), 106 | ) 107 | ) 108 | elements.append( 109 | AndGate( 110 | self._terminal_nodes[self._num_variables + prime_variable - 1], 111 | self._terminal_nodes[self._num_variables + sub_variable - 1], 112 | np.random.random_sample(size=(self._num_classes,)), 113 | ) 114 | ) 115 | elif left_vtree.is_leaf(): 116 | elements.append( 117 | AndGate( 118 | self._terminal_nodes[prime_variable - 1], 119 | self._new_logistic_psdd(right_vtree), 120 | np.random.random_sample(size=(self._num_classes,)), 121 | ) 122 | ) 123 | elements.append( 124 | AndGate( 125 | self._terminal_nodes[self._num_variables + prime_variable - 1], 126 | self._new_logistic_psdd(right_vtree), 127 | np.random.random_sample(size=(self._num_classes,)), 128 | ) 129 | ) 130 | for element in elements: 131 | element.splittable_variables = copy.deepcopy(right_vtree.variables) 132 | elif right_vtree.is_leaf(): 133 | elements.append( 134 | AndGate( 135 | self._new_logistic_psdd(left_vtree), 136 | self._terminal_nodes[sub_variable - 1], 137 | np.random.random_sample(size=(self._num_classes,)), 138 | ) 139 | ) 140 | elements.append( 141 | AndGate( 142 | self._new_logistic_psdd(left_vtree), 143 | self._terminal_nodes[self._num_variables + sub_variable - 1], 144 | np.random.random_sample(size=(self._num_classes,)), 145 | ) 146 | ) 147 | for element in elements: 148 | element.splittable_variables = copy.deepcopy(left_vtree.variables) 149 | else: 150 | elements.append( 151 | AndGate( 152 | self._new_logistic_psdd(left_vtree), 153 | self._new_logistic_psdd(right_vtree), 154 | np.random.random_sample(size=(self._num_classes,)), 155 | ) 156 | ) 157 | elements[0].splittable_variables = copy.deepcopy(vtree.variables) 158 | self._largest_index += 1 159 | root = OrGate(self._largest_index, vtree, elements) 160 | return root 161 | 162 | def _serialize(self): 163 | """Serialize all the decision nodes in the logistic psdd. 164 | Serialize all the elements in the logistic psdd. """ 165 | self._decision_nodes = [self._root] 166 | self._elements = [] 167 | decision_node_indices = set() 168 | decision_node_indices.add(self._root.index) 169 | unvisited = deque() 170 | unvisited.append(self._root) 171 | while len(unvisited) > 0: 172 | current = unvisited.popleft() 173 | for element in current.elements: 174 | self._elements.append(element) 175 | element.flag = False 176 | if isinstance(element.prime, OrGate) and element.prime.index not in decision_node_indices: 177 | decision_node_indices.add(element.prime.index) 178 | self._decision_nodes.append(element.prime) 179 | unvisited.append(element.prime) 180 | if isinstance(element.sub, OrGate) and element.sub.index not in decision_node_indices: 181 | decision_node_indices.add(element.sub.index) 182 | self._decision_nodes.append(element.sub) 183 | unvisited.append(element.sub) 184 | self._parameters = self._bias.reshape(-1, 1) 185 | for terminal_node in self._terminal_nodes: 186 | self._parameters = np.concatenate((self._parameters, terminal_node.parameter.reshape(-1, 1)), axis=1) 187 | for element in self._elements: 188 | self._parameters = np.concatenate((self._parameters, element.parameter.reshape(-1, 1)), axis=1) 189 | gc.collect() 190 | 191 | def _record_learned_parameters(self, parameters): 192 | self._parameters = copy.deepcopy(parameters) 193 | self._bias = self._parameters[:, 0] 194 | for i in range(len(self._terminal_nodes)): 195 | self._terminal_nodes[i].parameter = self._parameters[:, i + 1] 196 | for i in range(len(self._elements)): 197 | self._elements[i].parameter = self._parameters[:, i + 1 + 2 * self._num_variables] 198 | gc.collect() 199 | 200 | def calculate_features(self, images: np.array): 201 | num_images = images.shape[0] 202 | for terminal_node in self._terminal_nodes: 203 | terminal_node.calculate_prob(images) 204 | for decision_node in reversed(self._decision_nodes): 205 | decision_node.calculate_prob() 206 | self._root.feature = np.ones(shape=(num_images,), dtype=np.float32) 207 | for decision_node in self._decision_nodes: 208 | decision_node.calculate_feature() 209 | # bias feature 210 | bias_features = np.ones(shape=(num_images,), dtype=np.float32) 211 | terminal_node_features = np.vstack([terminal_node.feature for terminal_node in self._terminal_nodes]) 212 | element_features = np.vstack([element.feature for element in self._elements]) 213 | features = np.vstack((bias_features, terminal_node_features, element_features)) 214 | for terminal_node in self._terminal_nodes: 215 | terminal_node.feature = None 216 | terminal_node.prob = None 217 | for element in self._elements: 218 | element.feature = None 219 | element.prob = None 220 | return features.T 221 | 222 | def _select_element_and_variable_to_split(self, data, num_splits): 223 | y = self.predict_prob(data.features) 224 | if self._num_classes == 1: 225 | y = np.hstack(((1.0 - y).reshape(-1, 1), y.reshape(-1, 1))) 226 | 227 | delta = data.one_hot_labels - y 228 | element_gradients = np.stack( 229 | [ 230 | (delta[:, i].reshape(-1, 1) * data.features)[:, 2 * self._num_variables + 1 :] 231 | for i in range(self._num_classes) 232 | ], 233 | axis=0, 234 | ) 235 | element_gradient_variance = np.var(element_gradients, axis=1) 236 | element_gradient_variance = np.average(element_gradient_variance, axis=0) 237 | 238 | candidates = sorted( 239 | zip(self._elements, element_gradient_variance, data.features.T[2 * self._num_variables + 1 :]), 240 | reverse=True, 241 | key=lambda x: x[1], 242 | ) 243 | selected = [] 244 | for candidate in candidates[: min(5000, len(candidates))]: 245 | element_to_split = candidate[0] 246 | if len(element_to_split.splittable_variables) > 0 and np.sum(candidate[2]) > 25: 247 | original_feature = candidate[2] 248 | original_variance = candidate[1] 249 | variable_to_split = None 250 | min_after_split_variance = float("inf") 251 | for variable in element_to_split.splittable_variables: 252 | left_feature = original_feature * data.images[:, variable - 1] 253 | right_feature = original_feature - left_feature 254 | 255 | if np.sum(left_feature) > 10 and np.sum(right_feature) > 10: 256 | 257 | left_gradient = (data.one_hot_labels - y) * left_feature.reshape((-1, 1)) 258 | right_gradient = (data.one_hot_labels - y) * right_feature.reshape((-1, 1)) 259 | 260 | w = np.sum(data.images[:, variable - 1]) / data.num_samples 261 | 262 | after_split_variance = w * np.average(np.var(left_gradient, axis=0)) + (1 - w) * np.average( 263 | np.var(right_gradient, axis=0) 264 | ) 265 | if after_split_variance < min_after_split_variance: 266 | min_after_split_variance = after_split_variance 267 | variable_to_split = variable 268 | if min_after_split_variance < original_variance: 269 | improved_amount = min_after_split_variance - original_variance 270 | if len(selected) == num_splits: 271 | if improved_amount < selected[0][1]: 272 | selected = selected[1:] 273 | selected.append(((element_to_split, variable_to_split), improved_amount)) 274 | selected.sort(key=lambda x: x[1]) 275 | else: 276 | selected.append(((element_to_split, variable_to_split), improved_amount)) 277 | selected.sort(key=lambda x: x[1]) 278 | 279 | gc.collect() 280 | return [x[0] for x in selected] 281 | 282 | def _split(self, element_to_split, variable_to_split, depth): 283 | parent = element_to_split.parent 284 | original_element, copied_element = self._copy_and_modify_element_for_split( 285 | element_to_split, variable_to_split, 0, depth 286 | ) 287 | if original_element is None or copied_element is None: 288 | raise ValueError("Split elements become invalid.") 289 | parent.add_element(copied_element) 290 | 291 | def _copy_and_modify_element_for_split(self, original_element, variable, current_depth, max_depth): 292 | original_element.flag = True 293 | original_element.remove_splittable_variable(variable) 294 | original_prime = original_element.prime 295 | original_sub = original_element.sub 296 | if current_depth >= max_depth: 297 | if variable in original_prime.vtree.variables: 298 | original_prime, copied_prime = self._copy_and_modify_node_for_split( 299 | original_prime, variable, current_depth, max_depth 300 | ) 301 | copied_sub = original_sub 302 | elif variable in original_sub.vtree.variables: 303 | original_sub, copied_sub = self._copy_and_modify_node_for_split( 304 | original_sub, variable, current_depth, max_depth 305 | ) 306 | copied_prime = original_prime 307 | else: 308 | copied_prime = original_prime 309 | copied_sub = original_sub 310 | else: 311 | original_prime, copied_prime = self._copy_and_modify_node_for_split( 312 | original_prime, variable, current_depth, max_depth 313 | ) 314 | original_sub, copied_sub = self._copy_and_modify_node_for_split(original_sub, variable, current_depth, max_depth) 315 | if copied_prime is not None and copied_sub is not None: 316 | copied_element = AndGate(copied_prime, copied_sub, copy.deepcopy(original_element.parameter)) 317 | copied_element.splittable_variables = copy.deepcopy(original_element.splittable_variables) 318 | else: 319 | copied_element = None 320 | if original_prime is not None and original_sub is not None: 321 | original_element.prime = original_prime 322 | original_element.sub = original_sub 323 | else: 324 | original_element = None 325 | return original_element, copied_element 326 | 327 | def _copy_and_modify_node_for_split(self, original_node, variable, current_depth, max_depth): 328 | if original_node.num_parents == 0: 329 | raise ValueError("Some node does not have a parent.") 330 | original_node.decrease_num_parents_by_one() 331 | if isinstance(original_node, CircuitTerminal): 332 | if original_node.var_index == variable: 333 | if original_node.var_value == LITERAL_IS_TRUE: 334 | copied_node = None 335 | elif original_node.var_value == LITERAL_IS_FALSE: 336 | original_node = None 337 | copied_node = self._terminal_nodes[self._num_variables + variable - 1] 338 | else: 339 | raise ValueError( 340 | "Under the current setting," 341 | "we only support terminal nodes that are either positive or negative literals." 342 | ) 343 | else: 344 | copied_node = original_node 345 | return original_node, copied_node 346 | else: 347 | if original_node.num_parents > 0: 348 | original_node = self._deep_copy_node(original_node, variable, current_depth, max_depth) 349 | copied_elements = [] 350 | i = 0 351 | while i < len(original_node.elements): 352 | original_element, copied_element = self._copy_and_modify_element_for_split( 353 | original_node.elements[i], variable, current_depth + 1, max_depth 354 | ) 355 | if original_element is None: 356 | original_node.remove_element(i) 357 | else: 358 | i += 1 359 | if copied_element is not None: 360 | copied_elements.append(copied_element) 361 | if len(copied_elements) == 0: 362 | copied_node = None 363 | else: 364 | self._largest_index += 1 365 | copied_node = OrGate(self._largest_index, original_node.vtree, copied_elements) 366 | if len(original_node.elements) == 0: 367 | original_node = None 368 | return original_node, copied_node 369 | 370 | def _deep_copy_node(self, node, variable, current_depth, max_depth): 371 | if isinstance(node, CircuitTerminal): 372 | return node 373 | else: 374 | if len(node.elements) == 0: 375 | raise ValueError("Decision nodes should have at least one elements.") 376 | copied_elements = [] 377 | for element in node.elements: 378 | copied_elements.append(self._deep_copy_element(element, variable, current_depth + 1, max_depth)) 379 | self._largest_index += 1 380 | return OrGate(self._largest_index, node.vtree, copied_elements) 381 | 382 | def _deep_copy_element(self, element, variable, current_depth, max_depth): 383 | if current_depth >= max_depth: 384 | if variable in element.prime.vtree.variables: 385 | copied_element = AndGate( 386 | self._deep_copy_node(element.prime, variable, current_depth, max_depth), 387 | element.sub, 388 | copy.deepcopy(element.parameter), 389 | ) 390 | elif variable in element.sub.vtree.variables: 391 | copied_element = AndGate( 392 | element.prime, 393 | self._deep_copy_node(element.sub, variable, current_depth, max_depth), 394 | copy.deepcopy(element.parameter), 395 | ) 396 | else: 397 | copied_element = AndGate(element.prime, element.sub, copy.deepcopy(element.parameter)) 398 | else: 399 | copied_element = AndGate( 400 | self._deep_copy_node(element.prime, variable, current_depth, max_depth), 401 | self._deep_copy_node(element.sub, variable, current_depth, max_depth), 402 | copy.deepcopy(element.parameter), 403 | ) 404 | copied_element.splittable_variables = copy.deepcopy(element.splittable_variables) 405 | return copied_element 406 | 407 | def calculate_accuracy(self, data): 408 | """Calculate accuracy given the learned parameters on the provided data.""" 409 | y = self.predict(data.features) 410 | accuracy = np.sum(y == data.labels) / data.num_samples 411 | return accuracy 412 | 413 | def predict(self, features): 414 | y = self.predict_prob(features) 415 | if self._num_classes > 1: 416 | return np.argmax(y, axis=1) 417 | else: 418 | return (y > 0.5).astype(int).ravel() 419 | 420 | def predict_prob(self, features): 421 | """Predict the given images by providing their corresponding features.""" 422 | y = 1.0 / (1.0 + np.exp(-np.dot(features, self._parameters.T))) 423 | return y 424 | 425 | def learn_parameters(self, data, num_iterations, num_cores=-1): 426 | """Logistic Psdd's parameter learning is reduced to logistic regression. 427 | We use mini-batch SGD to optimize the parameters.""" 428 | model = LogisticRegression( 429 | solver="saga", 430 | fit_intercept=False, 431 | multi_class="ovr", 432 | max_iter=num_iterations, 433 | C=0.1, 434 | warm_start=True, 435 | tol=1e-5, 436 | coef_=self._parameters, 437 | n_jobs=num_cores, 438 | ) 439 | model.fit(data.features, data.labels) 440 | self._record_learned_parameters(model.coef_) 441 | gc.collect() 442 | 443 | def change_structure(self, data, depth, num_splits): 444 | splits = self._select_element_and_variable_to_split(data, num_splits) 445 | for element_to_split, variable_to_split in splits: 446 | if not element_to_split.flag: 447 | self._split(element_to_split, variable_to_split, depth) 448 | self._serialize() 449 | 450 | def save(self, f): 451 | self._serialize() 452 | f.write(FORMAT) 453 | f.write(f"Logisitic Circuit\n") 454 | for terminal_node in self._terminal_nodes: 455 | terminal_node.save(f) 456 | for decision_node in reversed(self._decision_nodes): 457 | decision_node.save(f) 458 | f.write("B") 459 | for parameter in self._bias: 460 | f.write(f" {parameter}") 461 | f.write("\n") 462 | 463 | def load(self, f): 464 | # read the format at the beginning 465 | line = f.readline() 466 | while line[0] == "c": 467 | line = f.readline() 468 | 469 | # serialize the vtree 470 | vtree_nodes = dict() 471 | unvisited_vtree_nodes = deque() 472 | unvisited_vtree_nodes.append(self._vtree) 473 | while len(unvisited_vtree_nodes): 474 | node = unvisited_vtree_nodes.popleft() 475 | vtree_nodes[node.index] = node 476 | if not node.is_leaf(): 477 | unvisited_vtree_nodes.append(node.left) 478 | unvisited_vtree_nodes.append(node.right) 479 | 480 | # extract the saved logistic circuit 481 | nodes = dict() 482 | line = f.readline() 483 | while line[0] == "T" or line[0] == "F": 484 | line_as_list = line.strip().split(" ") 485 | positive_literal, var = (line_as_list[0] == "T"), int(line_as_list[3]) 486 | index, vtree_index = int(line_as_list[1]), int(line_as_list[2]) 487 | parameters = [] 488 | for i in range(self._num_classes): 489 | parameters.append(float(line_as_list[4 + i])) 490 | parameters = np.array(parameters, dtype=np.float32) 491 | if positive_literal: 492 | nodes[index] = (CircuitTerminal(index, vtree_nodes[vtree_index], var, LITERAL_IS_TRUE, parameters), {var}) 493 | else: 494 | nodes[index] = (CircuitTerminal(index, vtree_nodes[vtree_index], var, LITERAL_IS_FALSE, parameters), {-var}) 495 | self._largest_index = max(self._largest_index, index) 496 | line = f.readline() 497 | 498 | self._terminal_nodes = [x[0] for x in nodes.values()] 499 | self._terminal_nodes.sort(key=lambda x: (-x.var_value, x.var_index)) 500 | if len(self._terminal_nodes) != 2 * self._num_variables: 501 | raise ValueError( 502 | "Number of terminal nodes recorded in the circuit file " 503 | "does not match 2 * number of variables in the provided vtree." 504 | ) 505 | 506 | root = None 507 | while line[0] == "D": 508 | line_as_list = line.strip().split(" ") 509 | index, vtree_index, num_elements = int(line_as_list[1]), int(line_as_list[2]), int(line_as_list[3]) 510 | elements = [] 511 | variables = set() 512 | for i in range(num_elements): 513 | prime_index = int(line_as_list[i * (self._num_classes + 2) + 4].strip("(")) 514 | sub_index = int(line_as_list[i * (self._num_classes + 2) + 5]) 515 | element_variables = nodes[prime_index][1].union(nodes[sub_index][1]) 516 | variables = variables.union(element_variables) 517 | splittable_variables = set() 518 | for variable in element_variables: 519 | if -variable in element_variables: 520 | splittable_variables.add(abs(variable)) 521 | parameters = [] 522 | for j in range(self._num_classes): 523 | parameters.append(float(line_as_list[i * (self._num_classes + 2) + 6 + j].strip(")"))) 524 | parameters = np.array(parameters, dtype=np.float32) 525 | elements.append(AndGate(nodes[prime_index][0], nodes[sub_index][0], parameters)) 526 | elements[-1].splittable_variables = splittable_variables 527 | nodes[index] = (OrGate(index, vtree_nodes[vtree_index], elements), variables) 528 | root = nodes[index][0] 529 | self._largest_index = max(self._largest_index, index) 530 | line = f.readline() 531 | 532 | if line[0] != "B": 533 | raise ValueError("The last line in a circuit file must record the bias parameters.") 534 | self._bias = np.array([float(x) for x in line.strip().split(" ")[1:]], dtype=np.float32) 535 | 536 | gc.collect() 537 | return root 538 | -------------------------------------------------------------------------------- /algo/LogisticRegression.py: -------------------------------------------------------------------------------- 1 | """ 2 | Logistic Regression 3 | """ 4 | 5 | # Author: Gael Varoquaux 6 | # Fabian Pedregosa 7 | # Alexandre Gramfort 8 | # Manoj Kumar 9 | # Lars Buitinck 10 | # Simon Wu 11 | # Arthur Mensch n_features: 163 | grad[-1] = z0.sum() 164 | return out, grad 165 | 166 | 167 | def _logistic_loss(w, X, y, alpha, sample_weight=None): 168 | """Computes the logistic loss. 169 | 170 | Parameters 171 | ---------- 172 | w : ndarray, shape (n_features,) or (n_features + 1,) 173 | Coefficient vector. 174 | 175 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 176 | Training data. 177 | 178 | y : ndarray, shape (n_samples,) 179 | Array of labels. 180 | 181 | alpha : float 182 | Regularization parameter. alpha is equal to 1 / C. 183 | 184 | sample_weight : array-like, shape (n_samples,) optional 185 | Array of weights that are assigned to individual samples. 186 | If not provided, then each sample is given unit weight. 187 | 188 | Returns 189 | ------- 190 | out : float 191 | Logistic loss. 192 | """ 193 | w, c, yz = _intercept_dot(w, X, y) 194 | 195 | if sample_weight is None: 196 | sample_weight = np.ones(y.shape[0]) 197 | 198 | # Logistic loss is the negative of the log of the logistic function. 199 | out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w) 200 | return out 201 | 202 | 203 | def _logistic_grad_hess(w, X, y, alpha, sample_weight=None): 204 | """Computes the gradient and the Hessian, in the case of a logistic loss. 205 | 206 | Parameters 207 | ---------- 208 | w : ndarray, shape (n_features,) or (n_features + 1,) 209 | Coefficient vector. 210 | 211 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 212 | Training data. 213 | 214 | y : ndarray, shape (n_samples,) 215 | Array of labels. 216 | 217 | alpha : float 218 | Regularization parameter. alpha is equal to 1 / C. 219 | 220 | sample_weight : array-like, shape (n_samples,) optional 221 | Array of weights that are assigned to individual samples. 222 | If not provided, then each sample is given unit weight. 223 | 224 | Returns 225 | ------- 226 | grad : ndarray, shape (n_features,) or (n_features + 1,) 227 | Logistic gradient. 228 | 229 | Hs : callable 230 | Function that takes the gradient as a parameter and returns the 231 | matrix product of the Hessian and gradient. 232 | """ 233 | n_samples, n_features = X.shape 234 | grad = np.empty_like(w) 235 | fit_intercept = grad.shape[0] > n_features 236 | 237 | w, c, yz = _intercept_dot(w, X, y) 238 | 239 | if sample_weight is None: 240 | sample_weight = np.ones(y.shape[0]) 241 | 242 | z = expit(yz) 243 | z0 = sample_weight * (z - 1) * y 244 | 245 | grad[:n_features] = safe_sparse_dot(X.T, z0) + alpha * w 246 | 247 | # Case where we fit the intercept. 248 | if fit_intercept: 249 | grad[-1] = z0.sum() 250 | 251 | # The mat-vec product of the Hessian 252 | d = sample_weight * z * (1 - z) 253 | if sparse.issparse(X): 254 | dX = safe_sparse_dot(sparse.dia_matrix((d, 0), 255 | shape=(n_samples, n_samples)), X) 256 | else: 257 | # Precompute as much as possible 258 | dX = d[:, np.newaxis] * X 259 | 260 | if fit_intercept: 261 | # Calculate the double derivative with respect to intercept 262 | # In the case of sparse matrices this returns a matrix object. 263 | dd_intercept = np.squeeze(np.array(dX.sum(axis=0))) 264 | 265 | def Hs(s): 266 | ret = np.empty_like(s) 267 | ret[:n_features] = X.T.dot(dX.dot(s[:n_features])) 268 | ret[:n_features] += alpha * s[:n_features] 269 | 270 | # For the fit intercept case. 271 | if fit_intercept: 272 | ret[:n_features] += s[-1] * dd_intercept 273 | ret[-1] = dd_intercept.dot(s[:n_features]) 274 | ret[-1] += d.sum() * s[-1] 275 | return ret 276 | 277 | return grad, Hs 278 | 279 | 280 | def _multinomial_loss(w, X, Y, alpha, sample_weight): 281 | """Computes multinomial loss and class probabilities. 282 | 283 | Parameters 284 | ---------- 285 | w : ndarray, shape (n_classes * n_features,) or 286 | (n_classes * (n_features + 1),) 287 | Coefficient vector. 288 | 289 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 290 | Training data. 291 | 292 | Y : ndarray, shape (n_samples, n_classes) 293 | Transformed labels according to the output of LabelBinarizer. 294 | 295 | alpha : float 296 | Regularization parameter. alpha is equal to 1 / C. 297 | 298 | sample_weight : array-like, shape (n_samples,) optional 299 | Array of weights that are assigned to individual samples. 300 | If not provided, then each sample is given unit weight. 301 | 302 | Returns 303 | ------- 304 | loss : float 305 | Multinomial loss. 306 | 307 | p : ndarray, shape (n_samples, n_classes) 308 | Estimated class probabilities. 309 | 310 | w : ndarray, shape (n_classes, n_features) 311 | Reshaped param vector excluding intercept terms. 312 | 313 | Reference 314 | --------- 315 | Bishop, C. M. (2006). Pattern recognition and machine learning. 316 | Springer. (Chapter 4.3.4) 317 | """ 318 | n_classes = Y.shape[1] 319 | n_features = X.shape[1] 320 | fit_intercept = w.size == (n_classes * (n_features + 1)) 321 | w = w.reshape(n_classes, -1) 322 | sample_weight = sample_weight[:, np.newaxis] 323 | if fit_intercept: 324 | intercept = w[:, -1] 325 | w = w[:, :-1] 326 | else: 327 | intercept = 0 328 | p = safe_sparse_dot(X, w.T) 329 | p += intercept 330 | p -= logsumexp(p, axis=1)[:, np.newaxis] 331 | loss = -(sample_weight * Y * p).sum() 332 | loss += 0.5 * alpha * squared_norm(w) 333 | p = np.exp(p, p) 334 | return loss, p, w 335 | 336 | 337 | def _multinomial_loss_grad(w, X, Y, alpha, sample_weight): 338 | """Computes the multinomial loss, gradient and class probabilities. 339 | 340 | Parameters 341 | ---------- 342 | w : ndarray, shape (n_classes * n_features,) or 343 | (n_classes * (n_features + 1),) 344 | Coefficient vector. 345 | 346 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 347 | Training data. 348 | 349 | Y : ndarray, shape (n_samples, n_classes) 350 | Transformed labels according to the output of LabelBinarizer. 351 | 352 | alpha : float 353 | Regularization parameter. alpha is equal to 1 / C. 354 | 355 | sample_weight : array-like, shape (n_samples,) optional 356 | Array of weights that are assigned to individual samples. 357 | 358 | Returns 359 | ------- 360 | loss : float 361 | Multinomial loss. 362 | 363 | grad : ndarray, shape (n_classes * n_features,) or 364 | (n_classes * (n_features + 1),) 365 | Ravelled gradient of the multinomial loss. 366 | 367 | p : ndarray, shape (n_samples, n_classes) 368 | Estimated class probabilities 369 | 370 | Reference 371 | --------- 372 | Bishop, C. M. (2006). Pattern recognition and machine learning. 373 | Springer. (Chapter 4.3.4) 374 | """ 375 | n_classes = Y.shape[1] 376 | n_features = X.shape[1] 377 | fit_intercept = (w.size == n_classes * (n_features + 1)) 378 | grad = np.zeros((n_classes, n_features + bool(fit_intercept)), 379 | dtype=X.dtype) 380 | loss, p, w = _multinomial_loss(w, X, Y, alpha, sample_weight) 381 | sample_weight = sample_weight[:, np.newaxis] 382 | diff = sample_weight * (p - Y) 383 | grad[:, :n_features] = safe_sparse_dot(diff.T, X) 384 | grad[:, :n_features] += alpha * w 385 | if fit_intercept: 386 | grad[:, -1] = diff.sum(axis=0) 387 | return loss, grad.ravel(), p 388 | 389 | 390 | def _multinomial_grad_hess(w, X, Y, alpha, sample_weight): 391 | """ 392 | Computes the gradient and the Hessian, in the case of a multinomial loss. 393 | 394 | Parameters 395 | ---------- 396 | w : ndarray, shape (n_classes * n_features,) or 397 | (n_classes * (n_features + 1),) 398 | Coefficient vector. 399 | 400 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 401 | Training data. 402 | 403 | Y : ndarray, shape (n_samples, n_classes) 404 | Transformed labels according to the output of LabelBinarizer. 405 | 406 | alpha : float 407 | Regularization parameter. alpha is equal to 1 / C. 408 | 409 | sample_weight : array-like, shape (n_samples,) optional 410 | Array of weights that are assigned to individual samples. 411 | 412 | Returns 413 | ------- 414 | grad : array, shape (n_classes * n_features,) or 415 | (n_classes * (n_features + 1),) 416 | Ravelled gradient of the multinomial loss. 417 | 418 | hessp : callable 419 | Function that takes in a vector input of shape (n_classes * n_features) 420 | or (n_classes * (n_features + 1)) and returns matrix-vector product 421 | with hessian. 422 | 423 | References 424 | ---------- 425 | Barak A. Pearlmutter (1993). Fast Exact Multiplication by the Hessian. 426 | http://www.bcl.hamilton.ie/~barak/papers/nc-hessian.pdf 427 | """ 428 | n_features = X.shape[1] 429 | n_classes = Y.shape[1] 430 | fit_intercept = w.size == (n_classes * (n_features + 1)) 431 | 432 | # `loss` is unused. Refactoring to avoid computing it does not 433 | # significantly speed up the computation and decreases readability 434 | loss, grad, p = _multinomial_loss_grad(w, X, Y, alpha, sample_weight) 435 | sample_weight = sample_weight[:, np.newaxis] 436 | 437 | # Hessian-vector product derived by applying the R-operator on the gradient 438 | # of the multinomial loss function. 439 | def hessp(v): 440 | v = v.reshape(n_classes, -1) 441 | if fit_intercept: 442 | inter_terms = v[:, -1] 443 | v = v[:, :-1] 444 | else: 445 | inter_terms = 0 446 | # r_yhat holds the result of applying the R-operator on the multinomial 447 | # estimator. 448 | r_yhat = safe_sparse_dot(X, v.T) 449 | r_yhat += inter_terms 450 | r_yhat += (-p * r_yhat).sum(axis=1)[:, np.newaxis] 451 | r_yhat *= p 452 | r_yhat *= sample_weight 453 | hessProd = np.zeros((n_classes, n_features + bool(fit_intercept))) 454 | hessProd[:, :n_features] = safe_sparse_dot(r_yhat.T, X) 455 | hessProd[:, :n_features] += v * alpha 456 | if fit_intercept: 457 | hessProd[:, -1] = r_yhat.sum(axis=0) 458 | return hessProd.ravel() 459 | 460 | return grad, hessp 461 | 462 | 463 | def _check_solver(solver, penalty, dual): 464 | if solver == 'warn': 465 | solver = 'liblinear' 466 | warnings.warn("Default solver will be changed to 'lbfgs' in 0.22. " 467 | "Specify a solver to silence this warning.", 468 | FutureWarning) 469 | 470 | all_solvers = ['liblinear', 'newton-cg', 'lbfgs', 'sag', 'saga'] 471 | if solver not in all_solvers: 472 | raise ValueError("Logistic Regression supports only solvers in %s, got" 473 | " %s." % (all_solvers, solver)) 474 | 475 | all_penalties = ['l1', 'l2'] 476 | if penalty not in all_penalties: 477 | raise ValueError("Logistic Regression supports only penalties in %s," 478 | " got %s." % (all_penalties, penalty)) 479 | 480 | if solver not in ['liblinear', 'saga'] and penalty != 'l2': 481 | raise ValueError("Solver %s supports only l2 penalties, " 482 | "got %s penalty." % (solver, penalty)) 483 | if solver != 'liblinear' and dual: 484 | raise ValueError("Solver %s supports only " 485 | "dual=False, got dual=%s" % (solver, dual)) 486 | return solver 487 | 488 | 489 | def _check_multi_class(multi_class, solver, n_classes): 490 | if multi_class == 'warn': 491 | multi_class = 'ovr' 492 | if n_classes > 2: 493 | warnings.warn("Default multi_class will be changed to 'auto' in" 494 | " 0.22. Specify the multi_class option to silence " 495 | "this warning.", FutureWarning) 496 | if multi_class == 'auto': 497 | if solver == 'liblinear': 498 | multi_class = 'ovr' 499 | elif n_classes > 2: 500 | multi_class = 'multinomial' 501 | else: 502 | multi_class = 'ovr' 503 | if multi_class not in ('multinomial', 'ovr'): 504 | raise ValueError("multi_class should be 'multinomial', 'ovr' or " 505 | "'auto'. Got %s." % multi_class) 506 | if multi_class == 'multinomial' and solver == 'liblinear': 507 | raise ValueError("Solver %s does not support " 508 | "a multinomial backend." % solver) 509 | return multi_class 510 | 511 | 512 | def logistic_regression_path(X, y, pos_class=None, Cs=10, fit_intercept=True, 513 | max_iter=100, tol=1e-4, verbose=0, 514 | solver='lbfgs', coef=None, 515 | class_weight=None, dual=False, penalty='l2', 516 | intercept_scaling=1., multi_class='warn', 517 | random_state=None, check_input=True, 518 | max_squared_sum=None, sample_weight=None): 519 | """Compute a Logistic Regression model for a list of regularization 520 | parameters. 521 | 522 | This is an implementation that uses the result of the previous model 523 | to speed up computations along the set of solutions, making it faster 524 | than sequentially calling LogisticRegression for the different parameters. 525 | Note that there will be no speedup with liblinear solver, since it does 526 | not handle warm-starting. 527 | 528 | Read more in the :ref:`User Guide `. 529 | 530 | Parameters 531 | ---------- 532 | X : array-like or sparse matrix, shape (n_samples, n_features) 533 | Input data. 534 | 535 | y : array-like, shape (n_samples,) or (n_samples, n_targets) 536 | Input data, target values. 537 | 538 | pos_class : int, None 539 | The class with respect to which we perform a one-vs-all fit. 540 | If None, then it is assumed that the given problem is binary. 541 | 542 | Cs : int | array-like, shape (n_cs,) 543 | List of values for the regularization parameter or integer specifying 544 | the number of regularization parameters that should be used. In this 545 | case, the parameters will be chosen in a logarithmic scale between 546 | 1e-4 and 1e4. 547 | 548 | fit_intercept : bool 549 | Whether to fit an intercept for the model. In this case the shape of 550 | the returned array is (n_cs, n_features + 1). 551 | 552 | max_iter : int 553 | Maximum number of iterations for the solver. 554 | 555 | tol : float 556 | Stopping criterion. For the newton-cg and lbfgs solvers, the iteration 557 | will stop when ``max{|g_i | i = 1, ..., n} <= tol`` 558 | where ``g_i`` is the i-th component of the gradient. 559 | 560 | verbose : int 561 | For the liblinear and lbfgs solvers set verbose to any positive 562 | number for verbosity. 563 | 564 | solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'} 565 | Numerical solver to use. 566 | 567 | coef : array-like, shape (n_features,), default None 568 | Initialization value for coefficients of logistic regression. 569 | Useless for liblinear solver. 570 | 571 | class_weight : dict or 'balanced', optional 572 | Weights associated with classes in the form ``{class_label: weight}``. 573 | If not given, all classes are supposed to have weight one. 574 | 575 | The "balanced" mode uses the values of y to automatically adjust 576 | weights inversely proportional to class frequencies in the input data 577 | as ``n_samples / (n_classes * np.bincount(y))``. 578 | 579 | Note that these weights will be multiplied with sample_weight (passed 580 | through the fit method) if sample_weight is specified. 581 | 582 | dual : bool 583 | Dual or primal formulation. Dual formulation is only implemented for 584 | l2 penalty with liblinear solver. Prefer dual=False when 585 | n_samples > n_features. 586 | 587 | penalty : str, 'l1' or 'l2' 588 | Used to specify the norm used in the penalization. The 'newton-cg', 589 | 'sag' and 'lbfgs' solvers support only l2 penalties. 590 | 591 | intercept_scaling : float, default 1. 592 | Useful only when the solver 'liblinear' is used 593 | and self.fit_intercept is set to True. In this case, x becomes 594 | [x, self.intercept_scaling], 595 | i.e. a "synthetic" feature with constant value equal to 596 | intercept_scaling is appended to the instance vector. 597 | The intercept becomes ``intercept_scaling * synthetic_feature_weight``. 598 | 599 | Note! the synthetic feature weight is subject to l1/l2 regularization 600 | as all other features. 601 | To lessen the effect of regularization on synthetic feature weight 602 | (and therefore on the intercept) intercept_scaling has to be increased. 603 | 604 | multi_class : str, {'ovr', 'multinomial', 'auto'}, default: 'ovr' 605 | If the option chosen is 'ovr', then a binary problem is fit for each 606 | label. For 'multinomial' the loss minimised is the multinomial loss fit 607 | across the entire probability distribution, *even when the data is 608 | binary*. 'multinomial' is unavailable when solver='liblinear'. 609 | 'auto' selects 'ovr' if the data is binary, or if solver='liblinear', 610 | and otherwise selects 'multinomial'. 611 | 612 | .. versionadded:: 0.18 613 | Stochastic Average Gradient descent solver for 'multinomial' case. 614 | .. versionchanged:: 0.20 615 | Default will change from 'ovr' to 'auto' in 0.22. 616 | 617 | random_state : int, RandomState instance or None, optional, default None 618 | The seed of the pseudo random number generator to use when shuffling 619 | the data. If int, random_state is the seed used by the random number 620 | generator; If RandomState instance, random_state is the random number 621 | generator; If None, the random number generator is the RandomState 622 | instance used by `np.random`. Used when ``solver`` == 'sag' or 623 | 'liblinear'. 624 | 625 | check_input : bool, default True 626 | If False, the input arrays X and y will not be checked. 627 | 628 | max_squared_sum : float, default None 629 | Maximum squared sum of X over samples. Used only in SAG solver. 630 | If None, it will be computed, going through all the samples. 631 | The value should be precomputed to speed up cross validation. 632 | 633 | sample_weight : array-like, shape(n_samples,) optional 634 | Array of weights that are assigned to individual samples. 635 | If not provided, then each sample is given unit weight. 636 | 637 | Returns 638 | ------- 639 | coefs : ndarray, shape (n_cs, n_features) or (n_cs, n_features + 1) 640 | List of coefficients for the Logistic Regression model. If 641 | fit_intercept is set to True then the second dimension will be 642 | n_features + 1, where the last item represents the intercept. For 643 | ``multiclass='multinomial'``, the shape is (n_classes, n_cs, 644 | n_features) or (n_classes, n_cs, n_features + 1). 645 | 646 | Cs : ndarray 647 | Grid of Cs used for cross-validation. 648 | 649 | n_iter : array, shape (n_cs,) 650 | Actual number of iteration for each Cs. 651 | 652 | Notes 653 | ----- 654 | You might get slightly different results with the solver liblinear than 655 | with the others since this uses LIBLINEAR which penalizes the intercept. 656 | 657 | .. versionchanged:: 0.19 658 | The "copy" parameter was removed. 659 | """ 660 | if isinstance(Cs, numbers.Integral): 661 | Cs = np.logspace(-4, 4, Cs) 662 | 663 | solver = _check_solver(solver, penalty, dual) 664 | 665 | # Preprocessing. 666 | if check_input: 667 | X = check_array(X, accept_sparse='csr', dtype=np.float64, 668 | accept_large_sparse=solver != 'liblinear') 669 | y = check_array(y, ensure_2d=False, dtype=None) 670 | check_consistent_length(X, y) 671 | _, n_features = X.shape 672 | 673 | classes = np.unique(y) 674 | random_state = check_random_state(random_state) 675 | 676 | multi_class = _check_multi_class(multi_class, solver, len(classes)) 677 | if pos_class is None and multi_class != 'multinomial': 678 | if (classes.size > 2): 679 | raise ValueError('To fit OvR, use the pos_class argument') 680 | # np.unique(y) gives labels in sorted order. 681 | pos_class = classes[1] 682 | 683 | # If sample weights exist, convert them to array (support for lists) 684 | # and check length 685 | # Otherwise set them to 1 for all examples 686 | if sample_weight is not None: 687 | sample_weight = np.array(sample_weight, dtype=X.dtype, order='C') 688 | check_consistent_length(y, sample_weight) 689 | else: 690 | sample_weight = np.ones(X.shape[0], dtype=X.dtype) 691 | 692 | # If class_weights is a dict (provided by the user), the weights 693 | # are assigned to the original labels. If it is "balanced", then 694 | # the class_weights are assigned after masking the labels with a OvR. 695 | le = LabelEncoder() 696 | if isinstance(class_weight, dict) or multi_class == 'multinomial': 697 | class_weight_ = compute_class_weight(class_weight, classes, y) 698 | sample_weight *= class_weight_[le.fit_transform(y)] 699 | 700 | # For doing a ovr, we need to mask the labels first. for the 701 | # multinomial case this is not necessary. 702 | if multi_class == 'ovr': 703 | w0 = np.zeros(n_features + int(fit_intercept), dtype=X.dtype) 704 | mask_classes = np.array([-1, 1]) 705 | mask = (y == pos_class) 706 | y_bin = np.ones(y.shape, dtype=X.dtype) 707 | y_bin[~mask] = -1. 708 | # for compute_class_weight 709 | 710 | if class_weight == "balanced": 711 | class_weight_ = compute_class_weight(class_weight, mask_classes, 712 | y_bin) 713 | sample_weight *= class_weight_[le.fit_transform(y_bin)] 714 | 715 | else: 716 | if solver not in ['sag', 'saga']: 717 | lbin = LabelBinarizer() 718 | Y_multi = lbin.fit_transform(y) 719 | if Y_multi.shape[1] == 1: 720 | Y_multi = np.hstack([1 - Y_multi, Y_multi]) 721 | else: 722 | # SAG multinomial solver needs LabelEncoder, not LabelBinarizer 723 | le = LabelEncoder() 724 | Y_multi = le.fit_transform(y).astype(X.dtype, copy=False) 725 | 726 | w0 = np.zeros((classes.size, n_features + int(fit_intercept)), 727 | order='F', dtype=X.dtype) 728 | 729 | if coef is not None: 730 | # it must work both giving the bias term and not 731 | if multi_class == 'ovr': 732 | if coef.size not in (n_features, w0.size): 733 | raise ValueError( 734 | 'Initialization coef is of shape %d, expected shape ' 735 | '%d or %d' % (coef.size, n_features, w0.size)) 736 | w0[:coef.size] = coef 737 | else: 738 | # For binary problems coef.shape[0] should be 1, otherwise it 739 | # should be classes.size. 740 | n_classes = classes.size 741 | if n_classes == 2: 742 | n_classes = 1 743 | 744 | if (coef.shape[0] != n_classes or 745 | coef.shape[1] not in (n_features, n_features + 1)): 746 | raise ValueError( 747 | 'Initialization coef is of shape (%d, %d), expected ' 748 | 'shape (%d, %d) or (%d, %d)' % ( 749 | coef.shape[0], coef.shape[1], classes.size, 750 | n_features, classes.size, n_features + 1)) 751 | 752 | if n_classes == 1: 753 | w0[0, :coef.shape[1]] = -coef 754 | w0[1, :coef.shape[1]] = coef 755 | else: 756 | w0[:, :coef.shape[1]] = coef 757 | 758 | if multi_class == 'multinomial': 759 | # fmin_l_bfgs_b and newton-cg accepts only ravelled parameters. 760 | if solver in ['lbfgs', 'newton-cg']: 761 | w0 = w0.ravel() 762 | target = Y_multi 763 | if solver == 'lbfgs': 764 | func = lambda x, *args: _multinomial_loss_grad(x, *args)[0:2] 765 | elif solver == 'newton-cg': 766 | func = lambda x, *args: _multinomial_loss(x, *args)[0] 767 | grad = lambda x, *args: _multinomial_loss_grad(x, *args)[1] 768 | hess = _multinomial_grad_hess 769 | warm_start_sag = {'coef': w0.T} 770 | else: 771 | target = y_bin 772 | if solver == 'lbfgs': 773 | func = _logistic_loss_and_grad 774 | elif solver == 'newton-cg': 775 | func = _logistic_loss 776 | grad = lambda x, *args: _logistic_loss_and_grad(x, *args)[1] 777 | hess = _logistic_grad_hess 778 | warm_start_sag = {'coef': np.expand_dims(w0, axis=1)} 779 | 780 | coefs = list() 781 | n_iter = np.zeros(len(Cs), dtype=np.int32) 782 | for i, C in enumerate(Cs): 783 | if solver == 'lbfgs': 784 | iprint = [-1, 50, 1, 100, 101][ 785 | np.searchsorted(np.array([0, 1, 2, 3]), verbose)] 786 | w0, loss, info = optimize.fmin_l_bfgs_b( 787 | func, w0, fprime=None, 788 | args=(X, target, 1. / C, sample_weight), 789 | iprint=iprint, pgtol=tol, maxiter=max_iter) 790 | if info["warnflag"] == 1: 791 | warnings.warn("lbfgs failed to converge. Increase the number " 792 | "of iterations.", ConvergenceWarning) 793 | # In scipy <= 1.0.0, nit may exceed maxiter. 794 | # See https://github.com/scipy/scipy/issues/7854. 795 | n_iter_i = min(info['nit'], max_iter) 796 | elif solver == 'newton-cg': 797 | args = (X, target, 1. / C, sample_weight) 798 | w0, n_iter_i = newton_cg(hess, func, grad, w0, args=args, 799 | maxiter=max_iter, tol=tol) 800 | elif solver == 'liblinear': 801 | coef_, intercept_, n_iter_i, = _fit_liblinear( 802 | X, target, C, fit_intercept, intercept_scaling, None, 803 | penalty, dual, verbose, max_iter, tol, random_state, 804 | sample_weight=sample_weight) 805 | if fit_intercept: 806 | w0 = np.concatenate([coef_.ravel(), intercept_]) 807 | else: 808 | w0 = coef_.ravel() 809 | 810 | elif solver in ['sag', 'saga']: 811 | if multi_class == 'multinomial': 812 | target = target.astype(np.float64) 813 | loss = 'multinomial' 814 | else: 815 | loss = 'log' 816 | if penalty == 'l1': 817 | alpha = 0. 818 | beta = 1. / C 819 | else: 820 | alpha = 1. / C 821 | beta = 0. 822 | w0, n_iter_i, warm_start_sag = sag_solver( 823 | X, target, sample_weight, loss, alpha, 824 | beta, max_iter, tol, 825 | verbose, random_state, False, max_squared_sum, warm_start_sag, 826 | is_saga=(solver == 'saga')) 827 | 828 | else: 829 | raise ValueError("solver must be one of {'liblinear', 'lbfgs', " 830 | "'newton-cg', 'sag'}, got '%s' instead" % solver) 831 | 832 | if multi_class == 'multinomial': 833 | n_classes = max(2, classes.size) 834 | multi_w0 = np.reshape(w0, (n_classes, -1)) 835 | if n_classes == 2: 836 | multi_w0 = multi_w0[1][np.newaxis, :] 837 | coefs.append(multi_w0.copy()) 838 | else: 839 | coefs.append(w0.copy()) 840 | 841 | n_iter[i] = n_iter_i 842 | 843 | return np.array(coefs), np.array(Cs), n_iter 844 | 845 | 846 | # helper function for LogisticCV 847 | def _log_reg_scoring_path(X, y, train, test, pos_class=None, Cs=10, 848 | scoring=None, fit_intercept=False, 849 | max_iter=100, tol=1e-4, class_weight=None, 850 | verbose=0, solver='lbfgs', penalty='l2', 851 | dual=False, intercept_scaling=1., 852 | multi_class='warn', random_state=None, 853 | max_squared_sum=None, sample_weight=None): 854 | """Computes scores across logistic_regression_path 855 | 856 | Parameters 857 | ---------- 858 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 859 | Training data. 860 | 861 | y : array-like, shape (n_samples,) or (n_samples, n_targets) 862 | Target labels. 863 | 864 | train : list of indices 865 | The indices of the train set. 866 | 867 | test : list of indices 868 | The indices of the test set. 869 | 870 | pos_class : int, None 871 | The class with respect to which we perform a one-vs-all fit. 872 | If None, then it is assumed that the given problem is binary. 873 | 874 | Cs : list of floats | int 875 | Each of the values in Cs describes the inverse of 876 | regularization strength. If Cs is as an int, then a grid of Cs 877 | values are chosen in a logarithmic scale between 1e-4 and 1e4. 878 | If not provided, then a fixed set of values for Cs are used. 879 | 880 | scoring : callable or None, optional, default: None 881 | A string (see model evaluation documentation) or 882 | a scorer callable object / function with signature 883 | ``scorer(estimator, X, y)``. For a list of scoring functions 884 | that can be used, look at :mod:`sklearn.metrics`. The 885 | default scoring option used is accuracy_score. 886 | 887 | fit_intercept : bool 888 | If False, then the bias term is set to zero. Else the last 889 | term of each coef_ gives us the intercept. 890 | 891 | max_iter : int 892 | Maximum number of iterations for the solver. 893 | 894 | tol : float 895 | Tolerance for stopping criteria. 896 | 897 | class_weight : dict or 'balanced', optional 898 | Weights associated with classes in the form ``{class_label: weight}``. 899 | If not given, all classes are supposed to have weight one. 900 | 901 | The "balanced" mode uses the values of y to automatically adjust 902 | weights inversely proportional to class frequencies in the input data 903 | as ``n_samples / (n_classes * np.bincount(y))`` 904 | 905 | Note that these weights will be multiplied with sample_weight (passed 906 | through the fit method) if sample_weight is specified. 907 | 908 | verbose : int 909 | For the liblinear and lbfgs solvers set verbose to any positive 910 | number for verbosity. 911 | 912 | solver : {'lbfgs', 'newton-cg', 'liblinear', 'sag', 'saga'} 913 | Decides which solver to use. 914 | 915 | penalty : str, 'l1' or 'l2' 916 | Used to specify the norm used in the penalization. The 'newton-cg', 917 | 'sag' and 'lbfgs' solvers support only l2 penalties. 918 | 919 | dual : bool 920 | Dual or primal formulation. Dual formulation is only implemented for 921 | l2 penalty with liblinear solver. Prefer dual=False when 922 | n_samples > n_features. 923 | 924 | intercept_scaling : float, default 1. 925 | Useful only when the solver 'liblinear' is used 926 | and self.fit_intercept is set to True. In this case, x becomes 927 | [x, self.intercept_scaling], 928 | i.e. a "synthetic" feature with constant value equals to 929 | intercept_scaling is appended to the instance vector. 930 | The intercept becomes intercept_scaling * synthetic feature weight 931 | Note! the synthetic feature weight is subject to l1/l2 regularization 932 | as all other features. 933 | To lessen the effect of regularization on synthetic feature weight 934 | (and therefore on the intercept) intercept_scaling has to be increased. 935 | 936 | multi_class : str, {'ovr', 'multinomial'} 937 | If the option chosen is 'ovr', then a binary problem is fit for each 938 | label. For 'multinomial' the loss minimised is the multinomial loss fit 939 | across the entire probability distribution, *even when the data is 940 | binary*. 'multinomial' is unavailable when solver='liblinear'. 941 | 942 | random_state : int, RandomState instance or None, optional, default None 943 | The seed of the pseudo random number generator to use when shuffling 944 | the data. If int, random_state is the seed used by the random number 945 | generator; If RandomState instance, random_state is the random number 946 | generator; If None, the random number generator is the RandomState 947 | instance used by `np.random`. Used when ``solver`` == 'sag' and 948 | 'liblinear'. 949 | 950 | max_squared_sum : float, default None 951 | Maximum squared sum of X over samples. Used only in SAG solver. 952 | If None, it will be computed, going through all the samples. 953 | The value should be precomputed to speed up cross validation. 954 | 955 | sample_weight : array-like, shape(n_samples,) optional 956 | Array of weights that are assigned to individual samples. 957 | If not provided, then each sample is given unit weight. 958 | 959 | Returns 960 | ------- 961 | coefs : ndarray, shape (n_cs, n_features) or (n_cs, n_features + 1) 962 | List of coefficients for the Logistic Regression model. If 963 | fit_intercept is set to True then the second dimension will be 964 | n_features + 1, where the last item represents the intercept. 965 | 966 | Cs : ndarray 967 | Grid of Cs used for cross-validation. 968 | 969 | scores : ndarray, shape (n_cs,) 970 | Scores obtained for each Cs. 971 | 972 | n_iter : array, shape(n_cs,) 973 | Actual number of iteration for each Cs. 974 | """ 975 | X_train = X[train] 976 | X_test = X[test] 977 | y_train = y[train] 978 | y_test = y[test] 979 | 980 | if sample_weight is not None: 981 | sample_weight = check_array(sample_weight, ensure_2d=False) 982 | check_consistent_length(y, sample_weight) 983 | 984 | sample_weight = sample_weight[train] 985 | 986 | coefs, Cs, n_iter = logistic_regression_path( 987 | X_train, y_train, Cs=Cs, fit_intercept=fit_intercept, 988 | solver=solver, max_iter=max_iter, class_weight=class_weight, 989 | pos_class=pos_class, multi_class=multi_class, 990 | tol=tol, verbose=verbose, dual=dual, penalty=penalty, 991 | intercept_scaling=intercept_scaling, random_state=random_state, 992 | check_input=False, max_squared_sum=max_squared_sum, 993 | sample_weight=sample_weight) 994 | 995 | log_reg = LogisticRegression(solver=solver, multi_class=multi_class) 996 | 997 | # The score method of Logistic Regression has a classes_ attribute. 998 | if multi_class == 'ovr': 999 | log_reg.classes_ = np.array([-1, 1]) 1000 | elif multi_class == 'multinomial': 1001 | log_reg.classes_ = np.unique(y_train) 1002 | else: 1003 | raise ValueError("multi_class should be either multinomial or ovr, " 1004 | "got %d" % multi_class) 1005 | 1006 | if pos_class is not None: 1007 | mask = (y_test == pos_class) 1008 | y_test = np.ones(y_test.shape, dtype=np.float64) 1009 | y_test[~mask] = -1. 1010 | 1011 | scores = list() 1012 | 1013 | if isinstance(scoring, six.string_types): 1014 | scoring = get_scorer(scoring) 1015 | for w in coefs: 1016 | if multi_class == 'ovr': 1017 | w = w[np.newaxis, :] 1018 | if fit_intercept: 1019 | log_reg.coef_ = w[:, :-1] 1020 | log_reg.intercept_ = w[:, -1] 1021 | else: 1022 | log_reg.coef_ = w 1023 | log_reg.intercept_ = 0. 1024 | 1025 | if scoring is None: 1026 | scores.append(log_reg.score(X_test, y_test)) 1027 | else: 1028 | scores.append(scoring(log_reg, X_test, y_test)) 1029 | return coefs, Cs, np.array(scores), n_iter 1030 | 1031 | 1032 | class LogisticRegression(BaseEstimator, LinearClassifierMixin, 1033 | SparseCoefMixin): 1034 | """Logistic Regression (aka logit, MaxEnt) classifier. 1035 | 1036 | In the multiclass case, the training algorithm uses the one-vs-rest (OvR) 1037 | scheme if the 'multi_class' option is set to 'ovr', and uses the cross- 1038 | entropy loss if the 'multi_class' option is set to 'multinomial'. 1039 | (Currently the 'multinomial' option is supported only by the 'lbfgs', 1040 | 'sag' and 'newton-cg' solvers.) 1041 | 1042 | This class implements regularized logistic regression using the 1043 | 'liblinear' library, 'newton-cg', 'sag' and 'lbfgs' solvers. It can handle 1044 | both dense and sparse input. Use C-ordered arrays or CSR matrices 1045 | containing 64-bit floats for optimal performance; any other input format 1046 | will be converted (and copied). 1047 | 1048 | The 'newton-cg', 'sag', and 'lbfgs' solvers support only L2 regularization 1049 | with primal formulation. The 'liblinear' solver supports both L1 and L2 1050 | regularization, with a dual formulation only for the L2 penalty. 1051 | 1052 | Read more in the :ref:`User Guide `. 1053 | 1054 | Parameters 1055 | ---------- 1056 | penalty : str, 'l1' or 'l2', default: 'l2' 1057 | Used to specify the norm used in the penalization. The 'newton-cg', 1058 | 'sag' and 'lbfgs' solvers support only l2 penalties. 1059 | 1060 | .. versionadded:: 0.19 1061 | l1 penalty with SAGA solver (allowing 'multinomial' + L1) 1062 | 1063 | dual : bool, default: False 1064 | Dual or primal formulation. Dual formulation is only implemented for 1065 | l2 penalty with liblinear solver. Prefer dual=False when 1066 | n_samples > n_features. 1067 | 1068 | tol : float, default: 1e-4 1069 | Tolerance for stopping criteria. 1070 | 1071 | C : float, default: 1.0 1072 | Inverse of regularization strength; must be a positive float. 1073 | Like in support vector machines, smaller values specify stronger 1074 | regularization. 1075 | 1076 | fit_intercept : bool, default: True 1077 | Specifies if a constant (a.k.a. bias or intercept) should be 1078 | added to the decision function. 1079 | 1080 | intercept_scaling : float, default 1. 1081 | Useful only when the solver 'liblinear' is used 1082 | and self.fit_intercept is set to True. In this case, x becomes 1083 | [x, self.intercept_scaling], 1084 | i.e. a "synthetic" feature with constant value equal to 1085 | intercept_scaling is appended to the instance vector. 1086 | The intercept becomes ``intercept_scaling * synthetic_feature_weight``. 1087 | 1088 | Note! the synthetic feature weight is subject to l1/l2 regularization 1089 | as all other features. 1090 | To lessen the effect of regularization on synthetic feature weight 1091 | (and therefore on the intercept) intercept_scaling has to be increased. 1092 | 1093 | class_weight : dict or 'balanced', default: None 1094 | Weights associated with classes in the form ``{class_label: weight}``. 1095 | If not given, all classes are supposed to have weight one. 1096 | 1097 | The "balanced" mode uses the values of y to automatically adjust 1098 | weights inversely proportional to class frequencies in the input data 1099 | as ``n_samples / (n_classes * np.bincount(y))``. 1100 | 1101 | Note that these weights will be multiplied with sample_weight (passed 1102 | through the fit method) if sample_weight is specified. 1103 | 1104 | .. versionadded:: 0.17 1105 | *class_weight='balanced'* 1106 | 1107 | random_state : int, RandomState instance or None, optional, default: None 1108 | The seed of the pseudo random number generator to use when shuffling 1109 | the data. If int, random_state is the seed used by the random number 1110 | generator; If RandomState instance, random_state is the random number 1111 | generator; If None, the random number generator is the RandomState 1112 | instance used by `np.random`. Used when ``solver`` == 'sag' or 1113 | 'liblinear'. 1114 | 1115 | solver : str, {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, \ 1116 | default: 'liblinear'. 1117 | 1118 | Algorithm to use in the optimization problem. 1119 | 1120 | - For small datasets, 'liblinear' is a good choice, whereas 'sag' and 1121 | 'saga' are faster for large ones. 1122 | - For multiclass problems, only 'newton-cg', 'sag', 'saga' and 'lbfgs' 1123 | handle multinomial loss; 'liblinear' is limited to one-versus-rest 1124 | schemes. 1125 | - 'newton-cg', 'lbfgs' and 'sag' only handle L2 penalty, whereas 1126 | 'liblinear' and 'saga' handle L1 penalty. 1127 | 1128 | Note that 'sag' and 'saga' fast convergence is only guaranteed on 1129 | features with approximately the same scale. You can 1130 | preprocess the data with a scaler from sklearn.preprocessing. 1131 | 1132 | .. versionadded:: 0.17 1133 | Stochastic Average Gradient descent solver. 1134 | .. versionadded:: 0.19 1135 | SAGA solver. 1136 | .. versionchanged:: 0.20 1137 | Default will change from 'liblinear' to 'lbfgs' in 0.22. 1138 | 1139 | max_iter : int, default: 100 1140 | Useful only for the newton-cg, sag and lbfgs solvers. 1141 | Maximum number of iterations taken for the solvers to converge. 1142 | 1143 | multi_class : str, {'ovr', 'multinomial', 'auto'}, default: 'ovr' 1144 | If the option chosen is 'ovr', then a binary problem is fit for each 1145 | label. For 'multinomial' the loss minimised is the multinomial loss fit 1146 | across the entire probability distribution, *even when the data is 1147 | binary*. 'multinomial' is unavailable when solver='liblinear'. 1148 | 'auto' selects 'ovr' if the data is binary, or if solver='liblinear', 1149 | and otherwise selects 'multinomial'. 1150 | 1151 | .. versionadded:: 0.18 1152 | Stochastic Average Gradient descent solver for 'multinomial' case. 1153 | .. versionchanged:: 0.20 1154 | Default will change from 'ovr' to 'auto' in 0.22. 1155 | 1156 | verbose : int, default: 0 1157 | For the liblinear and lbfgs solvers set verbose to any positive 1158 | number for verbosity. 1159 | 1160 | warm_start : bool, default: False 1161 | When set to True, reuse the solution of the previous call to fit as 1162 | initialization, otherwise, just erase the previous solution. 1163 | Useless for liblinear solver. See :term:`the Glossary `. 1164 | 1165 | .. versionadded:: 0.17 1166 | *warm_start* to support *lbfgs*, *newton-cg*, *sag*, *saga* solvers. 1167 | 1168 | n_jobs : int or None, optional (default=None) 1169 | Number of CPU cores used when parallelizing over classes if 1170 | multi_class='ovr'". This parameter is ignored when the ``solver`` is 1171 | set to 'liblinear' regardless of whether 'multi_class' is specified or 1172 | not. ``None`` means 1 unless in a :obj:`joblib.parallel_backend` 1173 | context. ``-1`` means using all processors. 1174 | See :term:`Glossary ` for more details. 1175 | 1176 | coef_ : array-like, shape (1, n_features,), default None 1177 | Initialization value for coefficients of logistic regression. 1178 | 1179 | Attributes 1180 | ---------- 1181 | 1182 | classes_ : array, shape (n_classes, ) 1183 | A list of class labels known to the classifier. 1184 | 1185 | coef_ : array, shape (1, n_features) or (n_classes, n_features) 1186 | Coefficient of the features in the decision function. 1187 | 1188 | `coef_` is of shape (1, n_features) when the given problem is binary. 1189 | In particular, when `multi_class='multinomial'`, `coef_` corresponds 1190 | to outcome 1 (True) and `-coef_` corresponds to outcome 0 (False). 1191 | 1192 | intercept_ : array, shape (1,) or (n_classes,) 1193 | Intercept (a.k.a. bias) added to the decision function. 1194 | 1195 | If `fit_intercept` is set to False, the intercept is set to zero. 1196 | `intercept_` is of shape (1,) when the given problem is binary. 1197 | In particular, when `multi_class='multinomial'`, `intercept_` 1198 | corresponds to outcome 1 (True) and `-intercept_` corresponds to 1199 | outcome 0 (False). 1200 | 1201 | n_iter_ : array, shape (n_classes,) or (1, ) 1202 | Actual number of iterations for all classes. If binary or multinomial, 1203 | it returns only 1 element. For liblinear solver, only the maximum 1204 | number of iteration across all classes is given. 1205 | 1206 | .. versionchanged:: 0.20 1207 | 1208 | In SciPy <= 1.0.0 the number of lbfgs iterations may exceed 1209 | ``max_iter``. ``n_iter_`` will now report at most ``max_iter``. 1210 | 1211 | See also 1212 | -------- 1213 | SGDClassifier : incrementally trained logistic regression (when given 1214 | the parameter ``loss="log"``). 1215 | LogisticRegressionCV : Logistic regression with built-in cross validation 1216 | 1217 | Notes 1218 | ----- 1219 | The underlying C implementation uses a random number generator to 1220 | select features when fitting the model. It is thus not uncommon, 1221 | to have slightly different results for the same input data. If 1222 | that happens, try with a smaller tol parameter. 1223 | 1224 | Predict output may not match that of standalone liblinear in certain 1225 | cases. See :ref:`differences from liblinear ` 1226 | in the narrative documentation. 1227 | 1228 | References 1229 | ---------- 1230 | 1231 | LIBLINEAR -- A Library for Large Linear Classification 1232 | https://www.csie.ntu.edu.tw/~cjlin/liblinear/ 1233 | 1234 | SAG -- Mark Schmidt, Nicolas Le Roux, and Francis Bach 1235 | Minimizing Finite Sums with the Stochastic Average Gradient 1236 | https://hal.inria.fr/hal-00860051/document 1237 | 1238 | SAGA -- Defazio, A., Bach F. & Lacoste-Julien S. (2014). 1239 | SAGA: A Fast Incremental Gradient Method With Support 1240 | for Non-Strongly Convex Composite Objectives 1241 | https://arxiv.org/abs/1407.0202 1242 | 1243 | Hsiang-Fu Yu, Fang-Lan Huang, Chih-Jen Lin (2011). Dual coordinate descent 1244 | methods for logistic regression and maximum entropy models. 1245 | Machine Learning 85(1-2):41-75. 1246 | https://www.csie.ntu.edu.tw/~cjlin/papers/maxent_dual.pdf 1247 | """ 1248 | 1249 | def __init__(self, penalty='l2', dual=False, tol=1e-4, C=1.0, 1250 | fit_intercept=True, intercept_scaling=1, class_weight=None, 1251 | random_state=None, solver='liblinear', max_iter=100, 1252 | multi_class='ovr', verbose=0, warm_start=False, coef_=None, n_jobs=1): 1253 | 1254 | self.penalty = penalty 1255 | self.dual = dual 1256 | self.tol = tol 1257 | self.C = C 1258 | self.fit_intercept = fit_intercept 1259 | self.intercept_scaling = intercept_scaling 1260 | self.class_weight = class_weight 1261 | self.random_state = random_state 1262 | self.solver = solver 1263 | self.max_iter = max_iter 1264 | self.multi_class = multi_class 1265 | self.verbose = verbose 1266 | self.warm_start = warm_start 1267 | self.coef_ = coef_ 1268 | self.n_jobs = n_jobs 1269 | 1270 | def fit(self, X, y, sample_weight=None): 1271 | """Fit the model according to the given training data. 1272 | 1273 | Parameters 1274 | ---------- 1275 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 1276 | Training vector, where n_samples is the number of samples and 1277 | n_features is the number of features. 1278 | 1279 | y : array-like, shape (n_samples,) 1280 | Target vector relative to X. 1281 | 1282 | sample_weight : array-like, shape (n_samples,) optional 1283 | Array of weights that are assigned to individual samples. 1284 | If not provided, then each sample is given unit weight. 1285 | 1286 | .. versionadded:: 0.17 1287 | *sample_weight* support to LogisticRegression. 1288 | 1289 | Returns 1290 | ------- 1291 | self : object 1292 | """ 1293 | if not isinstance(self.C, numbers.Number) or self.C < 0: 1294 | raise ValueError("Penalty term must be positive; got (C=%r)" 1295 | % self.C) 1296 | if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0: 1297 | raise ValueError("Maximum number of iteration must be positive;" 1298 | " got (max_iter=%r)" % self.max_iter) 1299 | if not isinstance(self.tol, numbers.Number) or self.tol < 0: 1300 | raise ValueError("Tolerance for stopping criteria must be " 1301 | "positive; got (tol=%r)" % self.tol) 1302 | 1303 | solver = _check_solver(self.solver, self.penalty, self.dual) 1304 | 1305 | if solver in ['newton-cg']: 1306 | _dtype = [np.float64, np.float32] 1307 | else: 1308 | _dtype = np.float64 1309 | 1310 | X, y = check_X_y(X, y, accept_sparse='csr', dtype=_dtype, order="C") 1311 | check_classification_targets(y) 1312 | self.classes_ = np.unique(y) 1313 | n_samples, n_features = X.shape 1314 | 1315 | multi_class = _check_multi_class(self.multi_class, solver, 1316 | len(self.classes_)) 1317 | 1318 | if solver == 'liblinear': 1319 | if effective_n_jobs(self.n_jobs) != 1: 1320 | warnings.warn("'n_jobs' > 1 does not have any effect when" 1321 | " 'solver' is set to 'liblinear'. Got 'n_jobs'" 1322 | " = {}.".format(effective_n_jobs(self.n_jobs))) 1323 | self.coef_, self.intercept_, n_iter_ = _fit_liblinear( 1324 | X, y, self.C, self.fit_intercept, self.intercept_scaling, 1325 | self.class_weight, self.penalty, self.dual, self.verbose, 1326 | self.max_iter, self.tol, self.random_state, 1327 | sample_weight=sample_weight) 1328 | self.n_iter_ = np.array([n_iter_]) 1329 | return self 1330 | 1331 | if solver in ['sag', 'saga']: 1332 | max_squared_sum = row_norms(X, squared=True).max() 1333 | else: 1334 | max_squared_sum = None 1335 | 1336 | n_classes = len(self.classes_) 1337 | classes_ = self.classes_ 1338 | if n_classes < 2: 1339 | raise ValueError("This solver needs samples of at least 2 classes" 1340 | " in the data, but the data contains only one" 1341 | " class: %r" % classes_[0]) 1342 | 1343 | if len(self.classes_) == 2: 1344 | n_classes = 1 1345 | classes_ = classes_[1:] 1346 | 1347 | if self.warm_start: 1348 | warm_start_coef = getattr(self, 'coef_', None) 1349 | else: 1350 | warm_start_coef = None 1351 | if warm_start_coef is not None and self.fit_intercept: 1352 | warm_start_coef = np.append(warm_start_coef, 1353 | self.intercept_[:, np.newaxis], 1354 | axis=1) 1355 | 1356 | self.coef_ = list() 1357 | self.intercept_ = np.zeros(n_classes) 1358 | 1359 | # Hack so that we iterate only once for the multinomial case. 1360 | if multi_class == 'multinomial': 1361 | classes_ = [None] 1362 | warm_start_coef = [warm_start_coef] 1363 | if warm_start_coef is None: 1364 | warm_start_coef = [None] * n_classes 1365 | 1366 | path_func = delayed(logistic_regression_path) 1367 | 1368 | # The SAG solver releases the GIL so it's more efficient to use 1369 | # threads for this solver. 1370 | if solver in ['sag', 'saga']: 1371 | prefer = 'threads' 1372 | else: 1373 | prefer = 'processes' 1374 | fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, 1375 | **_joblib_parallel_args(prefer=prefer))( 1376 | path_func(X, y, pos_class=class_, Cs=[self.C], 1377 | fit_intercept=self.fit_intercept, tol=self.tol, 1378 | verbose=self.verbose, solver=solver, 1379 | multi_class=multi_class, max_iter=self.max_iter, 1380 | class_weight=self.class_weight, check_input=False, 1381 | random_state=self.random_state, coef=warm_start_coef_, 1382 | penalty=self.penalty, 1383 | max_squared_sum=max_squared_sum, 1384 | sample_weight=sample_weight) 1385 | for class_, warm_start_coef_ in zip(classes_, warm_start_coef)) 1386 | 1387 | fold_coefs_, _, n_iter_ = zip(*fold_coefs_) 1388 | self.n_iter_ = np.asarray(n_iter_, dtype=np.int32)[:, 0] 1389 | 1390 | if multi_class == 'multinomial': 1391 | self.coef_ = fold_coefs_[0][0] 1392 | else: 1393 | self.coef_ = np.asarray(fold_coefs_) 1394 | self.coef_ = self.coef_.reshape(n_classes, n_features + 1395 | int(self.fit_intercept)) 1396 | 1397 | if self.fit_intercept: 1398 | self.intercept_ = self.coef_[:, -1] 1399 | self.coef_ = self.coef_[:, :-1] 1400 | 1401 | return self 1402 | 1403 | def predict_proba(self, X): 1404 | """Probability estimates. 1405 | 1406 | The returned estimates for all classes are ordered by the 1407 | label of classes. 1408 | 1409 | For a multi_class problem, if multi_class is set to be "multinomial" 1410 | the softmax function is used to find the predicted probability of 1411 | each class. 1412 | Else use a one-vs-rest approach, i.e calculate the probability 1413 | of each class assuming it to be positive using the logistic function. 1414 | and normalize these values across all the classes. 1415 | 1416 | Parameters 1417 | ---------- 1418 | X : array-like, shape = [n_samples, n_features] 1419 | 1420 | Returns 1421 | ------- 1422 | T : array-like, shape = [n_samples, n_classes] 1423 | Returns the probability of the sample for each class in the model, 1424 | where classes are ordered as they are in ``self.classes_``. 1425 | """ 1426 | if not hasattr(self, "coef_"): 1427 | raise NotFittedError("Call fit before prediction") 1428 | 1429 | ovr = (self.multi_class in ["ovr", "warn"] or 1430 | (self.multi_class == 'auto' and (self.classes_.size <= 2 or 1431 | self.solver == 'liblinear'))) 1432 | if ovr: 1433 | return super(LogisticRegression, self)._predict_proba_lr(X) 1434 | else: 1435 | decision = self.decision_function(X) 1436 | if decision.ndim == 1: 1437 | # Workaround for multi_class="multinomial" and binary outcomes 1438 | # which requires softmax prediction with only a 1D decision. 1439 | decision_2d = np.c_[-decision, decision] 1440 | else: 1441 | decision_2d = decision 1442 | return softmax(decision_2d, copy=False) 1443 | 1444 | def predict_log_proba(self, X): 1445 | """Log of probability estimates. 1446 | 1447 | The returned estimates for all classes are ordered by the 1448 | label of classes. 1449 | 1450 | Parameters 1451 | ---------- 1452 | X : array-like, shape = [n_samples, n_features] 1453 | 1454 | Returns 1455 | ------- 1456 | T : array-like, shape = [n_samples, n_classes] 1457 | Returns the log-probability of the sample for each class in the 1458 | model, where classes are ordered as they are in ``self.classes_``. 1459 | """ 1460 | return np.log(self.predict_proba(X)) 1461 | 1462 | 1463 | class LogisticRegressionCV(LogisticRegression, BaseEstimator, 1464 | LinearClassifierMixin): 1465 | """Logistic Regression CV (aka logit, MaxEnt) classifier. 1466 | 1467 | See glossary entry for :term:`cross-validation estimator`. 1468 | 1469 | This class implements logistic regression using liblinear, newton-cg, sag 1470 | of lbfgs optimizer. The newton-cg, sag and lbfgs solvers support only L2 1471 | regularization with primal formulation. The liblinear solver supports both 1472 | L1 and L2 regularization, with a dual formulation only for the L2 penalty. 1473 | 1474 | For the grid of Cs values (that are set by default to be ten values in 1475 | a logarithmic scale between 1e-4 and 1e4), the best hyperparameter is 1476 | selected by the cross-validator StratifiedKFold, but it can be changed 1477 | using the cv parameter. In the case of newton-cg and lbfgs solvers, 1478 | we warm start along the path i.e guess the initial coefficients of the 1479 | present fit to be the coefficients got after convergence in the previous 1480 | fit, so it is supposed to be faster for high-dimensional dense data. 1481 | 1482 | For a multiclass problem, the hyperparameters for each class are computed 1483 | using the best scores got by doing a one-vs-rest in parallel across all 1484 | folds and classes. Hence this is not the true multinomial loss. 1485 | 1486 | Read more in the :ref:`User Guide `. 1487 | 1488 | Parameters 1489 | ---------- 1490 | Cs : list of floats | int 1491 | Each of the values in Cs describes the inverse of regularization 1492 | strength. If Cs is as an int, then a grid of Cs values are chosen 1493 | in a logarithmic scale between 1e-4 and 1e4. 1494 | Like in support vector machines, smaller values specify stronger 1495 | regularization. 1496 | 1497 | fit_intercept : bool, default: True 1498 | Specifies if a constant (a.k.a. bias or intercept) should be 1499 | added to the decision function. 1500 | 1501 | cv : integer or cross-validation generator, default: None 1502 | The default cross-validation generator used is Stratified K-Folds. 1503 | If an integer is provided, then it is the number of folds used. 1504 | See the module :mod:`sklearn.model_selection` module for the 1505 | list of possible cross-validation objects. 1506 | 1507 | .. versionchanged:: 0.20 1508 | ``cv`` default value if None will change from 3-fold to 5-fold 1509 | in v0.22. 1510 | 1511 | dual : bool 1512 | Dual or primal formulation. Dual formulation is only implemented for 1513 | l2 penalty with liblinear solver. Prefer dual=False when 1514 | n_samples > n_features. 1515 | 1516 | penalty : str, 'l1' or 'l2' 1517 | Used to specify the norm used in the penalization. The 'newton-cg', 1518 | 'sag' and 'lbfgs' solvers support only l2 penalties. 1519 | 1520 | scoring : string, callable, or None 1521 | A string (see model evaluation documentation) or 1522 | a scorer callable object / function with signature 1523 | ``scorer(estimator, X, y)``. For a list of scoring functions 1524 | that can be used, look at :mod:`sklearn.metrics`. The 1525 | default scoring option used is 'accuracy'. 1526 | 1527 | solver : str, {'newton-cg', 'lbfgs', 'liblinear', 'sag', 'saga'}, \ 1528 | default: 'lbfgs'. 1529 | 1530 | Algorithm to use in the optimization problem. 1531 | 1532 | - For small datasets, 'liblinear' is a good choice, whereas 'sag' and 1533 | 'saga' are faster for large ones. 1534 | - For multiclass problems, only 'newton-cg', 'sag', 'saga' and 'lbfgs' 1535 | handle multinomial loss; 'liblinear' is limited to one-versus-rest 1536 | schemes. 1537 | - 'newton-cg', 'lbfgs' and 'sag' only handle L2 penalty, whereas 1538 | 'liblinear' and 'saga' handle L1 penalty. 1539 | - 'liblinear' might be slower in LogisticRegressionCV because it does 1540 | not handle warm-starting. 1541 | 1542 | Note that 'sag' and 'saga' fast convergence is only guaranteed on 1543 | features with approximately the same scale. You can preprocess the data 1544 | with a scaler from sklearn.preprocessing. 1545 | 1546 | .. versionadded:: 0.17 1547 | Stochastic Average Gradient descent solver. 1548 | .. versionadded:: 0.19 1549 | SAGA solver. 1550 | 1551 | tol : float, optional 1552 | Tolerance for stopping criteria. 1553 | 1554 | max_iter : int, optional 1555 | Maximum number of iterations of the optimization algorithm. 1556 | 1557 | class_weight : dict or 'balanced', optional 1558 | Weights associated with classes in the form ``{class_label: weight}``. 1559 | If not given, all classes are supposed to have weight one. 1560 | 1561 | The "balanced" mode uses the values of y to automatically adjust 1562 | weights inversely proportional to class frequencies in the input data 1563 | as ``n_samples / (n_classes * np.bincount(y))``. 1564 | 1565 | Note that these weights will be multiplied with sample_weight (passed 1566 | through the fit method) if sample_weight is specified. 1567 | 1568 | .. versionadded:: 0.17 1569 | class_weight == 'balanced' 1570 | 1571 | n_jobs : int or None, optional (default=None) 1572 | Number of CPU cores used during the cross-validation loop. 1573 | ``None`` means 1 unless in a :obj:`joblib.parallel_backend` context. 1574 | ``-1`` means using all processors. See :term:`Glossary ` 1575 | for more details. 1576 | 1577 | verbose : int 1578 | For the 'liblinear', 'sag' and 'lbfgs' solvers set verbose to any 1579 | positive number for verbosity. 1580 | 1581 | refit : bool 1582 | If set to True, the scores are averaged across all folds, and the 1583 | coefs and the C that corresponds to the best score is taken, and a 1584 | final refit is done using these parameters. 1585 | Otherwise the coefs, intercepts and C that correspond to the 1586 | best scores across folds are averaged. 1587 | 1588 | intercept_scaling : float, default 1. 1589 | Useful only when the solver 'liblinear' is used 1590 | and self.fit_intercept is set to True. In this case, x becomes 1591 | [x, self.intercept_scaling], 1592 | i.e. a "synthetic" feature with constant value equal to 1593 | intercept_scaling is appended to the instance vector. 1594 | The intercept becomes ``intercept_scaling * synthetic_feature_weight``. 1595 | 1596 | Note! the synthetic feature weight is subject to l1/l2 regularization 1597 | as all other features. 1598 | To lessen the effect of regularization on synthetic feature weight 1599 | (and therefore on the intercept) intercept_scaling has to be increased. 1600 | 1601 | multi_class : str, {'ovr', 'multinomial', 'auto'}, default: 'ovr' 1602 | If the option chosen is 'ovr', then a binary problem is fit for each 1603 | label. For 'multinomial' the loss minimised is the multinomial loss fit 1604 | across the entire probability distribution, *even when the data is 1605 | binary*. 'multinomial' is unavailable when solver='liblinear'. 1606 | 'auto' selects 'ovr' if the data is binary, or if solver='liblinear', 1607 | and otherwise selects 'multinomial'. 1608 | 1609 | .. versionadded:: 0.18 1610 | Stochastic Average Gradient descent solver for 'multinomial' case. 1611 | .. versionchanged:: 0.20 1612 | Default will change from 'ovr' to 'auto' in 0.22. 1613 | 1614 | random_state : int, RandomState instance or None, optional, default None 1615 | If int, random_state is the seed used by the random number generator; 1616 | If RandomState instance, random_state is the random number generator; 1617 | If None, the random number generator is the RandomState instance used 1618 | by `np.random`. 1619 | 1620 | Attributes 1621 | ---------- 1622 | classes_ : array, shape (n_classes, ) 1623 | A list of class labels known to the classifier. 1624 | 1625 | coef_ : array, shape (1, n_features) or (n_classes, n_features) 1626 | Coefficient of the features in the decision function. 1627 | 1628 | `coef_` is of shape (1, n_features) when the given problem 1629 | is binary. 1630 | 1631 | intercept_ : array, shape (1,) or (n_classes,) 1632 | Intercept (a.k.a. bias) added to the decision function. 1633 | 1634 | If `fit_intercept` is set to False, the intercept is set to zero. 1635 | `intercept_` is of shape(1,) when the problem is binary. 1636 | 1637 | Cs_ : array 1638 | Array of C i.e. inverse of regularization parameter values used 1639 | for cross-validation. 1640 | 1641 | coefs_paths_ : array, shape ``(n_folds, len(Cs_), n_features)`` or \ 1642 | ``(n_folds, len(Cs_), n_features + 1)`` 1643 | dict with classes as the keys, and the path of coefficients obtained 1644 | during cross-validating across each fold and then across each Cs 1645 | after doing an OvR for the corresponding class as values. 1646 | If the 'multi_class' option is set to 'multinomial', then 1647 | the coefs_paths are the coefficients corresponding to each class. 1648 | Each dict value has shape ``(n_folds, len(Cs_), n_features)`` or 1649 | ``(n_folds, len(Cs_), n_features + 1)`` depending on whether the 1650 | intercept is fit or not. 1651 | 1652 | scores_ : dict 1653 | dict with classes as the keys, and the values as the 1654 | grid of scores obtained during cross-validating each fold, after doing 1655 | an OvR for the corresponding class. If the 'multi_class' option 1656 | given is 'multinomial' then the same scores are repeated across 1657 | all classes, since this is the multinomial class. 1658 | Each dict value has shape (n_folds, len(Cs)) 1659 | 1660 | C_ : array, shape (n_classes,) or (n_classes - 1,) 1661 | Array of C that maps to the best scores across every class. If refit is 1662 | set to False, then for each class, the best C is the average of the 1663 | C's that correspond to the best scores for each fold. 1664 | `C_` is of shape(n_classes,) when the problem is binary. 1665 | 1666 | n_iter_ : array, shape (n_classes, n_folds, n_cs) or (1, n_folds, n_cs) 1667 | Actual number of iterations for all classes, folds and Cs. 1668 | In the binary or multinomial cases, the first dimension is equal to 1. 1669 | 1670 | See also 1671 | -------- 1672 | LogisticRegression 1673 | 1674 | """ 1675 | def __init__(self, Cs=10, fit_intercept=True, cv='warn', dual=False, 1676 | penalty='l2', scoring=None, solver='lbfgs', tol=1e-4, 1677 | max_iter=100, class_weight=None, n_jobs=None, verbose=0, 1678 | refit=True, intercept_scaling=1., multi_class='warn', 1679 | random_state=None): 1680 | self.Cs = Cs 1681 | self.fit_intercept = fit_intercept 1682 | self.cv = cv 1683 | self.dual = dual 1684 | self.penalty = penalty 1685 | self.scoring = scoring 1686 | self.tol = tol 1687 | self.max_iter = max_iter 1688 | self.class_weight = class_weight 1689 | self.n_jobs = n_jobs 1690 | self.verbose = verbose 1691 | self.solver = solver 1692 | self.refit = refit 1693 | self.intercept_scaling = intercept_scaling 1694 | self.multi_class = multi_class 1695 | self.random_state = random_state 1696 | 1697 | def fit(self, X, y, sample_weight=None): 1698 | """Fit the model according to the given training data. 1699 | 1700 | Parameters 1701 | ---------- 1702 | X : {array-like, sparse matrix}, shape (n_samples, n_features) 1703 | Training vector, where n_samples is the number of samples and 1704 | n_features is the number of features. 1705 | 1706 | y : array-like, shape (n_samples,) 1707 | Target vector relative to X. 1708 | 1709 | sample_weight : array-like, shape (n_samples,) optional 1710 | Array of weights that are assigned to individual samples. 1711 | If not provided, then each sample is given unit weight. 1712 | 1713 | Returns 1714 | ------- 1715 | self : object 1716 | """ 1717 | solver = _check_solver(self.solver, self.penalty, self.dual) 1718 | 1719 | if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0: 1720 | raise ValueError("Maximum number of iteration must be positive;" 1721 | " got (max_iter=%r)" % self.max_iter) 1722 | if not isinstance(self.tol, numbers.Number) or self.tol < 0: 1723 | raise ValueError("Tolerance for stopping criteria must be " 1724 | "positive; got (tol=%r)" % self.tol) 1725 | 1726 | X, y = check_X_y(X, y, accept_sparse='csr', dtype=np.float64, 1727 | order="C", 1728 | accept_large_sparse=solver != 'liblinear') 1729 | check_classification_targets(y) 1730 | 1731 | class_weight = self.class_weight 1732 | 1733 | # Encode for string labels 1734 | label_encoder = LabelEncoder().fit(y) 1735 | y = label_encoder.transform(y) 1736 | if isinstance(class_weight, dict): 1737 | class_weight = dict((label_encoder.transform([cls])[0], v) 1738 | for cls, v in class_weight.items()) 1739 | 1740 | # The original class labels 1741 | classes = self.classes_ = label_encoder.classes_ 1742 | encoded_labels = label_encoder.transform(label_encoder.classes_) 1743 | 1744 | multi_class = _check_multi_class(self.multi_class, solver, 1745 | len(classes)) 1746 | 1747 | if solver in ['sag', 'saga']: 1748 | max_squared_sum = row_norms(X, squared=True).max() 1749 | else: 1750 | max_squared_sum = None 1751 | 1752 | # init cross-validation generator 1753 | cv = check_cv(self.cv, y, classifier=True) 1754 | folds = list(cv.split(X, y)) 1755 | 1756 | # Use the label encoded classes 1757 | n_classes = len(encoded_labels) 1758 | 1759 | if n_classes < 2: 1760 | raise ValueError("This solver needs samples of at least 2 classes" 1761 | " in the data, but the data contains only one" 1762 | " class: %r" % classes[0]) 1763 | 1764 | if n_classes == 2: 1765 | # OvR in case of binary problems is as good as fitting 1766 | # the higher label 1767 | n_classes = 1 1768 | encoded_labels = encoded_labels[1:] 1769 | classes = classes[1:] 1770 | 1771 | # We need this hack to iterate only once over labels, in the case of 1772 | # multi_class = multinomial, without changing the value of the labels. 1773 | if multi_class == 'multinomial': 1774 | iter_encoded_labels = iter_classes = [None] 1775 | else: 1776 | iter_encoded_labels = encoded_labels 1777 | iter_classes = classes 1778 | 1779 | # compute the class weights for the entire dataset y 1780 | if class_weight == "balanced": 1781 | class_weight = compute_class_weight(class_weight, 1782 | np.arange(len(self.classes_)), 1783 | y) 1784 | class_weight = dict(enumerate(class_weight)) 1785 | 1786 | path_func = delayed(_log_reg_scoring_path) 1787 | 1788 | # The SAG solver releases the GIL so it's more efficient to use 1789 | # threads for this solver. 1790 | if self.solver in ['sag', 'saga']: 1791 | prefer = 'threads' 1792 | else: 1793 | prefer = 'processes' 1794 | fold_coefs_ = Parallel(n_jobs=self.n_jobs, verbose=self.verbose, 1795 | **_joblib_parallel_args(prefer=prefer))( 1796 | path_func(X, y, train, test, pos_class=label, Cs=self.Cs, 1797 | fit_intercept=self.fit_intercept, penalty=self.penalty, 1798 | dual=self.dual, solver=solver, tol=self.tol, 1799 | max_iter=self.max_iter, verbose=self.verbose, 1800 | class_weight=class_weight, scoring=self.scoring, 1801 | multi_class=multi_class, 1802 | intercept_scaling=self.intercept_scaling, 1803 | random_state=self.random_state, 1804 | max_squared_sum=max_squared_sum, 1805 | sample_weight=sample_weight 1806 | ) 1807 | for label in iter_encoded_labels 1808 | for train, test in folds) 1809 | 1810 | if multi_class == 'multinomial': 1811 | multi_coefs_paths, Cs, multi_scores, n_iter_ = zip(*fold_coefs_) 1812 | multi_coefs_paths = np.asarray(multi_coefs_paths) 1813 | multi_scores = np.asarray(multi_scores) 1814 | 1815 | # This is just to maintain API similarity between the ovr and 1816 | # multinomial option. 1817 | # Coefs_paths in now n_folds X len(Cs) X n_classes X n_features 1818 | # we need it to be n_classes X len(Cs) X n_folds X n_features 1819 | # to be similar to "ovr". 1820 | coefs_paths = np.rollaxis(multi_coefs_paths, 2, 0) 1821 | 1822 | # Multinomial has a true score across all labels. Hence the 1823 | # shape is n_folds X len(Cs). We need to repeat this score 1824 | # across all labels for API similarity. 1825 | scores = np.tile(multi_scores, (n_classes, 1, 1)) 1826 | self.Cs_ = Cs[0] 1827 | self.n_iter_ = np.reshape(n_iter_, (1, len(folds), 1828 | len(self.Cs_))) 1829 | 1830 | else: 1831 | coefs_paths, Cs, scores, n_iter_ = zip(*fold_coefs_) 1832 | self.Cs_ = Cs[0] 1833 | coefs_paths = np.reshape(coefs_paths, (n_classes, len(folds), 1834 | len(self.Cs_), -1)) 1835 | self.n_iter_ = np.reshape(n_iter_, (n_classes, len(folds), 1836 | len(self.Cs_))) 1837 | 1838 | self.coefs_paths_ = dict(zip(classes, coefs_paths)) 1839 | scores = np.reshape(scores, (n_classes, len(folds), -1)) 1840 | self.scores_ = dict(zip(classes, scores)) 1841 | 1842 | self.C_ = list() 1843 | self.coef_ = np.empty((n_classes, X.shape[1])) 1844 | self.intercept_ = np.zeros(n_classes) 1845 | 1846 | # hack to iterate only once for multinomial case. 1847 | if multi_class == 'multinomial': 1848 | scores = multi_scores 1849 | coefs_paths = multi_coefs_paths 1850 | 1851 | for index, (cls, encoded_label) in enumerate( 1852 | zip(iter_classes, iter_encoded_labels)): 1853 | 1854 | if multi_class == 'ovr': 1855 | # The scores_ / coefs_paths_ dict have unencoded class 1856 | # labels as their keys 1857 | scores = self.scores_[cls] 1858 | coefs_paths = self.coefs_paths_[cls] 1859 | 1860 | if self.refit: 1861 | best_index = scores.sum(axis=0).argmax() 1862 | 1863 | C_ = self.Cs_[best_index] 1864 | self.C_.append(C_) 1865 | if multi_class == 'multinomial': 1866 | coef_init = np.mean(coefs_paths[:, best_index, :, :], 1867 | axis=0) 1868 | else: 1869 | coef_init = np.mean(coefs_paths[:, best_index, :], axis=0) 1870 | 1871 | # Note that y is label encoded and hence pos_class must be 1872 | # the encoded label / None (for 'multinomial') 1873 | w, _, _ = logistic_regression_path( 1874 | X, y, pos_class=encoded_label, Cs=[C_], solver=solver, 1875 | fit_intercept=self.fit_intercept, coef=coef_init, 1876 | max_iter=self.max_iter, tol=self.tol, 1877 | penalty=self.penalty, 1878 | class_weight=class_weight, 1879 | multi_class=multi_class, 1880 | verbose=max(0, self.verbose - 1), 1881 | random_state=self.random_state, 1882 | check_input=False, max_squared_sum=max_squared_sum, 1883 | sample_weight=sample_weight) 1884 | w = w[0] 1885 | 1886 | else: 1887 | # Take the best scores across every fold and the average of all 1888 | # coefficients corresponding to the best scores. 1889 | best_indices = np.argmax(scores, axis=1) 1890 | w = np.mean([coefs_paths[i][best_indices[i]] 1891 | for i in range(len(folds))], axis=0) 1892 | self.C_.append(np.mean(self.Cs_[best_indices])) 1893 | 1894 | if multi_class == 'multinomial': 1895 | self.C_ = np.tile(self.C_, n_classes) 1896 | self.coef_ = w[:, :X.shape[1]] 1897 | if self.fit_intercept: 1898 | self.intercept_ = w[:, -1] 1899 | else: 1900 | self.coef_[index] = w[: X.shape[1]] 1901 | if self.fit_intercept: 1902 | self.intercept_[index] = w[-1] 1903 | 1904 | self.C_ = np.asarray(self.C_) 1905 | return self 1906 | 1907 | def score(self, X, y, sample_weight=None): 1908 | """Returns the score using the `scoring` option on the given 1909 | test data and labels. 1910 | 1911 | Parameters 1912 | ---------- 1913 | X : array-like, shape = (n_samples, n_features) 1914 | Test samples. 1915 | 1916 | y : array-like, shape = (n_samples,) 1917 | True labels for X. 1918 | 1919 | sample_weight : array-like, shape = [n_samples], optional 1920 | Sample weights. 1921 | 1922 | Returns 1923 | ------- 1924 | score : float 1925 | Score of self.predict(X) wrt. y. 1926 | 1927 | """ 1928 | 1929 | if self.scoring is not None: 1930 | warnings.warn("The long-standing behavior to use the " 1931 | "accuracy score has changed. The scoring " 1932 | "parameter is now used. " 1933 | "This warning will disappear in version 0.22.", 1934 | ChangedBehaviorWarning) 1935 | scoring = self.scoring or 'accuracy' 1936 | if isinstance(scoring, six.string_types): 1937 | scoring = get_scorer(scoring) 1938 | 1939 | return scoring(self, X, y, sample_weight=sample_weight) 1940 | -------------------------------------------------------------------------------- /algo/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCLA-StarAI/LogisticCircuit/00c95c689d81c4abbdf20f81a7a29fc2a14fc1e4/algo/__init__.py -------------------------------------------------------------------------------- /balanced.vtree: -------------------------------------------------------------------------------- 1 | c ids of vtree nodes start at 0 2 | c ids of variables start at 1 3 | c vtree nodes appear bottom-up, children before parents 4 | c 5 | c file syntax: 6 | c vtree number-of-nodes-in-vtree 7 | c L id-of-leaf-vtree-node id-of-variable 8 | c I id-of-internal-vtree-node id-of-left-child id-of-right-child 9 | c 10 | vtree 1567 11 | L 0 1 12 | L 1 2 13 | L 2 3 14 | L 3 4 15 | L 4 5 16 | L 5 6 17 | L 6 7 18 | L 7 8 19 | L 8 9 20 | L 9 10 21 | L 10 11 22 | L 11 12 23 | L 12 13 24 | L 13 14 25 | L 14 15 26 | L 15 16 27 | L 16 17 28 | L 17 18 29 | L 18 19 30 | L 19 20 31 | L 20 21 32 | L 21 22 33 | L 22 23 34 | L 23 24 35 | L 24 25 36 | L 25 26 37 | L 26 27 38 | L 27 28 39 | L 28 29 40 | L 29 30 41 | L 30 31 42 | L 31 32 43 | L 32 33 44 | L 33 34 45 | L 34 35 46 | L 35 36 47 | L 36 37 48 | L 37 38 49 | L 38 39 50 | L 39 40 51 | L 40 41 52 | L 41 42 53 | L 42 43 54 | L 43 44 55 | L 44 45 56 | L 45 46 57 | L 46 47 58 | L 47 48 59 | L 48 49 60 | L 49 50 61 | L 50 51 62 | L 51 52 63 | L 52 53 64 | L 53 54 65 | L 54 55 66 | L 55 56 67 | L 56 57 68 | L 57 58 69 | L 58 59 70 | L 59 60 71 | L 60 61 72 | L 61 62 73 | L 62 63 74 | L 63 64 75 | L 64 65 76 | L 65 66 77 | L 66 67 78 | L 67 68 79 | L 68 69 80 | L 69 70 81 | L 70 71 82 | L 71 72 83 | L 72 73 84 | L 73 74 85 | L 74 75 86 | L 75 76 87 | L 76 77 88 | L 77 78 89 | L 78 79 90 | L 79 80 91 | L 80 81 92 | L 81 82 93 | L 82 83 94 | L 83 84 95 | L 84 85 96 | L 85 86 97 | L 86 87 98 | L 87 88 99 | L 88 89 100 | L 89 90 101 | L 90 91 102 | L 91 92 103 | L 92 93 104 | L 93 94 105 | L 94 95 106 | L 95 96 107 | L 96 97 108 | L 97 98 109 | L 98 99 110 | L 99 100 111 | L 100 101 112 | L 101 102 113 | L 102 103 114 | L 103 104 115 | L 104 105 116 | L 105 106 117 | L 106 107 118 | L 107 108 119 | L 108 109 120 | L 109 110 121 | L 110 111 122 | L 111 112 123 | L 112 113 124 | L 113 114 125 | L 114 115 126 | L 115 116 127 | L 116 117 128 | L 117 118 129 | L 118 119 130 | L 119 120 131 | L 120 121 132 | L 121 122 133 | L 122 123 134 | L 123 124 135 | L 124 125 136 | L 125 126 137 | L 126 127 138 | L 127 128 139 | L 128 129 140 | L 129 130 141 | L 130 131 142 | L 131 132 143 | L 132 133 144 | L 133 134 145 | L 134 135 146 | L 135 136 147 | L 136 137 148 | L 137 138 149 | L 138 139 150 | L 139 140 151 | L 140 141 152 | L 141 142 153 | L 142 143 154 | L 143 144 155 | L 144 145 156 | L 145 146 157 | L 146 147 158 | L 147 148 159 | L 148 149 160 | L 149 150 161 | L 150 151 162 | L 151 152 163 | L 152 153 164 | L 153 154 165 | L 154 155 166 | L 155 156 167 | L 156 157 168 | L 157 158 169 | L 158 159 170 | L 159 160 171 | L 160 161 172 | L 161 162 173 | L 162 163 174 | L 163 164 175 | L 164 165 176 | L 165 166 177 | L 166 167 178 | L 167 168 179 | L 168 169 180 | L 169 170 181 | L 170 171 182 | L 171 172 183 | L 172 173 184 | L 173 174 185 | L 174 175 186 | L 175 176 187 | L 176 177 188 | L 177 178 189 | L 178 179 190 | L 179 180 191 | L 180 181 192 | L 181 182 193 | L 182 183 194 | L 183 184 195 | L 184 185 196 | L 185 186 197 | L 186 187 198 | L 187 188 199 | L 188 189 200 | L 189 190 201 | L 190 191 202 | L 191 192 203 | L 192 193 204 | L 193 194 205 | L 194 195 206 | L 195 196 207 | L 196 197 208 | L 197 198 209 | L 198 199 210 | L 199 200 211 | L 200 201 212 | L 201 202 213 | L 202 203 214 | L 203 204 215 | L 204 205 216 | L 205 206 217 | L 206 207 218 | L 207 208 219 | L 208 209 220 | L 209 210 221 | L 210 211 222 | L 211 212 223 | L 212 213 224 | L 213 214 225 | L 214 215 226 | L 215 216 227 | L 216 217 228 | L 217 218 229 | L 218 219 230 | L 219 220 231 | L 220 221 232 | L 221 222 233 | L 222 223 234 | L 223 224 235 | L 224 225 236 | L 225 226 237 | L 226 227 238 | L 227 228 239 | L 228 229 240 | L 229 230 241 | L 230 231 242 | L 231 232 243 | L 232 233 244 | L 233 234 245 | L 234 235 246 | L 235 236 247 | L 236 237 248 | L 237 238 249 | L 238 239 250 | L 239 240 251 | L 240 241 252 | L 241 242 253 | L 242 243 254 | L 243 244 255 | L 244 245 256 | L 245 246 257 | L 246 247 258 | L 247 248 259 | L 248 249 260 | L 249 250 261 | L 250 251 262 | L 251 252 263 | L 252 253 264 | L 253 254 265 | L 254 255 266 | L 255 256 267 | L 256 257 268 | L 257 258 269 | L 258 259 270 | L 259 260 271 | L 260 261 272 | L 261 262 273 | L 262 263 274 | L 263 264 275 | L 264 265 276 | L 265 266 277 | L 266 267 278 | L 267 268 279 | L 268 269 280 | L 269 270 281 | L 270 271 282 | L 271 272 283 | L 272 273 284 | L 273 274 285 | L 274 275 286 | L 275 276 287 | L 276 277 288 | L 277 278 289 | L 278 279 290 | L 279 280 291 | L 280 281 292 | L 281 282 293 | L 282 283 294 | L 283 284 295 | L 284 285 296 | L 285 286 297 | L 286 287 298 | L 287 288 299 | L 288 289 300 | L 289 290 301 | L 290 291 302 | L 291 292 303 | L 292 293 304 | L 293 294 305 | L 294 295 306 | L 295 296 307 | L 296 297 308 | L 297 298 309 | L 298 299 310 | L 299 300 311 | L 300 301 312 | L 301 302 313 | L 302 303 314 | L 303 304 315 | L 304 305 316 | L 305 306 317 | L 306 307 318 | L 307 308 319 | L 308 309 320 | L 309 310 321 | L 310 311 322 | L 311 312 323 | L 312 313 324 | L 313 314 325 | L 314 315 326 | L 315 316 327 | L 316 317 328 | L 317 318 329 | L 318 319 330 | L 319 320 331 | L 320 321 332 | L 321 322 333 | L 322 323 334 | L 323 324 335 | L 324 325 336 | L 325 326 337 | L 326 327 338 | L 327 328 339 | L 328 329 340 | L 329 330 341 | L 330 331 342 | L 331 332 343 | L 332 333 344 | L 333 334 345 | L 334 335 346 | L 335 336 347 | L 336 337 348 | L 337 338 349 | L 338 339 350 | L 339 340 351 | L 340 341 352 | L 341 342 353 | L 342 343 354 | L 343 344 355 | L 344 345 356 | L 345 346 357 | L 346 347 358 | L 347 348 359 | L 348 349 360 | L 349 350 361 | L 350 351 362 | L 351 352 363 | L 352 353 364 | L 353 354 365 | L 354 355 366 | L 355 356 367 | L 356 357 368 | L 357 358 369 | L 358 359 370 | L 359 360 371 | L 360 361 372 | L 361 362 373 | L 362 363 374 | L 363 364 375 | L 364 365 376 | L 365 366 377 | L 366 367 378 | L 367 368 379 | L 368 369 380 | L 369 370 381 | L 370 371 382 | L 371 372 383 | L 372 373 384 | L 373 374 385 | L 374 375 386 | L 375 376 387 | L 376 377 388 | L 377 378 389 | L 378 379 390 | L 379 380 391 | L 380 381 392 | L 381 382 393 | L 382 383 394 | L 383 384 395 | L 384 385 396 | L 385 386 397 | L 386 387 398 | L 387 388 399 | L 388 389 400 | L 389 390 401 | L 390 391 402 | L 391 392 403 | L 392 393 404 | L 393 394 405 | L 394 395 406 | L 395 396 407 | L 396 397 408 | L 397 398 409 | L 398 399 410 | L 399 400 411 | L 400 401 412 | L 401 402 413 | L 402 403 414 | L 403 404 415 | L 404 405 416 | L 405 406 417 | L 406 407 418 | L 407 408 419 | L 408 409 420 | L 409 410 421 | L 410 411 422 | L 411 412 423 | L 412 413 424 | L 413 414 425 | L 414 415 426 | L 415 416 427 | L 416 417 428 | L 417 418 429 | L 418 419 430 | L 419 420 431 | L 420 421 432 | L 421 422 433 | L 422 423 434 | L 423 424 435 | L 424 425 436 | L 425 426 437 | L 426 427 438 | L 427 428 439 | L 428 429 440 | L 429 430 441 | L 430 431 442 | L 431 432 443 | L 432 433 444 | L 433 434 445 | L 434 435 446 | L 435 436 447 | L 436 437 448 | L 437 438 449 | L 438 439 450 | L 439 440 451 | L 440 441 452 | L 441 442 453 | L 442 443 454 | L 443 444 455 | L 444 445 456 | L 445 446 457 | L 446 447 458 | L 447 448 459 | L 448 449 460 | L 449 450 461 | L 450 451 462 | L 451 452 463 | L 452 453 464 | L 453 454 465 | L 454 455 466 | L 455 456 467 | L 456 457 468 | L 457 458 469 | L 458 459 470 | L 459 460 471 | L 460 461 472 | L 461 462 473 | L 462 463 474 | L 463 464 475 | L 464 465 476 | L 465 466 477 | L 466 467 478 | L 467 468 479 | L 468 469 480 | L 469 470 481 | L 470 471 482 | L 471 472 483 | L 472 473 484 | L 473 474 485 | L 474 475 486 | L 475 476 487 | L 476 477 488 | L 477 478 489 | L 478 479 490 | L 479 480 491 | L 480 481 492 | L 481 482 493 | L 482 483 494 | L 483 484 495 | L 484 485 496 | L 485 486 497 | L 486 487 498 | L 487 488 499 | L 488 489 500 | L 489 490 501 | L 490 491 502 | L 491 492 503 | L 492 493 504 | L 493 494 505 | L 494 495 506 | L 495 496 507 | L 496 497 508 | L 497 498 509 | L 498 499 510 | L 499 500 511 | L 500 501 512 | L 501 502 513 | L 502 503 514 | L 503 504 515 | L 504 505 516 | L 505 506 517 | L 506 507 518 | L 507 508 519 | L 508 509 520 | L 509 510 521 | L 510 511 522 | L 511 512 523 | L 512 513 524 | L 513 514 525 | L 514 515 526 | L 515 516 527 | L 516 517 528 | L 517 518 529 | L 518 519 530 | L 519 520 531 | L 520 521 532 | L 521 522 533 | L 522 523 534 | L 523 524 535 | L 524 525 536 | L 525 526 537 | L 526 527 538 | L 527 528 539 | L 528 529 540 | L 529 530 541 | L 530 531 542 | L 531 532 543 | L 532 533 544 | L 533 534 545 | L 534 535 546 | L 535 536 547 | L 536 537 548 | L 537 538 549 | L 538 539 550 | L 539 540 551 | L 540 541 552 | L 541 542 553 | L 542 543 554 | L 543 544 555 | L 544 545 556 | L 545 546 557 | L 546 547 558 | L 547 548 559 | L 548 549 560 | L 549 550 561 | L 550 551 562 | L 551 552 563 | L 552 553 564 | L 553 554 565 | L 554 555 566 | L 555 556 567 | L 556 557 568 | L 557 558 569 | L 558 559 570 | L 559 560 571 | L 560 561 572 | L 561 562 573 | L 562 563 574 | L 563 564 575 | L 564 565 576 | L 565 566 577 | L 566 567 578 | L 567 568 579 | L 568 569 580 | L 569 570 581 | L 570 571 582 | L 571 572 583 | L 572 573 584 | L 573 574 585 | L 574 575 586 | L 575 576 587 | L 576 577 588 | L 577 578 589 | L 578 579 590 | L 579 580 591 | L 580 581 592 | L 581 582 593 | L 582 583 594 | L 583 584 595 | L 584 585 596 | L 585 586 597 | L 586 587 598 | L 587 588 599 | L 588 589 600 | L 589 590 601 | L 590 591 602 | L 591 592 603 | L 592 593 604 | L 593 594 605 | L 594 595 606 | L 595 596 607 | L 596 597 608 | L 597 598 609 | L 598 599 610 | L 599 600 611 | L 600 601 612 | L 601 602 613 | L 602 603 614 | L 603 604 615 | L 604 605 616 | L 605 606 617 | L 606 607 618 | L 607 608 619 | L 608 609 620 | L 609 610 621 | L 610 611 622 | L 611 612 623 | L 612 613 624 | L 613 614 625 | L 614 615 626 | L 615 616 627 | L 616 617 628 | L 617 618 629 | L 618 619 630 | L 619 620 631 | L 620 621 632 | L 621 622 633 | L 622 623 634 | L 623 624 635 | L 624 625 636 | L 625 626 637 | L 626 627 638 | L 627 628 639 | L 628 629 640 | L 629 630 641 | L 630 631 642 | L 631 632 643 | L 632 633 644 | L 633 634 645 | L 634 635 646 | L 635 636 647 | L 636 637 648 | L 637 638 649 | L 638 639 650 | L 639 640 651 | L 640 641 652 | L 641 642 653 | L 642 643 654 | L 643 644 655 | L 644 645 656 | L 645 646 657 | L 646 647 658 | L 647 648 659 | L 648 649 660 | L 649 650 661 | L 650 651 662 | L 651 652 663 | L 652 653 664 | L 653 654 665 | L 654 655 666 | L 655 656 667 | L 656 657 668 | L 657 658 669 | L 658 659 670 | L 659 660 671 | L 660 661 672 | L 661 662 673 | L 662 663 674 | L 663 664 675 | L 664 665 676 | L 665 666 677 | L 666 667 678 | L 667 668 679 | L 668 669 680 | L 669 670 681 | L 670 671 682 | L 671 672 683 | L 672 673 684 | L 673 674 685 | L 674 675 686 | L 675 676 687 | L 676 677 688 | L 677 678 689 | L 678 679 690 | L 679 680 691 | L 680 681 692 | L 681 682 693 | L 682 683 694 | L 683 684 695 | L 684 685 696 | L 685 686 697 | L 686 687 698 | L 687 688 699 | L 688 689 700 | L 689 690 701 | L 690 691 702 | L 691 692 703 | L 692 693 704 | L 693 694 705 | L 694 695 706 | L 695 696 707 | L 696 697 708 | L 697 698 709 | L 698 699 710 | L 699 700 711 | L 700 701 712 | L 701 702 713 | L 702 703 714 | L 703 704 715 | L 704 705 716 | L 705 706 717 | L 706 707 718 | L 707 708 719 | L 708 709 720 | L 709 710 721 | L 710 711 722 | L 711 712 723 | L 712 713 724 | L 713 714 725 | L 714 715 726 | L 715 716 727 | L 716 717 728 | L 717 718 729 | L 718 719 730 | L 719 720 731 | L 720 721 732 | L 721 722 733 | L 722 723 734 | L 723 724 735 | L 724 725 736 | L 725 726 737 | L 726 727 738 | L 727 728 739 | L 728 729 740 | L 729 730 741 | L 730 731 742 | L 731 732 743 | L 732 733 744 | L 733 734 745 | L 734 735 746 | L 735 736 747 | L 736 737 748 | L 737 738 749 | L 738 739 750 | L 739 740 751 | L 740 741 752 | L 741 742 753 | L 742 743 754 | L 743 744 755 | L 744 745 756 | L 745 746 757 | L 746 747 758 | L 747 748 759 | L 748 749 760 | L 749 750 761 | L 750 751 762 | L 751 752 763 | L 752 753 764 | L 753 754 765 | L 754 755 766 | L 755 756 767 | L 756 757 768 | L 757 758 769 | L 758 759 770 | L 759 760 771 | L 760 761 772 | L 761 762 773 | L 762 763 774 | L 763 764 775 | L 764 765 776 | L 765 766 777 | L 766 767 778 | L 767 768 779 | L 768 769 780 | L 769 770 781 | L 770 771 782 | L 771 772 783 | L 772 773 784 | L 773 774 785 | L 774 775 786 | L 775 776 787 | L 776 777 788 | L 777 778 789 | L 778 779 790 | L 779 780 791 | L 780 781 792 | L 781 782 793 | L 782 783 794 | L 783 784 795 | I 784 0 1 796 | I 785 2 3 797 | I 786 4 5 798 | I 787 6 7 799 | I 788 8 9 800 | I 789 10 11 801 | I 790 12 13 802 | I 791 14 15 803 | I 792 16 17 804 | I 793 18 19 805 | I 794 20 21 806 | I 795 22 23 807 | I 796 24 25 808 | I 797 26 27 809 | I 798 28 29 810 | I 799 30 31 811 | I 800 32 33 812 | I 801 34 35 813 | I 802 36 37 814 | I 803 38 39 815 | I 804 40 41 816 | I 805 42 43 817 | I 806 44 45 818 | I 807 46 47 819 | I 808 48 49 820 | I 809 50 51 821 | I 810 52 53 822 | I 811 54 55 823 | I 812 56 57 824 | I 813 58 59 825 | I 814 60 61 826 | I 815 62 63 827 | I 816 64 65 828 | I 817 66 67 829 | I 818 68 69 830 | I 819 70 71 831 | I 820 72 73 832 | I 821 74 75 833 | I 822 76 77 834 | I 823 78 79 835 | I 824 80 81 836 | I 825 82 83 837 | I 826 84 85 838 | I 827 86 87 839 | I 828 88 89 840 | I 829 90 91 841 | I 830 92 93 842 | I 831 94 95 843 | I 832 96 97 844 | I 833 98 99 845 | I 834 100 101 846 | I 835 102 103 847 | I 836 104 105 848 | I 837 106 107 849 | I 838 108 109 850 | I 839 110 111 851 | I 840 112 113 852 | I 841 114 115 853 | I 842 116 117 854 | I 843 118 119 855 | I 844 120 121 856 | I 845 122 123 857 | I 846 124 125 858 | I 847 126 127 859 | I 848 128 129 860 | I 849 130 131 861 | I 850 132 133 862 | I 851 134 135 863 | I 852 136 137 864 | I 853 138 139 865 | I 854 140 141 866 | I 855 142 143 867 | I 856 144 145 868 | I 857 146 147 869 | I 858 148 149 870 | I 859 150 151 871 | I 860 152 153 872 | I 861 154 155 873 | I 862 156 157 874 | I 863 158 159 875 | I 864 160 161 876 | I 865 162 163 877 | I 866 164 165 878 | I 867 166 167 879 | I 868 168 169 880 | I 869 170 171 881 | I 870 172 173 882 | I 871 174 175 883 | I 872 176 177 884 | I 873 178 179 885 | I 874 180 181 886 | I 875 182 183 887 | I 876 184 185 888 | I 877 186 187 889 | I 878 188 189 890 | I 879 190 191 891 | I 880 192 193 892 | I 881 194 195 893 | I 882 196 197 894 | I 883 198 199 895 | I 884 200 201 896 | I 885 202 203 897 | I 886 204 205 898 | I 887 206 207 899 | I 888 208 209 900 | I 889 210 211 901 | I 890 212 213 902 | I 891 214 215 903 | I 892 216 217 904 | I 893 218 219 905 | I 894 220 221 906 | I 895 222 223 907 | I 896 224 225 908 | I 897 226 227 909 | I 898 228 229 910 | I 899 230 231 911 | I 900 232 233 912 | I 901 234 235 913 | I 902 236 237 914 | I 903 238 239 915 | I 904 240 241 916 | I 905 242 243 917 | I 906 244 245 918 | I 907 246 247 919 | I 908 248 249 920 | I 909 250 251 921 | I 910 252 253 922 | I 911 254 255 923 | I 912 256 257 924 | I 913 258 259 925 | I 914 260 261 926 | I 915 262 263 927 | I 916 264 265 928 | I 917 266 267 929 | I 918 268 269 930 | I 919 270 271 931 | I 920 272 273 932 | I 921 274 275 933 | I 922 276 277 934 | I 923 278 279 935 | I 924 280 281 936 | I 925 282 283 937 | I 926 284 285 938 | I 927 286 287 939 | I 928 288 289 940 | I 929 290 291 941 | I 930 292 293 942 | I 931 294 295 943 | I 932 296 297 944 | I 933 298 299 945 | I 934 300 301 946 | I 935 302 303 947 | I 936 304 305 948 | I 937 306 307 949 | I 938 308 309 950 | I 939 310 311 951 | I 940 312 313 952 | I 941 314 315 953 | I 942 316 317 954 | I 943 318 319 955 | I 944 320 321 956 | I 945 322 323 957 | I 946 324 325 958 | I 947 326 327 959 | I 948 328 329 960 | I 949 330 331 961 | I 950 332 333 962 | I 951 334 335 963 | I 952 336 337 964 | I 953 338 339 965 | I 954 340 341 966 | I 955 342 343 967 | I 956 344 345 968 | I 957 346 347 969 | I 958 348 349 970 | I 959 350 351 971 | I 960 352 353 972 | I 961 354 355 973 | I 962 356 357 974 | I 963 358 359 975 | I 964 360 361 976 | I 965 362 363 977 | I 966 364 365 978 | I 967 366 367 979 | I 968 368 369 980 | I 969 370 371 981 | I 970 372 373 982 | I 971 374 375 983 | I 972 376 377 984 | I 973 378 379 985 | I 974 380 381 986 | I 975 382 383 987 | I 976 384 385 988 | I 977 386 387 989 | I 978 388 389 990 | I 979 390 391 991 | I 980 392 393 992 | I 981 394 395 993 | I 982 396 397 994 | I 983 398 399 995 | I 984 400 401 996 | I 985 402 403 997 | I 986 404 405 998 | I 987 406 407 999 | I 988 408 409 1000 | I 989 410 411 1001 | I 990 412 413 1002 | I 991 414 415 1003 | I 992 416 417 1004 | I 993 418 419 1005 | I 994 420 421 1006 | I 995 422 423 1007 | I 996 424 425 1008 | I 997 426 427 1009 | I 998 428 429 1010 | I 999 430 431 1011 | I 1000 432 433 1012 | I 1001 434 435 1013 | I 1002 436 437 1014 | I 1003 438 439 1015 | I 1004 440 441 1016 | I 1005 442 443 1017 | I 1006 444 445 1018 | I 1007 446 447 1019 | I 1008 448 449 1020 | I 1009 450 451 1021 | I 1010 452 453 1022 | I 1011 454 455 1023 | I 1012 456 457 1024 | I 1013 458 459 1025 | I 1014 460 461 1026 | I 1015 462 463 1027 | I 1016 464 465 1028 | I 1017 466 467 1029 | I 1018 468 469 1030 | I 1019 470 471 1031 | I 1020 472 473 1032 | I 1021 474 475 1033 | I 1022 476 477 1034 | I 1023 478 479 1035 | I 1024 480 481 1036 | I 1025 482 483 1037 | I 1026 484 485 1038 | I 1027 486 487 1039 | I 1028 488 489 1040 | I 1029 490 491 1041 | I 1030 492 493 1042 | I 1031 494 495 1043 | I 1032 496 497 1044 | I 1033 498 499 1045 | I 1034 500 501 1046 | I 1035 502 503 1047 | I 1036 504 505 1048 | I 1037 506 507 1049 | I 1038 508 509 1050 | I 1039 510 511 1051 | I 1040 512 513 1052 | I 1041 514 515 1053 | I 1042 516 517 1054 | I 1043 518 519 1055 | I 1044 520 521 1056 | I 1045 522 523 1057 | I 1046 524 525 1058 | I 1047 526 527 1059 | I 1048 528 529 1060 | I 1049 530 531 1061 | I 1050 532 533 1062 | I 1051 534 535 1063 | I 1052 536 537 1064 | I 1053 538 539 1065 | I 1054 540 541 1066 | I 1055 542 543 1067 | I 1056 544 545 1068 | I 1057 546 547 1069 | I 1058 548 549 1070 | I 1059 550 551 1071 | I 1060 552 553 1072 | I 1061 554 555 1073 | I 1062 556 557 1074 | I 1063 558 559 1075 | I 1064 560 561 1076 | I 1065 562 563 1077 | I 1066 564 565 1078 | I 1067 566 567 1079 | I 1068 568 569 1080 | I 1069 570 571 1081 | I 1070 572 573 1082 | I 1071 574 575 1083 | I 1072 576 577 1084 | I 1073 578 579 1085 | I 1074 580 581 1086 | I 1075 582 583 1087 | I 1076 584 585 1088 | I 1077 586 587 1089 | I 1078 588 589 1090 | I 1079 590 591 1091 | I 1080 592 593 1092 | I 1081 594 595 1093 | I 1082 596 597 1094 | I 1083 598 599 1095 | I 1084 600 601 1096 | I 1085 602 603 1097 | I 1086 604 605 1098 | I 1087 606 607 1099 | I 1088 608 609 1100 | I 1089 610 611 1101 | I 1090 612 613 1102 | I 1091 614 615 1103 | I 1092 616 617 1104 | I 1093 618 619 1105 | I 1094 620 621 1106 | I 1095 622 623 1107 | I 1096 624 625 1108 | I 1097 626 627 1109 | I 1098 628 629 1110 | I 1099 630 631 1111 | I 1100 632 633 1112 | I 1101 634 635 1113 | I 1102 636 637 1114 | I 1103 638 639 1115 | I 1104 640 641 1116 | I 1105 642 643 1117 | I 1106 644 645 1118 | I 1107 646 647 1119 | I 1108 648 649 1120 | I 1109 650 651 1121 | I 1110 652 653 1122 | I 1111 654 655 1123 | I 1112 656 657 1124 | I 1113 658 659 1125 | I 1114 660 661 1126 | I 1115 662 663 1127 | I 1116 664 665 1128 | I 1117 666 667 1129 | I 1118 668 669 1130 | I 1119 670 671 1131 | I 1120 672 673 1132 | I 1121 674 675 1133 | I 1122 676 677 1134 | I 1123 678 679 1135 | I 1124 680 681 1136 | I 1125 682 683 1137 | I 1126 684 685 1138 | I 1127 686 687 1139 | I 1128 688 689 1140 | I 1129 690 691 1141 | I 1130 692 693 1142 | I 1131 694 695 1143 | I 1132 696 697 1144 | I 1133 698 699 1145 | I 1134 700 701 1146 | I 1135 702 703 1147 | I 1136 704 705 1148 | I 1137 706 707 1149 | I 1138 708 709 1150 | I 1139 710 711 1151 | I 1140 712 713 1152 | I 1141 714 715 1153 | I 1142 716 717 1154 | I 1143 718 719 1155 | I 1144 720 721 1156 | I 1145 722 723 1157 | I 1146 724 725 1158 | I 1147 726 727 1159 | I 1148 728 729 1160 | I 1149 730 731 1161 | I 1150 732 733 1162 | I 1151 734 735 1163 | I 1152 736 737 1164 | I 1153 738 739 1165 | I 1154 740 741 1166 | I 1155 742 743 1167 | I 1156 744 745 1168 | I 1157 746 747 1169 | I 1158 748 749 1170 | I 1159 750 751 1171 | I 1160 752 753 1172 | I 1161 754 755 1173 | I 1162 756 757 1174 | I 1163 758 759 1175 | I 1164 760 761 1176 | I 1165 762 763 1177 | I 1166 764 765 1178 | I 1167 766 767 1179 | I 1168 768 769 1180 | I 1169 770 771 1181 | I 1170 772 773 1182 | I 1171 774 775 1183 | I 1172 776 777 1184 | I 1173 778 779 1185 | I 1174 780 781 1186 | I 1175 782 783 1187 | I 1176 784 798 1188 | I 1177 785 799 1189 | I 1178 786 800 1190 | I 1179 787 801 1191 | I 1180 788 802 1192 | I 1181 789 803 1193 | I 1182 790 804 1194 | I 1183 791 805 1195 | I 1184 792 806 1196 | I 1185 793 807 1197 | I 1186 794 808 1198 | I 1187 795 809 1199 | I 1188 796 810 1200 | I 1189 797 811 1201 | I 1190 812 826 1202 | I 1191 813 827 1203 | I 1192 814 828 1204 | I 1193 815 829 1205 | I 1194 816 830 1206 | I 1195 817 831 1207 | I 1196 818 832 1208 | I 1197 819 833 1209 | I 1198 820 834 1210 | I 1199 821 835 1211 | I 1200 822 836 1212 | I 1201 823 837 1213 | I 1202 824 838 1214 | I 1203 825 839 1215 | I 1204 840 854 1216 | I 1205 841 855 1217 | I 1206 842 856 1218 | I 1207 843 857 1219 | I 1208 844 858 1220 | I 1209 845 859 1221 | I 1210 846 860 1222 | I 1211 847 861 1223 | I 1212 848 862 1224 | I 1213 849 863 1225 | I 1214 850 864 1226 | I 1215 851 865 1227 | I 1216 852 866 1228 | I 1217 853 867 1229 | I 1218 868 882 1230 | I 1219 869 883 1231 | I 1220 870 884 1232 | I 1221 871 885 1233 | I 1222 872 886 1234 | I 1223 873 887 1235 | I 1224 874 888 1236 | I 1225 875 889 1237 | I 1226 876 890 1238 | I 1227 877 891 1239 | I 1228 878 892 1240 | I 1229 879 893 1241 | I 1230 880 894 1242 | I 1231 881 895 1243 | I 1232 896 910 1244 | I 1233 897 911 1245 | I 1234 898 912 1246 | I 1235 899 913 1247 | I 1236 900 914 1248 | I 1237 901 915 1249 | I 1238 902 916 1250 | I 1239 903 917 1251 | I 1240 904 918 1252 | I 1241 905 919 1253 | I 1242 906 920 1254 | I 1243 907 921 1255 | I 1244 908 922 1256 | I 1245 909 923 1257 | I 1246 924 938 1258 | I 1247 925 939 1259 | I 1248 926 940 1260 | I 1249 927 941 1261 | I 1250 928 942 1262 | I 1251 929 943 1263 | I 1252 930 944 1264 | I 1253 931 945 1265 | I 1254 932 946 1266 | I 1255 933 947 1267 | I 1256 934 948 1268 | I 1257 935 949 1269 | I 1258 936 950 1270 | I 1259 937 951 1271 | I 1260 952 966 1272 | I 1261 953 967 1273 | I 1262 954 968 1274 | I 1263 955 969 1275 | I 1264 956 970 1276 | I 1265 957 971 1277 | I 1266 958 972 1278 | I 1267 959 973 1279 | I 1268 960 974 1280 | I 1269 961 975 1281 | I 1270 962 976 1282 | I 1271 963 977 1283 | I 1272 964 978 1284 | I 1273 965 979 1285 | I 1274 980 994 1286 | I 1275 981 995 1287 | I 1276 982 996 1288 | I 1277 983 997 1289 | I 1278 984 998 1290 | I 1279 985 999 1291 | I 1280 986 1000 1292 | I 1281 987 1001 1293 | I 1282 988 1002 1294 | I 1283 989 1003 1295 | I 1284 990 1004 1296 | I 1285 991 1005 1297 | I 1286 992 1006 1298 | I 1287 993 1007 1299 | I 1288 1008 1022 1300 | I 1289 1009 1023 1301 | I 1290 1010 1024 1302 | I 1291 1011 1025 1303 | I 1292 1012 1026 1304 | I 1293 1013 1027 1305 | I 1294 1014 1028 1306 | I 1295 1015 1029 1307 | I 1296 1016 1030 1308 | I 1297 1017 1031 1309 | I 1298 1018 1032 1310 | I 1299 1019 1033 1311 | I 1300 1020 1034 1312 | I 1301 1021 1035 1313 | I 1302 1036 1050 1314 | I 1303 1037 1051 1315 | I 1304 1038 1052 1316 | I 1305 1039 1053 1317 | I 1306 1040 1054 1318 | I 1307 1041 1055 1319 | I 1308 1042 1056 1320 | I 1309 1043 1057 1321 | I 1310 1044 1058 1322 | I 1311 1045 1059 1323 | I 1312 1046 1060 1324 | I 1313 1047 1061 1325 | I 1314 1048 1062 1326 | I 1315 1049 1063 1327 | I 1316 1064 1078 1328 | I 1317 1065 1079 1329 | I 1318 1066 1080 1330 | I 1319 1067 1081 1331 | I 1320 1068 1082 1332 | I 1321 1069 1083 1333 | I 1322 1070 1084 1334 | I 1323 1071 1085 1335 | I 1324 1072 1086 1336 | I 1325 1073 1087 1337 | I 1326 1074 1088 1338 | I 1327 1075 1089 1339 | I 1328 1076 1090 1340 | I 1329 1077 1091 1341 | I 1330 1092 1106 1342 | I 1331 1093 1107 1343 | I 1332 1094 1108 1344 | I 1333 1095 1109 1345 | I 1334 1096 1110 1346 | I 1335 1097 1111 1347 | I 1336 1098 1112 1348 | I 1337 1099 1113 1349 | I 1338 1100 1114 1350 | I 1339 1101 1115 1351 | I 1340 1102 1116 1352 | I 1341 1103 1117 1353 | I 1342 1104 1118 1354 | I 1343 1105 1119 1355 | I 1344 1120 1134 1356 | I 1345 1121 1135 1357 | I 1346 1122 1136 1358 | I 1347 1123 1137 1359 | I 1348 1124 1138 1360 | I 1349 1125 1139 1361 | I 1350 1126 1140 1362 | I 1351 1127 1141 1363 | I 1352 1128 1142 1364 | I 1353 1129 1143 1365 | I 1354 1130 1144 1366 | I 1355 1131 1145 1367 | I 1356 1132 1146 1368 | I 1357 1133 1147 1369 | I 1358 1148 1162 1370 | I 1359 1149 1163 1371 | I 1360 1150 1164 1372 | I 1361 1151 1165 1373 | I 1362 1152 1166 1374 | I 1363 1153 1167 1375 | I 1364 1154 1168 1376 | I 1365 1155 1169 1377 | I 1366 1156 1170 1378 | I 1367 1157 1171 1379 | I 1368 1158 1172 1380 | I 1369 1159 1173 1381 | I 1370 1160 1174 1382 | I 1371 1161 1175 1383 | I 1372 1176 1177 1384 | I 1373 1178 1179 1385 | I 1374 1180 1181 1386 | I 1375 1182 1183 1387 | I 1376 1184 1185 1388 | I 1377 1186 1187 1389 | I 1378 1188 1189 1390 | I 1379 1190 1191 1391 | I 1380 1192 1193 1392 | I 1381 1194 1195 1393 | I 1382 1196 1197 1394 | I 1383 1198 1199 1395 | I 1384 1200 1201 1396 | I 1385 1202 1203 1397 | I 1386 1204 1205 1398 | I 1387 1206 1207 1399 | I 1388 1208 1209 1400 | I 1389 1210 1211 1401 | I 1390 1212 1213 1402 | I 1391 1214 1215 1403 | I 1392 1216 1217 1404 | I 1393 1218 1219 1405 | I 1394 1220 1221 1406 | I 1395 1222 1223 1407 | I 1396 1224 1225 1408 | I 1397 1226 1227 1409 | I 1398 1228 1229 1410 | I 1399 1230 1231 1411 | I 1400 1232 1233 1412 | I 1401 1234 1235 1413 | I 1402 1236 1237 1414 | I 1403 1238 1239 1415 | I 1404 1240 1241 1416 | I 1405 1242 1243 1417 | I 1406 1244 1245 1418 | I 1407 1246 1247 1419 | I 1408 1248 1249 1420 | I 1409 1250 1251 1421 | I 1410 1252 1253 1422 | I 1411 1254 1255 1423 | I 1412 1256 1257 1424 | I 1413 1258 1259 1425 | I 1414 1260 1261 1426 | I 1415 1262 1263 1427 | I 1416 1264 1265 1428 | I 1417 1266 1267 1429 | I 1418 1268 1269 1430 | I 1419 1270 1271 1431 | I 1420 1272 1273 1432 | I 1421 1274 1275 1433 | I 1422 1276 1277 1434 | I 1423 1278 1279 1435 | I 1424 1280 1281 1436 | I 1425 1282 1283 1437 | I 1426 1284 1285 1438 | I 1427 1286 1287 1439 | I 1428 1288 1289 1440 | I 1429 1290 1291 1441 | I 1430 1292 1293 1442 | I 1431 1294 1295 1443 | I 1432 1296 1297 1444 | I 1433 1298 1299 1445 | I 1434 1300 1301 1446 | I 1435 1302 1303 1447 | I 1436 1304 1305 1448 | I 1437 1306 1307 1449 | I 1438 1308 1309 1450 | I 1439 1310 1311 1451 | I 1440 1312 1313 1452 | I 1441 1314 1315 1453 | I 1442 1316 1317 1454 | I 1443 1318 1319 1455 | I 1444 1320 1321 1456 | I 1445 1322 1323 1457 | I 1446 1324 1325 1458 | I 1447 1326 1327 1459 | I 1448 1328 1329 1460 | I 1449 1330 1331 1461 | I 1450 1332 1333 1462 | I 1451 1334 1335 1463 | I 1452 1336 1337 1464 | I 1453 1338 1339 1465 | I 1454 1340 1341 1466 | I 1455 1342 1343 1467 | I 1456 1344 1345 1468 | I 1457 1346 1347 1469 | I 1458 1348 1349 1470 | I 1459 1350 1351 1471 | I 1460 1352 1353 1472 | I 1461 1354 1355 1473 | I 1462 1356 1357 1474 | I 1463 1358 1359 1475 | I 1464 1360 1361 1476 | I 1465 1362 1363 1477 | I 1466 1364 1365 1478 | I 1467 1366 1367 1479 | I 1468 1368 1369 1480 | I 1469 1370 1371 1481 | I 1470 1372 1379 1482 | I 1471 1373 1380 1483 | I 1472 1374 1381 1484 | I 1473 1375 1382 1485 | I 1474 1376 1383 1486 | I 1475 1377 1384 1487 | I 1476 1378 1385 1488 | I 1477 1386 1393 1489 | I 1478 1387 1394 1490 | I 1479 1388 1395 1491 | I 1480 1389 1396 1492 | I 1481 1390 1397 1493 | I 1482 1391 1398 1494 | I 1483 1392 1399 1495 | I 1484 1400 1407 1496 | I 1485 1401 1408 1497 | I 1486 1402 1409 1498 | I 1487 1403 1410 1499 | I 1488 1404 1411 1500 | I 1489 1405 1412 1501 | I 1490 1406 1413 1502 | I 1491 1414 1421 1503 | I 1492 1415 1422 1504 | I 1493 1416 1423 1505 | I 1494 1417 1424 1506 | I 1495 1418 1425 1507 | I 1496 1419 1426 1508 | I 1497 1420 1427 1509 | I 1498 1428 1435 1510 | I 1499 1429 1436 1511 | I 1500 1430 1437 1512 | I 1501 1431 1438 1513 | I 1502 1432 1439 1514 | I 1503 1433 1440 1515 | I 1504 1434 1441 1516 | I 1505 1442 1449 1517 | I 1506 1443 1450 1518 | I 1507 1444 1451 1519 | I 1508 1445 1452 1520 | I 1509 1446 1453 1521 | I 1510 1447 1454 1522 | I 1511 1448 1455 1523 | I 1512 1456 1463 1524 | I 1513 1457 1464 1525 | I 1514 1458 1465 1526 | I 1515 1459 1466 1527 | I 1516 1460 1467 1528 | I 1517 1461 1468 1529 | I 1518 1462 1469 1530 | I 1519 1470 1471 1531 | I 1520 1472 1473 1532 | I 1521 1474 1475 1533 | I 1522 1477 1478 1534 | I 1523 1479 1480 1535 | I 1524 1481 1482 1536 | I 1525 1484 1485 1537 | I 1526 1486 1487 1538 | I 1527 1488 1489 1539 | I 1528 1491 1492 1540 | I 1529 1493 1494 1541 | I 1530 1495 1496 1542 | I 1531 1498 1499 1543 | I 1532 1500 1501 1544 | I 1533 1502 1503 1545 | I 1534 1505 1506 1546 | I 1535 1507 1508 1547 | I 1536 1509 1510 1548 | I 1537 1512 1513 1549 | I 1538 1514 1515 1550 | I 1539 1516 1517 1551 | I 1540 1519 1522 1552 | I 1541 1520 1523 1553 | I 1542 1521 1524 1554 | I 1543 1476 1483 1555 | I 1544 1525 1528 1556 | I 1545 1526 1529 1557 | I 1546 1527 1530 1558 | I 1547 1490 1497 1559 | I 1548 1531 1534 1560 | I 1549 1532 1535 1561 | I 1550 1533 1536 1562 | I 1551 1504 1511 1563 | I 1552 1540 1541 1564 | I 1553 1542 1543 1565 | I 1554 1544 1545 1566 | I 1555 1546 1547 1567 | I 1556 1548 1549 1568 | I 1557 1550 1551 1569 | I 1558 1537 1538 1570 | I 1559 1539 1518 1571 | I 1560 1552 1554 1572 | I 1561 1553 1555 1573 | I 1562 1556 1558 1574 | I 1563 1557 1559 1575 | I 1564 1560 1561 1576 | I 1565 1562 1563 1577 | I 1566 1564 1565 1578 | -------------------------------------------------------------------------------- /learn.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import time 3 | import warnings 4 | 5 | from algo.LogisticCircuit import LogisticCircuit 6 | from structure.Vtree import Vtree 7 | from util.mnist_data import read_data_sets 8 | 9 | FLAGS = None 10 | 11 | 12 | def main(): 13 | # read dataset and vtree 14 | data = read_data_sets(FLAGS.data_path, FLAGS.percentage) 15 | vtree = Vtree.read(FLAGS.vtree) 16 | 17 | # create a logistic circuit 18 | if FLAGS.circuit != "": 19 | with open(FLAGS.circuit, "r") as circuit_file: 20 | circuit = LogisticCircuit(vtree, FLAGS.num_classes, circuit_file=circuit_file) 21 | print("The saved circuit is successfully loaded.") 22 | data.train.features = circuit.calculate_features(data.train.images) 23 | else: 24 | circuit = LogisticCircuit(vtree, FLAGS.num_classes) 25 | data.train.features = circuit.calculate_features(data.train.images) 26 | circuit.learn_parameters(data.train, 50) 27 | 28 | print(f"The starting circuit has {circuit.num_parameters} parameters.") 29 | data.valid.features = circuit.calculate_features(data.valid.images) 30 | data.test.features = circuit.calculate_features(data.test.images) 31 | valid_accuracy = circuit.calculate_accuracy(data.valid) 32 | print( 33 | f"Its performance is as follows. " 34 | f"Training accuracy: {circuit.calculate_accuracy(data.train):.5f}\t" 35 | f"Valid accuracy: {valid_accuracy:.5f}\t" 36 | f"Test accuracy: {circuit.calculate_accuracy(data.test):.5f}" 37 | ) 38 | 39 | print("Start structure learning.") 40 | 41 | best_accuracy = valid_accuracy 42 | for i in range(FLAGS.num_structure_learning_iterations): 43 | cur_time = time.time() 44 | 45 | circuit.change_structure(data.train, FLAGS.depth, FLAGS.num_splits) 46 | 47 | data.train.features = circuit.calculate_features(data.train.images) 48 | data.valid.features = circuit.calculate_features(data.valid.images) 49 | data.test.features = circuit.calculate_features(data.test.images) 50 | 51 | circuit.learn_parameters(data.train, FLAGS.num_parameter_learning_iterations) 52 | 53 | valid_accuracy = circuit.calculate_accuracy(data.valid) 54 | print( 55 | f"Training accuracy: {circuit.calculate_accuracy(data.train):.5f}\t" 56 | f"Valid accuracy: {valid_accuracy:.5f}\t" 57 | f"Test accuracy: {circuit.calculate_accuracy(data.test):.5f}" 58 | ) 59 | print(f"Num parameters: {circuit.num_parameters}\tTime spent: {(time.time() - cur_time):.2f}") 60 | 61 | if FLAGS.save_path != "" and (valid_accuracy > best_accuracy): 62 | best_accuracy = valid_accuracy 63 | print("Obtained a logistic circuit with higher classification accuracy. Start saving.") 64 | with open(FLAGS.save_path, "w") as circuit_file: 65 | circuit.save(circuit_file) 66 | print("Logistic circuit saved.") 67 | 68 | 69 | if __name__ == "__main__": 70 | parser = argparse.ArgumentParser() 71 | parser.add_argument("--data_path", type=str, help="Directory for the stored input data.") 72 | parser.add_argument("--num_classes", type=int, help="Number of classes in the classification task.") 73 | parser.add_argument("--vtree", type=str, default="balanced.vtree", help="Path for vtree.") 74 | parser.add_argument( 75 | "--circuit", 76 | type=str, 77 | default="", 78 | help="[Optional] File path for the saved logistic circuit to load. " 79 | "Note this circuit has to be based on the same vtree as provided in --vtree.", 80 | ) 81 | parser.add_argument( 82 | "--num_structure_learning_iterations", 83 | type=int, 84 | default=5000, 85 | help="[Optional] Num of iterations for structure learning. Its default value is 5000.", 86 | ) 87 | parser.add_argument( 88 | "--num_parameter_learning_iterations", 89 | type=int, 90 | default=15, 91 | help="[Optional] Number of iterations for parameter learning after the structure is changed." 92 | "Its default value is 15.", 93 | ) 94 | parser.add_argument("--depth", type=int, default=2, help="[Optional] The depth of every split. Its default value is 2.") 95 | parser.add_argument( 96 | "--num_splits", 97 | type=int, 98 | default=3, 99 | help="[Optional] The number of splits in one iteration of structure learning." "It default value is 3.", 100 | ) 101 | parser.add_argument( 102 | "--percentage", 103 | type=float, 104 | default=1.0, 105 | help="[Optional] The percentage of the training dataset that will be used. " "Its default value is 100%%.", 106 | ) 107 | parser.add_argument("--save_path", type=str, default="", help="[Optional] File path to save the best-performing circuit.") 108 | FLAGS = parser.parse_args() 109 | if FLAGS.num_classes == 2: 110 | FLAGS.num_classes = 1 111 | message = ( 112 | "It is essentially a binary classification task when num_classes is set to 2, " 113 | + "and hence we automatically modify it to be 1 to be better compatible with sklearn." 114 | ) 115 | warnings.warn(message, stacklevel=2) 116 | main() 117 | -------------------------------------------------------------------------------- /structure/AndGate.py: -------------------------------------------------------------------------------- 1 | from structure.CircuitNode import CircuitNode 2 | 3 | 4 | class AndGate(object): 5 | """ 6 | And Gate. 7 | We also refer AND Gates as Elements. 8 | In this implementation, we assume every AND gate is the child of one PSDD decision nodes (OR gate). 9 | In another words, they are not shared between different PSDD decision nodes. 10 | """ 11 | 12 | def __init__(self, prime: CircuitNode, sub: CircuitNode, parameter=None): 13 | self._prime = prime 14 | self._sub = sub 15 | self._prime.increase_num_parents_by_one() 16 | self._sub.increase_num_parents_by_one() 17 | # difference between prob and feature: 18 | # prob is calculated in a bottom-up pass and only considers values of variables the element has 19 | # feature is calculated in a top-down pass using probs; equals the WMC of that element reached 20 | self._feature = None 21 | self._prob = None 22 | self._parameter = parameter 23 | self._parent = None 24 | self._splittable_variables = set() 25 | self._flag = False 26 | 27 | @property 28 | def prime(self): 29 | return self._prime 30 | 31 | @prime.setter 32 | def prime(self, value): 33 | self._prime = value 34 | if self._prime is not None: 35 | self._prime.increase_num_parents_by_one() 36 | 37 | @property 38 | def sub(self): 39 | return self._sub 40 | 41 | @sub.setter 42 | def sub(self, value): 43 | self._sub = value 44 | if self._sub is not None: 45 | self._sub.increase_num_parents_by_one() 46 | 47 | @property 48 | def feature(self): 49 | return self._feature 50 | 51 | @feature.setter 52 | def feature(self, value): 53 | self._feature = value 54 | 55 | @property 56 | def prob(self): 57 | return self._prob 58 | 59 | @prob.setter 60 | def prob(self, value): 61 | self._prob = value 62 | 63 | def calculate_prob(self): 64 | self._prob = self._prime.prob + self._sub.prob 65 | 66 | @property 67 | def parameter(self): 68 | return self._parameter 69 | 70 | @parameter.setter 71 | def parameter(self, value): 72 | self._parameter = value 73 | 74 | @property 75 | def parent(self): 76 | return self._parent 77 | 78 | @parent.setter 79 | def parent(self, value): 80 | self._parent = value 81 | 82 | @property 83 | def splittable_variables(self): 84 | return self._splittable_variables 85 | 86 | @splittable_variables.setter 87 | def splittable_variables(self, value): 88 | self._splittable_variables = value 89 | 90 | def remove_splittable_variable(self, variable_to_remove): 91 | self._splittable_variables.discard(variable_to_remove) 92 | 93 | @property 94 | def flag(self): 95 | return self._flag 96 | 97 | @flag.setter 98 | def flag(self, value): 99 | self._flag = value 100 | -------------------------------------------------------------------------------- /structure/CircuitNode.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class CircuitNode(object): 5 | 6 | def __init__(self, index, vtree): 7 | self._index = index 8 | self._vtree = vtree 9 | self._num_parents = 0 10 | # difference between prob and feature: 11 | # prob is calculated in a bottom-up pass and only considers values of variables the node has 12 | # feature is calculated in a top-down pass using probs; equals the WMC of that node reached 13 | self._prob = None 14 | self._feature = None 15 | 16 | @property 17 | def index(self): 18 | return self._index 19 | 20 | @property 21 | def vtree(self): 22 | return self._vtree 23 | 24 | @property 25 | def num_parents(self): 26 | return self._num_parents 27 | 28 | def increase_num_parents_by_one(self): 29 | self._num_parents += 1 30 | 31 | def decrease_num_parents_by_one(self): 32 | self._num_parents -= 1 33 | 34 | @property 35 | def feature(self): 36 | return self._feature 37 | 38 | @feature.setter 39 | def feature(self, value): 40 | self._feature = value 41 | 42 | @property 43 | def prob(self): 44 | return self._prob 45 | 46 | @prob.setter 47 | def prob(self, value): 48 | self._prob = value 49 | 50 | 51 | class OrGate(CircuitNode): 52 | """OR Gate. 53 | Or gates are also referred as Decision nodes.""" 54 | 55 | def __init__(self, index, vtree, elements): 56 | super().__init__(index, vtree) 57 | self._elements = elements 58 | for element in elements: 59 | element.parent = self 60 | 61 | @property 62 | def elements(self): 63 | return self._elements 64 | 65 | def add_element(self, element): 66 | self._elements.append(element) 67 | element.parent = self 68 | 69 | def remove_element(self, index): 70 | del self._elements[index] 71 | 72 | def calculate_prob(self): 73 | if len(self._elements) == 0: 74 | raise ValueError("Decision nodes should have at least one elements.") 75 | for element in self._elements: 76 | element.calculate_prob() 77 | self._prob = np.sum([np.exp(element.prob) for element in self._elements], axis=0) 78 | self._prob = np.where(self._prob < 1e-5, 1e-5, self._prob) 79 | self._prob = np.log(self._prob) 80 | for element in self._elements: 81 | element.prob -= self._prob 82 | self._prob = np.where(self._prob > 0.0, 0.0, self._prob) 83 | self._feature = np.zeros(shape=self._prob.shape, dtype=np.float32) 84 | 85 | def calculate_feature(self): 86 | feature = np.log(self._feature) 87 | for element in self._elements: 88 | element.feature = np.exp(feature + element.prob) 89 | element.prime.feature += element.feature 90 | element.sub.feature += element.feature 91 | 92 | def save(self, f): 93 | f.write(f'D {self._index} {self._vtree.index} {len(self._elements)}') 94 | for element in self._elements: 95 | f.write(f' ({element.prime.index} {element.sub.index}') 96 | for parameter in element.parameter: 97 | f.write(f' {parameter}') 98 | f.write(f')') 99 | f.write('\n') 100 | 101 | 102 | LITERAL_IS_TRUE = 1 103 | LITERAL_IS_FALSE = 0 104 | 105 | 106 | class CircuitTerminal(CircuitNode): 107 | """Terminal(leaf) node.""" 108 | 109 | def __init__(self, index, vtree, var_index, var_value, parameter=None): 110 | super().__init__(index, vtree) 111 | self._var_index = var_index 112 | self._var_value = var_value 113 | self._parameter = parameter 114 | 115 | @property 116 | def var_index(self): 117 | return self._var_index 118 | 119 | @var_index.setter 120 | def var_index(self, value): 121 | self._var_index = value 122 | 123 | @property 124 | def var_value(self): 125 | return self._var_value 126 | 127 | @var_value.setter 128 | def var_value(self, value): 129 | self._var_value = value 130 | 131 | @property 132 | def parameter(self): 133 | return self._parameter 134 | 135 | @parameter.setter 136 | def parameter(self, value): 137 | self._parameter = value 138 | 139 | def calculate_prob(self, samples: np.array): 140 | if self._var_value == LITERAL_IS_TRUE: 141 | self._prob = np.log(samples[:, self._var_index - 1]) 142 | elif self._var_value == LITERAL_IS_FALSE: 143 | self._prob = np.log(1.0 - samples[:, self._var_index - 1]) 144 | else: 145 | raise ValueError('Terminal nodes should either be positive literals or negative literals.') 146 | self._feature = np.zeros(shape=self._prob.shape, dtype=np.float32) 147 | 148 | def save(self, f): 149 | if self._var_value == LITERAL_IS_TRUE: 150 | f.write(f'T {self._index} {self._vtree.index} {self._var_index}') 151 | elif self._var_value == LITERAL_IS_FALSE: 152 | f.write(f'F {self._index} {self._vtree.index} {self._var_index}') 153 | else: 154 | raise ValueError('Currently we only support terminal nodes that are either positive or negative literals.') 155 | for parameter in self._parameter: 156 | f.write(f' {parameter}') 157 | f.write('\n') 158 | -------------------------------------------------------------------------------- /structure/Vtree.py: -------------------------------------------------------------------------------- 1 | from abc import ABC, abstractmethod 2 | 3 | 4 | class Vtree(ABC): 5 | 6 | def __init__(self, index): 7 | self._index = index 8 | 9 | @property 10 | def index(self): 11 | return self._index 12 | 13 | @property 14 | def var_count(self): 15 | return self._var_count 16 | 17 | @abstractmethod 18 | def is_leaf(self): 19 | pass 20 | 21 | @staticmethod 22 | def read(file): 23 | with open(file, 'r') as vtree_file: 24 | line = 'c' 25 | while line[0] == 'c': 26 | line = vtree_file.readline() 27 | if line.strip().split(' ')[0] != 'vtree': 28 | raise ValueError('Number of vtree nodes is not specified') 29 | num_nodes = int(line.strip().split(' ')[1]) 30 | nodes = [None] * num_nodes 31 | root = None 32 | for line in vtree_file.readlines(): 33 | line_as_list = line.strip().split(' ') 34 | if line_as_list[0] == 'L': 35 | root = VtreeLeaf(int(line_as_list[1]), int(line_as_list[2])) 36 | nodes[int(line_as_list[1])] = root 37 | elif line_as_list[0] == 'I': 38 | root = VtreeIntermediate(int(line_as_list[1]), 39 | nodes[int(line_as_list[2])], nodes[int(line_as_list[3])]) 40 | nodes[int(line_as_list[1])] = root 41 | else: 42 | raise ValueError('Vtree node could only be L or I') 43 | return root 44 | 45 | 46 | class VtreeLeaf(Vtree): 47 | 48 | def __init__(self, index, variable): 49 | super(VtreeLeaf, self).__init__(index) 50 | self._var = variable 51 | self._var_count = 1 52 | 53 | def is_leaf(self): 54 | return True 55 | 56 | @property 57 | def var(self): 58 | return self._var 59 | 60 | @property 61 | def variables(self): 62 | return set([self._var]) 63 | 64 | 65 | class VtreeIntermediate(Vtree): 66 | 67 | def __init__(self, index, left, right): 68 | super(VtreeIntermediate, self).__init__(index) 69 | self._left = left 70 | self._right = right 71 | self._variables = set() 72 | self._var_count = self._left.var_count + self._right.var_count 73 | self._variables.update(self._left.variables) 74 | self._variables.update(self._right.variables) 75 | 76 | def is_leaf(self): 77 | return False 78 | 79 | @property 80 | def left(self): 81 | return self._left 82 | 83 | @property 84 | def right(self): 85 | return self._right 86 | 87 | @property 88 | def variables(self): 89 | return self._variables 90 | 91 | @property 92 | def var(self): 93 | return 0 94 | -------------------------------------------------------------------------------- /structure/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCLA-StarAI/LogisticCircuit/00c95c689d81c4abbdf20f81a7a29fc2a14fc1e4/structure/__init__.py -------------------------------------------------------------------------------- /util/DataSet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class DataSet(object): 5 | 6 | def __init__(self, images, labels): 7 | self._images = images 8 | self._labels = labels 9 | self._one_hot_labels = to_one_hot_encoding(labels) 10 | self._features = None 11 | self._num_samples = self._images.shape[0] 12 | self._num_epochs = 0 13 | self._index = 0 14 | 15 | @property 16 | def images(self): 17 | return self._images 18 | 19 | @property 20 | def labels(self): 21 | return self._labels 22 | 23 | @property 24 | def one_hot_labels(self): 25 | return self._one_hot_labels 26 | 27 | @property 28 | def features(self): 29 | return self._features 30 | 31 | @features.setter 32 | def features(self, value): 33 | self._features = value 34 | 35 | @property 36 | def num_samples(self): 37 | return self._num_samples 38 | 39 | @property 40 | def num_epochs(self): 41 | return self._num_epochs 42 | 43 | def next_batch(self, batch_size): 44 | """Return the next `batch_size` examples, features and labels from this data set.""" 45 | assert batch_size <= self._num_samples 46 | 47 | if self._index + batch_size >= self._num_samples: 48 | perm = np.arange(self._num_samples) 49 | np.random.shuffle(perm) 50 | self._images = self._images[perm] 51 | self._labels = self._labels[perm] 52 | self._features = self._features[perm] 53 | self._index = 0 54 | self._num_epochs += 1 55 | 56 | images = self._images[self._index: self._index + batch_size] 57 | labels = self._labels[self._index: self._index + batch_size] 58 | features = self._features[self._index: self._index + batch_size] 59 | self._index += batch_size 60 | return images, features, labels, to_one_hot_encoding(labels) 61 | 62 | 63 | class DataSets(object): 64 | 65 | def __init__(self, train, valid, test): 66 | self._train = train 67 | self._test = test 68 | self._valid = valid 69 | 70 | @property 71 | def train(self): 72 | return self._train 73 | 74 | @property 75 | def valid(self): 76 | return self._valid 77 | 78 | @property 79 | def test(self): 80 | return self._test 81 | 82 | 83 | def to_one_hot_encoding(labels): 84 | num_classes = np.max(labels) + 1 85 | one_hot_labels = np.zeros(shape=(len(labels), num_classes), dtype=np.float32) 86 | for i in range(len(labels)): 87 | one_hot_labels[i][labels[i]] = 1.0 88 | return one_hot_labels 89 | -------------------------------------------------------------------------------- /util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/UCLA-StarAI/LogisticCircuit/00c95c689d81c4abbdf20f81a7a29fc2a14fc1e4/util/__init__.py -------------------------------------------------------------------------------- /util/generate_balanced_vtree.py: -------------------------------------------------------------------------------- 1 | ''' 2 | Generate a balanced vtree. 3 | It takes two arguments: 1) the number of input variables in your dataset 2) the file name to store the vtree. 4 | ''' 5 | 6 | import argparse 7 | 8 | 9 | VTREE_FORMAT = """c ids of vtree nodes start at 0 10 | c ids of variables start at 1 11 | c vtree nodes appear bottom-up, children before parents 12 | c 13 | c file syntax: 14 | c vtree number-of-nodes-in-vtree 15 | c L id-of-leaf-vtree-node id-of-variable 16 | c I id-of-internal-vtree-node id-of-left-child id-of-right-child 17 | c 18 | """ 19 | 20 | 21 | FLAGS = None 22 | 23 | 24 | def main(): 25 | with open(FLAGS.vtree_file, 'w') as f_out: 26 | f_out.write(VTREE_FORMAT) 27 | 28 | num_nodes = FLAGS.num_variables 29 | num_to_be_paired_nodes = FLAGS.num_variables 30 | while num_to_be_paired_nodes > 1: 31 | num_nodes += num_to_be_paired_nodes // 2 32 | num_to_be_paired_nodes -= num_to_be_paired_nodes // 2 33 | f_out.write(f'vtree {num_nodes}\n') 34 | 35 | to_be_paired_nodes = [] 36 | for i in range(FLAGS.num_variables): 37 | f_out.write(f'L {i} {i+1}\n') 38 | to_be_paired_nodes.append(i) 39 | index = FLAGS.num_variables 40 | while len(to_be_paired_nodes) > 1: 41 | f_out.write(f'I {index} {to_be_paired_nodes[0]} {to_be_paired_nodes[1]}\n') 42 | to_be_paired_nodes = to_be_paired_nodes[2:] 43 | to_be_paired_nodes.append(index) 44 | index += 1 45 | 46 | 47 | if __name__ == '__main__': 48 | parser = argparse.ArgumentParser() 49 | parser.add_argument('--num_variables', type=int, 50 | help='The number of input variables in your dataset') 51 | parser.add_argument('--vtree_file', type=str, 52 | help='The file name to store the generated vtee') 53 | FLAGS = parser.parse_args() 54 | main() 55 | -------------------------------------------------------------------------------- /util/mnist_data.py: -------------------------------------------------------------------------------- 1 | import gzip 2 | import os 3 | 4 | import numpy as np 5 | 6 | from util.DataSet import DataSet, DataSets 7 | 8 | 9 | def _read32(bytestream): 10 | dt = np.dtype(np.uint32).newbyteorder('>') 11 | return np.frombuffer(bytestream.read(4), dtype=dt)[0] 12 | 13 | 14 | def extract_images_and_labels(image_file, label_file, percentage=1.0): 15 | """Extract the images into two 4D uint8 numpy array [index, y, x, depth]: positive and negative images.""" 16 | print('Extracting', image_file, label_file) 17 | with gzip.open(image_file) as image_bytestream, gzip.open(label_file) as label_bytestream: 18 | magic = _read32(image_bytestream) 19 | if magic != 2051: 20 | raise ValueError( 21 | 'Invalid magic number %d in image file: %s' % 22 | (magic, image_file)) 23 | magic = _read32(label_bytestream) 24 | if magic != 2049: 25 | raise ValueError( 26 | 'Invalid magic number %d in label file: %s' % 27 | (magic, label_file)) 28 | num_images = _read32(image_bytestream) 29 | rows = _read32(image_bytestream) 30 | cols = _read32(image_bytestream) 31 | num_labels = _read32(label_bytestream) 32 | if num_images != num_labels: 33 | raise ValueError( 34 | 'Num images does not match num labels. Image file : %s; label file: %s' % 35 | (image_file, label_file)) 36 | images = [] 37 | labels = [] 38 | num_images = int(num_images * percentage) 39 | for _ in range(num_images): 40 | image_buf = image_bytestream.read(rows * cols) 41 | image = np.frombuffer(image_buf, dtype=np.uint8) 42 | image = np.multiply(image.astype(np.float32), 1.0 / 255.0) 43 | image[np.where(image == 0.0)[0]] = 1e-5 44 | image[np.where(image == 1.0)[0]] -= 1e-5 45 | label = np.frombuffer(label_bytestream.read(1), dtype=np.uint8) 46 | images.append(image) 47 | labels.append(label) 48 | images = np.array(images, dtype=np.float32) 49 | labels = np.array(labels, dtype=np.int32).squeeze() 50 | return images, labels 51 | 52 | 53 | def crop_augment(images, target_side_length=26): 54 | images = np.reshape(images, (-1, 28, 28)) 55 | augmented_images_shape = list(images.shape) 56 | augmented_images_shape[0] *= 2 57 | augmented_images = np.zeros(shape=augmented_images_shape, dtype=np.float32) + 1e-5 58 | 59 | diff = (28 - target_side_length) // 2 60 | for i in range(len(images)): 61 | images_center = images[i][diff:-diff, diff:-diff] 62 | augmented_images[2*i] = images[i] 63 | choice = np.random.random() 64 | if choice < 0.25: 65 | augmented_images[2*i+1][:target_side_length, :target_side_length] = images_center 66 | elif choice < 0.5: 67 | augmented_images[2*i+1][:target_side_length, -target_side_length:] = images_center 68 | elif choice < 0.75: 69 | augmented_images[2*i+1][-target_side_length:, :target_side_length] = images_center 70 | else: 71 | augmented_images[2*i+1][-target_side_length:, -target_side_length:] = images_center 72 | 73 | augmented_images = np.reshape(augmented_images, (-1, 784)) 74 | return augmented_images 75 | 76 | 77 | def read_data_sets(dir, percentage=1.0): 78 | train_image_file = 'train-images-idx3-ubyte.gz' 79 | train_label_file = 'train-labels-idx1-ubyte.gz' 80 | test_image_file = 't10k-images-idx3-ubyte.gz' 81 | test_label_file = 't10k-labels-idx1-ubyte.gz' 82 | 83 | train_image_file = os.path.join(dir, train_image_file) 84 | train_label_file = os.path.join(dir, train_label_file) 85 | train_images, train_labels = extract_images_and_labels(train_image_file, train_label_file, percentage) 86 | 87 | perm = np.arange(len(train_images)) 88 | np.random.shuffle(perm) 89 | valid_images = train_images[perm[:len(train_images)//10]] 90 | valid_labels = train_labels[perm[:len(train_labels)//10]] 91 | train_images = train_images[perm[len(train_images)//10:]] 92 | train_labels = train_labels[perm[len(train_labels)//10:]] 93 | 94 | #train_images = crop_augment(train_images) 95 | #train_labels = np.repeat(train_labels, 2) 96 | 97 | test_image_file = os.path.join(dir, test_image_file) 98 | test_label_file = os.path.join(dir, test_label_file) 99 | test_images, test_labels = extract_images_and_labels(test_image_file, test_label_file) 100 | 101 | train = DataSet(train_images, train_labels) 102 | valid = DataSet(valid_images, valid_labels) 103 | test = DataSet(test_images, test_labels) 104 | return DataSets(train, valid, test) 105 | --------------------------------------------------------------------------------