├── LICENSE.md ├── README.md ├── assets └── anbncn.png ├── configuration.py ├── corpora.py ├── dfa.py ├── genetic_algorithm.py ├── island.py ├── main.py ├── manual_nets.py ├── mdlrnn_torch.py ├── network.py ├── requirements.txt ├── simulations.py ├── stats ├── tacl_stats.json └── tacl_stats.py ├── tests ├── test_corpus.py ├── test_genetic_algorithm.py └── test_network.py ├── torch_conversion.py ├── utils.py └── vanilla_rnn.py /LICENSE.md: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Nur Lan, Emmanuel Chemla, Roni Katzir 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Minimum Description Length Recurrent Neural Networks 2 | 3 | ![license](https://img.shields.io/badge/python-3.7_|_3.8_|_3.9-blue) 4 | ![license](https://img.shields.io/badge/license-MIT-green) 5 | ![code style](https://img.shields.io/badge/code_style-Black-black) 6 | [![arXiv](https://img.shields.io/badge/arXiv-2111.00600-b31b1b.svg)](https://arxiv.org/abs/2111.00600) 7 | 8 | Code for [Minimum Description Length Recurrent Neural Networks](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00489/112499/Minimum-Description-Length-Recurrent-Neural) by Nur Lan, Michal Geyer, Emmanuel Chemla, and Roni Katzir. 9 | 10 | Paper: https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00489/ 11 | 12 | 13 | 14 | ## Getting started 15 | 1. Install Python >= 3.7 16 | 2. `pip install -r requirements.txt` 17 | 18 | ### On Ubuntu, install: 19 | ``` 20 | $ apt-get install libsm6 libxext6 libxrender1 libffi-dev libopenmpi-dev 21 | ``` 22 | 23 | ## Running simulations 24 | 25 | ``` 26 | $ python main.py --simulation -n 27 | ``` 28 | 29 | For example, to run the `aⁿbⁿcⁿ` task using 16 island processes: 30 | ``` 31 | $ python main.py --simulation an_bn_cn -n 16 32 | ``` 33 | 34 | * All simulations are available in `simulations.py` 35 | 36 | * Final and intermediate solutions are saved to the `networks` sub-directory, both as `pickle` and in visual `dot` format. 37 | 38 | 39 | ## PyTorch conversion 40 | 41 | Converting a network trained using the genetic algorithm to a PyTorch module: 42 | 43 | ``` 44 | import torch_conversion 45 | 46 | with open("networks/net.pickle", "rb") as f: 47 | net = pickle.load(f) 48 | 49 | torch_net = torch_conversion.mdlnn_to_torch(net) 50 | ``` 51 | 52 | Then fine-tune and evaluate using [MDLRNN-torch](https://github.com/0xnurl/mdlrnn-torch). 53 | 54 | ## Parallelization 55 | 56 | Native Python multiprocessing is used by default. To use MPI, change `migration_channel` to `mpi` in `simulations.py`. 57 | 58 | ## Citing this work 59 | 60 | ``` 61 | @article{Lan-Geyer-Chemla-Katzir-MDLRNN-2022, 62 | title = {Minimum Description Length Recurrent Neural Networks}, 63 | author = {Lan, Nur and Geyer, Michal and Chemla, Emmanuel and Katzir, Roni}, 64 | year = {2022}, 65 | month = jul, 66 | journal = {Transactions of the Association for Computational Linguistics}, 67 | volume = {10}, 68 | pages = {785--799}, 69 | issn = {2307-387X}, 70 | doi = {10.1162/tacl_a_00489}, 71 | } 72 | ``` 73 | -------------------------------------------------------------------------------- /assets/anbncn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/taucompling/mdlrnn/4161677520f9200e8263fd5016f1843455991814/assets/anbncn.png -------------------------------------------------------------------------------- /configuration.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | from typing import Optional, Text, Tuple 3 | 4 | 5 | @dataclasses.dataclass(frozen=True) 6 | class SimulationConfig: 7 | simulation_id: Text 8 | num_islands: int 9 | migration_interval_seconds: int 10 | migration_interval_generations: int 11 | migration_ratio: float 12 | 13 | num_generations: int 14 | population_size: int 15 | elite_ratio: float 16 | mutation_probab: float 17 | allowed_activations: Tuple[int, ...] 18 | allowed_unit_types: Tuple[int, ...] 19 | start_smooth: bool 20 | 21 | max_network_units: int 22 | tournament_size: int 23 | 24 | grammar_multiplier: int 25 | data_given_grammar_multiplier: int 26 | 27 | compress_grammar_encoding: bool 28 | softmax_outputs: bool 29 | truncate_large_values: bool 30 | bias_connections: bool 31 | recurrent_connections: bool 32 | generation_dump_interval: int 33 | 34 | seed: int 35 | corpus_seed: int 36 | 37 | mini_batch_size: Optional[int] = None 38 | resumed_from_simulation_id: Optional[Text] = None 39 | comment: Optional[Text] = None 40 | parallelize: bool = True 41 | migration_channel: Text = "file" # {'file', 'mpi'} 42 | -------------------------------------------------------------------------------- /corpora.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import csv 3 | import dataclasses 4 | import itertools 5 | import logging 6 | import math 7 | import random 8 | from typing import Dict, FrozenSet, List, Optional, Text, Tuple, Union 9 | 10 | import configuration 11 | import dfa 12 | import numpy as np 13 | import utils 14 | 15 | _DEFAULT_CONFIG = configuration.SimulationConfig( 16 | simulation_id="test", 17 | num_islands=1, 18 | migration_ratio=0.1, 19 | migration_interval_seconds=20, 20 | migration_interval_generations=1000, 21 | num_generations=1000, 22 | population_size=20, 23 | elite_ratio=0.05, 24 | allowed_activations=(0, 1, 2, 3, 4, 5, 6,), 25 | start_smooth=False, 26 | allowed_unit_types=(0, 1,), 27 | tournament_size=4, 28 | mutation_probab=1.0, 29 | grammar_multiplier=1, 30 | data_given_grammar_multiplier=1, 31 | compress_grammar_encoding=False, 32 | max_network_units=1024, 33 | softmax_outputs=False, 34 | truncate_large_values=True, 35 | bias_connections=True, 36 | recurrent_connections=True, 37 | seed=1, 38 | corpus_seed=100, 39 | generation_dump_interval=1, 40 | parallelize=False, 41 | migration_channel="file", 42 | mini_batch_size=None, 43 | ) 44 | 45 | 46 | MASK_VALUE = np.nan 47 | 48 | _Vocabulary = Dict[int, Text] 49 | 50 | 51 | def is_masked(x: Union[np.ndarray, float]) -> Union[np.ndarray, bool]: 52 | return np.isnan(x) 53 | 54 | 55 | @dataclasses.dataclass(frozen=True) 56 | class Corpus: 57 | name: Text 58 | input_sequence: np.ndarray 59 | target_sequence: np.ndarray 60 | 61 | optimal_d_given_g: Optional[float] = None 62 | vocabulary: Optional[_Vocabulary] = None 63 | deterministic_steps_mask: Optional[np.ndarray] = None 64 | 65 | # Precomputed values for feeding efficiency. 66 | input_mask: Optional[np.ndarray] = None 67 | targets_mask: Optional[np.ndarray] = None 68 | input_values_per_time_step: Optional[Dict[int, List[np.ndarray]]] = None 69 | sample_weights: Optional[Tuple[int, ...]] = None 70 | 71 | test_corpus: Optional["Corpus"] = None 72 | 73 | 74 | def precompute_mask_idxs(corpus: Corpus) -> Corpus: 75 | masked = is_masked(corpus.input_sequence) 76 | input_mask = np.array( 77 | [ 78 | ~np.all(masked[i, j]) 79 | for (i, j) in np.ndindex(corpus.input_sequence.shape[:2]) 80 | ], 81 | dtype=np.bool, 82 | ).reshape(corpus.input_sequence.shape[:2]) 83 | return dataclasses.replace(corpus, input_mask=input_mask) 84 | 85 | 86 | def _precompute_input_unit_values(corpus: Corpus) -> Corpus: 87 | unit_to_timestep_val = {} 88 | for unit in range(corpus.input_sequence.shape[-1]): 89 | unit_to_timestep_val[unit] = [] 90 | for time_step in range(corpus.input_sequence.shape[1]): 91 | time_step_input = np.ascontiguousarray( 92 | corpus.input_sequence[:, time_step, unit] 93 | ) 94 | time_step_input.flags.writeable = False 95 | unit_to_timestep_val[unit].append(time_step_input) 96 | unit_to_timestep_val[unit] = tuple(unit_to_timestep_val[unit]) 97 | return dataclasses.replace(corpus, input_values_per_time_step=unit_to_timestep_val) 98 | 99 | 100 | def _precompute_targets_mask(corpus: Corpus) -> Corpus: 101 | if corpus.target_sequence.shape[-1] == 1: 102 | targets_mask = corpus.target_sequence == 1 103 | else: 104 | targets_mask = np.zeros_like(corpus.target_sequence, dtype=np.bool) 105 | target_classes = corpus.target_sequence.argmax(axis=-1).flatten() 106 | batch_idxs, time_idxs = tuple( 107 | zip(*np.ndindex(corpus.target_sequence.shape[:2])) 108 | ) 109 | targets_mask[batch_idxs, time_idxs, target_classes] = True 110 | return dataclasses.replace(corpus, targets_mask=targets_mask) 111 | 112 | 113 | def _make_inputs_read_only(corpus: Corpus) -> Corpus: 114 | corpus.input_sequence.flags.writeable = False 115 | return corpus 116 | 117 | 118 | def optimize_for_feeding(corpus: Corpus) -> Corpus: 119 | logging.info(f"Optimizing corpus for feeding...") 120 | corpus = _make_inputs_read_only(corpus) 121 | corpus = _precompute_targets_mask(corpus) 122 | corpus = precompute_mask_idxs(corpus) 123 | corpus = _precompute_input_unit_values(corpus) 124 | return corpus 125 | 126 | 127 | def make_random_binary(sequence_length: int = 100, batch_size: int = 1,) -> Corpus: 128 | return Corpus( 129 | "random_binary", 130 | input_sequence=np.random.randint(0, 2, size=(batch_size, sequence_length, 1)), 131 | target_sequence=np.random.randint(0, 2, size=(batch_size, sequence_length, 1)), 132 | ) 133 | 134 | 135 | def make_random_one_hot( 136 | num_input_classes: int, 137 | num_target_classes: int, 138 | sequence_length: int = 100, 139 | batch_size: int = 1, 140 | ) -> Corpus: 141 | input_classes = np.random.randint( 142 | 0, num_input_classes, size=(batch_size, sequence_length) 143 | ) 144 | target_classes = np.random.randint( 145 | 0, num_target_classes, size=(batch_size, sequence_length) 146 | ) 147 | return make_one_hot_corpus( 148 | "random_one_hot", 149 | input_classes=input_classes, 150 | target_classes=target_classes, 151 | num_input_classes=num_input_classes, 152 | num_target_classes=num_target_classes, 153 | ) 154 | 155 | 156 | def make_one_hot_corpus( 157 | name: Text, 158 | input_classes: Union[List, np.ndarray], 159 | target_classes: Union[List, np.ndarray], 160 | num_input_classes: int, 161 | num_target_classes: int, 162 | weights: Optional[Tuple[int, ...]] = None, 163 | vocabulary: Optional[_Vocabulary] = None, 164 | ) -> Corpus: 165 | return Corpus( 166 | name, 167 | input_sequence=_make_one_hot_sequence( 168 | np.array(input_classes), num_input_classes 169 | ), 170 | target_sequence=_make_one_hot_sequence( 171 | np.array(target_classes), num_target_classes 172 | ), 173 | sample_weights=weights, 174 | vocabulary=vocabulary, 175 | ) 176 | 177 | 178 | def _force_batch_dimension(arr: np.ndarray) -> np.ndarray: 179 | if arr.ndim == 1: 180 | return np.expand_dims(arr, axis=0) 181 | return arr 182 | 183 | 184 | def _make_one_hot_sequence(classes: np.ndarray, num_classes: int) -> np.ndarray: 185 | classes = _force_batch_dimension(classes) 186 | batch_size = classes.shape[0] 187 | sequence_length = classes.shape[1] 188 | 189 | one_hot = np.zeros( 190 | (batch_size, sequence_length, num_classes), dtype=utils.FLOAT_DTYPE, order="C" 191 | ) 192 | 193 | for b in range(batch_size): 194 | for s in range(sequence_length): 195 | c = classes[b, s] 196 | if is_masked(c): 197 | one_hot[b, s] = MASK_VALUE 198 | else: 199 | one_hot[b, s, int(c)] = 1.0 200 | return one_hot 201 | 202 | 203 | def make_between_dfa(start: int, end: int) -> dfa.DFA: 204 | final_state = end + 1 205 | transitions = {} 206 | for i in range(start): 207 | transitions[i] = {"0": i, "1": i + 1} 208 | for i in range(start, end): 209 | transitions[i] = {"0": i, "1": i + 1, dfa.END_OF_SEQUENCE: final_state} 210 | transitions[end] = {"0": end, dfa.END_OF_SEQUENCE: final_state} 211 | return dfa.DFA(transitions=transitions, accepting_states={final_state}) 212 | 213 | 214 | def make_at_least_dfa(n: int) -> dfa.DFA: 215 | transitions = {} 216 | for i in range(n): 217 | transitions[i] = {"0": i, "1": i + 1} 218 | transitions[n] = {"0": n, "1": n, dfa.END_OF_SEQUENCE: n + 1} 219 | return dfa.DFA(transitions=transitions, accepting_states={n + 1}) 220 | 221 | 222 | def _make_at_most_dfa(n: int) -> dfa.DFA: 223 | transitions = {} 224 | accepting_state = n + 1 225 | for i in range(n): 226 | transitions[i] = {"0": i, "1": i + 1, dfa.END_OF_SEQUENCE: accepting_state} 227 | transitions[n] = {"0": n, dfa.END_OF_SEQUENCE: accepting_state} 228 | return dfa.DFA(transitions=transitions, accepting_states={accepting_state}) 229 | 230 | 231 | def _dfa_to_inputs( 232 | dfa_: dfa.DFA, batch_size: int, end_of_sequence_char: int, max_sequence_length: int, 233 | ) -> np.ndarray: 234 | batch = np.empty((batch_size, max_sequence_length)) 235 | batch.fill(MASK_VALUE) 236 | 237 | for b in range(batch_size): 238 | input_idx = 0 239 | while True: 240 | string = dfa_.generate_string() 241 | if len(string) > max_sequence_length: 242 | continue 243 | if input_idx + len(string) > max_sequence_length: 244 | break 245 | for i, char in enumerate(string): 246 | if char == dfa.END_OF_SEQUENCE: 247 | char_int = end_of_sequence_char 248 | else: 249 | char_int = int(char) 250 | batch[b, input_idx + i] = char_int 251 | input_idx += len(string) 252 | 253 | return batch 254 | 255 | 256 | def make_identity( 257 | sequence_length: int = 1000, batch_size: int = 10, num_classes: int = 2 258 | ) -> Corpus: 259 | sequence = np.random.randint( 260 | num_classes, size=(batch_size, sequence_length) 261 | ).astype(utils.FLOAT_DTYPE) 262 | return make_one_hot_corpus("identity", sequence, sequence, num_classes, num_classes) 263 | 264 | 265 | def make_identity_binary(sequence_length: int, batch_size: int) -> Corpus: 266 | sequence = np.random.randint(2, size=(batch_size, sequence_length, 1)).astype( 267 | utils.FLOAT_DTYPE 268 | ) 269 | input_sequence = np.copy(sequence) 270 | input_sequence[sequence == 0] = -1 271 | return Corpus("identity", input_sequence, sequence) 272 | 273 | 274 | def _make_repetition_sequence( 275 | input_sequence: np.ndarray, offset: int, padding: float = 0.0 276 | ) -> np.ndarray: 277 | """[a,b,c,d, ..., y,z] -> [,a,b,c, ..., y] """ 278 | assert input_sequence.ndim == 2 279 | batch_size = input_sequence.shape[0] 280 | padded_arr = np.empty((batch_size, offset), dtype=utils.FLOAT_DTYPE) 281 | padded_arr.fill(padding) 282 | return np.concatenate((padded_arr, input_sequence[:, :-offset],), axis=-1,) 283 | 284 | 285 | def _make_prediction_sequence( 286 | input_sequence: np.ndarray, lookahead: int = 1, padding: float = 0.0 287 | ) -> np.ndarray: 288 | """[a,b,c,d, ..., y,z] -> [b,c,d,e, ..., z,] """ 289 | input_sequence = _force_batch_dimension(input_sequence) 290 | assert input_sequence.ndim == 2 291 | batch_size = input_sequence.shape[0] 292 | padded_arr = np.empty((batch_size, lookahead), dtype=utils.FLOAT_DTYPE) 293 | padded_arr.fill(padding) 294 | return np.concatenate((input_sequence[:, lookahead:], padded_arr), axis=-1) 295 | 296 | 297 | def make_prev_char_repetition( 298 | sequence_length: int = 1000, 299 | batch_size: int = 10, 300 | repetition_offset: int = 1, 301 | num_classes: int = 2, 302 | ) -> Corpus: 303 | input_sequence = np.random.randint(num_classes, size=(batch_size, sequence_length)) 304 | target_sequence = _make_repetition_sequence(input_sequence, repetition_offset) 305 | 306 | return make_one_hot_corpus( 307 | f"repeat_prev_{repetition_offset}_char", 308 | input_sequence, 309 | target_sequence, 310 | num_classes, 311 | num_classes, 312 | ) 313 | 314 | 315 | def make_prev_char_repetition_binary( 316 | sequence_length: int, batch_size: int, repetition_offset: int, 317 | ) -> Corpus: 318 | input_sequence = np.random.randint(2, size=(batch_size, sequence_length)).astype( 319 | utils.FLOAT_DTYPE 320 | ) 321 | target_sequence = _make_repetition_sequence(input_sequence, repetition_offset) 322 | return Corpus( 323 | f"repeat_prev_{repetition_offset}_char_binary", 324 | np.expand_dims(input_sequence, -1), 325 | np.expand_dims(target_sequence, -1), 326 | ) 327 | 328 | 329 | def make_elman_xor_binary(sequence_length: int = 3000, batch_size: int = 1) -> Corpus: 330 | assert sequence_length % 3 == 0 331 | 332 | input_batch = [] 333 | target_batch = [] 334 | for b in range(batch_size): 335 | sequence = [] 336 | for pair_idx in range(sequence_length // 3): 337 | a, b = random.choice([0, 1]), random.choice([0, 1]) 338 | sequence += [a, b, a ^ b] 339 | input_batch.append(sequence) 340 | # Target output is the next character of input 341 | target_sequence = sequence[1:] + [0] 342 | target_batch.append(target_sequence) 343 | input_batch = np.expand_dims( 344 | np.array(input_batch, dtype=utils.FLOAT_DTYPE), axis=-1 345 | ) 346 | target_batch = np.expand_dims( 347 | np.array(target_batch, dtype=utils.FLOAT_DTYPE), axis=-1 348 | ) 349 | return Corpus("elman_xor_binary", input_batch, target_batch) 350 | 351 | 352 | def make_elman_xor_one_hot(sequence_length: int = 3000, batch_size: int = 1) -> Corpus: 353 | binary_corpus = make_elman_xor_binary(sequence_length, batch_size) 354 | 355 | return make_one_hot_corpus( 356 | "elman_xor_one_hot", 357 | binary_corpus.input_sequence, 358 | binary_corpus.target_sequence, 359 | num_input_classes=2, 360 | num_target_classes=2, 361 | ) 362 | 363 | 364 | def make_semi_random_corpus(sequence_length: int = 100, batch_size: int = 10) -> Corpus: 365 | """ One random bit, one identical bit, e.g.: [0,0,0,0,1,1,0,0,1,1,0,0, ...] """ 366 | assert sequence_length % 2 == 0 367 | input_batch = [] 368 | target_batch = [] 369 | for _ in range(batch_size): 370 | sequence = [] 371 | for s in range(sequence_length // 2): 372 | sequence += [random.randrange(2)] * 2 373 | input_batch.append(sequence) 374 | target_sequence = sequence[1:] + [0] 375 | target_batch.append(target_sequence) 376 | 377 | input_batch = np.expand_dims(np.array(input_batch), axis=-1) 378 | target_batch = np.expand_dims(np.array(target_batch), axis=-1) 379 | return Corpus("semi_random_pairs", input_batch, target_batch) 380 | 381 | 382 | def make_elman_badigu(num_consonants: int = 1000) -> Corpus: 383 | # ba, dii, guuu 384 | feature_table = { 385 | # cons, vowel, int, high, back, voiced 386 | "b": [1, 0, 1, 0, 0, 1], # b 387 | "d": [1, 0, 1, 1, 0, 1], # d 388 | "g": [1, 0, 1, 0, 1, 1], # g 389 | "a": [0, 1, 0, 0, 1, 1], # a 390 | "i": [0, 1, 0, 1, 0, 1], # i 391 | "u": [0, 1, 0, 1, 1, 1], # u 392 | } 393 | segments = list("bdgaiu") 394 | num_classes = len(segments) 395 | segment_to_idx = {x: i for i, x in enumerate(segments)} 396 | 397 | consonant_to_sequence = {"b": list("ba"), "d": list("dii"), "g": list("guuu")} 398 | consonant_sequence = np.random.choice(["b", "d", "g"], size=num_consonants) 399 | 400 | letters_sequence = list( 401 | itertools.chain(*(consonant_to_sequence[c] for c in consonant_sequence)) 402 | ) 403 | input_sequence = [segment_to_idx[x] for x in letters_sequence] 404 | target_sequence = input_sequence[1:] + [0] 405 | 406 | logging.info(f"Elman badigu sequence: {letters_sequence}") 407 | consonants = tuple("bdg") 408 | 409 | consonant_percentage = len([x for x in letters_sequence if x in consonants]) / len( 410 | letters_sequence 411 | ) 412 | 413 | logging.info(f"Max accuracy for task: {1-consonant_percentage:.2f}") 414 | 415 | return make_one_hot_corpus( 416 | "elman_badigu", input_sequence, target_sequence, num_classes, num_classes 417 | ) 418 | 419 | 420 | def _make_0_1_pattern_binary(sequence_length: int, batch_size: int) -> Corpus: 421 | assert sequence_length % 2 == 0 422 | 423 | input_seq = np.array([[0, 1] * (sequence_length // 2)], dtype=utils.FLOAT_DTYPE) 424 | target_seq = _make_prediction_sequence(input_seq, lookahead=1, padding=0.0) 425 | 426 | input_seq = np.expand_dims(input_seq, axis=2) 427 | target_seq = np.expand_dims(target_seq, axis=2) 428 | 429 | return Corpus( 430 | name=f"0_1_pattern_binary_length_{sequence_length}_batch_{batch_size}", 431 | input_sequence=input_seq, 432 | target_sequence=target_seq, 433 | optimal_d_given_g=0.0, 434 | vocabulary={0: "1", 1: "1"}, 435 | sample_weights=(batch_size,) if batch_size > 1 else None, 436 | ) 437 | 438 | 439 | def make_0_1_pattern_binary(sequence_length: int, batch_size: int) -> Corpus: 440 | train_corpus = _make_0_1_pattern_binary(sequence_length, batch_size) 441 | test_corpus = _make_0_1_pattern_binary( 442 | sequence_length=sequence_length * 50_000, batch_size=1 443 | ) 444 | return dataclasses.replace(train_corpus, test_corpus=test_corpus) 445 | 446 | 447 | def _make_0_1_pattern_one_hot( 448 | sequence_length: int, add_end_of_sequence: bool, batch_size: int 449 | ) -> Corpus: 450 | assert sequence_length % 2 == 0 451 | 452 | input_classes = [0, 1] * (sequence_length // 2) 453 | num_classes = 2 454 | vocabulary = {0: "0", 1: "1"} 455 | 456 | if add_end_of_sequence: 457 | num_classes = 3 458 | input_classes += [2] 459 | vocabulary[2] = "#" 460 | 461 | vocabulary.update({x + len(vocabulary): vocabulary[x] for x in vocabulary}) 462 | 463 | input_classes_arr = np.array([input_classes]) 464 | target_classes_arr = _make_prediction_sequence( 465 | input_classes_arr, lookahead=1, padding=0.0 466 | ) 467 | 468 | corpus = make_one_hot_corpus( 469 | name=f"0_1_pattern_one_hot_length_{sequence_length}_batch_{batch_size}{'_eos' if add_end_of_sequence else ''}", 470 | input_classes=input_classes_arr, 471 | target_classes=target_classes_arr, 472 | num_input_classes=num_classes, 473 | num_target_classes=num_classes, 474 | weights=(batch_size,) if batch_size > 1 else None, 475 | vocabulary=vocabulary, 476 | ) 477 | return dataclasses.replace( 478 | # TODO: calculate optimal D 479 | corpus, 480 | optimal_d_given_g=0.0, 481 | ) 482 | 483 | 484 | def make_0_1_pattern_one_hot( 485 | sequence_length: int, add_end_of_sequence: bool, batch_size: int 486 | ) -> Corpus: 487 | train_corpus = _make_0_1_pattern_one_hot( 488 | sequence_length, add_end_of_sequence, batch_size 489 | ) 490 | test_corpus = _make_0_1_pattern_one_hot( 491 | sequence_length=sequence_length * 50_000, 492 | add_end_of_sequence=add_end_of_sequence, 493 | batch_size=1, 494 | ) 495 | return dataclasses.replace(train_corpus, test_corpus=test_corpus) 496 | 497 | 498 | def make_123_n_pattern_corpus( 499 | base_sequence_length: int = 3, sequence_length: int = 100 500 | ): 501 | # [0,1,2, ..., n-1] repeated 502 | assert sequence_length % base_sequence_length == 0 503 | input_sequence = np.array( 504 | list(range(base_sequence_length)) * (sequence_length // base_sequence_length) 505 | ) 506 | target_sequence = _make_prediction_sequence(input_sequence, lookahead=1) 507 | return make_one_hot_corpus( 508 | f"1_to_{base_sequence_length}_pattern", 509 | input_sequence, 510 | target_sequence, 511 | num_input_classes=base_sequence_length, 512 | num_target_classes=base_sequence_length, 513 | ) 514 | 515 | 516 | def make_between_quantifier( 517 | start: int, end: int, sequence_length: int = 100, batch_size: int = 1 518 | ) -> Corpus: 519 | assert end <= sequence_length 520 | between_dfa = make_between_dfa(start, end) 521 | between_dfa.visualize(f"between_{start}_{end}_dfa") 522 | input_batch = _dfa_to_inputs( 523 | between_dfa, 524 | batch_size=batch_size, 525 | end_of_sequence_char=2, 526 | max_sequence_length=sequence_length, 527 | ) 528 | target_batch = _make_prediction_sequence(input_batch, lookahead=1) 529 | if start == end: 530 | name = f"exactly_{start}" 531 | else: 532 | name = f"between_{start}_{end}" 533 | num_classes = 3 534 | return make_one_hot_corpus( 535 | name, input_batch, target_batch, num_classes, num_classes 536 | ) 537 | 538 | 539 | def make_exactly_n_quantifier( 540 | n: int = 1, sequence_length: int = 100, batch_size: int = 1 541 | ) -> Corpus: 542 | return make_between_quantifier( 543 | start=n, end=n, sequence_length=sequence_length, batch_size=batch_size 544 | ) 545 | 546 | 547 | def make_at_least_quantifier( 548 | n: int = 1, sequence_length: int = 100, batch_size: int = 1 549 | ) -> Corpus: 550 | name = f"at_least_{n}" 551 | at_least_dfa = make_at_least_dfa(n) 552 | at_least_dfa.visualize(name) 553 | input_batch = _dfa_to_inputs( 554 | at_least_dfa, 555 | batch_size=batch_size, 556 | end_of_sequence_char=2, 557 | max_sequence_length=sequence_length, 558 | ) 559 | target_batch = _make_prediction_sequence(input_batch, lookahead=1) 560 | num_classes = 3 561 | return make_one_hot_corpus( 562 | name, input_batch, target_batch, num_classes, num_classes 563 | ) 564 | 565 | 566 | def make_at_most_quantifier( 567 | n: int = 1, sequence_length: int = 100, batch_size: int = 1 568 | ) -> Corpus: 569 | name = f"at_most_{n}" 570 | at_most_dfa = _make_at_most_dfa(n) 571 | at_most_dfa.visualize(name) 572 | input_batch = _dfa_to_inputs( 573 | at_most_dfa, 574 | batch_size=batch_size, 575 | end_of_sequence_char=2, 576 | max_sequence_length=sequence_length, 577 | ) 578 | target_batch = _make_prediction_sequence(input_batch, lookahead=1) 579 | num_classes = 3 580 | return make_one_hot_corpus( 581 | name, input_batch, target_batch, num_classes, num_classes 582 | ) 583 | 584 | 585 | def make_every_quantifier(sequence_length: int = 100, batch_size: int = 1) -> Corpus: 586 | input_batch = np.ones((batch_size, sequence_length)) 587 | return make_one_hot_corpus(f"every_quantifier", input_batch, input_batch, 2, 2) 588 | 589 | 590 | def _int_to_classes(n: int) -> List[int]: 591 | return list(reversed(list(map(int, list(str(n)))))) 592 | 593 | 594 | def make_count_corpus(max_int: 100, batch_size: int = 100) -> Corpus: 595 | # Predict n+1 from n in a language-model setting. 596 | sequence_length = int(np.floor(np.log10(max_int))) + 1 597 | input_classes = np.zeros((batch_size, sequence_length)) 598 | target_classes = np.zeros((batch_size, sequence_length)) 599 | 600 | for b in range(batch_size): 601 | n = random.randrange(max_int) 602 | input_ = _int_to_classes(n) 603 | target = _int_to_classes(n + 1) 604 | input_classes[b, : len(input_)] = input_ 605 | target_classes[b, : len(target)] = target 606 | 607 | return make_one_hot_corpus( 608 | f"count_to_{max_int}", 609 | input_classes, 610 | target_classes, 611 | num_input_classes=10, 612 | num_target_classes=10, 613 | ) 614 | 615 | 616 | def base10_to_binary_vector(n: int, sequence_length=None) -> np.ndarray: 617 | """8 -> [0,0,0,1], 7 -> [1,1,1] """ 618 | if n == 0: 619 | return np.zeros(sequence_length) 620 | 621 | powers = [] 622 | while n: 623 | power = int(np.floor(np.log2(n))) 624 | powers.append(power) 625 | n -= 2 ** power 626 | 627 | rightmost_one_position = int(max(powers)) 628 | if sequence_length is None: 629 | sequence_length = rightmost_one_position + 1 630 | binary = np.zeros(sequence_length) 631 | binary[powers] = 1.0 632 | # TODO: mask redundant positions? 633 | return binary 634 | 635 | 636 | def _make_binary_addition_corpus(min_n: int, max_n: int): 637 | all_summands = tuple(itertools.product(range(min_n, max_n), repeat=2)) 638 | summands = [] 639 | for b, (n1, n2) in enumerate(all_summands): 640 | summands.append([n1, n2]) 641 | summands = np.array(summands) 642 | sums = np.sum(summands, axis=1) 643 | sequence_length = int(np.ceil(np.log2(np.max(sums)))) + 1 644 | 645 | summand_binaries = [] 646 | sum_binaries = [] 647 | for (n1, n2), sum_ in zip(summands, sums): 648 | summand_binaries.append( 649 | [ 650 | base10_to_binary_vector(n1, sequence_length), 651 | base10_to_binary_vector(n2, sequence_length), 652 | ] 653 | ) 654 | sum_binaries.append(base10_to_binary_vector(sum_, sequence_length)) 655 | 656 | summand_inputs = np.array( 657 | [np.stack(summands, axis=1,) for summands in summand_binaries] 658 | ) 659 | 660 | sum_outputs = np.expand_dims(np.stack(sum_binaries), axis=-1) 661 | 662 | return dataclasses.replace( 663 | Corpus( 664 | name=f"binary_addition_{min_n}_to_{max_n}", 665 | input_sequence=summand_inputs, 666 | target_sequence=sum_outputs, 667 | ), 668 | optimal_d_given_g=0.0, 669 | ) 670 | 671 | 672 | def make_binary_addition(min_n: int, max_n: int) -> Corpus: 673 | training_corpus = _make_binary_addition_corpus(min_n=min_n, max_n=max_n) 674 | test_corpus = _make_binary_addition_corpus(min_n=max_n + 1, max_n=max_n + 251) 675 | training_corpus = dataclasses.replace(training_corpus, test_corpus=test_corpus) 676 | return training_corpus 677 | 678 | 679 | def an_bn_handmade_net(input_seq: np.ndarray, prior: float): 680 | # Optimal network according to Schimdhuber (2001). 681 | outputs = np.zeros_like(input_seq) 682 | for b in range(input_seq.shape[0]): 683 | num_seen_a = 0 684 | for t in range(input_seq.shape[1]): 685 | input_vec = input_seq[b, t] 686 | input_class = input_vec.argmax() 687 | if input_class == 0: 688 | # Start of sequence symbol, always predict "a" (no empty string in current corpus). 689 | outputs[b, t] = [0.0, 1.0, 0.0] 690 | elif input_class == 1: 691 | # "a". 692 | num_seen_a += 1 693 | outputs[b, t] = [0.0, 1 - prior, prior] 694 | elif input_class == 2: 695 | # "b". 696 | num_seen_a -= 1 697 | if num_seen_a > 0: 698 | outputs[b, t] = [0.0, 0.0, 1.0] 699 | else: 700 | outputs[b, t] = [1.0, 0.0, 0.0] 701 | 702 | return outputs 703 | 704 | 705 | def make_english_onset_phonotactics(split_ratio: Optional[float] = None): 706 | PHONOTACTIC_COUNTS = { 707 | "k": 2764, 708 | "r": 2752, 709 | "d": 2526, 710 | "s": 2215, 711 | "m": 1965, 712 | "p": 1881, 713 | "b": 1544, 714 | "l": 1225, 715 | "f": 1222, 716 | "h": 1153, 717 | "t": 1146, 718 | "pr": 1046, 719 | "w": 780, 720 | "n": 716, 721 | "v": 615, 722 | "g": 537, 723 | "dÇ": 524, 724 | "st": 521, 725 | "tr": 515, 726 | "kr": 387, 727 | "+": 379, 728 | "gr": 331, 729 | "t+": 329, 730 | "br": 319, 731 | "sp": 313, 732 | "fl": 290, 733 | "kl": 285, 734 | "sk": 278, 735 | "j": 268, 736 | "fr": 254, 737 | "pl": 238, 738 | "bl": 213, 739 | "sl": 213, 740 | "dr": 211, 741 | "kw": 201, 742 | "str": 183, 743 | "‡": 173, 744 | "sw": 153, 745 | "gl": 131, 746 | "hw": 111, 747 | "sn": 109, 748 | "skr": 93, 749 | "z": 83, 750 | "sm": 82, 751 | "‡r": 73, 752 | "skw": 69, 753 | "tw": 55, 754 | "spr": 51, 755 | "+r": 40, 756 | "spl": 27, 757 | "L": 19, 758 | "dw": 17, 759 | "gw": 11, 760 | "‡w": 4, 761 | "skl": 1, 762 | } 763 | name = "english_onset_phonotactics" 764 | inputs, targets, num_classes, weights = _make_phonotactic_corpus( 765 | phonotactic_counts=PHONOTACTIC_COUNTS 766 | ) 767 | if not split_ratio: 768 | return make_one_hot_corpus( 769 | name, inputs, targets, num_classes, num_classes, weights 770 | ) 771 | train_inputs, test_inputs = split_train_and_test(inputs, split_ratio) 772 | train_targets, test_targets = split_train_and_test(targets, split_ratio) 773 | return Corpus( 774 | name=f"{name}_{split_ratio}_split", 775 | input_sequence=_make_one_hot_sequence(train_inputs, num_classes), 776 | target_sequence=_make_one_hot_sequence(train_targets, num_classes), 777 | test_input_sequence=_make_one_hot_sequence(test_inputs, num_classes), 778 | test_target_sequence=_make_one_hot_sequence(test_targets, num_classes), 779 | sample_weights=weights, 780 | ) 781 | 782 | 783 | def _load_futrell_phonotactic_data(filename: Text) -> Dict[Text, int]: 784 | word_to_count = {} 785 | with open(filename, "r") as f: 786 | reader = csv.DictReader(f) 787 | for row in reader: 788 | word_to_count[row["Phonology"]] = int(row["LemmaFrequency"]) 789 | return word_to_count 790 | 791 | 792 | def make_futrell_german_phonotactic_corpus(split_ratio: float): 793 | inputs, targets, num_classes, weights = _make_phonotactic_corpus( 794 | _load_futrell_phonotactic_data("German.Dict.CSV") 795 | ) 796 | train_inputs, test_inputs = split_train_and_test(inputs, split_ratio) 797 | train_targets, test_targets = split_train_and_test(targets, split_ratio) 798 | return Corpus( 799 | name=f"german_phonotactics_{split_ratio}_split", 800 | input_sequence=_make_one_hot_sequence(train_inputs, num_classes), 801 | target_sequence=_make_one_hot_sequence(train_targets, num_classes), 802 | test_input_sequence=_make_one_hot_sequence(test_inputs, num_classes), 803 | test_target_sequence=_make_one_hot_sequence(test_targets, num_classes), 804 | sample_weights=weights, 805 | ) 806 | 807 | 808 | def _make_phonotactic_corpus(phonotactic_counts: Dict[Text, int]): 809 | segments = sorted(list(set("".join(phonotactic_counts)))) 810 | segment_to_index = {segment: i for i, segment in enumerate(segments)} 811 | sequences = list(phonotactic_counts.keys()) 812 | max_sequence_len = max([len(x) for x in sequences]) 813 | segments = list(segment_to_index) 814 | weights = tuple(phonotactic_counts[seq] for seq in sequences) 815 | inputs = np.empty((len(sequences), max_sequence_len + 1)) 816 | targets = np.empty_like(inputs) 817 | inputs.fill(MASK_VALUE) 818 | targets.fill(MASK_VALUE) 819 | for i in range(len(sequences)): 820 | # Start-of-sequence symbol 821 | inputs[i, 0] = len(segments) 822 | for j in range(len(sequences[i])): 823 | inputs[i, j + 1] = segment_to_index[sequences[i][j]] 824 | targets[i, j] = segment_to_index[sequences[i][j]] 825 | # End-of-sequence symbol 826 | targets[i, len(sequences[i])] = len(segments) 827 | return inputs, targets, len(segments) + 1, weights 828 | 829 | 830 | class _DerivationTooLong(Exception): 831 | pass 832 | 833 | 834 | def _generate_string_from_pcfg( 835 | pcfg: Dict, max_length: Optional[int] = None 836 | ) -> Tuple[Text, ...]: 837 | """Stops when all generated characters are terminals. 838 | To stop without adding an epsilon terminal, use the empty string '', i.e. add a rule `S->''`. """ 839 | stack = ["S"] 840 | terminals = [] 841 | while stack: 842 | node = stack[0] 843 | stack = stack[1:] 844 | 845 | if node not in pcfg: 846 | terminals.append(node) 847 | if max_length is not None and len(terminals) > max_length: 848 | raise _DerivationTooLong 849 | continue 850 | 851 | rules, probabs = list(zip(*pcfg[node])) 852 | rule_idx = np.random.choice(len(rules), p=probabs) 853 | rule = rules[rule_idx] 854 | 855 | stack = list(rule) + stack 856 | 857 | return tuple(terminals) 858 | 859 | 860 | def _make_corpus_from_pcfg( 861 | name: Text, 862 | pcfg: Dict, 863 | batch_size: int, 864 | max_derivation_length: Optional[int] = None, 865 | sort_by_length: bool = False, 866 | ) -> Corpus: 867 | sequences = [] 868 | while len(sequences) < batch_size: 869 | try: 870 | sequence = _generate_string_from_pcfg( 871 | pcfg, max_length=max_derivation_length 872 | ) 873 | except _DerivationTooLong: 874 | continue 875 | sequences.append(sequence) 876 | 877 | if sort_by_length: 878 | sequences = sorted(sequences, key=len, reverse=True) 879 | 880 | lengths = list(map(len, sequences)) 881 | 882 | sequence_counts = collections.Counter(sequences) 883 | unique_sequences, weights = tuple(zip(*sequence_counts.items())) 884 | 885 | logging.info(f"PCFG sum of sequence lengths: {sum(lengths)}") 886 | logging.info(f"PCFG max sequence length: {max(lengths)}") 887 | logging.info(f"PCFG mean sequence length: {np.mean(lengths)}") 888 | logging.info( 889 | f"PCFG unique sequences: {len(unique_sequences)}/{len(sequences)} ({len(unique_sequences)/len(sequences):.2f})" 890 | ) 891 | 892 | alphabet = set() 893 | for rules in pcfg.values(): 894 | alphabet |= set(itertools.chain(*set(map(lambda x: x[0], rules)))) 895 | alphabet -= set(pcfg.keys()) 896 | alphabet = ("#",) + tuple(sorted(alphabet)) 897 | 898 | symbol_to_idx = {x: i for i, x in enumerate(alphabet)} 899 | 900 | max_seq_length = max(map(len, unique_sequences)) 901 | input_classes = np.empty((len(unique_sequences), max_seq_length + 1)) 902 | target_classes = np.empty_like(input_classes) 903 | input_classes.fill(MASK_VALUE) 904 | target_classes.fill(MASK_VALUE) 905 | 906 | for i, sequence in enumerate(unique_sequences): 907 | sequence_classes = [symbol_to_idx[symbol] for symbol in sequence] 908 | input_row = [symbol_to_idx["#"]] + sequence_classes 909 | target_row = sequence_classes + [symbol_to_idx["#"]] 910 | input_classes[i, : len(sequence_classes) + 1] = input_row 911 | target_classes[i, : len(sequence_classes) + 1] = target_row 912 | 913 | inputs = _make_one_hot_sequence(input_classes, num_classes=len(alphabet)) 914 | targets = _make_one_hot_sequence(target_classes, num_classes=len(alphabet)) 915 | 916 | vocabulary = _make_identical_input_output_vocabulary(alphabet) 917 | 918 | return Corpus( 919 | name=name, 920 | input_sequence=inputs, 921 | target_sequence=targets, 922 | sample_weights=weights, 923 | vocabulary=vocabulary, 924 | ) 925 | 926 | 927 | def make_center_embedding( 928 | batch_size: int, embedding_depth_probab: float, dependency_distance_probab: float 929 | ) -> Corpus: 930 | pcfg = { 931 | "S": ( 932 | (("NP_s", "VP_s"), (1 - embedding_depth_probab) / 2), 933 | (("NP_p", "VP_p"), (1 - embedding_depth_probab) / 2), 934 | (("NP_s", "S", "VP_s"), embedding_depth_probab / 2), 935 | (("NP_p", "S", "VP_p"), embedding_depth_probab / 2), 936 | ), 937 | "NP_s": ( 938 | (("N_s",), 1 - dependency_distance_probab), 939 | # (("A", "NP_s"), dependency_distance_probab), 940 | ), 941 | "NP_p": ( 942 | (("N_p",), 1 - dependency_distance_probab), 943 | # (("A", "NP_p"), dependency_distance_probab), 944 | ), 945 | "VP_s": ( 946 | (("V_s",), 1 - dependency_distance_probab), 947 | # (("A", "VP_s"), dependency_distance_probab), 948 | ), 949 | "VP_p": ( 950 | (("V_p",), 1 - dependency_distance_probab), 951 | # (("A", "VP_p"), dependency_distance_probab), 952 | ), 953 | "N_s": ( 954 | (("cat",), 1.0), 955 | # (("dog",), 1.0), 956 | # (("horse",), 0.2), 957 | # (("rat",), 0.2), 958 | # (("flower",), 0.2), 959 | ), 960 | "N_p": ( 961 | (("cats",), 1.0), 962 | # (("dogs",), 1.0), 963 | # (("horses",), 0.2), 964 | # (("rats",), 0.2), 965 | # (("flowers",), 0.2), 966 | ), 967 | "V_s": ( 968 | (("runs",), 1.0), 969 | # (("talks",), 1.0), 970 | # (("dances",), 0.2), 971 | # (("eats",), 0.2), 972 | # (("drinks",), 0.2), 973 | ), 974 | "V_p": ( 975 | (("run",), 1.0), 976 | # (("talk",), 1.0), 977 | # (("dance",), 0.2), 978 | # (("eat",), 0.2), 979 | # (("drink",), 0.2), 980 | ), 981 | # "A": ( 982 | # (("good",), 0.5), 983 | # (("bad",), 0.5), 984 | # (("nice",), 0.2), 985 | # (("smart",), 0.2), 986 | # (("funny",), 0.2), 987 | # ), 988 | } 989 | corpus = _make_corpus_from_pcfg( 990 | f"center_embedding_pcfg_embedding_{embedding_depth_probab}_distance_{dependency_distance_probab}", 991 | pcfg=pcfg, 992 | batch_size=batch_size, 993 | ) 994 | input_classes = np.argmax(corpus.input_sequence, axis=-1) 995 | deterministic_steps_mask = (~np.all(is_masked(corpus.input_sequence), axis=-1)) & ( 996 | # "cat/s" 997 | (input_classes == 3) 998 | | (input_classes == 4) 999 | ) 1000 | return dataclasses.replace( 1001 | corpus, deterministic_steps_mask=deterministic_steps_mask 1002 | ) 1003 | 1004 | 1005 | def make_palindrome_with_middle_marker_distinct(batch_size: int, nesting_probab: float): 1006 | pcfg = { 1007 | "S": ( 1008 | (("0", "S", "0"), nesting_probab / 2), 1009 | (("1", "S", "1"), nesting_probab / 2), 1010 | (("@",), 1 - nesting_probab), 1011 | ) 1012 | } 1013 | return _make_corpus_from_pcfg( 1014 | name=f"palindrome_middle_marker__batch_{batch_size}__p_{nesting_probab}", 1015 | pcfg=pcfg, 1016 | batch_size=batch_size, 1017 | ) 1018 | 1019 | 1020 | def _optimal_d_g_for_fixed_palindrome(corpus) -> float: 1021 | sequence_length = corpus.input_sequence.shape[1] 1022 | batch_size = sum(corpus.sample_weights) 1023 | deterministic_length = sequence_length // 2 1024 | return batch_size * deterministic_length 1025 | 1026 | 1027 | def make_binary_palindrome_fixed_length( 1028 | batch_size: int, sequence_length: int, train_set_ratio: float 1029 | ) -> Corpus: 1030 | assert sequence_length % 2 == 0 1031 | prefixes_non_unique = np.random.randint( 1032 | 2, size=(batch_size, sequence_length // 2) 1033 | ).astype(utils.FLOAT_DTYPE) 1034 | 1035 | sequence_counts = collections.Counter(list(map(tuple, prefixes_non_unique))) 1036 | unique_prefixes, weights = list(zip(*sequence_counts.items())) 1037 | 1038 | prefixes = np.array(unique_prefixes) 1039 | suffixes = np.flip(prefixes, axis=1) 1040 | sequences = np.concatenate([prefixes, suffixes], axis=1) 1041 | targets = _make_prediction_sequence(input_sequence=sequences, lookahead=1) 1042 | 1043 | input_sequences = np.expand_dims(sequences, axis=2) 1044 | target_sequences = np.expand_dims(targets, axis=2) 1045 | 1046 | logging.info( 1047 | f"Fixed palindrome: {len(unique_prefixes)}/{len(prefixes_non_unique)} unique sequences" 1048 | ) 1049 | 1050 | full_corpus = optimize_for_feeding( 1051 | Corpus( 1052 | name=f"palindrome_binary_fixed_length_batch_{batch_size}_length_{sequence_length}", 1053 | input_sequence=input_sequences, 1054 | target_sequence=target_sequences, 1055 | sample_weights=weights, 1056 | ) 1057 | ) 1058 | 1059 | train, test = split_train_test(full_corpus, train_ratio=train_set_ratio) 1060 | logging.info( 1061 | f"Train size: {train.input_sequence.shape[0]}, test size: {test.input_sequence.shape[0]}" 1062 | ) 1063 | test = dataclasses.replace( 1064 | test, optimal_d_given_g=_optimal_d_g_for_fixed_palindrome(test) 1065 | ) 1066 | return dataclasses.replace( 1067 | train, 1068 | test_corpus=test, 1069 | optimal_d_given_g=_optimal_d_g_for_fixed_palindrome(train), 1070 | ) 1071 | 1072 | 1073 | def _make_an_bn_square_corpus(n_values: Tuple[int, ...], prior: float): 1074 | start_end_of_sequence_symbol = 0 1075 | max_n = max(n_values) 1076 | max_sequence_length = max_n + (max_n ** 2) + 1 1077 | 1078 | n_values_counts = collections.Counter(n_values) 1079 | unique_n_values, n_values_weights = tuple(zip(*n_values_counts.items())) 1080 | 1081 | inputs = np.empty((len(unique_n_values), max_sequence_length)) 1082 | targets = np.empty_like(inputs) 1083 | inputs.fill(MASK_VALUE) 1084 | targets.fill(MASK_VALUE) 1085 | 1086 | for b, n in enumerate(unique_n_values): 1087 | input_seq = [start_end_of_sequence_symbol] + ([1] * n) + ([2] * n ** 2) 1088 | target_seq = input_seq[1:] + [start_end_of_sequence_symbol] 1089 | inputs[b, : len(input_seq)] = input_seq 1090 | targets[b, : len(input_seq)] = target_seq 1091 | 1092 | corpus = make_one_hot_corpus( 1093 | f"an_bn_square_batch_{len(n_values)}_p_{prior}", 1094 | inputs, 1095 | targets, 1096 | num_input_classes=3, 1097 | num_target_classes=3, 1098 | vocabulary=_make_identical_input_output_vocabulary(alphabet=("#", "a", "b")), 1099 | weights=n_values_weights, 1100 | ) 1101 | return dataclasses.replace( 1102 | corpus, 1103 | optimal_d_given_g=_get_ain_bjn_ckn_dtn_optimal_d_given_g(prior, n_values), 1104 | ) 1105 | 1106 | 1107 | def make_an_bn_square(batch_size: int, prior: float) -> Corpus: 1108 | training_n_values = tuple(np.random.geometric(p=prior, size=batch_size)) 1109 | training_corpus = _make_an_bn_square_corpus(training_n_values, prior) 1110 | 1111 | max_training_n = max(training_n_values) 1112 | test_n_values = tuple(range(max_training_n + 1, max_training_n + 11)) 1113 | test_corpus = _make_an_bn_square_corpus(test_n_values, prior) 1114 | return dataclasses.replace(training_corpus, test_corpus=test_corpus) 1115 | 1116 | 1117 | def _make_identical_input_output_vocabulary(alphabet: Tuple[Text, ...]) -> _Vocabulary: 1118 | # Create class index to symbol mapping, assuming inputs and outputs are identical and ordered identically. 1119 | class_to_symbol = {idx: alphabet[idx] for idx in range(len(alphabet))} 1120 | class_to_symbol.update( 1121 | {idx + len(alphabet): symbol for idx, symbol in class_to_symbol.items()} 1122 | ) 1123 | return class_to_symbol 1124 | 1125 | 1126 | def _get_ain_bjn_ckn_dtn_optimal_d_given_g(prior, n_values) -> float: 1127 | return -np.sum( 1128 | [(n - 1) * (np.log2(1 - prior)) + np.log2(prior) for n in n_values] 1129 | ).item() 1130 | 1131 | 1132 | def get_num_chars_in_corpus(corpus: Corpus) -> int: 1133 | non_masked = ~np.all(is_masked(corpus.input_sequence), axis=-1) 1134 | num_chars_per_row = np.sum(non_masked, axis=1) 1135 | if corpus.sample_weights: 1136 | total_chars = np.dot(num_chars_per_row, corpus.sample_weights) 1137 | else: 1138 | total_chars = np.sum(num_chars_per_row) 1139 | return total_chars.item() 1140 | 1141 | 1142 | def make_inputs_counter(num_inputs: int, num_ones: int, batch_size: int): 1143 | # From Schmidhuber (1997) -- network's goal is to output the number of ones in the input. 1144 | # The optimal solution is to set all weights to 1. 1145 | inputs = np.zeros((batch_size, 1, num_inputs), dtype=utils.FLOAT_DTYPE) 1146 | for b in range(batch_size): 1147 | idxs = random.choices(range(num_inputs), k=num_ones) 1148 | inputs[b, idxs] = 1.0 1149 | 1150 | targets = np.ones((batch_size, 1, 1)) * num_ones 1151 | 1152 | return Corpus( 1153 | name=f"inputs_counter", input_sequence=inputs, target_sequence=targets, 1154 | ) 1155 | 1156 | pass 1157 | 1158 | 1159 | def _make_ain_bjn_ckn_dtn_corpus( 1160 | n_values: Tuple[int, ...], 1161 | multipliers: Tuple[int, ...], 1162 | prior: float, 1163 | sort_by_length: bool, 1164 | ) -> Corpus: 1165 | # Create a corpus of a^in, b^jn, c^kn, d^tn, multipliers = [i,j,k,n]. 1166 | max_n = max(n_values) 1167 | max_sequence_length = (max_n * sum(multipliers)) + 1 1168 | 1169 | start_end_of_sequence_symbol = 0 # Using same symbol for start/end of sequence, as in Schmidhuber et al. (2001). 1170 | 1171 | n_values_counts = collections.Counter(n_values) 1172 | n_value_counts_items = tuple(n_values_counts.items()) 1173 | if sort_by_length: 1174 | n_value_counts_items = sorted(n_value_counts_items, reverse=True) 1175 | 1176 | unique_n_values, n_values_weights = tuple(zip(*n_value_counts_items)) 1177 | 1178 | inputs = np.empty((len(unique_n_values), max_sequence_length)) 1179 | targets = np.empty_like(inputs) 1180 | inputs.fill(MASK_VALUE) 1181 | targets.fill(MASK_VALUE) 1182 | 1183 | for b, n in enumerate(unique_n_values): 1184 | input_seq = ( 1185 | [start_end_of_sequence_symbol] 1186 | + ([1] * n * multipliers[0]) 1187 | + ([2] * n * multipliers[1]) 1188 | + ([3] * n * multipliers[2]) 1189 | + ([4] * n * multipliers[3]) 1190 | ) 1191 | target_seq = input_seq[1:] + [start_end_of_sequence_symbol] 1192 | inputs[b, : len(input_seq)] = input_seq 1193 | targets[b, : len(input_seq)] = target_seq 1194 | 1195 | name = f"a{multipliers[0]}n_b{multipliers[1]}n_c{multipliers[2]}n_d{multipliers[3]}n__p_{prior}__batch_{len(n_values)}" 1196 | num_input_classes = sum([1 for x in multipliers if x != 0]) + 1 1197 | 1198 | alphabet = ("#", "a", "b", "c", "d")[:num_input_classes] 1199 | vocabulary = _make_identical_input_output_vocabulary(alphabet) 1200 | 1201 | corpus = make_one_hot_corpus( 1202 | name, 1203 | inputs, 1204 | targets, 1205 | num_input_classes=num_input_classes, 1206 | num_target_classes=num_input_classes, 1207 | weights=n_values_weights, 1208 | vocabulary=vocabulary, 1209 | ) 1210 | return dataclasses.replace( 1211 | corpus, 1212 | optimal_d_given_g=_get_ain_bjn_ckn_dtn_optimal_d_given_g(prior, n_values), 1213 | # TODO: this assumes no empty sequences in corpus. 1214 | deterministic_steps_mask=(~is_masked(inputs)) & (inputs != 1), 1215 | ) 1216 | 1217 | 1218 | def make_ain_bjn_ckn_dtn( 1219 | batch_size: int, 1220 | prior: float, 1221 | multipliers: Tuple[int, ...], 1222 | sort_by_length: bool = False, 1223 | ) -> Corpus: 1224 | training_n_values = tuple(np.random.geometric(p=prior, size=batch_size)) 1225 | training_corpus = _make_ain_bjn_ckn_dtn_corpus( 1226 | training_n_values, multipliers, prior, sort_by_length 1227 | ) 1228 | 1229 | max_training_n = max(training_n_values) 1230 | test_n_values = tuple(range(max_training_n + 1, max_training_n + 1001)) 1231 | test_corpus = _make_ain_bjn_ckn_dtn_corpus( 1232 | test_n_values, multipliers, prior, sort_by_length 1233 | ) 1234 | test_corpus = dataclasses.replace(test_corpus, name=f"{test_corpus.name}_test") 1235 | 1236 | logging.info(f"Created corpus {training_corpus.name}") 1237 | logging.info(f"Max n in training set: {max_training_n}") 1238 | logging.info(f"Optimal training set D:G: {training_corpus.optimal_d_given_g:,.2f}") 1239 | logging.info(f"Optimal test set D:G: {test_corpus.optimal_d_given_g:,.2f}") 1240 | 1241 | return dataclasses.replace(training_corpus, test_corpus=test_corpus) 1242 | 1243 | 1244 | def _get_an_bm_cn_plus_m_corpus_optimal_d_given_g( 1245 | n_plus_m_values: Tuple[int, ...], prior: float 1246 | ) -> float: 1247 | return -np.sum( 1248 | [ 1249 | ((n_plus_m - 2) * np.log2(1 - prior)) + (2 * np.log2(prior)) 1250 | for n_plus_m in n_plus_m_values 1251 | ] 1252 | ).item() 1253 | 1254 | 1255 | def _make_an_bm_cn_plus_m_corpus( 1256 | n_values: Tuple[int, ...], 1257 | m_values: Tuple[int, ...], 1258 | prior: float, 1259 | sort_by_length: bool, 1260 | ) -> Corpus: 1261 | sum_values = tuple(np.add(n_values, m_values)) 1262 | start_end_of_sequence_symbol = 0 1263 | max_sequence_length = 2 * max(sum_values) + 1 1264 | 1265 | n_m_values_counts = collections.Counter(zip(n_values, m_values)) 1266 | n_m_values_counts_items = tuple(n_m_values_counts.items()) 1267 | if sort_by_length: 1268 | n_m_values_counts_items = sorted( 1269 | n_m_values_counts_items, key=lambda x: sum(x[0]), reverse=True 1270 | ) 1271 | 1272 | unique_n_m_values, n_m_values_weights = tuple(zip(*n_m_values_counts_items)) 1273 | 1274 | inputs = np.empty((len(unique_n_m_values), max_sequence_length)) 1275 | targets = np.empty_like(inputs) 1276 | inputs.fill(MASK_VALUE) 1277 | targets.fill(MASK_VALUE) 1278 | 1279 | for b in range(len(unique_n_m_values)): 1280 | n, m = unique_n_m_values[b] 1281 | input_seq = ( 1282 | [start_end_of_sequence_symbol] + ([1] * n) + ([2] * m) + ([3] * (n + m)) 1283 | ) 1284 | target_seq = input_seq[1:] + [start_end_of_sequence_symbol] 1285 | inputs[b, : len(input_seq)] = input_seq 1286 | targets[b, : len(input_seq)] = target_seq 1287 | 1288 | vocabulary = _make_identical_input_output_vocabulary(alphabet=("#", "a", "b", "c")) 1289 | 1290 | corpus = make_one_hot_corpus( 1291 | f"an_bm_cn_plus_m__batch_{len(n_values)}_p_{prior}", 1292 | inputs, 1293 | targets, 1294 | num_input_classes=4, 1295 | num_target_classes=4, 1296 | weights=n_m_values_weights, 1297 | vocabulary=vocabulary, 1298 | ) 1299 | 1300 | return dataclasses.replace( 1301 | corpus, 1302 | optimal_d_given_g=_get_an_bm_cn_plus_m_corpus_optimal_d_given_g( 1303 | sum_values, prior 1304 | ), 1305 | # TODO: this assumes no empty sequences in corpus. 1306 | deterministic_steps_mask=(~is_masked(inputs)) & (inputs != 1) & (inputs != 2), 1307 | ) 1308 | 1309 | 1310 | def make_an_bm_cn_plus_m( 1311 | batch_size: int, prior: float, sort_by_length: bool = False, 1312 | ) -> Corpus: 1313 | training_n_values = tuple(np.random.geometric(p=prior, size=batch_size)) 1314 | training_m_values = tuple(np.random.geometric(p=prior, size=batch_size)) 1315 | 1316 | training_corpus = _make_an_bm_cn_plus_m_corpus( 1317 | training_n_values, training_m_values, prior, sort_by_length 1318 | ) 1319 | max_n = max(training_n_values) 1320 | max_m = max(training_m_values) 1321 | max_training_n_or_m = max(max_n, max_m) 1322 | 1323 | test_n_values, test_m_values = zip( 1324 | *itertools.product( 1325 | range(max_training_n_or_m + 1, max_training_n_or_m + 50), repeat=2 1326 | ) 1327 | ) 1328 | 1329 | test_corpus = _make_an_bm_cn_plus_m_corpus( 1330 | test_n_values, test_m_values, prior, sort_by_length 1331 | ) 1332 | test_corpus = dataclasses.replace(test_corpus, name=f"{test_corpus.name}_test") 1333 | 1334 | logging.info(f"Created corpus {training_corpus.name}") 1335 | logging.info(f"Max n in training: {max_n}") 1336 | logging.info(f"Max m in training: {max_m}") 1337 | logging.info(f"Optimal training set D:G: {training_corpus.optimal_d_given_g:,.2f}") 1338 | logging.info(f"Optimal test set D:G: {test_corpus.optimal_d_given_g:,.2f}") 1339 | logging.info(f"Training set dimensions: {training_corpus.input_sequence.shape}") 1340 | logging.info(f"Test set dimensions: {test_corpus.input_sequence.shape}") 1341 | 1342 | return dataclasses.replace(training_corpus, test_corpus=test_corpus) 1343 | 1344 | 1345 | def _calculate_nesting_depths(corpus): 1346 | input_classes = corpus.input_sequence.argmax(axis=-1) 1347 | opening_classes = {1, 3} 1348 | closing_classes = {2, 4} 1349 | depths = [] 1350 | for b in range(input_classes.shape[0]): 1351 | depth = 0 1352 | max_depth = 0 1353 | for i in range(input_classes[b].shape[0]): 1354 | if np.all(is_masked(corpus.input_sequence[b, i])): 1355 | break 1356 | if input_classes[b, i] in opening_classes: 1357 | depth += 1 1358 | max_depth = max(max_depth, depth) 1359 | elif input_classes[b, i] in closing_classes: 1360 | depths.append(depth) 1361 | depth -= 1 1362 | 1363 | depth_counts = dict(collections.Counter(depths).most_common()) 1364 | max_depth = max(depths) 1365 | logging.info(f"Max depth in corpus: {max_depth}") 1366 | logging.info(f"Depth counts: {depth_counts}") 1367 | 1368 | 1369 | def _get_dyck_1_symbol_counts(corpus) -> Tuple[int, int, int]: 1370 | # Masks (nans) become 0 here after argmax, but we ignore it since we count only 1's and 2's. 1371 | input_classes = corpus.input_sequence.argmax(axis=-1) 1372 | 1373 | num_ends_of_sequence = np.sum(corpus.sample_weights).item() 1374 | num_opening_brackets = np.dot( 1375 | corpus.sample_weights, np.sum(input_classes == 1, axis=1) 1376 | ).item() 1377 | num_closing_brackets = np.dot( 1378 | corpus.sample_weights, np.sum(input_classes == 2, axis=1) 1379 | ).item() 1380 | 1381 | return num_ends_of_sequence, num_opening_brackets, num_closing_brackets 1382 | 1383 | 1384 | def get_dyck_1_target_probabs(corpus, prior) -> np.ndarray: 1385 | input_classes = np.argmax(corpus.input_sequence, axis=-1).astype(np.float64) 1386 | input_classes[~corpus.input_mask] = np.nan 1387 | 1388 | target_probabs = np.zeros_like(corpus.target_sequence) 1389 | target_probabs[~corpus.input_mask] = np.nan 1390 | 1391 | for b in range(input_classes.shape[0]): 1392 | open_brackets = 0 1393 | for i in range(input_classes.shape[1]): 1394 | if np.isnan(input_classes[b, i]): 1395 | break 1396 | if input_classes[b, i] == 0: 1397 | target_probabs[b, i, 0] = 1 - prior 1398 | target_probabs[b, i, 1] = prior 1399 | elif input_classes[b, i] == 1: 1400 | open_brackets += 1 1401 | target_probabs[b, i, 2] = 1 - prior 1402 | target_probabs[b, i, 1] = prior 1403 | elif input_classes[b, i] == 2: 1404 | open_brackets -= 1 1405 | if open_brackets == 0: 1406 | target_probabs[b, i, 0] = 1 - prior 1407 | target_probabs[b, i, 1] = prior 1408 | else: 1409 | target_probabs[b, i, 1] = prior 1410 | target_probabs[b, i, 2] = 1 - prior 1411 | 1412 | return target_probabs 1413 | 1414 | 1415 | def _make_dyck_n_corpus( 1416 | batch_size: int, 1417 | nesting_probab: float, 1418 | n: int, 1419 | max_sequence_length: Optional[int] = None, 1420 | sort_by_length: bool = False, 1421 | ): 1422 | bracket_pairs = ( 1423 | ("[", "]"), 1424 | ("(", ")"), 1425 | ("{", "}"), 1426 | ("<", ">"), 1427 | ("⟦", "⟧"), 1428 | ("〔", " 〕"), 1429 | ) 1430 | single_nesting_probab = nesting_probab / n 1431 | 1432 | bracket_derivations = [] 1433 | for i in range(n): 1434 | bracket_derivations.append( 1435 | # e.g. `S -> ("[", S, "]", S)`. 1436 | ( 1437 | (bracket_pairs[i][0], "S", bracket_pairs[i][1], "S",), 1438 | single_nesting_probab, 1439 | ) 1440 | ) 1441 | 1442 | pcfg = {"S": tuple(bracket_derivations) + (("", 1 - nesting_probab),)} 1443 | corpus = _make_corpus_from_pcfg( 1444 | name=f"dyck_{n}__batch_{batch_size}__p_{nesting_probab}", 1445 | pcfg=pcfg, 1446 | batch_size=batch_size, 1447 | max_derivation_length=max_sequence_length, 1448 | sort_by_length=sort_by_length, 1449 | ) 1450 | 1451 | if n == 1: 1452 | ( 1453 | num_ends_of_sequence, 1454 | num_opening_brackets, 1455 | num_closing_brackets, 1456 | ) = _get_dyck_1_symbol_counts(corpus) 1457 | 1458 | optimal_d_given_g = ( 1459 | -1 1460 | * ( 1461 | (num_ends_of_sequence * np.log2(1 - nesting_probab)) 1462 | + (num_opening_brackets * np.log2(nesting_probab)) 1463 | + (num_closing_brackets * np.log2(1 - nesting_probab)) 1464 | ).item() 1465 | ) 1466 | corpus = dataclasses.replace(corpus, optimal_d_given_g=optimal_d_given_g) 1467 | 1468 | if n == 2: 1469 | import manual_nets 1470 | import network 1471 | 1472 | stack_net = manual_nets.make_emmanuel_dyck_2_network( 1473 | nesting_probab=nesting_probab 1474 | ) 1475 | stack_net = network.calculate_fitness( 1476 | stack_net, optimize_for_feeding(corpus), config=_DEFAULT_CONFIG, 1477 | ) 1478 | corpus = dataclasses.replace( 1479 | corpus, optimal_d_given_g=stack_net.fitness.data_encoding_length 1480 | ) 1481 | logging.info(f"Optimal |D:G|: {corpus.optimal_d_given_g:,.2f}") 1482 | 1483 | _calculate_nesting_depths(corpus) 1484 | return corpus 1485 | 1486 | 1487 | def _get_sequence_strings(corpus) -> FrozenSet[Text]: 1488 | unique_sequences = set() 1489 | for b in range(corpus.input_sequence.shape[0]): 1490 | seq = corpus.input_sequence[b] 1491 | seq_str = str(np.argmax(seq, axis=-1).tolist()) 1492 | unique_sequences.add(seq_str) 1493 | return frozenset(unique_sequences) 1494 | 1495 | 1496 | def make_dyck_n( 1497 | batch_size: int, 1498 | nesting_probab: float, 1499 | n: int, 1500 | max_sequence_length: Optional[int] = None, 1501 | sort_by_length: bool = False, 1502 | ) -> Corpus: 1503 | training_corpus = _make_dyck_n_corpus( 1504 | batch_size=batch_size, 1505 | nesting_probab=nesting_probab, 1506 | n=n, 1507 | max_sequence_length=max_sequence_length, 1508 | sort_by_length=sort_by_length, 1509 | ) 1510 | test_corpus = _make_dyck_n_corpus( 1511 | batch_size=50_000, 1512 | nesting_probab=nesting_probab, 1513 | n=n, 1514 | max_sequence_length=max_sequence_length, 1515 | sort_by_length=sort_by_length, 1516 | ) 1517 | 1518 | training_sequences = _get_sequence_strings(training_corpus) 1519 | test_sequences = _get_sequence_strings(test_corpus) 1520 | shared = training_sequences & test_sequences 1521 | 1522 | logging.info( 1523 | f"Dyck-{n} Sequences shared between train and test: {len(shared)} ({len(shared)/len(test_sequences):.2f} of test)" 1524 | ) 1525 | 1526 | return dataclasses.replace(training_corpus, test_corpus=test_corpus) 1527 | 1528 | 1529 | def split_train_test(corpus: Corpus, train_ratio: float) -> Tuple[Corpus, Corpus]: 1530 | batch_size = corpus.input_sequence.shape[0] 1531 | train_size = math.floor(train_ratio * batch_size) 1532 | shuffled_idxs = np.random.permutation(batch_size) 1533 | train_idxs = sorted(shuffled_idxs[:train_size]) 1534 | test_idxs = sorted(shuffled_idxs[train_size:]) 1535 | return _split_corpus(corpus, batch_idxs_per_corpus=(train_idxs, test_idxs)) 1536 | 1537 | 1538 | def _split_corpus( 1539 | corpus: Corpus, batch_idxs_per_corpus: Tuple[np.ndarray, ...] 1540 | ) -> Tuple[Corpus, ...]: 1541 | new_corpora = [] 1542 | for batch_idxs in batch_idxs_per_corpus: 1543 | max_sample_length = max(np.where(corpus.input_mask[batch_idxs])[-1]) + 1 1544 | 1545 | inputs = corpus.input_sequence[batch_idxs, :, :max_sample_length] 1546 | targets = corpus.target_sequence[batch_idxs, :, :max_sample_length] 1547 | target_mask = corpus.targets_mask[batch_idxs, :, :max_sample_length] 1548 | input_mask = corpus.input_mask[batch_idxs, :max_sample_length] 1549 | 1550 | # TODO: recalculating indices for this is hard, need to change the data format. 1551 | inputs_per_time_step = None 1552 | 1553 | if corpus.sample_weights: 1554 | sample_weights = tuple(corpus.sample_weights[i] for i in batch_idxs) 1555 | else: 1556 | sample_weights = None 1557 | 1558 | new_corpora.append( 1559 | Corpus( 1560 | name=corpus.name, 1561 | input_sequence=inputs, 1562 | target_sequence=targets, 1563 | input_values_per_time_step=inputs_per_time_step, 1564 | input_mask=input_mask, 1565 | targets_mask=target_mask, 1566 | sample_weights=sample_weights, 1567 | ) 1568 | ) 1569 | 1570 | return tuple(new_corpora) 1571 | 1572 | 1573 | def split_to_mini_batches( 1574 | corpus: Corpus, mini_batch_size: Optional[int] 1575 | ) -> Tuple[Corpus, ...]: 1576 | if mini_batch_size is None: 1577 | return (corpus,) 1578 | 1579 | num_samples = corpus.input_sequence.shape[0] 1580 | mini_batch_idxs = tuple( 1581 | np.split(np.arange(num_samples), np.arange(0, num_samples, mini_batch_size)[1:]) 1582 | ) 1583 | 1584 | return _split_corpus(corpus, mini_batch_idxs) 1585 | -------------------------------------------------------------------------------- /dfa.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | import pathlib 3 | import random 4 | from typing import Dict, Set, Text 5 | 6 | import numpy as np 7 | 8 | _State = int 9 | _Label = Text 10 | _StateTransitions = Dict[_Label, _State] 11 | 12 | END_OF_SEQUENCE = "#" 13 | 14 | 15 | class DFA: 16 | def __init__( 17 | self, 18 | transitions: Dict[_State, _StateTransitions], 19 | accepting_states=Set[_State], 20 | ): 21 | self._states: Set[_State] = set(transitions.keys()) | set( 22 | itertools.chain(*[tuple(x.values()) for x in transitions.values()]) 23 | ) 24 | self._transitions: Dict[_State, _StateTransitions] = transitions 25 | self._accepting_states: Set[_State] = accepting_states 26 | 27 | def generate_string(self): 28 | curr_state = 0 29 | string = "" 30 | while curr_state not in self._accepting_states: 31 | char, curr_state = random.choice( 32 | tuple(self._transitions[curr_state].items()) 33 | ) 34 | string += char 35 | return string 36 | 37 | def visualize(self, name: Text): 38 | dot = "digraph G {\n" "colorscheme=X11\n" 39 | # inputs and outputs 40 | for state in sorted(self._states): 41 | if state in self._accepting_states: 42 | style = "peripheries=2" 43 | else: 44 | style = "" 45 | description = f'[label="q{state}" {style}]' 46 | 47 | dot += f"{state} {description}\n" 48 | 49 | for state, transitions in self._transitions.items(): 50 | for label, neighbor in transitions.items(): 51 | dot += f'{state} -> {neighbor} [ label="{label}" ];\n' 52 | 53 | dot += "}" 54 | 55 | path = pathlib.Path(f"dot_files/dfa_{name}.dot") 56 | path.parent.mkdir(parents=True, exist_ok=True) 57 | with path.open("w") as f: 58 | f.write(dot) 59 | 60 | def get_optimal_data_given_grammar_for_dfa( 61 | self, input_sequence: np.ndarray 62 | ) -> float: 63 | total_d_g = 0.0 64 | curr_state = 0 65 | 66 | for b in range(input_sequence.shape[0]): 67 | for i in range(input_sequence.shape[1]): 68 | curr_vec = input_sequence[b, i] 69 | if np.all(np.isnan(curr_vec)): 70 | # Sequence is masked until its end. 71 | break 72 | if curr_vec.shape[0] == 1: 73 | curr_val = curr_vec[0] 74 | else: 75 | curr_val = curr_vec.argmax() 76 | 77 | curr_transitions = self._transitions[curr_state] 78 | total_d_g += -np.log2(1 / len(curr_transitions)) 79 | 80 | curr_char = {0: "0", 1: "1", 2: END_OF_SEQUENCE}[curr_val] 81 | curr_state = curr_transitions[curr_char] 82 | 83 | if curr_state in self._accepting_states: 84 | curr_state = 0 85 | 86 | return total_d_g 87 | -------------------------------------------------------------------------------- /genetic_algorithm.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import datetime 3 | import itertools 4 | import logging 5 | import math 6 | import multiprocessing 7 | import operator 8 | import os 9 | import pathlib 10 | import pickle 11 | import queue 12 | import random 13 | import shutil 14 | import uuid 15 | from typing import FrozenSet, Iterator, List, Optional, Text, Tuple 16 | 17 | import cachetools 18 | import numpy as np 19 | from mpi4py import MPI 20 | 21 | import configuration 22 | import corpora 23 | import network 24 | import utils 25 | 26 | _DEBUG_MODE = False 27 | 28 | _FITNESS_CACHE = {} 29 | _ISLAND_MIGRATIONS_PATH = pathlib.Path("/tmp/mdlnn/island_migrations/") 30 | 31 | _GET_NET_MDL = operator.attrgetter("fitness.mdl") 32 | 33 | _Population = List[network.Network] 34 | 35 | 36 | _NETWORKS_CACHE = cachetools.LRUCache(maxsize=1_000_000) 37 | 38 | _MPI_COMMUNICATOR = MPI.COMM_WORLD 39 | _MPI_RANK = _MPI_COMMUNICATOR.Get_rank() 40 | _MPI_MIGRANTS_BUFFER_SIZE = 10_000_000 41 | 42 | 43 | @dataclasses.dataclass(frozen=True) 44 | class _Tournament: 45 | winner_idx: int 46 | loser_idx: int 47 | 48 | 49 | @cachetools.cached(_NETWORKS_CACHE, key=lambda net, corpus, config: hash(net)) 50 | def _evaluate_net_cached( 51 | net: network.Network, 52 | corpus: corpora.Corpus, 53 | config: configuration.SimulationConfig, 54 | ) -> network.Fitness: 55 | return network.calculate_fitness(net, corpus, config).fitness 56 | 57 | 58 | def get_migration_path(simulation_id) -> pathlib.Path: 59 | return _ISLAND_MIGRATIONS_PATH.joinpath(simulation_id) 60 | 61 | 62 | def _make_migration_target_island_generator( 63 | island_num: int, total_islands: int, 64 | ): 65 | yield from itertools.cycle( 66 | itertools.chain(range(island_num + 1, total_islands), range(island_num)) 67 | ) 68 | 69 | 70 | def get_migrants_through_mpi() -> Optional[_Population]: 71 | migrant_batches = [] 72 | 73 | while True: 74 | has_awaiting_migrants = _MPI_COMMUNICATOR.iprobe(tag=utils.MPI_MIGRANTS_TAG) 75 | if not has_awaiting_migrants: 76 | break 77 | migrants = _MPI_COMMUNICATOR.recv(bytearray(_MPI_MIGRANTS_BUFFER_SIZE)) 78 | migrant_batches.append(migrants) 79 | 80 | if not migrant_batches: 81 | return None 82 | 83 | return min(migrant_batches, key=_mean_population_fitness) 84 | 85 | 86 | def _get_migrants_from_file( 87 | simulation_id: Text, island_num: int, file_lock: multiprocessing.Lock 88 | ) -> Optional[_Population]: 89 | migrants_filename = get_target_island_filename(island_num) 90 | incoming_migrants_path = get_migration_path(simulation_id).joinpath( 91 | migrants_filename 92 | ) 93 | 94 | lock_start = datetime.datetime.now() 95 | 96 | with file_lock: 97 | if not incoming_migrants_path.exists(): 98 | return None 99 | 100 | with incoming_migrants_path.open("rb") as f: 101 | incoming_migrants = pickle.load(f) 102 | incoming_migrants_path.unlink() 103 | 104 | lock_end = datetime.datetime.now() 105 | lock_delta = lock_end - lock_start 106 | logging.info( 107 | f"Incoming lock took {lock_delta.seconds}.{str(lock_delta.microseconds)[:2]} seconds" 108 | ) 109 | return incoming_migrants 110 | 111 | 112 | def get_target_island_filename(target_island: int) -> Text: 113 | return f"island_{target_island}_incoming_migrants" 114 | 115 | 116 | def _make_random_population(config, input_size, output_size) -> _Population: 117 | return [ 118 | network.make_random_net( 119 | input_size=input_size, 120 | output_size=output_size, 121 | allowed_activations=config.allowed_activations, 122 | start_smooth=config.start_smooth, 123 | ) 124 | for _ in range(config.population_size) 125 | ] 126 | 127 | 128 | def _mean_population_fitness(population: _Population) -> float: 129 | return np.mean([x.fitness.mdl for x in population]).item() 130 | 131 | 132 | def _should_migrate( 133 | outgoing_migrants: _Population, awaiting_migrants_at_target: _Population 134 | ) -> bool: 135 | mean_awaiting_fitness = _mean_population_fitness(awaiting_migrants_at_target) 136 | mean_outgoing_fitness = _mean_population_fitness(outgoing_migrants) 137 | logging.info(f"Awaiting mean fitness: {mean_awaiting_fitness:.2f}") 138 | logging.info(f"Outgoing mean fitness: {mean_outgoing_fitness:.2f}") 139 | return mean_outgoing_fitness < mean_awaiting_fitness 140 | 141 | 142 | def _send_migrants_through_mpi(migrants: _Population, target_island: int) -> bool: 143 | # We can't use _should_migrate() here because in MPI we can't override the target's buffer, so filtering the best migrants is done at the receiving side. 144 | _MPI_COMMUNICATOR.send(migrants, dest=target_island, tag=utils.MPI_MIGRANTS_TAG) 145 | return True 146 | 147 | 148 | def _send_migrants_through_file( 149 | migrants: _Population, 150 | target_island: int, 151 | file_lock: multiprocessing.Lock, 152 | simulation_id: Text, 153 | ) -> bool: 154 | lock_start = datetime.datetime.now() 155 | target_island_filename = get_target_island_filename(target_island) 156 | with file_lock: 157 | outgoing_path = get_migration_path(simulation_id).joinpath( 158 | target_island_filename 159 | ) 160 | 161 | if outgoing_path.exists(): 162 | with outgoing_path.open("rb") as f: 163 | awaiting_migrants_at_target: _Population = pickle.load(f) 164 | should_migrate = _should_migrate( 165 | outgoing_migrants=migrants, 166 | awaiting_migrants_at_target=awaiting_migrants_at_target, 167 | ) 168 | else: 169 | should_migrate = True 170 | 171 | if should_migrate: 172 | with outgoing_path.open("wb") as f: 173 | pickle.dump(migrants, f) 174 | 175 | lock_end = datetime.datetime.now() 176 | lock_delta = lock_end - lock_start 177 | logging.info( 178 | f"Outgoing file lock took {lock_delta.seconds}.{str(lock_delta.microseconds)[:2]} seconds" 179 | ) 180 | 181 | return should_migrate 182 | 183 | 184 | def _simulation_exists(simulation_id: Text) -> bool: 185 | return get_migration_path(simulation_id).exists() 186 | 187 | 188 | def remove_simulation_directory(simulation_id: Text): 189 | shutil.rmtree(get_migration_path(simulation_id), ignore_errors=True) 190 | logging.info(f"Removed directory {simulation_id} from local storage.") 191 | 192 | 193 | def verify_existing_simulation_override(simulation_id: Text): 194 | if _simulation_exists(simulation_id): 195 | logging.info( 196 | f"Ddirectory for simulation {simulation_id} already exists. Re-run using `--override` flag to delete the previous run.\n" 197 | ) 198 | exit() 199 | 200 | 201 | def _select_best_no_repetition_by_arch_uniqueness( 202 | population: _Population, k: int 203 | ) -> _Population: 204 | return sorted(set(population), key=_GET_NET_MDL)[:k] 205 | 206 | 207 | def _select_best_no_repetition(population: _Population, k: int) -> _Population: 208 | population_fitness = set() 209 | individuals_no_repetition = [] 210 | for net in population: 211 | fitness = net.fitness.mdl 212 | if fitness not in population_fitness: 213 | individuals_no_repetition.append(net) 214 | population_fitness.add(fitness) 215 | return sorted(individuals_no_repetition, key=_GET_NET_MDL)[:k] 216 | 217 | 218 | def _get_worst_individuals_idxs(population: _Population, n: int) -> List[int]: 219 | fitness = [x.fitness.mdl for x in population] 220 | return np.argsort(fitness)[-n:].tolist() 221 | 222 | 223 | def _get_elite_idxs(population: _Population, elite_ratio: float) -> FrozenSet[int]: 224 | elite_size = math.ceil(len(population) * elite_ratio) 225 | fitness = [x.fitness.mdl for x in population] 226 | argsort = np.argsort(fitness) 227 | seen = set() 228 | best_idxs = set() 229 | for i in argsort: 230 | if len(best_idxs) == elite_size: 231 | break 232 | net = population[i] 233 | if net.fitness.mdl in seen or np.isinf(net.fitness.mdl): 234 | continue 235 | seen.add(net.fitness.mdl) 236 | best_idxs.add(i) 237 | return frozenset(best_idxs) 238 | 239 | 240 | def _get_elite(elite_ratio: float, population: _Population) -> _Population: 241 | elite_size = math.ceil(len(population) * elite_ratio) 242 | return _select_best_no_repetition(population, elite_size) 243 | 244 | 245 | def _tournament_selection(population: _Population, tournament_size: int) -> _Tournament: 246 | # Returns (winner index, loser index). 247 | tournament_idxs = random.sample(range(len(population)), tournament_size) 248 | tournament_nets = tuple(population[i] for i in tournament_idxs) 249 | nets_and_idxs = tuple(zip(tournament_nets, tournament_idxs)) 250 | 251 | if len(set(x.fitness.mdl for x in tournament_nets)) == 1: 252 | # MDL Tie. 253 | if np.isinf(tournament_nets[0].fitness.mdl): 254 | # Break |D:G| infinity ties using |G|. 255 | argsort_by_d_g = tuple( 256 | np.argsort([x.fitness.grammar_encoding_length for x in tournament_nets]) 257 | ) 258 | return _Tournament( 259 | winner_idx=argsort_by_d_g[0], loser_idx=argsort_by_d_g[-1] 260 | ) 261 | return _Tournament(*tuple(random.sample(tournament_idxs, k=2))) 262 | 263 | sorted_tournament = sorted(nets_and_idxs, key=lambda x: x[0].fitness.mdl) 264 | return _Tournament( 265 | winner_idx=sorted_tournament[0][1], loser_idx=sorted_tournament[-1][1] 266 | ) 267 | 268 | 269 | def _get_population_incoming_degrees( 270 | population: _Population, edge_type: int 271 | ) -> np.ndarray: 272 | degrees = [] 273 | for net in population: 274 | (_, reverse_connections, _,) = network.get_connections_and_weights_by_edge_type( 275 | net, edge_type 276 | ) 277 | degrees += list(map(len, reverse_connections.values())) 278 | return np.array(degrees) 279 | 280 | 281 | def _initialize_population_and_generation_from_existing_simulation( 282 | config, island_num 283 | ) -> Tuple[_Population, int]: 284 | if config.migration_channel in {"mpi", "file"}: 285 | with open( 286 | f"./generations/{config.resumed_from_simulation_id}_latest_generation_island_{island_num}.pickle", 287 | "rb", 288 | ) as f: 289 | latest_generation_data = pickle.load(f) 290 | else: 291 | raise ValueError(config.migration_channel) 292 | generation = latest_generation_data["generation"] 293 | population = latest_generation_data["population"] 294 | logging.info(f"Loaded population island {island_num} from generation {generation}") 295 | return population, generation 296 | 297 | 298 | def _initialize_population_and_generation( 299 | config: configuration.SimulationConfig, 300 | island_num: int, 301 | input_size: int, 302 | output_size: int, 303 | ) -> Tuple[_Population, int]: 304 | if config.resumed_from_simulation_id is not None: 305 | return _initialize_population_and_generation_from_existing_simulation( 306 | config, island_num 307 | ) 308 | 309 | generation = 1 310 | population = _make_random_population( 311 | config=config, input_size=input_size, output_size=output_size, 312 | ) 313 | logging.debug(f"Initialized random population size {config.population_size}") 314 | 315 | if _DEBUG_MODE: 316 | [ 317 | network.visualize( 318 | population[i], f"random_initial_net_{i}__island_{island_num}", 319 | ) 320 | for i in random.sample(range(len(population)), 10) 321 | ] 322 | 323 | return population, generation 324 | 325 | 326 | def _evaluate_population( 327 | population: _Population, 328 | corpus: corpora.Corpus, 329 | config: configuration.SimulationConfig, 330 | ) -> _Population: 331 | return [ 332 | dataclasses.replace( 333 | net, fitness=_evaluate_net_cached(net=net, corpus=corpus, config=config), 334 | ) 335 | for net in population 336 | ] 337 | 338 | 339 | def _make_single_reproduction( 340 | population: _Population, 341 | elite_idxs: FrozenSet[int], 342 | corpus: corpora.Corpus, 343 | config: configuration.SimulationConfig, 344 | ) -> _Population: 345 | # Select parent(s) using tournament selection, create an offspring, replace tournament loser with offspring. 346 | p = random.random() 347 | if p < config.mutation_probab: 348 | tournament = _tournament_selection(population, config.tournament_size) 349 | parent_idx = tournament.winner_idx 350 | killed_idx = tournament.loser_idx 351 | offspring = network.mutate(population[parent_idx], config=config) 352 | else: 353 | tournament = _tournament_selection(population, config.tournament_size) 354 | offspring = population[tournament.winner_idx] 355 | killed_idx = tournament.loser_idx 356 | 357 | offspring_fitness = _evaluate_net_cached(offspring, corpus, config) 358 | offspring = dataclasses.replace(offspring, fitness=offspring_fitness) 359 | 360 | if ( 361 | killed_idx in elite_idxs 362 | and offspring.fitness.mdl >= population[killed_idx].fitness.mdl 363 | ): 364 | # Only kill a losing elite if the offspring is better. 365 | return population 366 | 367 | population[killed_idx] = offspring 368 | return population 369 | 370 | 371 | def _make_generation( 372 | population: _Population, 373 | corpus: corpora.Corpus, 374 | config: configuration.SimulationConfig, 375 | ) -> _Population: 376 | # Calculate elite once per generation for performance. 377 | elite_idxs = _get_elite_idxs(population, config.elite_ratio) 378 | for _ in range(len(population)): 379 | population = _make_single_reproduction( 380 | population=population, elite_idxs=elite_idxs, corpus=corpus, config=config 381 | ) 382 | return population 383 | 384 | 385 | def _save_generation( 386 | generation: int, 387 | population: _Population, 388 | island_num: int, 389 | config: configuration.SimulationConfig, 390 | cloud_upload_queue: queue.Queue, 391 | ): 392 | data = { 393 | "generation": generation, 394 | "population": population, 395 | "island": island_num, 396 | } 397 | if config.migration_channel in {"mpi", "file"}: 398 | path = pathlib.Path( 399 | f"./generations/{config.simulation_id}_latest_generation_island_{island_num}.pickle" 400 | ) 401 | path.parent.mkdir(parents=True, exist_ok=True) 402 | with open(path, "wb") as f: 403 | pickle.dump(data, f) 404 | 405 | 406 | def _log_generation_to_logging_process( 407 | island_num: int, 408 | generation: int, 409 | best_net: network.Network, 410 | corpus: corpora.Corpus, 411 | config: configuration.SimulationConfig, 412 | logging_queue: multiprocessing.Queue, 413 | ): 414 | if config.migration_channel == "mpi": 415 | _MPI_COMMUNICATOR.Send( 416 | np.array( 417 | [ 418 | island_num, 419 | generation, 420 | best_net.fitness.mdl, 421 | best_net.fitness.grammar_encoding_length, 422 | best_net.fitness.data_encoding_length, 423 | network.get_num_units(best_net), 424 | network.get_total_connections(best_net, include_biases=True), 425 | ] 426 | ), 427 | dest=config.num_islands, 428 | tag=utils.MPI_LOGGING_TAG, 429 | ) 430 | return 431 | 432 | stats = { 433 | "island": island_num, 434 | "generation": generation, 435 | "mdl": best_net.fitness.mdl, 436 | "|g|": best_net.fitness.grammar_encoding_length, 437 | "|d:g|": best_net.fitness.data_encoding_length, 438 | "units": network.get_num_units(best_net), 439 | "connections": network.get_total_connections(best_net, include_biases=True), 440 | "accuracy": best_net.fitness.accuracy, 441 | } 442 | 443 | logging_queue.put({"best_net": best_net, "stats": stats}) 444 | 445 | 446 | def _log_generation( 447 | population: _Population, 448 | corpus: corpora.Corpus, 449 | config: configuration.SimulationConfig, 450 | generation: int, 451 | island_num: int, 452 | generation_time_delta: datetime.timedelta, 453 | logging_queue: multiprocessing.Queue, 454 | cloud_upload_queue: queue.Queue, 455 | ): 456 | all_fitness = [x.fitness.mdl for x in population] 457 | valid_population = [ 458 | population[i] for i in range(len(population)) if not np.isinf(all_fitness[i]) 459 | ] 460 | 461 | if valid_population: 462 | best_net_idx = int(np.argmin(all_fitness)) 463 | else: 464 | best_net_idx = 0 465 | valid_population = [population[0]] 466 | 467 | best_net = population[best_net_idx] 468 | best_fitness = all_fitness[best_net_idx] 469 | 470 | valid_fitnesses = [x.fitness.mdl for x in valid_population] 471 | mean_fitness = np.mean(valid_fitnesses) 472 | fitness_std = np.std(valid_fitnesses) 473 | 474 | num_connections = [ 475 | network.get_total_connections(x, include_biases=True) for x in population 476 | ] 477 | num_connections_mean = np.mean(num_connections) 478 | num_connections_std = np.std(num_connections) 479 | num_connections_max = np.max(num_connections) 480 | 481 | num_units = [network.get_num_units(x) for x in population] 482 | num_units_mean = np.mean(num_units) 483 | num_units_std = np.std(num_units) 484 | num_units_max = np.max(num_units) 485 | 486 | incoming_forward_degrees = _get_population_incoming_degrees( 487 | population=population, edge_type=network.FORWARD_CONNECTION 488 | ) 489 | multiple_inputs_forward_degrees = incoming_forward_degrees[ 490 | incoming_forward_degrees > 1 491 | ] 492 | incoming_recurrent_degrees = _get_population_incoming_degrees( 493 | population=population, edge_type=network.RECURRENT_CONNECTION 494 | ) 495 | multiple_inputs_recurrent_degrees = incoming_recurrent_degrees[ 496 | incoming_recurrent_degrees > 1 497 | ] 498 | 499 | g_s = [x.fitness.grammar_encoding_length for x in population] 500 | mean_g = np.mean(g_s) 501 | std_g = np.std(g_s) 502 | max_g = np.max(g_s) 503 | d_g_s = [x.fitness.data_encoding_length for x in valid_population] 504 | mean_d_g = np.mean(d_g_s) 505 | std_d_g = np.std(d_g_s) 506 | max_d_g = np.max(d_g_s) 507 | mean_accuracy = np.mean([x.fitness.accuracy for x in valid_population]) 508 | 509 | all_weights = [] 510 | for x in population: 511 | all_weights.append(network.get_forward_weights(x)) 512 | all_weights.append(network.get_recurrent_weights(x)) 513 | 514 | all_weights = np.concatenate(all_weights) 515 | mean_weight = np.mean(all_weights) 516 | max_weight = np.max(all_weights) 517 | 518 | num_invalid = len(list(filter(np.isinf, [x.fitness.mdl for x in population],))) 519 | invalid_ratio = num_invalid / len(population) 520 | 521 | unique_ratio = len(set(population)) / len(population) 522 | 523 | logging.info( 524 | f"\nIsland {island_num} (pid {os.getpid()}) Generation {generation}" 525 | f"\n\tGeneration took {generation_time_delta.seconds}.{str(generation_time_delta.microseconds)[:2]} seconds" 526 | f"\n\tMean fitness: {mean_fitness:.2f} (±{fitness_std:.2f}, worst valid {np.max(valid_fitnesses):.2f}) \tBest fitness: {best_fitness:.2f}" 527 | f"\n\tMean num nodes: {num_units_mean:.2f} (±{num_units_std:.2f}, max {num_units_max}) \tMean num connections: {num_connections_mean:.2f} (±{num_connections_std:.2f}, max {num_connections_max}) \tMean G: {mean_g:.2f} (±{std_g:.2f}, max {max_g:.2f})\tMean D:G: {mean_d_g:.2f} (±{std_d_g:.2f}, max {max_d_g:.2f})" 528 | f"\n\tMean forward in degree: {np.mean(incoming_forward_degrees):.2f} (±{np.std(incoming_forward_degrees):.2f}, max {np.max(incoming_forward_degrees)}) \tMean recurrent in degree: {np.mean(incoming_recurrent_degrees):.2f} (±{np.std(incoming_recurrent_degrees):.2f}, max {np.max(incoming_recurrent_degrees) if incoming_recurrent_degrees.size else '-'})" 529 | f"\n\tMean forward in degree>1: {np.mean(multiple_inputs_forward_degrees):.2f} (±{np.std(multiple_inputs_forward_degrees):.2f}) \tMean recurrent in degree>1: {np.mean(multiple_inputs_recurrent_degrees):.2f} (±{np.std(multiple_inputs_recurrent_degrees):.2f})" 530 | f"\n\tMean weight: {mean_weight:.2f} (max {max_weight})\tMean accuracy: {mean_accuracy:.2f}\tInvalid: {invalid_ratio*100:.1f}%\tUnique: {unique_ratio*100:.1f}%" 531 | f"\n\tBest network:\n\t{network.to_string(best_net)}\n\n" 532 | ) 533 | 534 | if generation == 1 or generation % 100 == 0: 535 | network_filename = f"{config.simulation_id}__island_{island_num}__best_network" 536 | network.visualize(best_net, network_filename, class_to_label=corpus.vocabulary) 537 | network.save(best_net, network_filename) 538 | 539 | _log_generation_to_logging_process( 540 | island_num=island_num, 541 | generation=generation, 542 | best_net=best_net, 543 | corpus=corpus, 544 | config=config, 545 | logging_queue=logging_queue, 546 | ) 547 | 548 | if generation == 1 or generation % config.generation_dump_interval == 0: 549 | _save_generation( 550 | generation=generation, 551 | population=population, 552 | island_num=island_num, 553 | config=config, 554 | cloud_upload_queue=cloud_upload_queue, 555 | ) 556 | 557 | if _DEBUG_MODE and generation > 0 and generation % 5 == 0: 558 | [ 559 | network.visualize( 560 | population[x], 561 | f"random_gen_{generation}__island_{island_num}__{str(uuid.uuid1())}", 562 | ) 563 | for x in random.sample(range(len(population)), 5) 564 | ] 565 | 566 | 567 | def _send_migrants( 568 | population: _Population, 569 | island_num: int, 570 | config: configuration.SimulationConfig, 571 | target_island: int, 572 | target_process_lock: multiprocessing.Lock, 573 | cloud_upload_queue: queue.Queue, 574 | ): 575 | num_migrants = math.floor(config.migration_ratio * config.population_size) 576 | migrants = list( 577 | set( 578 | population[ 579 | _tournament_selection( 580 | population, tournament_size=config.tournament_size 581 | ).winner_idx 582 | ] 583 | for _ in range(num_migrants) 584 | ) 585 | ) 586 | if config.migration_channel == "file": 587 | did_send = _send_migrants_through_file( 588 | migrants=migrants, 589 | target_island=target_island, 590 | simulation_id=config.simulation_id, 591 | file_lock=target_process_lock, 592 | ) 593 | elif config.migration_channel == "mpi": 594 | did_send = _send_migrants_through_mpi( 595 | migrants=migrants, target_island=target_island, 596 | ) 597 | else: 598 | raise ValueError(config.migration_channel) 599 | 600 | if did_send: 601 | best_sent = min(migrants, key=_GET_NET_MDL) 602 | logging.info( 603 | f"Island {island_num} sent {len(migrants)} migrants to island {target_island} through {config.migration_channel}. Best sent: {best_sent.fitness.mdl:,.2f}" 604 | ) 605 | 606 | 607 | def _integrate_migrants( 608 | incoming_migrants: _Population, 609 | population: _Population, 610 | config: configuration.SimulationConfig, 611 | island_num: int, 612 | ) -> _Population: 613 | losing_idxs = tuple( 614 | _tournament_selection(population, config.tournament_size).loser_idx 615 | for _ in range(len(incoming_migrants)) 616 | ) 617 | prev_best_fitness = min([x.fitness.mdl for x in population]) 618 | for migrant_idx, local_idx in enumerate(losing_idxs): 619 | population[local_idx] = incoming_migrants[migrant_idx] 620 | new_best_fitness = min([x.fitness.mdl for x in population]) 621 | logging.info( 622 | f"Island {island_num} got {len(incoming_migrants)} incoming migrants. Previous best fitness: {prev_best_fitness:.2f}, new best: {new_best_fitness:.2f}" 623 | ) 624 | return population 625 | 626 | 627 | def _receive_and_integrate_migrants( 628 | population, config, island_num, process_lock 629 | ) -> _Population: 630 | if config.migration_channel == "file": 631 | incoming_migrants = _get_migrants_from_file( 632 | simulation_id=config.simulation_id, 633 | island_num=island_num, 634 | file_lock=process_lock, 635 | ) 636 | elif config.migration_channel == "mpi": 637 | incoming_migrants = get_migrants_through_mpi() 638 | else: 639 | raise ValueError(config.migration_channel) 640 | 641 | if incoming_migrants is None: 642 | logging.info(f"Island {island_num} has no incoming migrants waiting.") 643 | return population 644 | 645 | return _integrate_migrants( 646 | incoming_migrants=incoming_migrants, 647 | population=population, 648 | config=config, 649 | island_num=island_num, 650 | ) 651 | 652 | 653 | def _make_migration( 654 | population: _Population, 655 | island_num: int, 656 | config: configuration.SimulationConfig, 657 | migration_target_generator: Iterator[int], 658 | process_locks: Tuple[multiprocessing.Lock, ...], 659 | cloud_upload_queue: queue.Queue, 660 | ) -> _Population: 661 | target_island = next(migration_target_generator) 662 | _send_migrants( 663 | population=population, 664 | island_num=island_num, 665 | config=config, 666 | target_island=target_island, 667 | target_process_lock=process_locks[target_island] if process_locks else None, 668 | cloud_upload_queue=cloud_upload_queue, 669 | ) 670 | return _receive_and_integrate_migrants( 671 | population, 672 | config, 673 | island_num, 674 | process_locks[island_num] if process_locks else None, 675 | ) 676 | 677 | 678 | def run( 679 | island_num: int, 680 | corpus: corpora.Corpus, 681 | config: configuration.SimulationConfig, 682 | process_locks: Tuple[multiprocessing.Lock, ...], 683 | logging_queue: multiprocessing.Queue, 684 | ): 685 | seed = config.seed + island_num 686 | utils.seed(seed) 687 | logging.info(f"Island {island_num}, seed {seed}") 688 | 689 | population, generation = _initialize_population_and_generation( 690 | config, 691 | island_num, 692 | input_size=corpus.input_sequence.shape[-1], 693 | output_size=corpus.target_sequence.shape[-1], 694 | ) 695 | population = _evaluate_population(population, corpus, config) 696 | migration_target_generator = _make_migration_target_island_generator( 697 | island_num=island_num, total_islands=config.num_islands 698 | ) 699 | 700 | cloud_upload_queue = queue.Queue() 701 | 702 | stopwatch_start = datetime.datetime.now() 703 | while generation <= config.num_generations: 704 | generation_start_time = datetime.datetime.now() 705 | population = _make_generation(population, corpus, config) 706 | generation_time_delta = datetime.datetime.now() - generation_start_time 707 | 708 | time_delta = datetime.datetime.now() - stopwatch_start 709 | if config.num_islands > 1 and ( 710 | time_delta.total_seconds() >= config.migration_interval_seconds 711 | or generation % config.migration_interval_generations == 0 712 | ): 713 | logging.info( 714 | f"Island {island_num} performing migration, time passed {time_delta.total_seconds()} seconds." 715 | ) 716 | population = _make_migration( 717 | population=population, 718 | island_num=island_num, 719 | config=config, 720 | migration_target_generator=migration_target_generator, 721 | process_locks=process_locks, 722 | cloud_upload_queue=cloud_upload_queue, 723 | ) 724 | stopwatch_start = datetime.datetime.now() 725 | 726 | _log_generation( 727 | population=population, 728 | corpus=corpus, 729 | config=config, 730 | generation=generation, 731 | island_num=island_num, 732 | generation_time_delta=generation_time_delta, 733 | logging_queue=logging_queue, 734 | cloud_upload_queue=cloud_upload_queue, 735 | ) 736 | generation += 1 737 | 738 | population = _make_migration( 739 | population=population, 740 | island_num=island_num, 741 | config=config, 742 | migration_target_generator=migration_target_generator, 743 | process_locks=process_locks, 744 | cloud_upload_queue=cloud_upload_queue, 745 | ) 746 | 747 | best_network = min(population, key=_GET_NET_MDL) 748 | return best_network 749 | -------------------------------------------------------------------------------- /island.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | import logging 4 | import multiprocessing 5 | import pickle 6 | import queue 7 | from datetime import datetime 8 | from typing import Dict, Optional, Text, Tuple 9 | 10 | import numpy as np 11 | from mpi4py import MPI 12 | 13 | import configuration 14 | import corpora 15 | import genetic_algorithm 16 | import network 17 | import utils 18 | 19 | 20 | _MPI_COMMUNICATOR = MPI.COMM_WORLD 21 | _MPI_RANK = _MPI_COMMUNICATOR.Get_rank() 22 | 23 | _MPI_LOGGING_BUFFER_SIZE = 5_000_000 24 | 25 | 26 | def _island_process( 27 | island_num: int, 28 | config: configuration.SimulationConfig, 29 | corpus: corpora.Corpus, 30 | result_queue: multiprocessing.Queue, 31 | logging_queue: multiprocessing.Queue, 32 | process_locks: Tuple[multiprocessing.Lock, ...], 33 | ): 34 | try: 35 | result_network = genetic_algorithm.run( 36 | island_num=island_num, 37 | corpus=corpus, 38 | config=config, 39 | logging_queue=logging_queue, 40 | process_locks=process_locks, 41 | ) 42 | result_queue.put((island_num, result_network)) 43 | except: 44 | logging.exception(f"Exception in island {island_num}:") 45 | 46 | 47 | class _DummyProcess: 48 | # Used for profiling without multiprocessing. 49 | def __init__(self, target, kwargs): 50 | self._target = target 51 | self._kwargs = kwargs 52 | self.__setattr__ = lambda x: x 53 | 54 | def start(self): 55 | self._target(**self._kwargs) 56 | 57 | 58 | def _log_net_stats( 59 | stats: Dict, 60 | config: configuration.SimulationConfig, 61 | train_corpus: corpora.Corpus, 62 | test_corpus: Optional[corpora.Corpus], 63 | best_net: network.Network, 64 | status: Text, 65 | ): 66 | stats.update( 67 | { 68 | "simulation": config.simulation_id, 69 | "simulation id": config.simulation_id, 70 | "configuration": json.dumps(config.__dict__), 71 | "params": config.comment, 72 | "status": status, 73 | "best island": stats["island"], 74 | "last update": datetime.now().isoformat(), 75 | } 76 | ) 77 | 78 | log_text = ( 79 | f"Current best net, island {stats['island']}, generation {stats['generation']}: " 80 | f"MDL = {stats['mdl']:,.2f}, |G| = {stats['|g|']}, |D:G| = {stats['|d:g|']:,.2f}, units: {stats['units']}, connections: {stats['connections']}." 81 | ) 82 | 83 | train_num_chars = corpora.get_num_chars_in_corpus(train_corpus) 84 | average_train_dg_per_char = best_net.fitness.data_encoding_length / train_num_chars 85 | stats.update( 86 | { 87 | "average train d:g per character": f"{average_train_dg_per_char:.2f}", 88 | "training set num. chars": f"{train_num_chars}", 89 | } 90 | ) 91 | 92 | if train_corpus.deterministic_steps_mask is not None: 93 | train_deterministic_accuracy = network.calculate_deterministic_accuracy( 94 | best_net, train_corpus, config 95 | ) 96 | stats.update({"deterministic accuracy": train_deterministic_accuracy}) 97 | log_text += ( 98 | f" Training set deterministic accuracy: {train_deterministic_accuracy:,.2f}" 99 | ) 100 | 101 | if test_corpus is not None: 102 | test_net = network.invalidate_fitness(best_net) 103 | test_net = network.calculate_fitness( 104 | test_net, corpus=test_corpus, config=config 105 | ) 106 | 107 | test_num_chars = corpora.get_num_chars_in_corpus(test_corpus) 108 | average_test_dg_per_char = ( 109 | test_net.fitness.data_encoding_length / test_num_chars 110 | ) 111 | 112 | stats.update( 113 | { 114 | "test set d:g": f"{test_net.fitness.data_encoding_length:.2f}", 115 | "test set accuracy": f"{test_net.fitness.accuracy}", 116 | "average test d:g per character": f"{average_test_dg_per_char:.2f}", 117 | } 118 | ) 119 | log_text += f" Test set |D:G|: {test_net.fitness.data_encoding_length:,.2f}." 120 | 121 | if test_corpus.deterministic_steps_mask is not None: 122 | test_deterministic_accuracy = network.calculate_deterministic_accuracy( 123 | test_net, test_corpus, config 124 | ) 125 | stats.update( 126 | {"test set deterministic accuracy": test_deterministic_accuracy} 127 | ) 128 | log_text += ( 129 | f" Test set deterministic accuracy: {test_deterministic_accuracy:,.2f}" 130 | ) 131 | 132 | logging.info(log_text) 133 | 134 | current_best_net_filename = f"{config.simulation_id}__current_best" 135 | network.save(best_net, current_best_net_filename) 136 | network.visualize( 137 | best_net, 138 | current_best_net_filename, 139 | class_to_label=test_corpus.vocabulary if test_corpus else None, 140 | ) 141 | 142 | 143 | def _mpi_logging_worker( 144 | config: configuration.SimulationConfig, 145 | train_corpus: corpora.Corpus, 146 | test_corpus: Optional[corpora.Corpus], 147 | ): 148 | best_mdl = float("inf") 149 | buffer = np.empty(7) 150 | logging.info(f"Started MPI logging worker, rank {_MPI_RANK}") 151 | while True: 152 | _MPI_COMMUNICATOR.Recv(buffer, tag=utils.MPI_LOGGING_TAG) 153 | ( 154 | island_num, 155 | generation, 156 | mdl, 157 | grammar_encoding_length, 158 | data_encoding_length, 159 | num_units, 160 | connections, 161 | ) = buffer 162 | island_num = int(island_num) 163 | if mdl < best_mdl: 164 | best_mdl = mdl 165 | 166 | with open( 167 | f"./networks/{config.simulation_id}__island_{island_num}__best_network.pickle", 168 | "rb", 169 | ) as f: 170 | best_net = pickle.load(f) 171 | 172 | best_island_stats = { 173 | "island": island_num, 174 | "generation": int(generation), 175 | "mdl": mdl, 176 | "|g|": grammar_encoding_length, 177 | "|d:g|": data_encoding_length, 178 | "units": int(num_units), 179 | "connections": int(connections), 180 | "accuracy": best_net.fitness.accuracy, 181 | } 182 | 183 | _log_net_stats( 184 | stats=best_island_stats, 185 | config=config, 186 | train_corpus=train_corpus, 187 | test_corpus=test_corpus, 188 | best_net=best_net, 189 | status="Running", 190 | ) 191 | 192 | 193 | def _queue_logging_worker( 194 | logging_queue: multiprocessing.Queue, 195 | config: configuration.SimulationConfig, 196 | train_corpus: corpora.Corpus, 197 | test_corpus: Optional[corpora.Corpus], 198 | ): 199 | best_net_fitness = float("inf") 200 | 201 | while True: 202 | island_num_to_data = {} 203 | start_time = datetime.now() 204 | while (datetime.now() - start_time).total_seconds() < 60: 205 | try: 206 | data = logging_queue.get(timeout=5) 207 | island_num = data["stats"]["island"] 208 | island_num_to_data[island_num] = data 209 | except queue.Empty: 210 | pass 211 | 212 | best_island_data = None 213 | for data in island_num_to_data.values(): 214 | stats = data["stats"] 215 | if stats["mdl"] <= best_net_fitness: 216 | best_net_fitness = stats["mdl"] 217 | best_island_data = data 218 | 219 | if best_island_data is None: 220 | continue 221 | 222 | _log_net_stats( 223 | stats=best_island_data["stats"], 224 | config=config, 225 | train_corpus=train_corpus, 226 | test_corpus=test_corpus, 227 | best_net=best_island_data["best_net"], 228 | status="Running", 229 | ) 230 | 231 | 232 | def _create_csv_and_spreadsheet_entry( 233 | config: configuration.SimulationConfig, train_corpus, test_corpus 234 | ): 235 | data = { 236 | "simulation": config.simulation_id, 237 | "simulation id": config.simulation_id, 238 | "params": config.comment, 239 | "configuration": json.dumps(config.__dict__), 240 | "status": "Running", 241 | "started": datetime.now().isoformat(), 242 | "seed": config.seed, 243 | } 244 | 245 | train_num_chars = corpora.get_num_chars_in_corpus(train_corpus) 246 | data.update( 247 | {"training set num. chars": f"{train_num_chars}",} 248 | ) 249 | if train_corpus.optimal_d_given_g is not None: 250 | data.update( 251 | { 252 | "training set optimal d:g": f"{train_corpus.optimal_d_given_g:.2f}", 253 | "optimal average train d:g per character": f"{train_corpus.optimal_d_given_g / train_num_chars:.2f}", 254 | } 255 | ) 256 | 257 | if test_corpus is not None: 258 | test_num_chars = corpora.get_num_chars_in_corpus(test_corpus) 259 | data.update( 260 | { 261 | "test set params": f"Input shape: {test_corpus.input_sequence.shape}. Output shape: {test_corpus.target_sequence.shape}", 262 | "test set num. chars": f"{test_num_chars}", 263 | } 264 | ) 265 | if test_corpus.optimal_d_given_g is not None: 266 | data.update( 267 | { 268 | "test set optimal d:g": f"{test_corpus.optimal_d_given_g:.2f}", 269 | "optimal average test d:g per character": f"{test_corpus.optimal_d_given_g / test_num_chars:.2f}", 270 | } 271 | ) 272 | 273 | 274 | def _init_islands( 275 | first_island: int, 276 | last_island: int, 277 | config: configuration.SimulationConfig, 278 | train_corpus, 279 | result_queue: multiprocessing.Queue, 280 | logging_queue: multiprocessing.Queue, 281 | process_locks: Tuple[multiprocessing.Lock, ...], 282 | ) -> Tuple[multiprocessing.Process, ...]: 283 | processes = [] 284 | for i in range(first_island, last_island + 1): 285 | if config.parallelize: 286 | process_class = multiprocessing.Process 287 | else: 288 | process_class = _DummyProcess 289 | 290 | p = process_class( 291 | target=_island_process, 292 | kwargs={ 293 | "island_num": i, 294 | "config": config, 295 | "corpus": train_corpus, 296 | "result_queue": result_queue, 297 | "logging_queue": logging_queue, 298 | "process_locks": process_locks, 299 | }, 300 | ) 301 | p.daemon = True 302 | processes.append(p) 303 | 304 | return tuple(processes) 305 | 306 | 307 | def _get_island_results_from_queue_or_mpi(result_queue: multiprocessing.Queue): 308 | if result_queue is not None: 309 | return result_queue.get() 310 | return _MPI_COMMUNICATOR.recv(tag=utils.MPI_RESULTS_TAG) 311 | 312 | 313 | def _collect_results( 314 | num_islands_to_collect: int, 315 | result_queue: multiprocessing.Queue, 316 | config: configuration.SimulationConfig, 317 | train_corpus: corpora.Corpus, 318 | test_corpus: Optional[corpora.Corpus], 319 | ): 320 | result_network_per_island = {} 321 | while len(result_network_per_island) < num_islands_to_collect: 322 | island_num, result_net = _get_island_results_from_queue_or_mpi( 323 | result_queue=result_queue 324 | ) 325 | result_network_per_island[island_num] = result_net 326 | logging.info(f"Island {island_num} best network:\n{str(result_net)}") 327 | logging.info( 328 | f"{len(result_network_per_island)}/{num_islands_to_collect} (global {config.num_islands}) islands done." 329 | ) 330 | 331 | best_island, best_net = min( 332 | result_network_per_island.items(), key=lambda x: x[1].fitness.mdl 333 | ) 334 | network_filename = f"{config.simulation_id}__best_network" 335 | network.visualize( 336 | best_net, network_filename, class_to_label=train_corpus.vocabulary 337 | ) 338 | network.save(best_net, network_filename) 339 | 340 | logging.info(f"Best network of all islands:\n{str(best_net)}") 341 | 342 | csv_data = { 343 | "simulation id": config.simulation_id, 344 | "island": best_island, 345 | "mdl": best_net.fitness.mdl, 346 | "|g|": best_net.fitness.grammar_encoding_length, 347 | "|d:g|": best_net.fitness.data_encoding_length, 348 | "units": network.get_num_units(best_net), 349 | "connections": network.get_total_connections(best_net, include_biases=True), 350 | "accuracy": best_net.fitness.accuracy, 351 | "generation": config.num_generations, 352 | "finished": datetime.now().isoformat(), 353 | } 354 | _log_net_stats( 355 | stats=csv_data, 356 | config=config, 357 | train_corpus=train_corpus, 358 | test_corpus=test_corpus, 359 | best_net=best_net, 360 | status="Done", 361 | ) 362 | 363 | 364 | def _run_mpi_island(island_num, corpus, config): 365 | result_network = genetic_algorithm.run( 366 | island_num=island_num, 367 | corpus=corpus, 368 | config=config, 369 | logging_queue=None, 370 | process_locks=None, 371 | ) 372 | _MPI_COMMUNICATOR.send( 373 | (island_num, result_network), 374 | dest=config.num_islands + 1, 375 | tag=utils.MPI_RESULTS_TAG, 376 | ) 377 | while True: 378 | # TODO: need to keep running this so buffer won't fill. need to find a better solution. 379 | genetic_algorithm.get_migrants_through_mpi() 380 | 381 | 382 | def run( 383 | corpus: corpora.Corpus, 384 | config: configuration.SimulationConfig, 385 | first_island: int, 386 | last_island: int, 387 | ): 388 | train_corpus = corpora.optimize_for_feeding( 389 | dataclasses.replace(corpus, test_corpus=None) 390 | ) 391 | test_corpus = ( 392 | corpora.optimize_for_feeding(corpus.test_corpus) if corpus.test_corpus else None 393 | ) 394 | 395 | genetic_algorithm.verify_existing_simulation_override( 396 | simulation_id=config.simulation_id, 397 | ) 398 | 399 | logging.info(f"Starting simulation {config.simulation_id}\n") 400 | logging.info(f"Config: {config}\n") 401 | logging.info( 402 | f"Running islands {first_island}-{last_island} ({last_island-first_island+1}/{config.num_islands})" 403 | ) 404 | 405 | if config.migration_channel == "file": 406 | genetic_algorithm.get_migration_path(config.simulation_id).mkdir( 407 | parents=True, exist_ok=True 408 | ) 409 | 410 | if first_island == 0: 411 | _create_csv_and_spreadsheet_entry(config, train_corpus, test_corpus) 412 | 413 | result_queue = multiprocessing.Queue() 414 | logging_queue = multiprocessing.Queue() 415 | process_locks = tuple(multiprocessing.Lock() for _ in range(config.num_islands)) 416 | 417 | island_processes = _init_islands( 418 | first_island=first_island, 419 | last_island=last_island, 420 | config=config, 421 | train_corpus=train_corpus, 422 | result_queue=result_queue, 423 | logging_queue=logging_queue, 424 | process_locks=process_locks, 425 | ) 426 | 427 | for p in island_processes: 428 | p.start() 429 | 430 | logging_process = multiprocessing.Process( 431 | target=_queue_logging_worker, 432 | args=(logging_queue, config, train_corpus, test_corpus), 433 | ) 434 | logging_process.daemon = True 435 | logging_process.start() 436 | 437 | _collect_results( 438 | num_islands_to_collect=len(island_processes), 439 | result_queue=result_queue, 440 | config=config, 441 | train_corpus=train_corpus, 442 | test_corpus=test_corpus, 443 | ) 444 | 445 | elif config.migration_channel == "mpi": 446 | if _MPI_RANK < config.num_islands: 447 | # Ranks [0, num_islands - 1] = islands 448 | _run_mpi_island( 449 | island_num=_MPI_RANK, corpus=train_corpus, config=config, 450 | ) 451 | elif _MPI_RANK == config.num_islands: 452 | # Rank num_islands = logger 453 | _create_csv_and_spreadsheet_entry(config, train_corpus, test_corpus) 454 | _mpi_logging_worker(config, train_corpus, test_corpus) 455 | elif _MPI_RANK == config.num_islands + 1: 456 | # Rank num_island + 1 = results collector 457 | logging.info(f"Starting results collection, rank {_MPI_RANK}") 458 | _collect_results( 459 | num_islands_to_collect=config.num_islands, 460 | result_queue=None, 461 | config=config, 462 | train_corpus=train_corpus, 463 | test_corpus=test_corpus, 464 | ) 465 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import json 3 | 4 | import configuration 5 | import genetic_algorithm 6 | import island 7 | import simulations 8 | import utils 9 | 10 | utils.setup_logging() 11 | 12 | 13 | _NUM_REPRODUCTIONS = 1 14 | 15 | 16 | def _make_corpus(factory, args, corpus_seed): 17 | utils.seed(corpus_seed) 18 | return factory(**args) 19 | 20 | 21 | def run(): 22 | arg_parser = utils.make_cli_arguments() 23 | arguments = arg_parser.parse_args() 24 | 25 | simulation = simulations.SIMULATIONS[arguments.simulation_name] 26 | 27 | for i in range(_NUM_REPRODUCTIONS): 28 | simulation_seed = arguments.base_seed + i 29 | 30 | base_config = configuration.SimulationConfig( 31 | seed=simulation_seed, 32 | simulation_id=arguments.simulation_name, 33 | num_islands=arguments.total_islands, 34 | **{**simulations.DEFAULT_CONFIG, **simulation.get("config", {})}, 35 | ) 36 | 37 | corpus_args = simulation["corpus"]["args"] 38 | if arguments.corpus_args is not None: 39 | corpus_args.update(json.loads(arguments.corpus_args)) 40 | 41 | corpus = _make_corpus( 42 | factory=simulation["corpus"]["factory"], 43 | args=corpus_args, 44 | corpus_seed=base_config.corpus_seed, 45 | ) 46 | 47 | simulation_config = dataclasses.replace( 48 | base_config, 49 | comment=f"Corpus params: {json.dumps(simulation['corpus']['args'])}, input shape: {corpus.input_sequence.shape}. Output shape: {corpus.target_sequence.shape}", 50 | simulation_id=corpus.name, 51 | resumed_from_simulation_id=arguments.resumed_simulation_id, 52 | ) 53 | simulation_config = utils.add_hash_to_simulation_id(simulation_config, corpus) 54 | 55 | if arguments.override_existing: 56 | genetic_algorithm.remove_simulation_directory( 57 | simulation_id=simulation_config.simulation_id, 58 | ) 59 | 60 | utils.seed(simulation_seed) 61 | 62 | island.run( 63 | corpus=corpus, 64 | config=simulation_config, 65 | first_island=arguments.first_island 66 | if arguments.first_island is not None 67 | else 0, 68 | last_island=( 69 | arguments.last_island 70 | if arguments.last_island is not None 71 | else simulation_config.num_islands - 1 72 | ), 73 | ) 74 | 75 | 76 | if __name__ == "__main__": 77 | run() 78 | -------------------------------------------------------------------------------- /manual_nets.py: -------------------------------------------------------------------------------- 1 | import network 2 | 3 | 4 | def make_emmanuel_dyck_2_network(nesting_probab: float): 5 | opening_bracket_output_bias = nesting_probab / (2 * (1 - nesting_probab)) 6 | return network.make_custom_network( 7 | input_size=5, 8 | output_size=5, 9 | num_units=17, 10 | forward_weights={ 11 | 1: ((3, 1, 2, 1),), 12 | 2: ((12, -1, 1, 1), (13, 1, 1, 1)), 13 | 3: ((10, 1, 1, 1),), 14 | 4: ((2, 1, 1, 1),), 15 | 5: ((9, -1, 1, 1),), 16 | 6: ((8, 1, 1, 1),), 17 | 7: ((9, -1, 1, 1),), 18 | 10: ((11, 1, 1, 1),), 19 | 11: ((15, 1, 1, 1),), 20 | 12: ((11, 1, 1, 1),), 21 | 13: ((15, 1, 1, 1),), 22 | 14: ((13, 1, 1, 1),), 23 | 15: ((16, 1, 1, 1),), 24 | 16: ((5, -1, 1, 1), (7, 1, 1, 1)), 25 | }, 26 | recurrent_weights={15: ((10, 1, 3, 1), (14, 1, 1, 3))}, 27 | unit_types={11: network.MULTIPLICATION_UNIT, 13: network.MULTIPLICATION_UNIT,}, 28 | biases={5: 1, 6: opening_bracket_output_bias, 7: -1, 9: 1, 12: 1}, 29 | activations={ 30 | 5: network.UNSIGNED_STEP, 31 | 7: network.UNSIGNED_STEP, 32 | 14: network.FLOOR, 33 | 16: network.MODULO_3, 34 | }, 35 | ) 36 | 37 | 38 | def make_emmanuel_dyck_2_network_io_protection(nesting_probab: float): 39 | opening_bracket_output_bias = nesting_probab / (2 * (1 - nesting_probab)) 40 | return network.make_custom_network( 41 | input_size=5, 42 | output_size=5, 43 | num_units=23, 44 | forward_weights={ 45 | 1: ((18, 1, 2, 1),), 46 | 2: ((17, 1, 1, 1),), 47 | 3: ((18, 1, 1, 1),), 48 | 4: ((17, 1, 1, 1),), 49 | 10: ((11, 1, 1, 1),), 50 | 11: ((15, 1, 1, 1),), 51 | 12: ((11, 1, 1, 1),), 52 | 13: ((15, 1, 1, 1),), 53 | 14: ((13, 1, 1, 1),), 54 | 15: ((16, 1, 1, 1),), 55 | 16: ((19, -1, 1, 1), (20, 1, 1, 1)), 56 | 17: ((12, -1, 1, 1), (13, 1, 1, 1)), 57 | 18: ((10, 1, 1, 1),), 58 | 19: ((21, -1, 1, 1), (20, 1, 1, 1), (5, 1, 1, 1)), 59 | 20: ((7, 1, 1, 1), (21, -1, 1, 1)), 60 | 21: ((9, 1, 1, 1),), 61 | 22: ((6, 1, 1, 1), (8, 1, 1, 1)), 62 | }, 63 | recurrent_weights={15: ((10, 1, 3, 1), (14, 1, 1, 3))}, 64 | unit_types={11: network.MULTIPLICATION_UNIT, 13: network.MULTIPLICATION_UNIT,}, 65 | biases={12: 1, 19: 1, 20: -1, 21: 1, 22: opening_bracket_output_bias}, 66 | activations={ 67 | 14: network.FLOOR, 68 | 16: network.MODULO_3, 69 | 19: network.UNSIGNED_STEP, 70 | 20: network.UNSIGNED_STEP, 71 | }, 72 | ) 73 | 74 | 75 | def make_emmanuel_triplet_xor_network(): 76 | net = network.make_custom_network( 77 | input_size=1, 78 | output_size=1, 79 | num_units=7, 80 | forward_weights={ 81 | 0: ((2, 1, 1, 1),), 82 | 2: ((5, -1, 1, 1), (6, 1, 1, 1),), 83 | 3: ((6, 1, 1, 1), (5, 1, 1, 1)), 84 | 4: ((3, 1, 1, 1),), 85 | 5: ((1, -1, 1, 2),), 86 | 6: ((1, 1, 1, 2),), 87 | }, 88 | recurrent_weights={ 89 | 0: ((2, -1, 1, 1),), 90 | 3: ((4, -1, 3, 1),), 91 | 4: ((4, 1, 1, 1),), 92 | }, 93 | biases={1: 0.5, 3: -1, 4: 1, 6: -1}, 94 | activations={ 95 | 2: network.SQUARE, 96 | 3: network.RELU, 97 | 5: network.RELU, 98 | 6: network.RELU, 99 | }, 100 | ) 101 | return net 102 | -------------------------------------------------------------------------------- /mdlrnn_torch.py: -------------------------------------------------------------------------------- 1 | import itertools 2 | 3 | import torch 4 | from torch import nn 5 | 6 | 7 | class MDLRNN(nn.Module): 8 | def __init__( 9 | self, 10 | computation_graph: dict[int, tuple[int, nn.Linear]], 11 | layer_to_memory_weights: dict[int, nn.Linear], 12 | memory_to_layer_weights: dict[int, nn.Linear], 13 | layer_to_activation_to_units: dict[int, dict[int, frozenset[int]]], 14 | ): 15 | super(MDLRNN, self).__init__() 16 | self._computation_graph = computation_graph 17 | self._layer_to_activation_to_units = layer_to_activation_to_units 18 | self._memory_to_layer_weights = memory_to_layer_weights 19 | self.layer_to_memory_weights = layer_to_memory_weights 20 | 21 | self._memory_size = memory_to_layer_weights[ 22 | min(memory_to_layer_weights) 23 | ].in_features 24 | 25 | self.module_list = nn.ModuleList( 26 | [x[1] for x in list(itertools.chain(*self._computation_graph.values()))] 27 | + list(self.layer_to_memory_weights.values()) 28 | + list(self._memory_to_layer_weights.values()) 29 | ) 30 | 31 | def forward(self, inputs, memory=None, output_layer=None): 32 | """ 33 | :param inputs: batched tensor of shape `(batch_size, sequence_length, num_input_classes)`. 34 | :param memory: batch tensor of shape `(batch_size, memory_size)`. 35 | :param output_layer: function to apply to outputs: `None` for raw logits, `"softmax"`, or `"normalize"` for simple normalization. 36 | :return: tensor of shape `(batch_size, sequence_length, num_output_classes)`. 37 | """ 38 | 39 | def recurrence(inputs_inner, memory_inner): 40 | input_layer_num = min(self._computation_graph) 41 | layer_to_vals = {input_layer_num: inputs_inner} 42 | 43 | memory_out = torch.zeros( 44 | ( 45 | inputs_inner.shape[0], 46 | self._memory_size, 47 | ) 48 | ) 49 | 50 | for source_layer in sorted(self._computation_graph): 51 | # Add memory. 52 | memory_weights = self._memory_to_layer_weights[source_layer] 53 | incoming_memory = memory_weights(memory_inner) 54 | layer_to_vals[source_layer] = ( 55 | layer_to_vals[source_layer] + incoming_memory 56 | ) 57 | 58 | # Apply activations. 59 | source_layer_activations_to_unit = self._layer_to_activation_to_units[ 60 | source_layer 61 | ] 62 | activation_vals = self._apply_activations( 63 | source_layer_activations_to_unit, layer_to_vals[source_layer] 64 | ) 65 | layer_to_vals[source_layer] = activation_vals 66 | 67 | # Feed-forward. 68 | for target_layer, current_weights in self._computation_graph[ 69 | source_layer 70 | ]: 71 | source_layer_val = layer_to_vals[source_layer] 72 | target_layer_val = current_weights(source_layer_val) 73 | 74 | if target_layer in layer_to_vals: 75 | layer_to_vals[target_layer] = ( 76 | layer_to_vals[target_layer] + target_layer_val 77 | ) 78 | else: 79 | layer_to_vals[target_layer] = target_layer_val 80 | 81 | # Write to memory. 82 | memory_out = memory_out + self.layer_to_memory_weights[source_layer]( 83 | layer_to_vals[source_layer] 84 | ) 85 | 86 | y_out = layer_to_vals[max(layer_to_vals)] 87 | return y_out, memory_out 88 | 89 | if memory is None: 90 | memory = torch.zeros( 91 | (inputs.shape[0], self._memory_size), 92 | ) 93 | 94 | outputs = [] 95 | for step in range(inputs.shape[1]): 96 | y_t, memory = recurrence(inputs[:, step], memory) 97 | outputs.append(y_t) 98 | 99 | outputs = torch.stack(outputs, dim=1) 100 | 101 | if output_layer is not None: 102 | if output_layer == "softmax": 103 | outputs = torch.softmax(outputs, dim=-1) 104 | elif output_layer == "normalize": 105 | outputs = torch.clamp(outputs, min=0, max=None) 106 | outputs = nn.functional.normalize(outputs, p=1, dim=-1) 107 | else: 108 | raise ValueError(output_layer) 109 | 110 | return outputs, memory 111 | 112 | @staticmethod 113 | def _apply_activations(activation_to_unit, layer_vals) -> torch.Tensor: 114 | for activation_id in activation_to_unit: 115 | activation_unit_idxs = activation_to_unit[activation_id] 116 | if activation_id == 0: # Identity. 117 | continue 118 | activation_func = { 119 | 1: torch.relu, 120 | 3: torch.tanh, 121 | 4: torch.square, 122 | 2: torch.sigmoid, 123 | }[activation_id] 124 | layer_vals[:, activation_unit_idxs] = activation_func( 125 | layer_vals[:, activation_unit_idxs] 126 | ) 127 | return layer_vals 128 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | cachetools~=4.1.0 2 | mpi4py==3.0.3 3 | numba==0.53.1 4 | numpy==1.18.2 5 | pytest==6.2.2 6 | scikit-learn~=0.22.2.post1 7 | scipy~=1.4.1 8 | torch~=1.5.0 -------------------------------------------------------------------------------- /simulations.py: -------------------------------------------------------------------------------- 1 | import corpora 2 | import network 3 | 4 | DEFAULT_ACTIVATIONS = ( 5 | network.SIGMOID, 6 | network.LINEAR, 7 | network.RELU, 8 | network.SQUARE, 9 | network.UNSIGNED_STEP, 10 | network.FLOOR, 11 | ) 12 | 13 | EXTENDED_ACTIVATIONS = DEFAULT_ACTIVATIONS + (network.MODULO_2, network.MODULO_3,) 14 | 15 | DEFAULT_UNIT_TYPES = (network.SUMMATION_UNIT,) 16 | 17 | DEFAULT_CONFIG = { 18 | "migration_ratio": 0.01, 19 | "migration_interval_seconds": 1800, 20 | "migration_interval_generations": 1000, 21 | "num_generations": 25_000, 22 | "population_size": 500, 23 | "elite_ratio": 0.001, 24 | "allowed_activations": DEFAULT_ACTIVATIONS, 25 | "allowed_unit_types": DEFAULT_UNIT_TYPES, 26 | "start_smooth": False, 27 | "compress_grammar_encoding": False, 28 | "tournament_size": 2, 29 | "mutation_probab": 1.0, 30 | "mini_batch_size": None, 31 | "grammar_multiplier": 1, 32 | "data_given_grammar_multiplier": 1, 33 | "max_network_units": 1024, 34 | "softmax_outputs": False, 35 | "truncate_large_values": True, 36 | "bias_connections": True, 37 | "recurrent_connections": True, 38 | "corpus_seed": 100, 39 | "parallelize": True, 40 | "migration_channel": "file", 41 | "generation_dump_interval": 250, 42 | } 43 | 44 | SIMULATIONS = { 45 | "identity": { 46 | "corpus": { 47 | "factory": corpora.make_identity_binary, 48 | "args": {"sequence_length": 100, "batch_size": 10}, 49 | } 50 | }, 51 | "repeat_last_char": { 52 | "corpus": { 53 | "factory": corpora.make_prev_char_repetition_binary, 54 | "args": {"sequence_length": 100, "batch_size": 10, "repetition_offset": 1,}, 55 | } 56 | }, 57 | "binary_addition": { 58 | "corpus": { 59 | "factory": corpora.make_binary_addition, 60 | "args": {"min_n": 0, "max_n": 20}, 61 | }, 62 | }, 63 | "dyck_1": { 64 | "corpus": { 65 | "factory": corpora.make_dyck_n, 66 | "args": { 67 | "n": 1, 68 | "batch_size": 100, 69 | "nesting_probab": 0.3, 70 | "max_sequence_length": 200, 71 | }, 72 | }, 73 | }, 74 | "dyck_2": { 75 | "corpus": { 76 | "factory": corpora.make_dyck_n, 77 | "args": { 78 | "batch_size": 20_000, 79 | "nesting_probab": 0.3, 80 | "n": 2, 81 | "max_sequence_length": 200, 82 | }, 83 | }, 84 | "config": { 85 | "allowed_activations": EXTENDED_ACTIVATIONS, 86 | "allowed_unit_types": (network.SUMMATION_UNIT, network.MULTIPLICATION_UNIT), 87 | }, 88 | }, 89 | "an_bn": { 90 | "corpus": { 91 | "factory": corpora.make_ain_bjn_ckn_dtn, 92 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 1, 0, 0)}, 93 | } 94 | }, 95 | "an_bn_cn": { 96 | "corpus": { 97 | "factory": corpora.make_ain_bjn_ckn_dtn, 98 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 1, 1, 0)}, 99 | } 100 | }, 101 | "an_bn_cn_dn": { 102 | "corpus": { 103 | "factory": corpora.make_ain_bjn_ckn_dtn, 104 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 1, 1, 1)}, 105 | } 106 | }, 107 | "an_b2n": { 108 | "corpus": { 109 | "factory": corpora.make_ain_bjn_ckn_dtn, 110 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 2, 0, 0)}, 111 | } 112 | }, 113 | "an_bn_square": { 114 | "corpus": { 115 | "factory": corpora.make_an_bn_square, 116 | "args": {"batch_size": 1000, "prior": 0.5}, 117 | } 118 | }, 119 | "palindrome_fixed_length": { 120 | "corpus": { 121 | "factory": corpora.make_binary_palindrome_fixed_length, 122 | "args": { 123 | "batch_size": 1000, 124 | "sequence_length": 50, 125 | "train_set_ratio": 0.7, 126 | }, 127 | } 128 | }, 129 | "an_bm_cn_plus_m": { 130 | "corpus": { 131 | "factory": corpora.make_an_bm_cn_plus_m, 132 | "args": {"batch_size": 100, "prior": 0.3}, 133 | } 134 | }, 135 | "center_embedding": { 136 | "corpus": { 137 | "factory": corpora.make_center_embedding, 138 | "args": { 139 | "batch_size": 20_000, 140 | "embedding_depth_probab": 0.3, 141 | "dependency_distance_probab": 0.0, 142 | }, 143 | }, 144 | "config": { 145 | "allowed_activations": DEFAULT_ACTIVATIONS + (network.MODULO_2,), 146 | "allowed_unit_types": (network.SUMMATION_UNIT, network.MULTIPLICATION_UNIT), 147 | }, 148 | }, 149 | "0_1_pattern_binary": { 150 | "corpus": { 151 | "factory": corpora.make_0_1_pattern_binary, 152 | "args": {"sequence_length": 20, "batch_size": 1}, 153 | } 154 | }, 155 | "0_1_pattern_one_hot_no_eos": { 156 | "corpus": { 157 | "factory": corpora.make_0_1_pattern_one_hot, 158 | "args": { 159 | "add_end_of_sequence": False, 160 | "sequence_length": 50, 161 | "batch_size": 1, 162 | }, 163 | } 164 | }, 165 | "0_1_pattern_one_hot_with_eos": { 166 | "corpus": { 167 | "factory": corpora.make_0_1_pattern_one_hot, 168 | "args": { 169 | "add_end_of_sequence": True, 170 | "sequence_length": 50, 171 | "batch_size": 1, 172 | }, 173 | } 174 | }, 175 | } 176 | -------------------------------------------------------------------------------- /stats/tacl_stats.py: -------------------------------------------------------------------------------- 1 | import json 2 | import operator 3 | import re 4 | from typing import Dict 5 | 6 | import numpy as np 7 | 8 | 9 | def _params_key(params_dict): 10 | return str(tuple(sorted(params_dict.items()))) 11 | 12 | 13 | def _parse_mdlnn_params(params_string) -> Dict: 14 | return json.loads(re.findall(r"\{.+\}", params_string)[0]) 15 | 16 | 17 | def _parse_rnn_params(params_string) -> Dict: 18 | return json.loads(params_string) 19 | 20 | 21 | def _dict_subset(parent_dict, child_dict): 22 | t = all(child_key in parent_dict for child_key in child_dict.keys()) 23 | return t 24 | 25 | 26 | def _float_or_none(val): 27 | return float(val) if val is not None else -1 28 | 29 | 30 | def _get_value_from_sim_stats(sims_list, key): 31 | return np.array([_float_or_none(x[key]) for x in sims_list]) 32 | 33 | 34 | def _percent_format(x): 35 | return f"{x*100:.1f}%" 36 | 37 | 38 | def _ln_to_log2(ln_loss): 39 | return ln_loss / np.log(2) 40 | 41 | 42 | def _compare_mdl_and_rnn_winners_by_stat(mdl_stats, rnn_stats, key): 43 | mdl_best_val = _float_or_none(mdl_stats[key]) 44 | rnn_best_val = _float_or_none(rnn_stats[key]) 45 | 46 | if "accuracy" in key.lower(): 47 | # Accuracy, higher is better. 48 | op = operator.gt 49 | mdl_best_val_str = _percent_format(mdl_best_val) 50 | rnn_best_val_str = _percent_format(rnn_best_val) 51 | optimal_str = "" 52 | else: 53 | # Cross-entropy, lower is better. 54 | op = operator.lt 55 | rnn_best_val = _ln_to_log2(rnn_best_val) 56 | mdl_best_val_str = f"{mdl_best_val*100:.1f}" 57 | rnn_best_val_str = f"{rnn_best_val*100:.1f}" 58 | optimal_str = f"& ({_float_or_none(rnn_stats['Optimal average test CE per character'])*100:.1f})" 59 | 60 | print(f"\t{key}") 61 | print(f"\tMDLRNN & RNN") 62 | print(f"\t{mdl_best_val_str} & {rnn_best_val_str} {optimal_str}") 63 | 64 | if mdl_best_val == rnn_best_val: 65 | print("\t\tTie") 66 | elif op(mdl_best_val, rnn_best_val): 67 | print("\t\tMDL wins") 68 | else: 69 | print(f"\t\tRNN wins") 70 | 71 | 72 | def compare_mdlnn_rnn_stats(select_by: str): 73 | print(f"Comparing based on {select_by}\n") 74 | 75 | with open("./tacl_stats.json", "r") as f: 76 | data = json.load(f) 77 | 78 | mdlnn_rnn_comparisons = ( 79 | "Average test CE per character", 80 | "Test accuracy", 81 | "Test deterministic accuracy", 82 | ) 83 | 84 | test_best_rnn_ids = [] 85 | test_best_mdlnn_ids = [] 86 | 87 | for task_dict in data.values(): 88 | corpus_name = task_dict["corpus_name"] 89 | try: 90 | mdl_sims = task_dict["mdl"] 91 | except KeyError: 92 | print(f"No MDL model for {corpus_name}\n\n") 93 | continue 94 | rnn_sims = task_dict["rnn"] 95 | 96 | print(f"* {corpus_name}") 97 | 98 | train_winning_mdl_sim_idx = np.argmin( 99 | _get_value_from_sim_stats(mdl_sims, "MDL") 100 | ) 101 | train_winning_mdl_sim_stats = mdl_sims[train_winning_mdl_sim_idx] 102 | 103 | train_mdl_g_scores = _get_value_from_sim_stats(mdl_sims, "|G|") 104 | test_mdl_scores = train_mdl_g_scores + _get_value_from_sim_stats( 105 | mdl_sims, "Test set D:G" 106 | ) 107 | 108 | test_winning_mdl_sim_idx = np.argmin(test_mdl_scores) 109 | test_winning_mdl_sim_stats = mdl_sims[test_winning_mdl_sim_idx] 110 | 111 | train_winning_rnn_sim_idx = np.argmin( 112 | _get_value_from_sim_stats(rnn_sims, "Average train CE per character") 113 | ) 114 | train_winning_rnn_sim_stats = rnn_sims[train_winning_rnn_sim_idx] 115 | 116 | print(f"\tTrain-best MDL:\t{train_winning_mdl_sim_stats['Simulation id']}") 117 | print(f"\tTrain-best RNN:\t{train_winning_rnn_sim_stats['Simulation id']}\n") 118 | 119 | test_winning_rnn_sim_idx = np.argmin( 120 | # Take winner based on actual CE performance, without regularization ('loss' stat includes regularization term). 121 | _get_value_from_sim_stats(rnn_sims, "Average test CE per character") 122 | ) 123 | test_winning_rnn_sim_stats = rnn_sims[test_winning_rnn_sim_idx] 124 | test_rnn_args = json.loads(test_winning_rnn_sim_stats["Params"]) 125 | 126 | print( 127 | f"\tTest-best MDL:\t{test_winning_mdl_sim_stats['Simulation id']}.\n{test_winning_mdl_sim_stats['Units']} units, {test_winning_mdl_sim_stats['Connections']} connections" 128 | ) 129 | print( 130 | f"\tTest-best RNN:\t{test_winning_rnn_sim_stats['Simulation id']}\n{test_rnn_args['network_type']}, {test_rnn_args['num_hidden_units']} units, {test_rnn_args.get('regularization', 'no')} regularization, {test_rnn_args.get('regularization_lambda', '-')} lambda.\n" 131 | ) 132 | 133 | test_best_mdlnn_ids.append(test_winning_mdl_sim_stats["Simulation id"]) 134 | test_best_rnn_ids.append(test_winning_rnn_sim_stats["Simulation id"]) 135 | 136 | if select_by == "train": 137 | for comparison_key in mdlnn_rnn_comparisons: 138 | _compare_mdl_and_rnn_winners_by_stat( 139 | mdl_stats=train_winning_mdl_sim_stats, 140 | rnn_stats=train_winning_rnn_sim_stats, 141 | key=comparison_key, 142 | ) 143 | else: 144 | for comparison_key in mdlnn_rnn_comparisons: 145 | _compare_mdl_and_rnn_winners_by_stat( 146 | mdl_stats=test_winning_mdl_sim_stats, 147 | rnn_stats=test_winning_rnn_sim_stats, 148 | key=comparison_key, 149 | ) 150 | 151 | 152 | if __name__ == "__main__": 153 | compare_mdlnn_rnn_stats(select_by="train") 154 | compare_mdlnn_rnn_stats(select_by="test") 155 | -------------------------------------------------------------------------------- /tests/test_corpus.py: -------------------------------------------------------------------------------- 1 | import unittest 2 | 3 | import numpy as np 4 | 5 | import configuration 6 | import corpora 7 | import manual_nets 8 | import network 9 | import simulations 10 | import utils 11 | 12 | utils.setup_logging() 13 | 14 | 15 | def _test_xor_correctness(input_sequence, target_sequence): 16 | memory = {"step_t_minus_1": 0, "step_t_minus_2": 0} 17 | 18 | def predictor(x): 19 | memory["step_t_minus_2"] = memory["step_t_minus_1"] 20 | memory["step_t_minus_1"] = x 21 | return memory["step_t_minus_1"] ^ memory["step_t_minus_2"] 22 | 23 | correct = 0 24 | for in_, target in zip(input_sequence[0], target_sequence[0]): 25 | prediction = predictor(in_) 26 | if np.all(prediction == target): 27 | correct += 1 28 | 29 | accuracy = correct / input_sequence.shape[1] 30 | print(f"Accuracy: {accuracy:.2f}") 31 | 32 | # A perfect network predicts 33% of the values deterministically, and another 66% x 0.5 by chance. 33 | assert accuracy >= 0.66 34 | 35 | 36 | class TestCorpus(unittest.TestCase): 37 | def test_train_test_shuffle(self): 38 | train_ratio = 0.7 39 | batch_size = 1000 40 | train_corpus = corpora.make_binary_palindrome_fixed_length( 41 | batch_size=batch_size, sequence_length=10, train_set_ratio=train_ratio 42 | ) 43 | test_corpus = train_corpus.test_corpus 44 | assert ( 45 | sum(train_corpus.sample_weights) + sum(test_corpus.sample_weights) 46 | == batch_size 47 | ) 48 | 49 | assert train_corpus.input_sequence.shape[0] == 22 50 | assert test_corpus.input_sequence.shape[0] == 10 51 | 52 | def test_mini_batches(self): 53 | mini_batch_size = 100 54 | 55 | net = manual_nets.make_emmanuel_triplet_xor_network() 56 | full_corpus = corpora.optimize_for_feeding( 57 | corpora.make_elman_xor_binary(sequence_length=300, batch_size=1000) 58 | ) 59 | 60 | mini_batches = corpora.split_to_mini_batches( 61 | full_corpus, mini_batch_size=mini_batch_size 62 | ) 63 | assert len(mini_batches) == 10 64 | 65 | config = configuration.SimulationConfig( 66 | **{**simulations.DEFAULT_CONFIG, "mini_batch_size": mini_batch_size}, 67 | simulation_id="", 68 | num_islands=1, 69 | seed=1, 70 | ) 71 | net = network.calculate_fitness(net, full_corpus, config) 72 | print(net) 73 | assert net.fitness.data_encoding_length == 200000 74 | 75 | def test_0_1_pattern_binary(self): 76 | seq_length = 10 77 | batch_size = 10 78 | corpus = corpora.make_0_1_pattern_binary( 79 | sequence_length=seq_length, batch_size=batch_size 80 | ) 81 | assert corpus.sample_weights == (batch_size,) 82 | assert corpus.test_corpus.input_sequence.shape == (1, seq_length * 50_000, 1) 83 | 84 | seq_length = 10 85 | batch_size = 1 86 | corpus = corpora.make_0_1_pattern_binary( 87 | sequence_length=seq_length, batch_size=batch_size 88 | ) 89 | assert corpus.sample_weights is None 90 | 91 | def test_xor_corpus_correctness(self): 92 | sequence_length = 9999 93 | xor_corpus_binary = corpora.make_elman_xor_binary(sequence_length, batch_size=1) 94 | xor_corpus_one_hot = corpora.make_elman_xor_one_hot( 95 | sequence_length, batch_size=1 96 | ) 97 | _test_xor_correctness( 98 | xor_corpus_binary.input_sequence.astype(np.int), 99 | xor_corpus_binary.target_sequence.astype(np.int), 100 | ) 101 | _test_xor_correctness( 102 | xor_corpus_one_hot.input_sequence.argmax(axis=-1), 103 | xor_corpus_one_hot.target_sequence.argmax(axis=-1), 104 | ) 105 | 106 | def test_binary_addition_corpus_correctness(self): 107 | addition_corpus = corpora.make_binary_addition(min_n=0, max_n=100) 108 | 109 | for input_sequence, target_sequence in zip( 110 | addition_corpus.input_sequence, addition_corpus.target_sequence 111 | ): 112 | n1 = 0 113 | n2 = 0 114 | sum_ = 0 115 | for i, bin_digit in enumerate(input_sequence): 116 | n1_binary_digit = bin_digit[0] 117 | n2_binary_digit = bin_digit[1] 118 | current_exp = 2 ** i 119 | if n1_binary_digit == 1: 120 | n1 += current_exp 121 | if n2_binary_digit == 1: 122 | n2 += current_exp 123 | target_binary_digit = target_sequence[i] 124 | if target_binary_digit == 1: 125 | sum_ += current_exp 126 | 127 | assert n1 + n2 == sum_, (n1, n2, sum_) 128 | 129 | def test_an_bn_corpus(self): 130 | n_values = tuple(range(50)) 131 | an_bn_corpus = corpora.optimize_for_feeding( 132 | corpora._make_ain_bjn_ckn_dtn_corpus( 133 | n_values, multipliers=(1, 1, 0, 0), prior=0.1, sort_by_length=True 134 | ) 135 | ) 136 | 137 | for n in n_values: 138 | row = 49 - n # Sequences are sorted by decreasing length. 139 | input_seq = an_bn_corpus.input_sequence[row] 140 | target_seq = an_bn_corpus.target_sequence[row] 141 | seq_len = 1 + (2 * n) 142 | 143 | zeros_start = 1 144 | ones_start = n + 1 145 | 146 | assert not np.all(corpora.is_masked(input_seq[:seq_len])) 147 | assert np.all(corpora.is_masked(input_seq[seq_len:])) 148 | 149 | input_classes = np.argmax(input_seq, axis=-1)[:seq_len] 150 | target_classes = np.argmax(target_seq, axis=-1)[:seq_len] 151 | 152 | assert np.sum(input_classes == 1) == n 153 | assert np.sum(input_classes == 2) == n 154 | 155 | assert input_classes[0] == 0 # Start of sequence. 156 | assert np.all(input_classes[zeros_start:ones_start] == 1) 157 | assert np.all(input_classes[ones_start:seq_len] == 2) 158 | 159 | assert target_classes[seq_len - 1] == 0 # End of sequence. 160 | assert np.all(target_classes[zeros_start - 1 : ones_start - 1] == 1) 161 | assert np.all(target_classes[ones_start - 1 : seq_len - 1] == 2) 162 | 163 | def test_dfa_baseline_d_g(self): 164 | dfa_ = corpora.make_between_dfa(start=4, end=4) 165 | 166 | corpus = corpora.make_exactly_n_quantifier(4, sequence_length=50, batch_size=10) 167 | 168 | optimal_d_g = dfa_.get_optimal_data_given_grammar_for_dfa(corpus.input_sequence) 169 | 170 | num_non_masked_steps = np.sum(np.all(~np.isnan(corpus.input_sequence), axis=-1)) 171 | assert optimal_d_g == num_non_masked_steps 172 | -------------------------------------------------------------------------------- /tests/test_genetic_algorithm.py: -------------------------------------------------------------------------------- 1 | import dataclasses 2 | import random 3 | import unittest 4 | 5 | import numpy as np 6 | 7 | from mdlrnn import configuration, corpora, genetic_algorithm, network, utils 8 | 9 | utils.setup_logging() 10 | _TEST_CONFIG = configuration.SimulationConfig( 11 | simulation_id="test", 12 | num_islands=1, 13 | migration_ratio=0.1, 14 | migration_interval_seconds=20, 15 | migration_interval_generations=1000, 16 | num_generations=1000, 17 | population_size=20, 18 | elite_ratio=0.05, 19 | allowed_activations=( 20 | network.SIGMOID, 21 | network.LINEAR, 22 | network.RELU, 23 | network.SQUARE, 24 | ), 25 | start_smooth=False, 26 | allowed_unit_types=(network.SUMMATION_UNIT, network.MULTIPLICATION_UNIT), 27 | tournament_size=4, 28 | mutation_probab=0.9, 29 | grammar_multiplier=1, 30 | data_given_grammar_multiplier=1, 31 | compress_grammar_encoding=False, 32 | max_network_units=1024, 33 | softmax_outputs=False, 34 | truncate_large_values=False, 35 | bias_connections=True, 36 | recurrent_connections=True, 37 | seed=1, 38 | corpus_seed=100, 39 | generation_dump_interval=1, 40 | ) 41 | 42 | 43 | def _num_non_masked_inputs(x): 44 | return np.sum(~np.all(corpora.is_masked(x), axis=-1)) 45 | 46 | 47 | def _make_random_population(size): 48 | population = [] 49 | for _ in range(size): 50 | net = network.make_random_net( 51 | input_size=3, 52 | output_size=3, 53 | allowed_activations=_TEST_CONFIG.allowed_activations, 54 | start_smooth=False, 55 | ) 56 | p = random.random() 57 | if p < 0.3: 58 | grammar_encoding_length = np.inf 59 | else: 60 | grammar_encoding_length = np.random.randint(100) 61 | data_encoding_length = np.random.randint(100) 62 | net = dataclasses.replace( 63 | net, 64 | fitness=network.Fitness( 65 | mdl=grammar_encoding_length + data_encoding_length, 66 | grammar_encoding_length=grammar_encoding_length, 67 | data_encoding_length=data_encoding_length, 68 | accuracy=1.0, 69 | ), 70 | ) 71 | population.append(net) 72 | return population 73 | 74 | 75 | def _test_elite(population, elite): 76 | population_fitnesses = tuple(map(genetic_algorithm._GET_NET_MDL, population)) 77 | elite_fitnesses = tuple(map(genetic_algorithm._GET_NET_MDL, elite)) 78 | 79 | population_fitness_without_elite = frozenset(population_fitnesses) - frozenset( 80 | elite_fitnesses 81 | ) 82 | 83 | worst_elite_fitness = max(elite_fitnesses) 84 | best_non_elite_fitness = min(population_fitness_without_elite) 85 | 86 | assert worst_elite_fitness <= best_non_elite_fitness 87 | 88 | 89 | class TestGeneticAlgorithm(unittest.TestCase): 90 | def test_get_elite_idxs(self): 91 | population = _make_random_population(size=1000) 92 | best_idxs = genetic_algorithm._get_elite_idxs(population, elite_ratio=0.01) 93 | assert len(best_idxs) == 10, len(best_idxs) 94 | 95 | elite = [population[i] for i in best_idxs] 96 | _test_elite(population, elite) 97 | 98 | def test_get_elite(self): 99 | population = _make_random_population(size=1000) 100 | elite = genetic_algorithm._get_elite(elite_ratio=0.01, population=population) 101 | assert len(elite) == 10 102 | _test_elite(population, elite) 103 | 104 | def test_get_migration_target_island(self): 105 | island_num = 3 106 | migration_interval = 20 107 | total_islands = 16 108 | targets = [] 109 | generator = genetic_algorithm._make_migration_target_island_generator( 110 | island_num, total_islands 111 | ) 112 | 113 | for generation in range( 114 | 1, (total_islands + 1) * migration_interval, migration_interval 115 | ): 116 | targets.append(next(generator)) 117 | 118 | assert island_num not in targets 119 | assert tuple(targets) == ( 120 | 4, 121 | 5, 122 | 6, 123 | 7, 124 | 8, 125 | 9, 126 | 10, 127 | 11, 128 | 12, 129 | 13, 130 | 14, 131 | 15, 132 | 0, 133 | 1, 134 | 2, 135 | 4, 136 | 5, 137 | ) 138 | 139 | def test_tournament_selection_break_infinity_ties(self): 140 | population = _make_random_population(size=10000) 141 | inf_population = [x for x in population if np.isinf(x.fitness.mdl)] 142 | for _ in range(100): 143 | tournament = genetic_algorithm._tournament_selection( 144 | population=inf_population, tournament_size=2 145 | ) 146 | winner_idx = tournament.winner_idx 147 | loser_idx = tournament.loser_idx 148 | assert ( 149 | inf_population[winner_idx].fitness.grammar_encoding_length 150 | <= inf_population[loser_idx].fitness.grammar_encoding_length 151 | ) 152 | 153 | def test_tournament_selection(self): 154 | population = _make_random_population(size=1000) 155 | tournament = genetic_algorithm._tournament_selection( 156 | population=population, tournament_size=2 157 | ) 158 | winner_idx = tournament.winner_idx 159 | loser_idx = tournament.loser_idx 160 | assert population[winner_idx].fitness.mdl < population[loser_idx].fitness.mdl 161 | 162 | tournament = genetic_algorithm._tournament_selection( 163 | population=population, tournament_size=len(population), 164 | ) 165 | absolute_best_offspring_idx = tournament.winner_idx 166 | absolute_worst_offspring_idx = tournament.loser_idx 167 | 168 | assert ( 169 | population[absolute_best_offspring_idx].fitness.mdl 170 | == min(population, key=genetic_algorithm._GET_NET_MDL).fitness.mdl 171 | ) 172 | 173 | assert ( 174 | population[absolute_worst_offspring_idx].fitness.mdl 175 | == max(population, key=genetic_algorithm._GET_NET_MDL).fitness.mdl 176 | ) 177 | 178 | identical_fitness_population = [population[0], population[0]] 179 | tournament = genetic_algorithm._tournament_selection( 180 | identical_fitness_population, tournament_size=2 181 | ) 182 | winner_idx = tournament.winner_idx 183 | loser_idx = tournament.loser_idx 184 | assert winner_idx != loser_idx 185 | -------------------------------------------------------------------------------- /torch_conversion.py: -------------------------------------------------------------------------------- 1 | import collections 2 | import itertools 3 | from typing import Optional 4 | 5 | import mdlrnn_torch 6 | import numpy as np 7 | import torch 8 | from torch import nn 9 | 10 | import corpora 11 | import network 12 | import utils 13 | 14 | logger = utils.setup_logging() 15 | 16 | 17 | def _get_forward_mesh_layers(net) -> dict[int, frozenset[int]]: 18 | forward_mesh_layers = network.bfs_layers( 19 | forward_connections=net.forward_connections, 20 | reverse_forward_connections=net.reverse_forward_connections, 21 | units=network.get_units(net), 22 | input_units_range=net.input_units_range, 23 | output_units_range=net.output_units_range, 24 | ) 25 | 26 | # Move inputs to own layer. 27 | input_units = frozenset(net.input_units_range) 28 | for depth in range(2): # Inputs are at most depth 1. 29 | forward_mesh_layers[depth] = forward_mesh_layers[depth] - input_units 30 | if len(forward_mesh_layers[0]) == 0: 31 | input_level = 0 32 | else: 33 | input_level = -1 34 | forward_mesh_layers[input_level] = input_units 35 | 36 | # Move outputs to own layer. 37 | output_units = frozenset(net.output_units_range) 38 | for depth in forward_mesh_layers: 39 | forward_mesh_layers[depth] = forward_mesh_layers[depth] - output_units 40 | forward_mesh_layers = {k: v for k, v in forward_mesh_layers.items() if len(v)} 41 | max_depth = max(forward_mesh_layers) 42 | forward_mesh_layers[max_depth + 1] = output_units 43 | 44 | return forward_mesh_layers 45 | 46 | 47 | def _make_linear_weights(weights: torch.Tensor, bias: Optional[torch.Tensor] = None): 48 | linear = nn.Linear( 49 | in_features=weights.shape[1], 50 | out_features=weights.shape[0], 51 | bias=bias is not None, 52 | ) 53 | with torch.no_grad(): 54 | linear.weight.copy_(weights) 55 | 56 | if bias is not None: 57 | linear.bias.copy_(bias) 58 | 59 | return linear 60 | 61 | 62 | def _get_unit_to_bfs_layer( 63 | bfs_layer_to_units: dict[int, frozenset[int]] 64 | ) -> dict[int, int]: 65 | unit_to_layer = {} 66 | for source_layer, source_layer_units in bfs_layer_to_units.items(): 67 | for unit in source_layer_units: 68 | unit_to_layer[unit] = source_layer 69 | return unit_to_layer 70 | 71 | 72 | def _build_memory_layers( 73 | bfs_layer_to_units: dict[int, frozenset[int]], 74 | net: network.Network, 75 | ): 76 | memory_units = frozenset( 77 | unit 78 | for unit in net.recurrent_connections 79 | if len(net.recurrent_connections[unit]) 80 | ) 81 | memory_size = len(memory_units) 82 | 83 | unit_to_memory_idx = {unit: i for i, unit in enumerate(sorted(memory_units))} 84 | 85 | layer_to_memory_weights = {} 86 | memory_to_layer_weights = {} 87 | 88 | for layer, layer_units in bfs_layer_to_units.items(): 89 | to_memory_weights = torch.zeros((memory_size, len(layer_units))) 90 | from_memory_weight = torch.zeros((len(layer_units), memory_size)) 91 | 92 | for i, unit in enumerate(sorted(layer_units)): 93 | if unit in memory_units: 94 | to_memory_weights[unit_to_memory_idx[unit], i] = 1 95 | 96 | if unit in net.reverse_recurrent_connections: 97 | for source_memory_unit in net.reverse_recurrent_connections[unit]: 98 | incoming_weight = network.weight_to_float( 99 | net.recurrent_weights[(source_memory_unit, unit)] 100 | ) 101 | incoming_memory_idx = unit_to_memory_idx[source_memory_unit] 102 | from_memory_weight[i, incoming_memory_idx] = incoming_weight 103 | 104 | layer_to_memory_weights[layer] = _make_linear_weights( 105 | to_memory_weights, bias=torch.zeros((memory_size,)) 106 | ) 107 | memory_to_layer_weights[layer] = _make_linear_weights( 108 | from_memory_weight, bias=torch.zeros((from_memory_weight.shape[0],)) 109 | ) 110 | 111 | return layer_to_memory_weights, memory_to_layer_weights 112 | 113 | 114 | def _freeze_defaultdict(dd) -> dict: 115 | if "lambda" in str(dd.default_factory): 116 | return {key: _freeze_defaultdict(val) for key, val in dd.items()} 117 | frozen_dict = {} 118 | for key, val in dd.items(): 119 | if type(val) == set: 120 | val = frozenset(val) 121 | elif type(val) == list: 122 | val = tuple(val) 123 | frozen_dict[key] = val 124 | return frozen_dict 125 | 126 | 127 | def _build_computation_graph(net: network.Network, bfs_layer_to_units): 128 | unit_to_layer = _get_unit_to_bfs_layer(bfs_layer_to_units) 129 | 130 | layer_to_outgoing_layers = collections.defaultdict(set) 131 | for layer in bfs_layer_to_units: 132 | for unit in bfs_layer_to_units[layer]: 133 | for target_unit in net.forward_connections.get(unit, set()): 134 | layer_to_outgoing_layers[layer].add(unit_to_layer[target_unit]) 135 | 136 | layer_to_outgoing_layers = _freeze_defaultdict(layer_to_outgoing_layers) 137 | 138 | units_with_biases = set() 139 | 140 | computation_graph = collections.defaultdict(list) 141 | 142 | for source_layer in sorted(bfs_layer_to_units): 143 | source_layer_units = bfs_layer_to_units[source_layer] 144 | source_layer_size = len(source_layer_units) 145 | source_to_idx = dict((x, i) for i, x in enumerate(sorted(source_layer_units))) 146 | 147 | for target_layer in layer_to_outgoing_layers.get(source_layer, frozenset()): 148 | target_layer_units = bfs_layer_to_units[target_layer] 149 | target_to_idx = dict( 150 | (x, i) for i, x in enumerate(sorted(target_layer_units)) 151 | ) 152 | target_layer_size = len(target_layer_units) 153 | 154 | weights = torch.zeros((target_layer_size, source_layer_size)) 155 | 156 | for source in source_layer_units: 157 | source_idx = source_to_idx[source] 158 | source_unit_targets = ( 159 | net.forward_connections.get(source, frozenset()) 160 | & target_layer_units 161 | ) 162 | for target in source_unit_targets: 163 | target_idx = target_to_idx[target] 164 | 165 | weight = network.weight_to_float( 166 | net.forward_weights[(source, target)] 167 | ) 168 | weights[target_idx, source_idx] = weight 169 | 170 | bias = torch.zeros((target_layer_size,)) 171 | 172 | for target in target_layer_units: 173 | if target in net.biases and target not in units_with_biases: 174 | units_with_biases.add(target) 175 | bias[target_to_idx[target]] = network.weight_to_float( 176 | net.biases[target] 177 | ) 178 | 179 | weights = _make_linear_weights(weights=weights, bias=bias) 180 | 181 | computation_graph[source_layer].append((target_layer, weights)) 182 | 183 | # Connect inputs to layers that have no inputs with weights = 0. 184 | floating_layers = set(computation_graph) - set( 185 | [x[0] for x in list(itertools.chain(*computation_graph.values()))] 186 | ) 187 | input_layer = min(bfs_layer_to_units) 188 | input_size = len(bfs_layer_to_units[input_layer]) 189 | for floating_layer in floating_layers: 190 | layer_size = len(bfs_layer_to_units[floating_layer]) 191 | floating_layer_weights = _make_linear_weights( 192 | weights=torch.zeros((layer_size, input_size)), bias=torch.zeros((1,)) 193 | ) 194 | computation_graph[input_layer].append((floating_layer, floating_layer_weights)) 195 | computation_graph = _freeze_defaultdict(computation_graph) 196 | 197 | return computation_graph 198 | 199 | 200 | def _build_activation_function_vectors(net: network.Network, bfs_layer_to_units): 201 | layer_to_activation_to_units = collections.defaultdict( 202 | lambda: collections.defaultdict(list) 203 | ) 204 | for source_layer, source_layer_units in bfs_layer_to_units.items(): 205 | for unit_idx, unit in enumerate(sorted(source_layer_units)): 206 | activation = net.activations[unit] 207 | layer_to_activation_to_units[source_layer][activation].append(unit_idx) 208 | 209 | layer_to_activation_to_units = _freeze_defaultdict(layer_to_activation_to_units) 210 | return layer_to_activation_to_units 211 | 212 | 213 | def mdlnn_to_torch(net: network.Network) -> mdlrnn_torch.MDLRNN: 214 | bfs_layer_to_units = _get_forward_mesh_layers(net) 215 | 216 | computation_graph = _build_computation_graph(net, bfs_layer_to_units) 217 | layer_to_activation_to_units = _build_activation_function_vectors( 218 | net, bfs_layer_to_units 219 | ) 220 | layer_to_memory_weights, memory_to_layer_weights = _build_memory_layers( 221 | bfs_layer_to_units=bfs_layer_to_units, net=net 222 | ) 223 | 224 | return mdlrnn_torch.MDLRNN( 225 | computation_graph=computation_graph, 226 | layer_to_memory_weights=layer_to_memory_weights, 227 | memory_to_layer_weights=memory_to_layer_weights, 228 | layer_to_activation_to_units=layer_to_activation_to_units, 229 | ) 230 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import dataclasses 3 | import hashlib 4 | import itertools 5 | import logging 6 | import os 7 | import pathlib 8 | import pickle 9 | import random 10 | from typing import Any, Dict, Iterable, Optional, Text, Tuple 11 | 12 | import numpy as np 13 | from numba import types 14 | 15 | import corpora 16 | 17 | BASE_SEED = 100 18 | 19 | MPI_LOGGING_TAG = 1 20 | MPI_RESULTS_TAG = 2 21 | MPI_MIGRANTS_TAG = 3 22 | 23 | FLOAT_DTYPE = np.float64 24 | NUMBA_FLOAT_DTYPE = types.float64 25 | 26 | 27 | def kwargs_from_param_grid(param_grid: Dict[Text, Iterable[Any]]) -> Iterable[Dict]: 28 | arg_names = list(param_grid.keys()) 29 | arg_products = list(itertools.product(*param_grid.values())) 30 | for arg_product in arg_products: 31 | yield {arg: val for arg, val in zip(arg_names, arg_product)} 32 | 33 | 34 | def setup_logging(): 35 | logging.basicConfig( 36 | format="%(asctime)s.%(msecs)03d %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s", 37 | datefmt="%d-%m-%Y:%H:%M:%S", 38 | level=logging.INFO, 39 | ) 40 | 41 | 42 | def load_network_from_zoo(name, subdir=None): 43 | print(name) 44 | path = pathlib.Path(f"./network_zoo/") 45 | if subdir: 46 | path = path.joinpath(subdir.strip("/")) 47 | path = path.joinpath(f"{name}.pickle") 48 | print(path) 49 | with path.open("rb") as f: 50 | return pickle.load(f) 51 | 52 | 53 | def seed(n): 54 | random.seed(n) 55 | np.random.seed(n) 56 | 57 | 58 | def dict_and_corpus_hash(dict_, corpus) -> Text: 59 | s = corpus.name 60 | for key in sorted(dict_.keys()): 61 | s += f"{key} {dict_[key]}" 62 | # TODO: ugly but works. 63 | s += corpus.name + " " 64 | s += str(corpus.input_sequence) + " " 65 | s += str(corpus.target_sequence) 66 | hash = hashlib.sha1() 67 | hash.update(s.encode()) 68 | return hash.hexdigest() 69 | 70 | 71 | def add_hash_to_simulation_id(simulation_config, corpus): 72 | config_hash = dict_and_corpus_hash(simulation_config.__dict__, corpus) 73 | simulation_id = f"{corpus.name}_{config_hash}" 74 | return dataclasses.replace(simulation_config, simulation_id=simulation_id) 75 | 76 | 77 | def make_cli_arguments(): 78 | arg_parser = argparse.ArgumentParser() 79 | arg_parser.add_argument( 80 | "-s", 81 | "--simulation", 82 | dest="simulation_name", 83 | required=True, 84 | help=f"Simulation name.", 85 | ) 86 | 87 | arg_parser.add_argument( 88 | "-n", 89 | "--total-islands", 90 | type=int, 91 | dest="total_islands", 92 | default=os.cpu_count(), 93 | help=f"Total number of islands in entire simulation (including other machines). Default: number of local cores ({os.cpu_count()}).", 94 | ) 95 | 96 | arg_parser.add_argument( 97 | "--first-island", 98 | type=int, 99 | default=None, 100 | dest="first_island", 101 | help="First island index on this machine. Default: 0.", 102 | ) 103 | 104 | arg_parser.add_argument( 105 | "--last-island", 106 | type=int, 107 | default=None, 108 | dest="last_island", 109 | help="Last island index on this machine. Default: number of islands minus 1.", 110 | ) 111 | 112 | arg_parser.add_argument( 113 | "--seed", 114 | type=int, 115 | default=BASE_SEED, 116 | dest="base_seed", 117 | help=f"Base seed value. Default: {BASE_SEED}. For the i-th reproduction (0-based), the seed will be {BASE_SEED} + i.", 118 | ) 119 | 120 | arg_parser.add_argument( 121 | "--override", 122 | action="store_true", 123 | dest="override_existing", 124 | help="Override an existing simulation that has the same hash.", 125 | ) 126 | 127 | arg_parser.add_argument( 128 | "--resume", 129 | dest="resumed_simulation_id", 130 | default=None, 131 | help="Resume simulation from latest generations.", 132 | ) 133 | 134 | arg_parser.add_argument( 135 | "--corpus-args", 136 | default=None, 137 | dest="corpus_args", 138 | help="json to override default corpus arguments.", 139 | ) 140 | 141 | return arg_parser 142 | 143 | 144 | def calculate_symbolic_accuracy( 145 | predicted_probabs: np.ndarray, 146 | target_probabs: np.ndarray, 147 | input_mask: Optional[np.ndarray], 148 | sample_weights: Tuple[int], 149 | plots: bool, 150 | epsilon: float = 0.0, 151 | ) -> Tuple[float, Tuple[int, ...]]: 152 | zero_target_probabs = target_probabs == 0.0 153 | 154 | zero_predicted_probabs = predicted_probabs <= epsilon 155 | 156 | prediction_matches = np.all( 157 | np.equal(zero_predicted_probabs, zero_target_probabs), axis=-1 158 | ) 159 | 160 | prediction_matches[~input_mask] = True 161 | 162 | sequence_idxs_with_errors = tuple(np.where(np.any(~prediction_matches, axis=1))[0]) 163 | logging.info(f"Sequence idxs with mismatches: {sequence_idxs_with_errors}") 164 | 165 | incorrect_predictions_per_time_step = np.sum(~prediction_matches, axis=0) 166 | 167 | if plots: 168 | from matplotlib import pyplot as plt 169 | 170 | fig, ax = plt.subplots() 171 | ax.set_title("Num prediction mismatches by time step") 172 | ax.bar( 173 | np.arange(len(incorrect_predictions_per_time_step)), 174 | incorrect_predictions_per_time_step, 175 | ) 176 | plt.show() 177 | 178 | prediction_matches_without_masked = prediction_matches[input_mask] 179 | 180 | w = np.array(sample_weights).reshape((-1, 1)) 181 | weights_repeated = np.matmul(w, np.ones((1, predicted_probabs.shape[1]))) 182 | weights_masked = weights_repeated[input_mask] 183 | 184 | prediction_matches_weighted = np.multiply( 185 | prediction_matches_without_masked, weights_masked 186 | ) 187 | 188 | symbolic_accuracy = np.sum(prediction_matches_weighted) / np.sum(weights_masked) 189 | return symbolic_accuracy, sequence_idxs_with_errors 190 | 191 | 192 | def plot_probabs(probabs: np.ndarray, input_classes, class_to_label=None): 193 | from matplotlib import _color_data as matploit_color_data 194 | from matplotlib import pyplot as plt 195 | 196 | if probabs.shape[-1] == 1: 197 | # Binary outputs, output is P(1). 198 | probabs_ = np.zeros((probabs.shape[0], 2)) 199 | probabs_[:, 0] = (1 - probabs).squeeze() 200 | probabs_[:, 1] = probabs.squeeze() 201 | probabs = probabs_ 202 | 203 | masked_timesteps = np.where(corpora.is_masked(probabs))[0] 204 | if len(masked_timesteps): 205 | first_mask_step = masked_timesteps[0] 206 | probabs = probabs[:first_mask_step] 207 | 208 | plt.rc("grid", color="w", linestyle="solid") 209 | if class_to_label is None: 210 | class_to_label = {i: str(i) for i in range(len(input_classes))} 211 | fig, ax = plt.subplots(figsize=(9, 5), dpi=150, facecolor="white") 212 | x = np.arange(probabs.shape[0]) 213 | num_classes = probabs.shape[1] 214 | width = 0.8 215 | colors = ( 216 | list(matploit_color_data.TABLEAU_COLORS) + list(matploit_color_data.XKCD_COLORS) 217 | )[:num_classes] 218 | for c in range(num_classes): 219 | ax.bar( 220 | x, 221 | probabs[:, c], 222 | label=f"P({class_to_label[c]})" if num_classes > 1 else "P(1)", 223 | color=colors[c], 224 | width=width, 225 | bottom=np.sum(probabs[:, :c], axis=-1), 226 | ) 227 | ax.set_facecolor("white") 228 | ax.set_xticks(x) 229 | ax.set_xticklabels([class_to_label[x] for x in input_classes], fontsize=13) 230 | 231 | ax.set_xlabel("Input characters", fontsize=15) 232 | ax.set_ylabel("Next character probability", fontsize=15) 233 | 234 | ax.grid(b=True, color="#bcbcbc") 235 | # plt.title("Next step prediction probabilities", fontsize=22) 236 | plt.legend(loc="upper left", fontsize=15) 237 | 238 | # fig.savefig("test.png") 239 | fig.subplots_adjust(bottom=0.1) 240 | 241 | plt.show() 242 | fig.savefig( 243 | f"./figures/net_probabs_{random.randint(0,10_000)}.pdf", 244 | dpi=300, 245 | facecolor="white", 246 | ) 247 | -------------------------------------------------------------------------------- /vanilla_rnn.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from typing import List, Optional, Text, Tuple 3 | 4 | import numpy as np 5 | import torch 6 | from torch import nn, optim 7 | from torch.nn.utils import rnn 8 | 9 | import corpora 10 | import utils 11 | 12 | utils.setup_logging() 13 | 14 | 15 | def _ensure_torch(x) -> torch.Tensor: 16 | if isinstance(x, np.ndarray): 17 | x = torch.from_numpy(x).float() 18 | return x 19 | 20 | 21 | def _get_masked_sequence_lengths(x) -> List[int]: 22 | # Ugly and slow but works. 23 | sequence_lengths = [] 24 | for b in range(x.shape[0]): 25 | seq = x[b] 26 | any_masked = False 27 | for i in range(seq.shape[0]): 28 | if np.all(corpora.is_masked(seq[i])): 29 | any_masked = True 30 | sequence_lengths.append(i) 31 | break 32 | if not any_masked: 33 | sequence_lengths.append(seq.shape[0]) 34 | return sequence_lengths 35 | 36 | 37 | def _get_packed_sample_weights( 38 | sample_weights, max_sequence_length, masked_sequence_lengths 39 | ): 40 | if sample_weights is None: 41 | return None 42 | w = np.array(sample_weights).reshape((-1, 1)) 43 | weights_repeated = np.matmul(w, torch.ones((1, max_sequence_length))) 44 | return _pack_masked(weights_repeated, masked_sequence_lengths).data 45 | 46 | 47 | def _pack_masked(sequences: np.ndarray, sequence_lengths): 48 | return rnn.pack_padded_sequence( 49 | _ensure_torch(sequences), 50 | lengths=sequence_lengths, 51 | batch_first=True, 52 | enforce_sorted=True, 53 | ) 54 | 55 | 56 | def _unpack( 57 | packed_tensor, 58 | packed_sequence_batch_sizes, 59 | batch_size, 60 | max_sequence_length, 61 | num_classes, 62 | ): 63 | # Assumes sorted indices. 64 | unpacked = np.empty((batch_size, max_sequence_length, num_classes)) 65 | unpacked.fill(corpora.MASK_VALUE) 66 | i = 0 67 | for t, batch_size in enumerate(packed_sequence_batch_sizes): 68 | unpacked[:batch_size, t] = packed_tensor[i : i + batch_size] 69 | i += batch_size 70 | return unpacked 71 | 72 | 73 | class VanillaRNN(nn.Module): 74 | def __init__( 75 | self, 76 | num_hidden_units: int, 77 | corpus: corpora.Corpus, 78 | network_type: Text, 79 | num_epochs: int, 80 | regularization: Optional[Text], 81 | regularization_lambda: Optional[float], 82 | ): 83 | super(VanillaRNN, self).__init__() 84 | 85 | self._num_hidden_units = num_hidden_units 86 | self._num_epochs = num_epochs 87 | self._train_corpus = corpus 88 | self._regularization = regularization 89 | self._regularization_lambda = regularization_lambda 90 | 91 | self._input_size = self._train_corpus.input_sequence.shape[-1] 92 | self._output_size = self._train_corpus.target_sequence.shape[-1] 93 | 94 | rnn_kwargs = { 95 | "input_size": self._input_size, 96 | "hidden_size": self._num_hidden_units, 97 | "batch_first": True, 98 | } 99 | rnn_type_to_layer = {"elman": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU} 100 | rnn_layer = rnn_type_to_layer[network_type] 101 | self._rnn_layer = rnn_layer(**rnn_kwargs) 102 | 103 | layers = [ 104 | nn.Linear( 105 | in_features=self._num_hidden_units, out_features=self._output_size, 106 | ), 107 | ] 108 | if self._output_size == 1: 109 | layers.append(nn.Sigmoid()) 110 | else: 111 | layers.append(nn.LogSoftmax(dim=-1)) 112 | self._layers = nn.Sequential(*layers) 113 | 114 | def _forward(self, x): 115 | rnn_packed_outputs, _ = self._rnn_layer(x) 116 | rnn_outputs = rnn_packed_outputs.data 117 | return self._layers(rnn_outputs) 118 | 119 | def fit(self): 120 | train_max_sequence_length = self._train_corpus.input_sequence.shape[1] 121 | train_sequence_lengths = _get_masked_sequence_lengths( 122 | self._train_corpus.input_sequence 123 | ) 124 | train_inputs_packed = _pack_masked( 125 | self._train_corpus.input_sequence, sequence_lengths=train_sequence_lengths 126 | ) 127 | train_targets_packed = _pack_masked( 128 | self._train_corpus.target_sequence, sequence_lengths=train_sequence_lengths 129 | ).data 130 | 131 | train_sample_weights_packed = _get_packed_sample_weights( 132 | self._train_corpus.sample_weights, 133 | max_sequence_length=train_max_sequence_length, 134 | masked_sequence_lengths=train_sequence_lengths, 135 | ) 136 | 137 | optimizer = optim.Adam(self.parameters(), lr=0.001) 138 | 139 | for epoch in range(self._num_epochs): 140 | optimizer.zero_grad() 141 | output = self._forward(train_inputs_packed) 142 | cross_entropy_loss, _ = _calculate_loss( 143 | net=self, 144 | outputs_packed=output, 145 | targets_packed=train_targets_packed, 146 | sample_weights=train_sample_weights_packed, 147 | regularization=self._regularization, 148 | regularization_lambda=self._regularization_lambda, 149 | ) 150 | cross_entropy_loss.backward() 151 | optimizer.step() 152 | 153 | if epoch % 10 == 0: 154 | logging.info( 155 | f"Epoch {epoch} training loss: {cross_entropy_loss.item():.3e}" 156 | ) 157 | 158 | def feed_sequence(self, input_sequence): 159 | with torch.no_grad(): 160 | return self._forward(input_sequence) 161 | 162 | 163 | def _calculate_loss( 164 | net: VanillaRNN, 165 | outputs_packed, 166 | targets_packed, 167 | sample_weights, 168 | regularization, 169 | regularization_lambda, 170 | ): 171 | if targets_packed.shape[-1] == 1: 172 | loss_func = nn.BCELoss 173 | target_classes = targets_packed 174 | else: 175 | # Not using `CrossEntropyLoss` because network outputs are already log-softmaxed. 176 | loss_func = nn.NLLLoss 177 | target_classes = targets_packed.argmax(axis=-1) 178 | 179 | non_reduced_loss = loss_func(reduction="none")(outputs_packed, target_classes) 180 | 181 | if sample_weights is not None: 182 | weighted_losses = torch.mul(non_reduced_loss, sample_weights) 183 | weighted_losses_sum = weighted_losses.sum() 184 | total_chars_in_input = sample_weights.sum() 185 | average_loss = weighted_losses_sum / total_chars_in_input 186 | else: 187 | average_loss = non_reduced_loss.mean() 188 | weighted_losses_sum = non_reduced_loss.sum() 189 | 190 | regularized_loss = 0 191 | if regularization == "L1": 192 | for p in net._rnn_layer.parameters(): 193 | regularized_loss += torch.sum(torch.abs(p)) 194 | for p in net._layers.parameters(): 195 | regularized_loss += torch.sum(torch.abs(p)) 196 | 197 | elif regularization == "L2": 198 | for p in net._rnn_layer.parameters(): 199 | regularized_loss += torch.sum(torch.square(p)) 200 | for p in net._layers.parameters(): 201 | regularized_loss += torch.sum(torch.square(p)) 202 | 203 | average_loss = average_loss + (regularization_lambda * regularized_loss) 204 | 205 | return average_loss, weighted_losses_sum 206 | 207 | 208 | def calculate_symbolic_accuracy( 209 | found_net: VanillaRNN, 210 | inputs: np.ndarray, 211 | target_probabs: np.ndarray, 212 | sample_weights: Tuple[int, ...], 213 | input_mask: np.ndarray, 214 | epsilon: float, 215 | plots: bool = False, 216 | ): 217 | sequence_lengths = _get_masked_sequence_lengths(inputs) 218 | 219 | inputs_packed = _pack_masked(inputs, sequence_lengths) 220 | predicted_probabs_packed = found_net.feed_sequence(inputs_packed) 221 | predicted_probabs = _unpack( 222 | packed_tensor=predicted_probabs_packed, 223 | packed_sequence_batch_sizes=inputs_packed.batch_sizes, 224 | batch_size=inputs.shape[0], 225 | max_sequence_length=inputs.shape[1], 226 | num_classes=target_probabs.shape[-1], 227 | ) 228 | predicted_probabs = np.exp(predicted_probabs) 229 | 230 | return utils.calculate_symbolic_accuracy( 231 | predicted_probabs=predicted_probabs, 232 | target_probabs=target_probabs, 233 | input_mask=input_mask, 234 | plots=plots, 235 | sample_weights=sample_weights, 236 | epsilon=epsilon, 237 | ) 238 | 239 | 240 | def evaluate( 241 | net, 242 | inputs, 243 | targets, 244 | sample_weights, 245 | deterministic_steps_mask, 246 | regularization, 247 | regularization_lambda, 248 | ): 249 | sequence_lengths = _get_masked_sequence_lengths(inputs) 250 | 251 | inputs_packed = _pack_masked(inputs, sequence_lengths) 252 | targets_packed = _pack_masked(targets, sequence_lengths).data 253 | 254 | sample_weights_packed = _get_packed_sample_weights( 255 | sample_weights, 256 | max_sequence_length=inputs.shape[1], 257 | masked_sequence_lengths=sequence_lengths, 258 | ) 259 | 260 | y_pred = net.feed_sequence(inputs_packed) 261 | 262 | cross_entropy_loss, cross_entropy_sum = _calculate_loss( 263 | net, 264 | y_pred, 265 | targets_packed, 266 | sample_weights_packed, 267 | regularization, 268 | regularization_lambda, 269 | ) 270 | 271 | if targets.shape[-1] == 1: 272 | target_classes = targets_packed.flatten() 273 | predicted_classes = (y_pred > 0.5).flatten().long() 274 | else: 275 | target_classes = targets_packed.argmax(dim=-1).flatten() 276 | predicted_classes = y_pred.argmax(dim=-1).flatten() 277 | 278 | correct = torch.sum(torch.eq(predicted_classes, target_classes)).item() 279 | accuracy = correct / len(target_classes) 280 | 281 | if deterministic_steps_mask is not None: 282 | deterministic_mask_packed = _pack_masked( 283 | deterministic_steps_mask, sequence_lengths 284 | ).data.bool() 285 | det_target_classes = target_classes[deterministic_mask_packed] 286 | det_correct = torch.eq( 287 | predicted_classes[deterministic_mask_packed], det_target_classes 288 | ) 289 | det_flat_sample_weights = sample_weights_packed[deterministic_mask_packed] 290 | det_correct_weighted = torch.mul(det_correct.int(), det_flat_sample_weights) 291 | det_accuracy = ( 292 | f"{det_correct_weighted.sum() / det_flat_sample_weights.sum().item():.5f}" 293 | ) 294 | else: 295 | det_accuracy = None 296 | 297 | logging.info( 298 | f"Accuracy: {accuracy:.5f} (Correct: {correct} / {len(target_classes)})\n" 299 | f"Deterministic accuracy: {det_accuracy}\n" 300 | f"Cross-entropy loss: {cross_entropy_loss:.2f}\n" 301 | f"Cross-entropy sum: {cross_entropy_sum:.2f}" 302 | ) 303 | return ( 304 | accuracy, 305 | cross_entropy_loss.item(), 306 | cross_entropy_sum.item(), 307 | det_accuracy, 308 | ) 309 | --------------------------------------------------------------------------------