├── .gitignore ├── LICENSE ├── README.md ├── group_split.py ├── k-fold.ipynb ├── k_fold_split.py ├── requirements.txt └── tools ├── __init__.py ├── functions.py └── history.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | *.py,cover 51 | .hypothesis/ 52 | .pytest_cache/ 53 | 54 | # Translations 55 | *.mo 56 | *.pot 57 | 58 | # Django stuff: 59 | *.log 60 | local_settings.py 61 | db.sqlite3 62 | db.sqlite3-journal 63 | 64 | # Flask stuff: 65 | instance/ 66 | .webassets-cache 67 | 68 | # Scrapy stuff: 69 | .scrapy 70 | 71 | # Sphinx documentation 72 | docs/_build/ 73 | 74 | # PyBuilder 75 | target/ 76 | 77 | # Jupyter Notebook 78 | .ipynb_checkpoints 79 | 80 | # IPython 81 | profile_default/ 82 | ipython_config.py 83 | 84 | # pyenv 85 | .python-version 86 | 87 | # pipenv 88 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 89 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 90 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 91 | # install all needed dependencies. 92 | #Pipfile.lock 93 | 94 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow 95 | __pypackages__/ 96 | 97 | # Celery stuff 98 | celerybeat-schedule 99 | celerybeat.pid 100 | 101 | # SageMath parsed files 102 | *.sage.py 103 | 104 | # Environments 105 | .env 106 | .venv 107 | env/ 108 | venv/ 109 | ENV/ 110 | env.bak/ 111 | venv.bak/ 112 | 113 | # Spyder project settings 114 | .spyderproject 115 | .spyproject 116 | 117 | # Rope project settings 118 | .ropeproject 119 | 120 | # mkdocs documentation 121 | /site 122 | 123 | # mypy 124 | .mypy_cache/ 125 | .dmypy.json 126 | dmypy.json 127 | 128 | # Pyre type checker 129 | .pyre/ 130 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 João Paulo Figueira 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # strat-group-split 2 | This repository contains code to perform stratified splitting 3 | of grouped datasets into train/validation sets or K-folds 4 | using optimization. 5 | 6 | ## Summary 7 | Given a labeled and grouped dataset, we want to split it into 8 | training and validation sets (or equally sized K folds) 9 | while keeping the label 10 | distribution as close as possible on both and group integrity. 11 | After breaking the data into the two datasets, the groups must 12 | maintain their integrity, assigned to either set and not split 13 | among them. Furthermore, the splitting process should closely 14 | respect the imposed splitting proportion and label 15 | stratification. 16 | 17 | The expected result for this problem is, given an input dataset, 18 | the list of groups assigned to each dataset, ensuring that both 19 | the train/validation split and the stratification are as close 20 | as possible to the specified values. 21 | 22 | ## Using the Code 23 | ### Train/Validation Split 24 | All the code is contained in the `group_split.py` file. 25 | The `main` function runs a benchmark between the two 26 | optimization algorithms. It generates a problem matrix using 27 | the `generate_counts` function and then submits it to both 28 | algorithms, outputting the time taken, final cost value and 29 | the approximations to both the desired split and the 30 | stratification. 31 | 32 | Please note that the code is on a proof-of-concept stage. In 33 | the future I plan to create an independent Python package 34 | with these ideas. 35 | 36 | ### K-Fold Split 37 | All the code is contained in the `k_fold_split.py` file. You can 38 | alternatively use the `k-fold.ipynb` Jupyter notebook. 39 | 40 | ## Medium Articles 41 | [Stratified Splitting of Grouped Datasets Using Optimization](https://towardsdatascience.com/stratified-splitting-of-grouped-datasets-using-optimization-bdc12fb6e691) 42 | 43 | [Stratified K-Fold Cross-Validation on Grouped Datasets](https://towardsdatascience.com/stratified-k-fold-cross-validation-on-grouped-datasets-b3bca8f0f53e) -------------------------------------------------------------------------------- /group_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import time 3 | 4 | from numpy.random import default_rng 5 | from numba import jit 6 | from collections import namedtuple 7 | from tools.functions import index_to_str 8 | 9 | RANDOM_SEED = 23 10 | Solution = namedtuple("Solution", "cost index") 11 | 12 | 13 | def create_initial_solution(sample_counts, p): 14 | """ 15 | Creates the initial solution using a random shuffle 16 | @param sample_counts: The problem array. 17 | @param p: The validation set proportion. 18 | @return: A solution array. 19 | """ 20 | rng = default_rng(seed=RANDOM_SEED) 21 | group_count = sample_counts.shape[0] 22 | idx = np.zeros(group_count, dtype=bool) 23 | rnd_idx = rng.permutation(group_count) 24 | start_count = 0 25 | sample_size = sample_counts[:, 0].sum() 26 | for i in range(group_count): 27 | start_count += sample_counts[rnd_idx[i], 0] 28 | idx[rnd_idx[i]] = True 29 | if start_count > sample_size * p: 30 | break 31 | return idx 32 | 33 | 34 | @jit(nopython=True) 35 | def calculate_cost(sample_counts, idx, p): 36 | """ 37 | Calculates the cost of a given solution 38 | @param sample_counts: The problem array. 39 | @param idx: The solution array to evaluate. 40 | @param p: The train/validation split proportion. 41 | @return: The cost value. 42 | """ 43 | train_count = sample_counts[~idx, 0].sum() 44 | valid_count = sample_counts[idx, 0].sum() 45 | total_count = train_count + valid_count 46 | 47 | cost = (valid_count - total_count * p) ** 2 48 | 49 | for i in range(1, sample_counts.shape[1]): 50 | r = sample_counts[:, i].sum() / total_count 51 | cost += (sample_counts[idx, i].sum() - r * sample_counts[idx, 0].sum()) ** 2 52 | return cost / 2.0 53 | 54 | 55 | def calculate_cost_grad(sample_counts, idx, p): 56 | """ 57 | Calculates the cost gradient of a given solution 58 | @param sample_counts: The problem array. 59 | @param idx: The solution array to evaluate. 60 | @param p: The train/validation split proportion. 61 | @return: The cost value. 62 | """ 63 | grad = np.zeros(sample_counts.shape[1]) 64 | 65 | total_count = sample_counts[:, 0].sum() 66 | valid_count = sample_counts[idx, 0].sum() 67 | 68 | grad[0] = total_count * p - valid_count 69 | 70 | for i in range(1, sample_counts.shape[1]): 71 | r = sample_counts[:, i].sum() / total_count 72 | grad[i] = r * sample_counts[idx, 0].sum() - sample_counts[idx, i].sum() 73 | return grad 74 | 75 | 76 | def cosine_similarity(sample_counts, idx, cost_grad): 77 | """ 78 | Calculates the cosine similarity vector between the problem array 79 | and the cost gradient vector 80 | @param sample_counts: The problem array. 81 | @param idx: The solution vector. 82 | @param cost_grad: The cost gradient vector. 83 | @return: The cosine similarity vector. 84 | """ 85 | c = np.copy(sample_counts) 86 | c[idx] = -c[idx] # Reverse direction of validation vectors 87 | a = np.inner(c, cost_grad) 88 | b = np.multiply(np.linalg.norm(c, axis=1), np.linalg.norm(cost_grad)) 89 | return np.divide(a, b) 90 | 91 | 92 | def euclidean_similarity(sample_counts, idx, cost_grad): 93 | c = np.copy(sample_counts) 94 | c[idx] = -c[idx] 95 | return np.linalg.norm(c - cost_grad, axis=1) 96 | 97 | 98 | def generate_cosine_move(sample_counts, idx, p, expanded_set, intensify): 99 | """ 100 | Generates a new move using the cosine similarity. 101 | @param sample_counts: The problem array. 102 | @param idx: The solution vector. 103 | @param p: The validation set proportion. 104 | @param expanded_set: The set of expanded solutions. 105 | @param intensify: Intensification / diversification flag. 106 | @return: The new solution vector. 107 | """ 108 | cost_grad = calculate_cost_grad(sample_counts, idx, p) 109 | similarity = cosine_similarity(sample_counts, idx, cost_grad) 110 | sorted_ixs = np.argsort(similarity) 111 | if intensify: 112 | sorted_ixs = np.flip(sorted_ixs) 113 | for i in sorted_ixs: 114 | move = np.copy(idx) 115 | move[i] = not move[i] 116 | if index_to_str(move) not in expanded_set: 117 | return move 118 | return None 119 | 120 | 121 | def generate_euclidean_move(sample_counts, idx, p, expanded_set, get_min): 122 | cost_grad = calculate_cost_grad(sample_counts, idx, p) 123 | similarity = euclidean_similarity(sample_counts, idx, cost_grad) 124 | sorted_ixs = np.argsort(similarity) 125 | if not get_min: 126 | sorted_ixs = np.flip(sorted_ixs) 127 | for i in sorted_ixs: 128 | move = np.copy(idx) 129 | move[i] = not move[i] 130 | if index_to_str(move) not in expanded_set: 131 | return move 132 | return None 133 | 134 | 135 | def generate_moves(idx, expanded_set): 136 | """ 137 | Generator for all acceptable moves from a previous solution. 138 | @param idx: The solution vector. 139 | @param expanded_set: The set of expanded solutions. 140 | """ 141 | for i in range(idx.shape[0]): 142 | move = np.copy(idx) 143 | move[i] = not move[i] 144 | if index_to_str(move) not in expanded_set: 145 | yield move 146 | 147 | 148 | def generate_counts(num_groups, num_classes, 149 | min_group_size, max_group_size, 150 | max_group_percent): 151 | """ 152 | Generates a problem matrix from the given parameters. 153 | @param num_groups: The number of data groups. 154 | @param num_classes: The number of classes. 155 | @param min_group_size: The minimum group size. 156 | @param max_group_size: The maximum group size. 157 | @param max_group_percent: The maximum class percent. 158 | @return: The problem matrix. 159 | """ 160 | rng = default_rng(seed=RANDOM_SEED) 161 | sample_cnt = np.zeros((num_groups, num_classes), dtype=int) 162 | sample_cnt[:, 0] = rng.integers(low=min_group_size, high=max_group_size, size=num_groups) 163 | 164 | for i in range(1, num_groups): 165 | for j in range(1, num_classes): 166 | sample_cnt[i, j] = rng.integers(low=0, 167 | high=max_group_percent * sample_cnt[i, 0] - sample_cnt[i, 1:j+1].sum()) 168 | return sample_cnt 169 | 170 | 171 | class BaseSolver(object): 172 | 173 | def __init__(self, problem, candidate, p): 174 | self.problem = problem 175 | self.p = p 176 | self.incumbent = Solution(calculate_cost(problem, candidate, p), candidate) 177 | 178 | 179 | class SearchSolver(BaseSolver): 180 | 181 | def __init__(self, problem, candidate, p): 182 | super(SearchSolver, self).__init__(problem, candidate, p) 183 | 184 | def solve(self, min_cost, max_empty_iterations, 185 | max_intensity_iterations, verbose=True): 186 | """ 187 | Uses the search solver to calculate the best split. 188 | @param min_cost: Minimum cost criterion. 189 | @param max_empty_iterations: Maximum number of non-improving iterations. 190 | @param max_intensity_iterations: Maximum number of intensity iterations. 191 | @param verbose: Verbose flag. 192 | @return: The incumbent solution. 193 | """ 194 | terminated = False 195 | intensify = True 196 | expanded_set = set() 197 | solution = self.incumbent 198 | n = 0 199 | n_intensity = 0 200 | 201 | while not terminated: 202 | move_list = generate_moves(solution.index, expanded_set) 203 | cost_list = [Solution(calculate_cost(self.problem, move, self.p), move) 204 | for move in move_list] 205 | cost_list.sort(key=lambda t: t[0], reverse=not intensify) 206 | intensify = True 207 | 208 | if len(cost_list) > 0: 209 | solution = cost_list[0] 210 | expanded_set.add(index_to_str(solution.index)) 211 | if solution.cost < self.incumbent.cost: 212 | self.incumbent = solution 213 | n = 0 214 | n_intensity = 0 215 | intensify = True 216 | if verbose: 217 | print(self.incumbent.cost) 218 | else: 219 | # Diversify? 220 | if n_intensity > max_intensity_iterations: 221 | intensify = False 222 | n_intensity = 0 223 | else: 224 | terminated = True 225 | 226 | n += 1 227 | n_intensity += 1 228 | if n > max_empty_iterations or self.incumbent.cost < min_cost: 229 | terminated = True 230 | return self.incumbent 231 | 232 | 233 | class GradientSolver(BaseSolver): 234 | 235 | def __init__(self, problem, candidate, p): 236 | super(GradientSolver, self).__init__(problem, candidate, p) 237 | 238 | def solve(self, min_cost, max_empty_iterations, 239 | max_intensity_iterations, verbose=True): 240 | """ 241 | This function uses the gradient solver to calculate the best split. 242 | @param min_cost: Minimum cost criterion. 243 | @param max_empty_iterations: Maximum number of non-improving iterations. 244 | @param max_intensity_iterations: Maximum number of intensity iterations. 245 | @param verbose: Verbose flag. 246 | @return: The incumbent solution. 247 | """ 248 | terminated = False 249 | intensify = True 250 | expanded_set = set() 251 | solution = self.incumbent 252 | n = 0 253 | n_intensity = 0 254 | 255 | while not terminated: 256 | move = generate_cosine_move(self.problem, solution.index, self.p, 257 | expanded_set, intensify) 258 | intensify = True 259 | 260 | if move is not None: 261 | solution = Solution(calculate_cost(self.problem, move, self.p), move) 262 | expanded_set.add(index_to_str(solution.index)) 263 | if solution.cost < self.incumbent.cost: 264 | self.incumbent = solution 265 | n = 0 266 | n_intensity = 0 267 | 268 | if verbose: 269 | print(self.incumbent.cost) 270 | else: 271 | if n_intensity > max_intensity_iterations: 272 | intensify = False 273 | n_intensity = 0 274 | else: 275 | terminated = True 276 | 277 | n += 1 278 | n_intensity += 1 279 | if n > max_empty_iterations or self.incumbent.cost < min_cost: 280 | terminated = True 281 | 282 | return self.incumbent 283 | 284 | 285 | def print_solution(problem, solution, p): 286 | idx = solution.index 287 | valid_count = problem[idx, 0].sum() 288 | train_count = problem[~idx, 0].sum() 289 | total_count = train_count + valid_count 290 | 291 | print(solution.cost) 292 | print(p, valid_count / total_count) 293 | for i in range(1, problem.shape[1]): 294 | r = problem[:, i].sum() / total_count 295 | print(r, problem[idx, i].sum() / problem[idx, 0].sum()) 296 | 297 | 298 | def main(): 299 | num_groups = 250 # Number of groups to simulate 300 | num_classes = 2 # Number of classes 301 | max_group_size = 10000 # Maximum group size 302 | max_group_perc = 0.4 # Maximum proportion for each class 303 | p = 0.3 # Validation split proportion 304 | 305 | max_empty_iterations = 100 306 | max_intensity_iterations = 10 307 | min_cost = 10000 308 | 309 | sample_cnt = generate_counts(num_groups, num_classes, 310 | min_group_size=10, 311 | max_group_size=max_group_size, 312 | max_group_percent=max_group_perc) 313 | 314 | solution_arr = create_initial_solution(sample_cnt, p) 315 | 316 | s_solver = SearchSolver(sample_cnt, solution_arr, p) 317 | g_solver = GradientSolver(sample_cnt, solution_arr, p) 318 | 319 | start = time.time() 320 | solution = s_solver.solve(min_cost, max_empty_iterations, max_intensity_iterations, verbose=False) 321 | print(time.time() - start) 322 | 323 | print_solution(sample_cnt, solution, p) 324 | print() 325 | 326 | start = time.time() 327 | solution = g_solver.solve(min_cost, max_empty_iterations, max_intensity_iterations, verbose=False) 328 | print(time.time() - start) 329 | 330 | print_solution(sample_cnt, solution, p) 331 | 332 | 333 | if __name__ == "__main__": 334 | main() 335 | -------------------------------------------------------------------------------- /k-fold.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 1, 6 | "metadata": { 7 | "tags": [] 8 | }, 9 | "outputs": [], 10 | "source": [ 11 | "import numpy as np\n", 12 | "\n", 13 | "from numpy.random import default_rng\n", 14 | "from numba import njit\n", 15 | "from typing import Set, Tuple" 16 | ] 17 | }, 18 | { 19 | "cell_type": "code", 20 | "execution_count": 1, 21 | "metadata": { 22 | "collapsed": false, 23 | "jupyter": { 24 | "outputs_hidden": false 25 | } 26 | }, 27 | "outputs": [], 28 | "source": [] 29 | }, 30 | { 31 | "cell_type": "code", 32 | "execution_count": 2, 33 | "metadata": { 34 | "collapsed": false, 35 | "jupyter": { 36 | "outputs_hidden": false 37 | } 38 | }, 39 | "outputs": [], 40 | "source": [ 41 | "def generate_problem(num_groups: int,\n", 42 | " num_classes: int,\n", 43 | " min_group_size: int,\n", 44 | " max_group_size: int,\n", 45 | " class_percent: np.array) -> np.ndarray:\n", 46 | "\n", 47 | " problem = np.zeros((num_groups, num_classes), dtype=int)\n", 48 | "\n", 49 | " rng = default_rng()\n", 50 | " group_sizes = rng.integers(low=min_group_size,\n", 51 | " high=max_group_size,\n", 52 | " size=num_groups)\n", 53 | "\n", 54 | " for i in range(num_groups):\n", 55 | " # Calculate the\n", 56 | " proportions = np.random.normal(class_percent, class_percent / 10)\n", 57 | "\n", 58 | " problem[i, :] = proportions * group_sizes[i]\n", 59 | " return problem" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": 3, 65 | "metadata": { 66 | "collapsed": false, 67 | "jupyter": { 68 | "outputs_hidden": false 69 | } 70 | }, 71 | "outputs": [], 72 | "source": [ 73 | "@njit\n", 74 | "def calculate_cost(problem: np.ndarray,\n", 75 | " solution: np.ndarray,\n", 76 | " k: int) -> float:\n", 77 | " cost = 0.0\n", 78 | " total = np.sum(problem)\n", 79 | " class_sums = np.sum(problem, axis=0)\n", 80 | " num_classes = problem.shape[1]\n", 81 | "\n", 82 | " for i in range(k):\n", 83 | " idx = solution == i\n", 84 | " fold_sum = np.sum(problem[idx, :])\n", 85 | "\n", 86 | " if total > 0.0 and fold_sum > 0.0:\n", 87 | " # Start by calculating the fold imbalance cost\n", 88 | " cost += (fold_sum / total - 1.0 / k) ** 2\n", 89 | "\n", 90 | " # Now calculate the cost associated with the class imbalances\n", 91 | " for j in range(num_classes):\n", 92 | " cost += (np.sum(problem[idx, j]) / fold_sum - class_sums[j] / total) ** 2\n", 93 | " else:\n", 94 | " cost += 1.0\n", 95 | " return cost" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": 4, 101 | "metadata": { 102 | "collapsed": false, 103 | "jupyter": { 104 | "outputs_hidden": false 105 | } 106 | }, 107 | "outputs": [], 108 | "source": [ 109 | "@njit\n", 110 | "def generate_search_space(problem: np.ndarray,\n", 111 | " solution: np.ndarray,\n", 112 | " k: int) -> np.ndarray:\n", 113 | " num_groups = problem.shape[0]\n", 114 | "\n", 115 | " space = np.zeros((num_groups, k))\n", 116 | " sol = solution.copy()\n", 117 | "\n", 118 | " for i in range(num_groups):\n", 119 | " for j in range(k):\n", 120 | " if solution[i] == j:\n", 121 | " space[i,j] = np.infty\n", 122 | " else:\n", 123 | " sol[i] = j\n", 124 | " space[i,j] = calculate_cost(problem, sol, k)\n", 125 | " sol[i] = solution[i]\n", 126 | " return space" 127 | ] 128 | }, 129 | { 130 | "cell_type": "code", 131 | "execution_count": 5, 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "@njit\n", 136 | "def solution_to_str(solution: np.ndarray) -> str:\n", 137 | " return \"\".join([str(n) for n in solution])" 138 | ] 139 | }, 140 | { 141 | "cell_type": "code", 142 | "execution_count": 6, 143 | "outputs": [], 144 | "source": [ 145 | "def select_move(decision: np.ndarray,\n", 146 | " solution: np.ndarray,\n", 147 | " history: Set) -> Tuple:\n", 148 | " candidates = np.argsort(decision, axis=None)\n", 149 | "\n", 150 | " for c in candidates:\n", 151 | " grp, cls = np.unravel_index(c, decision.shape)\n", 152 | " s = solution.copy()\n", 153 | " s[grp] = cls\n", 154 | " sol_str = solution_to_str(s)\n", 155 | "\n", 156 | " if sol_str not in history:\n", 157 | " return grp, cls\n", 158 | " return -1, -1 # No move found!" 159 | ], 160 | "metadata": { 161 | "collapsed": false 162 | } 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": 6, 167 | "metadata": {}, 168 | "outputs": [], 169 | "source": [] 170 | }, 171 | { 172 | "cell_type": "code", 173 | "execution_count": 7, 174 | "metadata": { 175 | "collapsed": false, 176 | "jupyter": { 177 | "outputs_hidden": false 178 | } 179 | }, 180 | "outputs": [], 181 | "source": [ 182 | "prb = generate_problem(num_groups=500,\n", 183 | " num_classes=4,\n", 184 | " min_group_size=400,\n", 185 | " max_group_size=2000,\n", 186 | " class_percent=np.array([0.4, 0.3, 0.2, 0.1]))" 187 | ] 188 | }, 189 | { 190 | "cell_type": "code", 191 | "execution_count": 7, 192 | "outputs": [], 193 | "source": [], 194 | "metadata": { 195 | "collapsed": false 196 | } 197 | }, 198 | { 199 | "cell_type": "code", 200 | "execution_count": 8, 201 | "outputs": [], 202 | "source": [ 203 | "def generate_initial_solution(problem: np.ndarray,\n", 204 | " k: int,\n", 205 | " algo: str=\"k-bound\") -> np.ndarray:\n", 206 | " num_groups = problem.shape[0]\n", 207 | " if algo == \"k-bound\":\n", 208 | " rng = default_rng()\n", 209 | " total = np.sum(problem)\n", 210 | " indices = rng.permutation(problem.shape[0])\n", 211 | "\n", 212 | " solution = np.zeros(num_groups, dtype=int)\n", 213 | " c = 0\n", 214 | " fold_total = 0\n", 215 | " for i in indices:\n", 216 | " group = np.sum(problem[i, :])\n", 217 | " if fold_total + group < total / k:\n", 218 | " fold_total += group\n", 219 | " else:\n", 220 | " c = (c + 1) % k\n", 221 | " fold_total = group\n", 222 | " solution[i] = c\n", 223 | " elif algo == \"random\":\n", 224 | " rng = default_rng()\n", 225 | " solution = rng.integers(low=0, high=k, size=num_groups)\n", 226 | " elif algo == \"zeros\":\n", 227 | " solution = np.zeros(num_groups, dtype=int)\n", 228 | " else:\n", 229 | " raise Exception(\"Invalid algorithm name\")\n", 230 | " return solution" 231 | ], 232 | "metadata": { 233 | "collapsed": false 234 | } 235 | }, 236 | { 237 | "cell_type": "code", 238 | "execution_count": 8, 239 | "metadata": { 240 | "collapsed": false, 241 | "jupyter": { 242 | "outputs_hidden": false 243 | } 244 | }, 245 | "outputs": [], 246 | "source": [] 247 | }, 248 | { 249 | "cell_type": "code", 250 | "execution_count": 9, 251 | "outputs": [], 252 | "source": [ 253 | "def solve(problem: np.ndarray,\n", 254 | " k=5,\n", 255 | " min_cost=1e-5,\n", 256 | " max_retry=100,\n", 257 | " verbose=False) -> np.ndarray:\n", 258 | " hist = set()\n", 259 | " retry = 0\n", 260 | "\n", 261 | " solution = generate_initial_solution(problem, k)\n", 262 | " incumbent = solution.copy()\n", 263 | " low_cost = calculate_cost(problem, solution, k)\n", 264 | " cost = 1.0\n", 265 | " while retry < max_retry and cost > min_cost:\n", 266 | " decision = generate_search_space(problem, solution, k=5)\n", 267 | " grp, cls = select_move(decision, solution, hist)\n", 268 | "\n", 269 | " if grp != -1:\n", 270 | " solution[grp] = cls\n", 271 | " cost = calculate_cost(problem, solution, k=5)\n", 272 | " if cost < low_cost:\n", 273 | " low_cost = cost\n", 274 | " incumbent = solution.copy()\n", 275 | " retry = 0\n", 276 | " if verbose:\n", 277 | " print(cost)\n", 278 | " else:\n", 279 | " retry += 1\n", 280 | " hist.add(solution_to_str(solution))\n", 281 | " return incumbent" 282 | ], 283 | "metadata": { 284 | "collapsed": false 285 | } 286 | }, 287 | { 288 | "cell_type": "code", 289 | "execution_count": 10, 290 | "metadata": {}, 291 | "outputs": [ 292 | { 293 | "name": "stdout", 294 | "output_type": "stream", 295 | "text": [ 296 | "8.028158540955434e-05\n", 297 | "6.502376139707922e-05\n", 298 | "5.464347782944837e-05\n", 299 | "4.301962138717457e-05\n", 300 | "3.6589581706198305e-05\n", 301 | "2.7849518341838307e-05\n", 302 | "2.610685042347762e-05\n", 303 | "2.3610338623673816e-05\n", 304 | "2.008385337007756e-05\n", 305 | "1.6837909344540007e-05\n", 306 | "1.498558025272575e-05\n", 307 | "1.3517929416647237e-05\n", 308 | "1.2074721967291692e-05\n", 309 | "1.1421630300944058e-05\n", 310 | "1.0721054483775257e-05\n", 311 | "9.524765715259523e-06\n" 312 | ] 313 | }, 314 | { 315 | "data": { 316 | "text/plain": "array([4, 0, 4, 2, 2, 1, 0, 4, 2, 1, 1, 3, 3, 0, 0, 0, 3, 1, 3, 2, 4, 2,\n 4, 0, 3, 1, 4, 2, 0, 3, 4, 1, 2, 2, 3, 4, 0, 3, 0, 0, 4, 3, 4, 4,\n 0, 4, 3, 0, 3, 3, 3, 2, 2, 2, 3, 1, 3, 3, 0, 0, 4, 4, 2, 1, 2, 2,\n 4, 2, 1, 4, 3, 2, 4, 3, 3, 1, 3, 0, 1, 1, 4, 2, 3, 1, 0, 0, 1, 3,\n 2, 4, 0, 1, 2, 4, 4, 2, 4, 1, 4, 3, 3, 1, 1, 2, 3, 1, 2, 2, 2, 3,\n 1, 0, 2, 3, 2, 3, 4, 0, 2, 1, 4, 1, 0, 2, 0, 2, 1, 3, 4, 3, 4, 1,\n 3, 3, 1, 4, 2, 0, 2, 2, 0, 1, 3, 4, 2, 3, 0, 1, 2, 0, 2, 1, 4, 0,\n 4, 3, 2, 3, 4, 4, 0, 4, 0, 3, 3, 2, 1, 4, 3, 1, 1, 0, 2, 1, 2, 3,\n 2, 0, 1, 2, 0, 0, 0, 0, 0, 1, 3, 2, 3, 1, 1, 0, 3, 1, 4, 2, 0, 3,\n 1, 2, 0, 4, 4, 2, 0, 1, 2, 4, 1, 1, 1, 2, 1, 0, 4, 0, 4, 1, 3, 4,\n 3, 4, 2, 4, 3, 4, 0, 0, 0, 0, 3, 0, 4, 0, 2, 1, 0, 3, 1, 0, 1, 4,\n 3, 3, 0, 2, 3, 3, 4, 1, 1, 3, 0, 0, 1, 0, 0, 4, 4, 4, 2, 1, 2, 1,\n 2, 0, 1, 2, 1, 2, 0, 3, 3, 0, 1, 3, 0, 1, 3, 2, 2, 3, 1, 0, 4, 2,\n 2, 1, 2, 1, 4, 1, 1, 4, 2, 3, 0, 0, 1, 1, 2, 1, 2, 3, 0, 2, 4, 4,\n 2, 2, 3, 0, 3, 4, 2, 4, 2, 0, 1, 2, 1, 2, 3, 2, 3, 3, 2, 3, 2, 2,\n 1, 4, 1, 2, 2, 3, 0, 3, 4, 4, 1, 0, 0, 3, 1, 2, 4, 4, 4, 0, 4, 3,\n 4, 1, 4, 0, 0, 3, 4, 3, 0, 3, 3, 4, 3, 1, 2, 4, 4, 3, 3, 1, 3, 3,\n 3, 4, 0, 1, 0, 0, 3, 3, 2, 3, 2, 4, 2, 4, 0, 4, 1, 3, 0, 3, 4, 2,\n 2, 3, 0, 1, 0, 2, 2, 1, 1, 3, 2, 4, 3, 0, 4, 1, 0, 2, 1, 0, 1, 3,\n 4, 0, 0, 2, 2, 4, 3, 4, 2, 4, 4, 2, 3, 0, 4, 3, 1, 3, 4, 3, 0, 4,\n 1, 1, 0, 0, 1, 0, 2, 3, 4, 2, 4, 4, 4, 4, 2, 3, 4, 1, 2, 0, 2, 3,\n 1, 1, 0, 2, 0, 1, 2, 0, 1, 0, 2, 4, 2, 1, 0, 2, 4, 1, 1, 2, 4, 2,\n 0, 2, 1, 3, 4, 1, 0, 0, 1, 2, 4, 1, 3, 1, 2, 1])" 317 | }, 318 | "execution_count": 10, 319 | "metadata": {}, 320 | "output_type": "execute_result" 321 | } 322 | ], 323 | "source": [ 324 | "solution = solve(prb, min_cost=1e-5, k=5, verbose=True)\n", 325 | "solution" 326 | ] 327 | }, 328 | { 329 | "cell_type": "code", 330 | "execution_count": 10, 331 | "metadata": { 332 | "collapsed": false, 333 | "jupyter": { 334 | "outputs_hidden": false 335 | } 336 | }, 337 | "outputs": [], 338 | "source": [] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": 11, 343 | "outputs": [ 344 | { 345 | "data": { 346 | "text/plain": "array([0.40122698, 0.29987707, 0.19984937, 0.09904658])" 347 | }, 348 | "execution_count": 11, 349 | "metadata": {}, 350 | "output_type": "execute_result" 351 | } 352 | ], 353 | "source": [ 354 | "np.sum(prb, axis=0) / np.sum(prb)" 355 | ], 356 | "metadata": { 357 | "collapsed": false 358 | } 359 | }, 360 | { 361 | "cell_type": "code", 362 | "execution_count": 12, 363 | "outputs": [ 364 | { 365 | "data": { 366 | "text/plain": "array([[0.40213076, 0.29938555, 0.19929537, 0.09918832],\n [0.40074922, 0.30057671, 0.19912732, 0.09954675],\n [0.4010002 , 0.29856801, 0.20065612, 0.09977567],\n [0.40156828, 0.30083997, 0.19946032, 0.09813143],\n [0.40068528, 0.30002025, 0.20070553, 0.09858893]])" 367 | }, 368 | "execution_count": 12, 369 | "metadata": {}, 370 | "output_type": "execute_result" 371 | } 372 | ], 373 | "source": [ 374 | "folds = [prb[solution==i] for i in range(5)]\n", 375 | "fold_percents = np.array([np.sum(folds[i], axis=0) / np.sum(folds[i]) for i in range(5)])\n", 376 | "fold_percents" 377 | ], 378 | "metadata": { 379 | "collapsed": false 380 | } 381 | }, 382 | { 383 | "cell_type": "code", 384 | "execution_count": 13, 385 | "outputs": [ 386 | { 387 | "data": { 388 | "text/plain": "[0.20034752594623986,\n 0.19969401505608037,\n 0.20023438581796935,\n 0.19963153468673694,\n 0.20009253849297348]" 389 | }, 390 | "execution_count": 13, 391 | "metadata": {}, 392 | "output_type": "execute_result" 393 | } 394 | ], 395 | "source": [ 396 | "[np.sum(folds[i]) / np.sum(prb) for i in range(5)]" 397 | ], 398 | "metadata": { 399 | "collapsed": false 400 | } 401 | }, 402 | { 403 | "cell_type": "code", 404 | "execution_count": 13, 405 | "outputs": [], 406 | "source": [], 407 | "metadata": { 408 | "collapsed": false 409 | } 410 | }, 411 | { 412 | "cell_type": "code", 413 | "execution_count": 14, 414 | "outputs": [], 415 | "source": [ 416 | "# import pandas as pd" 417 | ], 418 | "metadata": { 419 | "collapsed": false 420 | } 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": 15, 425 | "outputs": [], 426 | "source": [ 427 | "# df = pd.DataFrame(data=prb, columns=['Class 0', 'Class 1', 'Class 2', 'Class 3'])\n", 428 | "# df['Solution'] = solution" 429 | ], 430 | "metadata": { 431 | "collapsed": false 432 | } 433 | }, 434 | { 435 | "cell_type": "code", 436 | "execution_count": 16, 437 | "outputs": [], 438 | "source": [ 439 | "# df" 440 | ], 441 | "metadata": { 442 | "collapsed": false 443 | } 444 | }, 445 | { 446 | "cell_type": "code", 447 | "execution_count": 17, 448 | "outputs": [], 449 | "source": [ 450 | "# df.to_clipboard(excel=True)" 451 | ], 452 | "metadata": { 453 | "collapsed": false 454 | } 455 | }, 456 | { 457 | "cell_type": "code", 458 | "execution_count": 18, 459 | "outputs": [], 460 | "source": [ 461 | "# decision_df = pd.DataFrame(data=generate_search_space(prb, solution, k=5), columns=['Fold 0', 'Fold 1', 'Fold 2', 'Fold 3', 'Fold 4'])\n", 462 | "# decision_df" 463 | ], 464 | "metadata": { 465 | "collapsed": false 466 | } 467 | }, 468 | { 469 | "cell_type": "code", 470 | "execution_count": 18, 471 | "outputs": [], 472 | "source": [], 473 | "metadata": { 474 | "collapsed": false 475 | } 476 | }, 477 | { 478 | "cell_type": "code", 479 | "execution_count": 19, 480 | "outputs": [], 481 | "source": [ 482 | "# decision_df.to_clipboard(excel=True)" 483 | ], 484 | "metadata": { 485 | "collapsed": false 486 | } 487 | }, 488 | { 489 | "cell_type": "code", 490 | "execution_count": 19, 491 | "outputs": [], 492 | "source": [], 493 | "metadata": { 494 | "collapsed": false 495 | } 496 | } 497 | ], 498 | "metadata": { 499 | "kernelspec": { 500 | "display_name": "Python 3 (ipykernel)", 501 | "language": "python", 502 | "name": "python3" 503 | }, 504 | "language_info": { 505 | "codemirror_mode": { 506 | "name": "ipython", 507 | "version": 3 508 | }, 509 | "file_extension": ".py", 510 | "mimetype": "text/x-python", 511 | "name": "python", 512 | "nbconvert_exporter": "python", 513 | "pygments_lexer": "ipython3", 514 | "version": "3.10.4" 515 | } 516 | }, 517 | "nbformat": 4, 518 | "nbformat_minor": 4 519 | } 520 | -------------------------------------------------------------------------------- /k_fold_split.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from numpy.random import default_rng 4 | from numba import njit 5 | from typing import Set, Tuple 6 | 7 | 8 | def generate_problem(num_groups: int, 9 | num_classes: int, 10 | min_group_size: int, 11 | max_group_size: int, 12 | class_percent: np.array) -> np.ndarray: 13 | 14 | problem = np.zeros((num_groups, num_classes), dtype=int) 15 | 16 | rng = default_rng() 17 | group_sizes = rng.integers(low=min_group_size, 18 | high=max_group_size, 19 | size=num_groups) 20 | 21 | for i in range(num_groups): 22 | # Calculate the 23 | proportions = np.random.normal(class_percent, class_percent / 10) 24 | 25 | problem[i, :] = proportions * group_sizes[i] 26 | return problem 27 | 28 | 29 | @njit 30 | def calculate_cost(problem: np.ndarray, 31 | solution: np.ndarray, 32 | k: int) -> float: 33 | cost = 0.0 34 | total = np.sum(problem) 35 | class_sums = np.sum(problem, axis=0) 36 | num_classes = problem.shape[1] 37 | 38 | for i in range(k): 39 | idx = solution == i 40 | fold_sum = np.sum(problem[idx, :]) 41 | 42 | # Start by calculating the fold imbalance cost 43 | cost += (fold_sum / total - 1.0 / k) ** 2 44 | 45 | # Now calculate the cost associated with the class imbalances 46 | for j in range(num_classes): 47 | cost += (np.sum(problem[idx, j]) / fold_sum - class_sums[j] / total) ** 2 48 | return cost 49 | 50 | 51 | @njit 52 | def generate_search_space(problem: np.ndarray, 53 | solution: np.ndarray, 54 | k: int) -> np.ndarray: 55 | num_groups = problem.shape[0] 56 | 57 | space = np.zeros((num_groups, k)) 58 | sol = solution.copy() 59 | 60 | for i in range(num_groups): 61 | for j in range(k): 62 | if solution[i] == j: 63 | space[i,j] = np.infty 64 | else: 65 | sol[i] = j 66 | space[i, j] = calculate_cost(problem, sol, k) 67 | sol[i] = solution[i] 68 | return space 69 | 70 | 71 | @njit 72 | def solution_to_str(solution: np.ndarray) -> str: 73 | return "".join([str(n) for n in solution]) 74 | 75 | 76 | def generate_initial_solution(problem: np.ndarray, 77 | k: int, 78 | algo: str="k-bound") -> np.ndarray: 79 | num_groups = problem.shape[0] 80 | if algo == "k-bound": 81 | rng = default_rng() 82 | total = np.sum(problem) 83 | indices = rng.permutation(problem.shape[0]) 84 | 85 | solution = np.zeros(num_groups, dtype=int) 86 | c = 0 87 | fold_total = 0 88 | for i in indices: 89 | group = np.sum(problem[i, :]) 90 | if fold_total + group < total / k: 91 | fold_total += group 92 | else: 93 | c = (c + 1) % k 94 | fold_total = group 95 | solution[i] = c 96 | elif algo == "random": 97 | rng = default_rng() 98 | solution = rng.integers(low=0, high=k, size=num_groups) 99 | elif algo == "zeros": 100 | solution = np.zeros(num_groups, dtype=int) 101 | else: 102 | raise Exception("Invalid algorithm name") 103 | return solution 104 | 105 | 106 | def solve(problem: np.ndarray, 107 | k=5, 108 | min_cost=1e-5, 109 | max_retry=100, 110 | verbose=False) -> np.ndarray: 111 | hist = set() 112 | retry = 0 113 | 114 | solution = generate_initial_solution(problem, k) 115 | incumbent = solution.copy() 116 | low_cost = calculate_cost(problem, solution, k) 117 | cost = 1.0 118 | while retry < max_retry and cost > min_cost: 119 | decision = generate_search_space(problem, solution, k=5) 120 | grp, cls = select_move(decision, solution, hist) 121 | 122 | if grp != -1: 123 | solution[grp] = cls 124 | cost = calculate_cost(problem, solution, k=5) 125 | if cost < low_cost: 126 | low_cost = cost 127 | incumbent = solution.copy() 128 | retry = 0 129 | if verbose: 130 | print(cost) 131 | else: 132 | retry += 1 133 | hist.add(solution_to_str(solution)) 134 | return incumbent 135 | 136 | 137 | def select_move(decision: np.ndarray, 138 | solution: np.ndarray, 139 | history: Set) -> Tuple: 140 | candidates = np.argsort(decision, axis=None) 141 | 142 | for c in candidates: 143 | p = np.unravel_index(c, decision.shape) 144 | s = solution.copy() 145 | s[p[0]] = p[1] 146 | sol_str = solution_to_str(s) 147 | 148 | if sol_str not in history: 149 | return p 150 | return -1, -1 # No move found! 151 | 152 | 153 | def main(): 154 | problem = generate_problem(num_groups=500, 155 | num_classes=4, 156 | min_group_size=400, 157 | max_group_size=2000, 158 | class_percent=np.array([0.4, 0.3, 0.2, 0.1])) 159 | solution = solve(problem, k=5, verbose=True) 160 | 161 | print(np.sum(problem, axis=0) / np.sum(problem)) 162 | print() 163 | 164 | folds = [problem[solution == i] for i in range(5)] 165 | fold_percents = np.array([np.sum(folds[i], axis=0) / np.sum(folds[i]) for i in range(5)]) 166 | print(fold_percents) 167 | print() 168 | print([np.sum(folds[i]) / np.sum(problem) for i in range(5)]) 169 | 170 | 171 | if __name__ == "__main__": 172 | main() 173 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | anyio==3.6.1 2 | appnope==0.1.3 3 | argon2-cffi==21.3.0 4 | argon2-cffi-bindings==21.2.0 5 | asttokens==2.0.8 6 | attrs==22.1.0 7 | Babel==2.10.3 8 | backcall==0.2.0 9 | beautifulsoup4==4.11.1 10 | bleach==5.0.1 11 | certifi==2022.9.24 12 | cffi==1.15.1 13 | charset-normalizer==2.1.1 14 | debugpy==1.6.3 15 | decorator==5.1.1 16 | defusedxml==0.7.1 17 | entrypoints==0.4 18 | executing==1.1.0 19 | fastjsonschema==2.16.2 20 | idna==3.4 21 | ipykernel==6.16.0 22 | ipython==8.5.0 23 | ipython-genutils==0.2.0 24 | jedi==0.18.1 25 | Jinja2==3.1.2 26 | json5==0.9.10 27 | jsonschema==4.16.0 28 | jupyter-core==4.11.1 29 | jupyter-server==1.19.1 30 | jupyter_client==7.3.5 31 | jupyterlab==3.4.8 32 | jupyterlab-pygments==0.2.2 33 | jupyterlab_server==2.15.2 34 | llvmlite==0.39.1 35 | lxml==4.9.1 36 | MarkupSafe==2.1.1 37 | matplotlib-inline==0.1.6 38 | mistune==2.0.4 39 | nbclassic==0.4.4 40 | nbclient==0.6.8 41 | nbconvert==7.1.0 42 | nbformat==5.6.1 43 | nest-asyncio==1.5.6 44 | notebook==6.4.12 45 | notebook-shim==0.1.0 46 | numba==0.56.2 47 | numpy==1.23.3 48 | packaging==21.3 49 | pandocfilters==1.5.0 50 | parso==0.8.3 51 | pexpect==4.8.0 52 | pickleshare==0.7.5 53 | prometheus-client==0.14.1 54 | prompt-toolkit==3.0.31 55 | psutil==5.9.2 56 | ptyprocess==0.7.0 57 | pure-eval==0.2.2 58 | pycparser==2.21 59 | Pygments==2.13.0 60 | pyparsing==3.0.9 61 | pyrsistent==0.18.1 62 | python-dateutil==2.8.2 63 | pytz==2022.4 64 | pyzmq==24.0.1 65 | requests==2.28.1 66 | scipy==1.9.1 67 | Send2Trash==1.8.0 68 | six==1.16.0 69 | sniffio==1.3.0 70 | soupsieve==2.3.2.post1 71 | stack-data==0.5.1 72 | terminado==0.16.0 73 | tinycss2==1.1.1 74 | tomli==2.0.1 75 | tornado==6.2 76 | traitlets==5.4.0 77 | urllib3==1.26.12 78 | wcwidth==0.2.5 79 | webencodings==0.5.1 80 | websocket-client==1.4.1 81 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- 1 | from .history import History 2 | from .functions import get_sort_index, hash_solution 3 | from .functions import calculate_cost, calculate_costs, calculate_cost_gradients, cosine_similarity 4 | from .functions import get_similarities 5 | -------------------------------------------------------------------------------- /tools/functions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from numba import jit 4 | 5 | 6 | @jit(nopython=True) 7 | def index_to_str(idx): 8 | """ 9 | Generates a string representation from an index array. 10 | 11 | :param idx: The NumPy boolean index array. 12 | :return: The string representation of the array. 13 | """ 14 | num_chars = int(idx.shape[0] / 6 + 0.5) 15 | s = "" 16 | for i in range(num_chars): 17 | b = i * 6 18 | six = idx[b:b+6] 19 | c = 0 20 | for j in range(six.shape[0]): 21 | c = c * 2 + int(six[j]) 22 | s = s + chr(c+32) 23 | return s 24 | 25 | 26 | @jit(nopython=True) 27 | def hash_solution(solution: np.ndarray, k: int) -> str: 28 | """ 29 | Calculates a string hash for the solution. 30 | :param solution: The solution vector. 31 | :param k: The number of folds. 32 | :return: The string hash. 33 | """ 34 | s = "" 35 | for i in range(k-1): 36 | s = s + index_to_str(solution == i) 37 | return s 38 | 39 | 40 | @jit(nopython=True) 41 | def is_in(element, test_elements): 42 | """ 43 | Predicate to test the inclusion of items in the first array on the second 44 | 45 | :param element: Array whose elements we want to test the inclusion for 46 | :param test_elements: Target array 47 | :return: Boolean array of the same size as `element` with the element-wise inclusion test results 48 | """ 49 | unique = set(test_elements) 50 | result = np.zeros_like(element, dtype=np.bool_) 51 | for i in range(element.shape[0]): 52 | result[i] = element[i] in unique 53 | return result 54 | 55 | 56 | @jit(nopython=True) 57 | def get_sort_index(solution: np.ndarray, cs: np.ndarray, k: int, arr_ix: np.ndarray) -> np.ndarray: 58 | """ 59 | 60 | :param solution: The solution vector. 61 | :param cs: The cosine similarity matrix. 62 | :param k: The selected fold index [0..K) 63 | :param arr_ix: A pre-calculated integer range [0..N). 64 | :return: 65 | """ 66 | sort_ix = np.zeros((0,), dtype=np.int_) 67 | solution_indices_for_k = solution == k 68 | n = solution_indices_for_k.sum() 69 | 70 | # Check if there are any indexes for fold k 71 | if n > 0: 72 | # Get the descending sort indices for the similarities of fold k. 73 | # Lower similarities mean larger differences. 74 | sort_ix = np.flip(np.argsort(cs[k])) 75 | # Filter the solution indices that belong to fold k. 76 | sort_ix = sort_ix[is_in(sort_ix, arr_ix[solution_indices_for_k])] 77 | return sort_ix 78 | 79 | 80 | @jit(nopython=True) 81 | def cosine_similarity(problem: np.ndarray, cost_grad: np.ndarray) -> np.ndarray: 82 | """ 83 | Calculates the cosine similarity vector between the problem array 84 | and the cost gradient vector. 85 | :param problem: The problem array. 86 | :param cost_grad: The cost gradient matrix. 87 | :return: The cosine similarity vector. 88 | """ 89 | k = cost_grad.shape[0] 90 | s = np.zeros((k, problem.shape[0])) 91 | c = problem 92 | norm_c = np.zeros(problem.shape[0]) 93 | for i in range(problem.shape[0]): 94 | norm_c[i] = np.linalg.norm(c[i]) 95 | for i in range(k): 96 | g = cost_grad[i] 97 | a = np.dot(c, g) 98 | b = np.multiply(norm_c, np.linalg.norm(g)) 99 | s[i, :] = np.divide(a, b) 100 | return s 101 | 102 | 103 | @jit(nopython=True) 104 | def vector_similarity(v0: np.ndarray, v1: np.ndarray) -> float: 105 | """ 106 | Calculates the cosine similarity between two vectors. 107 | 108 | :param v0: Vector 109 | :param v1: Vector 110 | :return: Similarity scalar. 111 | """ 112 | a = np.dot(v0, v1) 113 | b = np.linalg.norm(v0) * np.linalg.norm(v1) 114 | return a / b * np.linalg.norm(v0 - v1) 115 | 116 | 117 | @jit(nopython=True) 118 | def get_lowest_similarity(cost_grads): 119 | """ 120 | Calculates the fold index pair with the lowest cost gradient 121 | cosine similarity. 122 | :param cost_grads: K-dimensional cost gradient vector. 123 | :return: Fold index pair with the lowest cosine similarity. 124 | """ 125 | sim = 1.0 126 | p = (-1, -1) 127 | n = cost_grads.shape[0] 128 | for i in range(n): 129 | for j in range(i + 1, n): 130 | s = vector_similarity(cost_grads[i], cost_grads[j]) 131 | if s < sim: 132 | sim = s 133 | p = (i, j) 134 | return p 135 | 136 | 137 | @jit(nopython=True) 138 | def get_similarities(cost_grads): 139 | """ 140 | Calculates the similarity array between all pairs of cost gradients 141 | :param cost_grads: The cost gradient matrix 142 | :return: The sorted similarity array (K,3) containing rows of 143 | [similarity, i, j] with i != j 144 | """ 145 | n = cost_grads.shape[0] 146 | k_count = int(n * (n - 1) / 2) 147 | sims = np.zeros((k_count, 3)) 148 | k = 0 149 | for i in range(n - 1): 150 | for j in range(i + 1, n): 151 | s = vector_similarity(cost_grads[i], cost_grads[j]) 152 | sims[k, 0] = s 153 | sims[k, 1] = i 154 | sims[k, 2] = j 155 | k += 1 156 | return sims[sims[:, 0].argsort()] 157 | 158 | 159 | @jit(nopython=True) 160 | def calculate_costs(problem: np.ndarray, solution: np.ndarray, k: int) -> np.ndarray: 161 | """ 162 | Calculates the cost vector for the given solution. 163 | 164 | :param problem: The problem matrix. 165 | :param solution: The solution vector. 166 | :param k: The number of folds. 167 | :return: The K-dimensional cost vector. 168 | """ 169 | c = problem.shape[1] 170 | costs = np.zeros(k) 171 | total_count = problem[:, 0].sum() 172 | 173 | for i in range(k): 174 | index = solution == i 175 | costs[i] = 0.5 * (problem[index, 0].sum() - total_count / k) ** 2 176 | stratum_sum = problem[index, 0].sum() 177 | for j in range(1, c): 178 | r = problem[:, j].sum() / total_count 179 | costs[i] += 0.5 * (problem[index, j].sum() - r * stratum_sum) ** 2 180 | return costs 181 | 182 | 183 | @jit(nopython=True) 184 | def calculate_cost(problem: np.ndarray, solution: np.ndarray, k: int) -> float: 185 | """ 186 | Calculates the overall cost as the L2 norm of the cost vector. 187 | 188 | :param problem: The problem matrix. 189 | :param solution: The solution vector. 190 | :param k: The number of folds. 191 | :return: The scalar cost. 192 | """ 193 | return np.linalg.norm(calculate_costs(problem, solution, k)) 194 | 195 | 196 | @jit(nopython=True) 197 | def calculate_cost_gradients(problem: np.ndarray, solution: np.ndarray, k: int) -> np.ndarray: 198 | """ 199 | Computes the K cost gradients. 200 | 201 | :param problem: The problem matrix. 202 | :param solution: The solution vector. 203 | :param k: The number of folds. 204 | :return: The (K,C) gradient matrix. 205 | """ 206 | c = problem.shape[1] 207 | gradients = np.zeros((k, c)) 208 | total_count = problem[:, 0].sum() 209 | 210 | for i in range(k): 211 | index = solution == i 212 | gradients[i, 0] = problem[index, 0].sum() - total_count / k 213 | stratum_sum = problem[index, 0].sum() 214 | for j in range(1, c): 215 | r = problem[:, j].sum() / total_count 216 | gradients[i, j] = problem[index, j].sum() - r * stratum_sum 217 | return gradients 218 | -------------------------------------------------------------------------------- /tools/history.py: -------------------------------------------------------------------------------- 1 | 2 | class History(object): 3 | 4 | def __init__(self, intensity=2): 5 | self.hist = [] 6 | self.read_count = 0 7 | self.intensity = intensity 8 | 9 | def __len__(self): 10 | return len(self.hist) 11 | 12 | def add(self, solution): 13 | self.hist.append(solution) 14 | 15 | def get(self): 16 | self.read_count += 1 17 | if self.read_count % self.intensity != 0: 18 | res = self.hist[-1] 19 | self.hist = self.hist[:-1] 20 | else: 21 | res = self.hist[0] 22 | self.hist = self.hist[1:] 23 | return res 24 | --------------------------------------------------------------------------------