├── LICENSE.md
├── README.md
├── assets
└── anbncn.png
├── configuration.py
├── corpora.py
├── dfa.py
├── genetic_algorithm.py
├── island.py
├── main.py
├── manual_nets.py
├── mdlrnn_torch.py
├── network.py
├── requirements.txt
├── simulations.py
├── stats
├── tacl_stats.json
└── tacl_stats.py
├── tests
├── test_corpus.py
├── test_genetic_algorithm.py
└── test_network.py
├── torch_conversion.py
├── utils.py
└── vanilla_rnn.py
/LICENSE.md:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2023 Nur Lan, Emmanuel Chemla, Roni Katzir
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Minimum Description Length Recurrent Neural Networks
2 |
3 | 
4 | 
5 | 
6 | [](https://arxiv.org/abs/2111.00600)
7 |
8 | Code for [Minimum Description Length Recurrent Neural Networks](https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00489/112499/Minimum-Description-Length-Recurrent-Neural) by Nur Lan, Michal Geyer, Emmanuel Chemla, and Roni Katzir.
9 |
10 | Paper: https://direct.mit.edu/tacl/article/doi/10.1162/tacl_a_00489/
11 |
12 |
13 |
14 | ## Getting started
15 | 1. Install Python >= 3.7
16 | 2. `pip install -r requirements.txt`
17 |
18 | ### On Ubuntu, install:
19 | ```
20 | $ apt-get install libsm6 libxext6 libxrender1 libffi-dev libopenmpi-dev
21 | ```
22 |
23 | ## Running simulations
24 |
25 | ```
26 | $ python main.py --simulation -n
27 | ```
28 |
29 | For example, to run the `aⁿbⁿcⁿ` task using 16 island processes:
30 | ```
31 | $ python main.py --simulation an_bn_cn -n 16
32 | ```
33 |
34 | * All simulations are available in `simulations.py`
35 |
36 | * Final and intermediate solutions are saved to the `networks` sub-directory, both as `pickle` and in visual `dot` format.
37 |
38 |
39 | ## PyTorch conversion
40 |
41 | Converting a network trained using the genetic algorithm to a PyTorch module:
42 |
43 | ```
44 | import torch_conversion
45 |
46 | with open("networks/net.pickle", "rb") as f:
47 | net = pickle.load(f)
48 |
49 | torch_net = torch_conversion.mdlnn_to_torch(net)
50 | ```
51 |
52 | Then fine-tune and evaluate using [MDLRNN-torch](https://github.com/0xnurl/mdlrnn-torch).
53 |
54 | ## Parallelization
55 |
56 | Native Python multiprocessing is used by default. To use MPI, change `migration_channel` to `mpi` in `simulations.py`.
57 |
58 | ## Citing this work
59 |
60 | ```
61 | @article{Lan-Geyer-Chemla-Katzir-MDLRNN-2022,
62 | title = {Minimum Description Length Recurrent Neural Networks},
63 | author = {Lan, Nur and Geyer, Michal and Chemla, Emmanuel and Katzir, Roni},
64 | year = {2022},
65 | month = jul,
66 | journal = {Transactions of the Association for Computational Linguistics},
67 | volume = {10},
68 | pages = {785--799},
69 | issn = {2307-387X},
70 | doi = {10.1162/tacl_a_00489},
71 | }
72 | ```
73 |
--------------------------------------------------------------------------------
/assets/anbncn.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/taucompling/mdlrnn/4161677520f9200e8263fd5016f1843455991814/assets/anbncn.png
--------------------------------------------------------------------------------
/configuration.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | from typing import Optional, Text, Tuple
3 |
4 |
5 | @dataclasses.dataclass(frozen=True)
6 | class SimulationConfig:
7 | simulation_id: Text
8 | num_islands: int
9 | migration_interval_seconds: int
10 | migration_interval_generations: int
11 | migration_ratio: float
12 |
13 | num_generations: int
14 | population_size: int
15 | elite_ratio: float
16 | mutation_probab: float
17 | allowed_activations: Tuple[int, ...]
18 | allowed_unit_types: Tuple[int, ...]
19 | start_smooth: bool
20 |
21 | max_network_units: int
22 | tournament_size: int
23 |
24 | grammar_multiplier: int
25 | data_given_grammar_multiplier: int
26 |
27 | compress_grammar_encoding: bool
28 | softmax_outputs: bool
29 | truncate_large_values: bool
30 | bias_connections: bool
31 | recurrent_connections: bool
32 | generation_dump_interval: int
33 |
34 | seed: int
35 | corpus_seed: int
36 |
37 | mini_batch_size: Optional[int] = None
38 | resumed_from_simulation_id: Optional[Text] = None
39 | comment: Optional[Text] = None
40 | parallelize: bool = True
41 | migration_channel: Text = "file" # {'file', 'mpi'}
42 |
--------------------------------------------------------------------------------
/corpora.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import csv
3 | import dataclasses
4 | import itertools
5 | import logging
6 | import math
7 | import random
8 | from typing import Dict, FrozenSet, List, Optional, Text, Tuple, Union
9 |
10 | import configuration
11 | import dfa
12 | import numpy as np
13 | import utils
14 |
15 | _DEFAULT_CONFIG = configuration.SimulationConfig(
16 | simulation_id="test",
17 | num_islands=1,
18 | migration_ratio=0.1,
19 | migration_interval_seconds=20,
20 | migration_interval_generations=1000,
21 | num_generations=1000,
22 | population_size=20,
23 | elite_ratio=0.05,
24 | allowed_activations=(0, 1, 2, 3, 4, 5, 6,),
25 | start_smooth=False,
26 | allowed_unit_types=(0, 1,),
27 | tournament_size=4,
28 | mutation_probab=1.0,
29 | grammar_multiplier=1,
30 | data_given_grammar_multiplier=1,
31 | compress_grammar_encoding=False,
32 | max_network_units=1024,
33 | softmax_outputs=False,
34 | truncate_large_values=True,
35 | bias_connections=True,
36 | recurrent_connections=True,
37 | seed=1,
38 | corpus_seed=100,
39 | generation_dump_interval=1,
40 | parallelize=False,
41 | migration_channel="file",
42 | mini_batch_size=None,
43 | )
44 |
45 |
46 | MASK_VALUE = np.nan
47 |
48 | _Vocabulary = Dict[int, Text]
49 |
50 |
51 | def is_masked(x: Union[np.ndarray, float]) -> Union[np.ndarray, bool]:
52 | return np.isnan(x)
53 |
54 |
55 | @dataclasses.dataclass(frozen=True)
56 | class Corpus:
57 | name: Text
58 | input_sequence: np.ndarray
59 | target_sequence: np.ndarray
60 |
61 | optimal_d_given_g: Optional[float] = None
62 | vocabulary: Optional[_Vocabulary] = None
63 | deterministic_steps_mask: Optional[np.ndarray] = None
64 |
65 | # Precomputed values for feeding efficiency.
66 | input_mask: Optional[np.ndarray] = None
67 | targets_mask: Optional[np.ndarray] = None
68 | input_values_per_time_step: Optional[Dict[int, List[np.ndarray]]] = None
69 | sample_weights: Optional[Tuple[int, ...]] = None
70 |
71 | test_corpus: Optional["Corpus"] = None
72 |
73 |
74 | def precompute_mask_idxs(corpus: Corpus) -> Corpus:
75 | masked = is_masked(corpus.input_sequence)
76 | input_mask = np.array(
77 | [
78 | ~np.all(masked[i, j])
79 | for (i, j) in np.ndindex(corpus.input_sequence.shape[:2])
80 | ],
81 | dtype=np.bool,
82 | ).reshape(corpus.input_sequence.shape[:2])
83 | return dataclasses.replace(corpus, input_mask=input_mask)
84 |
85 |
86 | def _precompute_input_unit_values(corpus: Corpus) -> Corpus:
87 | unit_to_timestep_val = {}
88 | for unit in range(corpus.input_sequence.shape[-1]):
89 | unit_to_timestep_val[unit] = []
90 | for time_step in range(corpus.input_sequence.shape[1]):
91 | time_step_input = np.ascontiguousarray(
92 | corpus.input_sequence[:, time_step, unit]
93 | )
94 | time_step_input.flags.writeable = False
95 | unit_to_timestep_val[unit].append(time_step_input)
96 | unit_to_timestep_val[unit] = tuple(unit_to_timestep_val[unit])
97 | return dataclasses.replace(corpus, input_values_per_time_step=unit_to_timestep_val)
98 |
99 |
100 | def _precompute_targets_mask(corpus: Corpus) -> Corpus:
101 | if corpus.target_sequence.shape[-1] == 1:
102 | targets_mask = corpus.target_sequence == 1
103 | else:
104 | targets_mask = np.zeros_like(corpus.target_sequence, dtype=np.bool)
105 | target_classes = corpus.target_sequence.argmax(axis=-1).flatten()
106 | batch_idxs, time_idxs = tuple(
107 | zip(*np.ndindex(corpus.target_sequence.shape[:2]))
108 | )
109 | targets_mask[batch_idxs, time_idxs, target_classes] = True
110 | return dataclasses.replace(corpus, targets_mask=targets_mask)
111 |
112 |
113 | def _make_inputs_read_only(corpus: Corpus) -> Corpus:
114 | corpus.input_sequence.flags.writeable = False
115 | return corpus
116 |
117 |
118 | def optimize_for_feeding(corpus: Corpus) -> Corpus:
119 | logging.info(f"Optimizing corpus for feeding...")
120 | corpus = _make_inputs_read_only(corpus)
121 | corpus = _precompute_targets_mask(corpus)
122 | corpus = precompute_mask_idxs(corpus)
123 | corpus = _precompute_input_unit_values(corpus)
124 | return corpus
125 |
126 |
127 | def make_random_binary(sequence_length: int = 100, batch_size: int = 1,) -> Corpus:
128 | return Corpus(
129 | "random_binary",
130 | input_sequence=np.random.randint(0, 2, size=(batch_size, sequence_length, 1)),
131 | target_sequence=np.random.randint(0, 2, size=(batch_size, sequence_length, 1)),
132 | )
133 |
134 |
135 | def make_random_one_hot(
136 | num_input_classes: int,
137 | num_target_classes: int,
138 | sequence_length: int = 100,
139 | batch_size: int = 1,
140 | ) -> Corpus:
141 | input_classes = np.random.randint(
142 | 0, num_input_classes, size=(batch_size, sequence_length)
143 | )
144 | target_classes = np.random.randint(
145 | 0, num_target_classes, size=(batch_size, sequence_length)
146 | )
147 | return make_one_hot_corpus(
148 | "random_one_hot",
149 | input_classes=input_classes,
150 | target_classes=target_classes,
151 | num_input_classes=num_input_classes,
152 | num_target_classes=num_target_classes,
153 | )
154 |
155 |
156 | def make_one_hot_corpus(
157 | name: Text,
158 | input_classes: Union[List, np.ndarray],
159 | target_classes: Union[List, np.ndarray],
160 | num_input_classes: int,
161 | num_target_classes: int,
162 | weights: Optional[Tuple[int, ...]] = None,
163 | vocabulary: Optional[_Vocabulary] = None,
164 | ) -> Corpus:
165 | return Corpus(
166 | name,
167 | input_sequence=_make_one_hot_sequence(
168 | np.array(input_classes), num_input_classes
169 | ),
170 | target_sequence=_make_one_hot_sequence(
171 | np.array(target_classes), num_target_classes
172 | ),
173 | sample_weights=weights,
174 | vocabulary=vocabulary,
175 | )
176 |
177 |
178 | def _force_batch_dimension(arr: np.ndarray) -> np.ndarray:
179 | if arr.ndim == 1:
180 | return np.expand_dims(arr, axis=0)
181 | return arr
182 |
183 |
184 | def _make_one_hot_sequence(classes: np.ndarray, num_classes: int) -> np.ndarray:
185 | classes = _force_batch_dimension(classes)
186 | batch_size = classes.shape[0]
187 | sequence_length = classes.shape[1]
188 |
189 | one_hot = np.zeros(
190 | (batch_size, sequence_length, num_classes), dtype=utils.FLOAT_DTYPE, order="C"
191 | )
192 |
193 | for b in range(batch_size):
194 | for s in range(sequence_length):
195 | c = classes[b, s]
196 | if is_masked(c):
197 | one_hot[b, s] = MASK_VALUE
198 | else:
199 | one_hot[b, s, int(c)] = 1.0
200 | return one_hot
201 |
202 |
203 | def make_between_dfa(start: int, end: int) -> dfa.DFA:
204 | final_state = end + 1
205 | transitions = {}
206 | for i in range(start):
207 | transitions[i] = {"0": i, "1": i + 1}
208 | for i in range(start, end):
209 | transitions[i] = {"0": i, "1": i + 1, dfa.END_OF_SEQUENCE: final_state}
210 | transitions[end] = {"0": end, dfa.END_OF_SEQUENCE: final_state}
211 | return dfa.DFA(transitions=transitions, accepting_states={final_state})
212 |
213 |
214 | def make_at_least_dfa(n: int) -> dfa.DFA:
215 | transitions = {}
216 | for i in range(n):
217 | transitions[i] = {"0": i, "1": i + 1}
218 | transitions[n] = {"0": n, "1": n, dfa.END_OF_SEQUENCE: n + 1}
219 | return dfa.DFA(transitions=transitions, accepting_states={n + 1})
220 |
221 |
222 | def _make_at_most_dfa(n: int) -> dfa.DFA:
223 | transitions = {}
224 | accepting_state = n + 1
225 | for i in range(n):
226 | transitions[i] = {"0": i, "1": i + 1, dfa.END_OF_SEQUENCE: accepting_state}
227 | transitions[n] = {"0": n, dfa.END_OF_SEQUENCE: accepting_state}
228 | return dfa.DFA(transitions=transitions, accepting_states={accepting_state})
229 |
230 |
231 | def _dfa_to_inputs(
232 | dfa_: dfa.DFA, batch_size: int, end_of_sequence_char: int, max_sequence_length: int,
233 | ) -> np.ndarray:
234 | batch = np.empty((batch_size, max_sequence_length))
235 | batch.fill(MASK_VALUE)
236 |
237 | for b in range(batch_size):
238 | input_idx = 0
239 | while True:
240 | string = dfa_.generate_string()
241 | if len(string) > max_sequence_length:
242 | continue
243 | if input_idx + len(string) > max_sequence_length:
244 | break
245 | for i, char in enumerate(string):
246 | if char == dfa.END_OF_SEQUENCE:
247 | char_int = end_of_sequence_char
248 | else:
249 | char_int = int(char)
250 | batch[b, input_idx + i] = char_int
251 | input_idx += len(string)
252 |
253 | return batch
254 |
255 |
256 | def make_identity(
257 | sequence_length: int = 1000, batch_size: int = 10, num_classes: int = 2
258 | ) -> Corpus:
259 | sequence = np.random.randint(
260 | num_classes, size=(batch_size, sequence_length)
261 | ).astype(utils.FLOAT_DTYPE)
262 | return make_one_hot_corpus("identity", sequence, sequence, num_classes, num_classes)
263 |
264 |
265 | def make_identity_binary(sequence_length: int, batch_size: int) -> Corpus:
266 | sequence = np.random.randint(2, size=(batch_size, sequence_length, 1)).astype(
267 | utils.FLOAT_DTYPE
268 | )
269 | input_sequence = np.copy(sequence)
270 | input_sequence[sequence == 0] = -1
271 | return Corpus("identity", input_sequence, sequence)
272 |
273 |
274 | def _make_repetition_sequence(
275 | input_sequence: np.ndarray, offset: int, padding: float = 0.0
276 | ) -> np.ndarray:
277 | """[a,b,c,d, ..., y,z] -> [,a,b,c, ..., y] """
278 | assert input_sequence.ndim == 2
279 | batch_size = input_sequence.shape[0]
280 | padded_arr = np.empty((batch_size, offset), dtype=utils.FLOAT_DTYPE)
281 | padded_arr.fill(padding)
282 | return np.concatenate((padded_arr, input_sequence[:, :-offset],), axis=-1,)
283 |
284 |
285 | def _make_prediction_sequence(
286 | input_sequence: np.ndarray, lookahead: int = 1, padding: float = 0.0
287 | ) -> np.ndarray:
288 | """[a,b,c,d, ..., y,z] -> [b,c,d,e, ..., z,] """
289 | input_sequence = _force_batch_dimension(input_sequence)
290 | assert input_sequence.ndim == 2
291 | batch_size = input_sequence.shape[0]
292 | padded_arr = np.empty((batch_size, lookahead), dtype=utils.FLOAT_DTYPE)
293 | padded_arr.fill(padding)
294 | return np.concatenate((input_sequence[:, lookahead:], padded_arr), axis=-1)
295 |
296 |
297 | def make_prev_char_repetition(
298 | sequence_length: int = 1000,
299 | batch_size: int = 10,
300 | repetition_offset: int = 1,
301 | num_classes: int = 2,
302 | ) -> Corpus:
303 | input_sequence = np.random.randint(num_classes, size=(batch_size, sequence_length))
304 | target_sequence = _make_repetition_sequence(input_sequence, repetition_offset)
305 |
306 | return make_one_hot_corpus(
307 | f"repeat_prev_{repetition_offset}_char",
308 | input_sequence,
309 | target_sequence,
310 | num_classes,
311 | num_classes,
312 | )
313 |
314 |
315 | def make_prev_char_repetition_binary(
316 | sequence_length: int, batch_size: int, repetition_offset: int,
317 | ) -> Corpus:
318 | input_sequence = np.random.randint(2, size=(batch_size, sequence_length)).astype(
319 | utils.FLOAT_DTYPE
320 | )
321 | target_sequence = _make_repetition_sequence(input_sequence, repetition_offset)
322 | return Corpus(
323 | f"repeat_prev_{repetition_offset}_char_binary",
324 | np.expand_dims(input_sequence, -1),
325 | np.expand_dims(target_sequence, -1),
326 | )
327 |
328 |
329 | def make_elman_xor_binary(sequence_length: int = 3000, batch_size: int = 1) -> Corpus:
330 | assert sequence_length % 3 == 0
331 |
332 | input_batch = []
333 | target_batch = []
334 | for b in range(batch_size):
335 | sequence = []
336 | for pair_idx in range(sequence_length // 3):
337 | a, b = random.choice([0, 1]), random.choice([0, 1])
338 | sequence += [a, b, a ^ b]
339 | input_batch.append(sequence)
340 | # Target output is the next character of input
341 | target_sequence = sequence[1:] + [0]
342 | target_batch.append(target_sequence)
343 | input_batch = np.expand_dims(
344 | np.array(input_batch, dtype=utils.FLOAT_DTYPE), axis=-1
345 | )
346 | target_batch = np.expand_dims(
347 | np.array(target_batch, dtype=utils.FLOAT_DTYPE), axis=-1
348 | )
349 | return Corpus("elman_xor_binary", input_batch, target_batch)
350 |
351 |
352 | def make_elman_xor_one_hot(sequence_length: int = 3000, batch_size: int = 1) -> Corpus:
353 | binary_corpus = make_elman_xor_binary(sequence_length, batch_size)
354 |
355 | return make_one_hot_corpus(
356 | "elman_xor_one_hot",
357 | binary_corpus.input_sequence,
358 | binary_corpus.target_sequence,
359 | num_input_classes=2,
360 | num_target_classes=2,
361 | )
362 |
363 |
364 | def make_semi_random_corpus(sequence_length: int = 100, batch_size: int = 10) -> Corpus:
365 | """ One random bit, one identical bit, e.g.: [0,0,0,0,1,1,0,0,1,1,0,0, ...] """
366 | assert sequence_length % 2 == 0
367 | input_batch = []
368 | target_batch = []
369 | for _ in range(batch_size):
370 | sequence = []
371 | for s in range(sequence_length // 2):
372 | sequence += [random.randrange(2)] * 2
373 | input_batch.append(sequence)
374 | target_sequence = sequence[1:] + [0]
375 | target_batch.append(target_sequence)
376 |
377 | input_batch = np.expand_dims(np.array(input_batch), axis=-1)
378 | target_batch = np.expand_dims(np.array(target_batch), axis=-1)
379 | return Corpus("semi_random_pairs", input_batch, target_batch)
380 |
381 |
382 | def make_elman_badigu(num_consonants: int = 1000) -> Corpus:
383 | # ba, dii, guuu
384 | feature_table = {
385 | # cons, vowel, int, high, back, voiced
386 | "b": [1, 0, 1, 0, 0, 1], # b
387 | "d": [1, 0, 1, 1, 0, 1], # d
388 | "g": [1, 0, 1, 0, 1, 1], # g
389 | "a": [0, 1, 0, 0, 1, 1], # a
390 | "i": [0, 1, 0, 1, 0, 1], # i
391 | "u": [0, 1, 0, 1, 1, 1], # u
392 | }
393 | segments = list("bdgaiu")
394 | num_classes = len(segments)
395 | segment_to_idx = {x: i for i, x in enumerate(segments)}
396 |
397 | consonant_to_sequence = {"b": list("ba"), "d": list("dii"), "g": list("guuu")}
398 | consonant_sequence = np.random.choice(["b", "d", "g"], size=num_consonants)
399 |
400 | letters_sequence = list(
401 | itertools.chain(*(consonant_to_sequence[c] for c in consonant_sequence))
402 | )
403 | input_sequence = [segment_to_idx[x] for x in letters_sequence]
404 | target_sequence = input_sequence[1:] + [0]
405 |
406 | logging.info(f"Elman badigu sequence: {letters_sequence}")
407 | consonants = tuple("bdg")
408 |
409 | consonant_percentage = len([x for x in letters_sequence if x in consonants]) / len(
410 | letters_sequence
411 | )
412 |
413 | logging.info(f"Max accuracy for task: {1-consonant_percentage:.2f}")
414 |
415 | return make_one_hot_corpus(
416 | "elman_badigu", input_sequence, target_sequence, num_classes, num_classes
417 | )
418 |
419 |
420 | def _make_0_1_pattern_binary(sequence_length: int, batch_size: int) -> Corpus:
421 | assert sequence_length % 2 == 0
422 |
423 | input_seq = np.array([[0, 1] * (sequence_length // 2)], dtype=utils.FLOAT_DTYPE)
424 | target_seq = _make_prediction_sequence(input_seq, lookahead=1, padding=0.0)
425 |
426 | input_seq = np.expand_dims(input_seq, axis=2)
427 | target_seq = np.expand_dims(target_seq, axis=2)
428 |
429 | return Corpus(
430 | name=f"0_1_pattern_binary_length_{sequence_length}_batch_{batch_size}",
431 | input_sequence=input_seq,
432 | target_sequence=target_seq,
433 | optimal_d_given_g=0.0,
434 | vocabulary={0: "1", 1: "1"},
435 | sample_weights=(batch_size,) if batch_size > 1 else None,
436 | )
437 |
438 |
439 | def make_0_1_pattern_binary(sequence_length: int, batch_size: int) -> Corpus:
440 | train_corpus = _make_0_1_pattern_binary(sequence_length, batch_size)
441 | test_corpus = _make_0_1_pattern_binary(
442 | sequence_length=sequence_length * 50_000, batch_size=1
443 | )
444 | return dataclasses.replace(train_corpus, test_corpus=test_corpus)
445 |
446 |
447 | def _make_0_1_pattern_one_hot(
448 | sequence_length: int, add_end_of_sequence: bool, batch_size: int
449 | ) -> Corpus:
450 | assert sequence_length % 2 == 0
451 |
452 | input_classes = [0, 1] * (sequence_length // 2)
453 | num_classes = 2
454 | vocabulary = {0: "0", 1: "1"}
455 |
456 | if add_end_of_sequence:
457 | num_classes = 3
458 | input_classes += [2]
459 | vocabulary[2] = "#"
460 |
461 | vocabulary.update({x + len(vocabulary): vocabulary[x] for x in vocabulary})
462 |
463 | input_classes_arr = np.array([input_classes])
464 | target_classes_arr = _make_prediction_sequence(
465 | input_classes_arr, lookahead=1, padding=0.0
466 | )
467 |
468 | corpus = make_one_hot_corpus(
469 | name=f"0_1_pattern_one_hot_length_{sequence_length}_batch_{batch_size}{'_eos' if add_end_of_sequence else ''}",
470 | input_classes=input_classes_arr,
471 | target_classes=target_classes_arr,
472 | num_input_classes=num_classes,
473 | num_target_classes=num_classes,
474 | weights=(batch_size,) if batch_size > 1 else None,
475 | vocabulary=vocabulary,
476 | )
477 | return dataclasses.replace(
478 | # TODO: calculate optimal D
479 | corpus,
480 | optimal_d_given_g=0.0,
481 | )
482 |
483 |
484 | def make_0_1_pattern_one_hot(
485 | sequence_length: int, add_end_of_sequence: bool, batch_size: int
486 | ) -> Corpus:
487 | train_corpus = _make_0_1_pattern_one_hot(
488 | sequence_length, add_end_of_sequence, batch_size
489 | )
490 | test_corpus = _make_0_1_pattern_one_hot(
491 | sequence_length=sequence_length * 50_000,
492 | add_end_of_sequence=add_end_of_sequence,
493 | batch_size=1,
494 | )
495 | return dataclasses.replace(train_corpus, test_corpus=test_corpus)
496 |
497 |
498 | def make_123_n_pattern_corpus(
499 | base_sequence_length: int = 3, sequence_length: int = 100
500 | ):
501 | # [0,1,2, ..., n-1] repeated
502 | assert sequence_length % base_sequence_length == 0
503 | input_sequence = np.array(
504 | list(range(base_sequence_length)) * (sequence_length // base_sequence_length)
505 | )
506 | target_sequence = _make_prediction_sequence(input_sequence, lookahead=1)
507 | return make_one_hot_corpus(
508 | f"1_to_{base_sequence_length}_pattern",
509 | input_sequence,
510 | target_sequence,
511 | num_input_classes=base_sequence_length,
512 | num_target_classes=base_sequence_length,
513 | )
514 |
515 |
516 | def make_between_quantifier(
517 | start: int, end: int, sequence_length: int = 100, batch_size: int = 1
518 | ) -> Corpus:
519 | assert end <= sequence_length
520 | between_dfa = make_between_dfa(start, end)
521 | between_dfa.visualize(f"between_{start}_{end}_dfa")
522 | input_batch = _dfa_to_inputs(
523 | between_dfa,
524 | batch_size=batch_size,
525 | end_of_sequence_char=2,
526 | max_sequence_length=sequence_length,
527 | )
528 | target_batch = _make_prediction_sequence(input_batch, lookahead=1)
529 | if start == end:
530 | name = f"exactly_{start}"
531 | else:
532 | name = f"between_{start}_{end}"
533 | num_classes = 3
534 | return make_one_hot_corpus(
535 | name, input_batch, target_batch, num_classes, num_classes
536 | )
537 |
538 |
539 | def make_exactly_n_quantifier(
540 | n: int = 1, sequence_length: int = 100, batch_size: int = 1
541 | ) -> Corpus:
542 | return make_between_quantifier(
543 | start=n, end=n, sequence_length=sequence_length, batch_size=batch_size
544 | )
545 |
546 |
547 | def make_at_least_quantifier(
548 | n: int = 1, sequence_length: int = 100, batch_size: int = 1
549 | ) -> Corpus:
550 | name = f"at_least_{n}"
551 | at_least_dfa = make_at_least_dfa(n)
552 | at_least_dfa.visualize(name)
553 | input_batch = _dfa_to_inputs(
554 | at_least_dfa,
555 | batch_size=batch_size,
556 | end_of_sequence_char=2,
557 | max_sequence_length=sequence_length,
558 | )
559 | target_batch = _make_prediction_sequence(input_batch, lookahead=1)
560 | num_classes = 3
561 | return make_one_hot_corpus(
562 | name, input_batch, target_batch, num_classes, num_classes
563 | )
564 |
565 |
566 | def make_at_most_quantifier(
567 | n: int = 1, sequence_length: int = 100, batch_size: int = 1
568 | ) -> Corpus:
569 | name = f"at_most_{n}"
570 | at_most_dfa = _make_at_most_dfa(n)
571 | at_most_dfa.visualize(name)
572 | input_batch = _dfa_to_inputs(
573 | at_most_dfa,
574 | batch_size=batch_size,
575 | end_of_sequence_char=2,
576 | max_sequence_length=sequence_length,
577 | )
578 | target_batch = _make_prediction_sequence(input_batch, lookahead=1)
579 | num_classes = 3
580 | return make_one_hot_corpus(
581 | name, input_batch, target_batch, num_classes, num_classes
582 | )
583 |
584 |
585 | def make_every_quantifier(sequence_length: int = 100, batch_size: int = 1) -> Corpus:
586 | input_batch = np.ones((batch_size, sequence_length))
587 | return make_one_hot_corpus(f"every_quantifier", input_batch, input_batch, 2, 2)
588 |
589 |
590 | def _int_to_classes(n: int) -> List[int]:
591 | return list(reversed(list(map(int, list(str(n))))))
592 |
593 |
594 | def make_count_corpus(max_int: 100, batch_size: int = 100) -> Corpus:
595 | # Predict n+1 from n in a language-model setting.
596 | sequence_length = int(np.floor(np.log10(max_int))) + 1
597 | input_classes = np.zeros((batch_size, sequence_length))
598 | target_classes = np.zeros((batch_size, sequence_length))
599 |
600 | for b in range(batch_size):
601 | n = random.randrange(max_int)
602 | input_ = _int_to_classes(n)
603 | target = _int_to_classes(n + 1)
604 | input_classes[b, : len(input_)] = input_
605 | target_classes[b, : len(target)] = target
606 |
607 | return make_one_hot_corpus(
608 | f"count_to_{max_int}",
609 | input_classes,
610 | target_classes,
611 | num_input_classes=10,
612 | num_target_classes=10,
613 | )
614 |
615 |
616 | def base10_to_binary_vector(n: int, sequence_length=None) -> np.ndarray:
617 | """8 -> [0,0,0,1], 7 -> [1,1,1] """
618 | if n == 0:
619 | return np.zeros(sequence_length)
620 |
621 | powers = []
622 | while n:
623 | power = int(np.floor(np.log2(n)))
624 | powers.append(power)
625 | n -= 2 ** power
626 |
627 | rightmost_one_position = int(max(powers))
628 | if sequence_length is None:
629 | sequence_length = rightmost_one_position + 1
630 | binary = np.zeros(sequence_length)
631 | binary[powers] = 1.0
632 | # TODO: mask redundant positions?
633 | return binary
634 |
635 |
636 | def _make_binary_addition_corpus(min_n: int, max_n: int):
637 | all_summands = tuple(itertools.product(range(min_n, max_n), repeat=2))
638 | summands = []
639 | for b, (n1, n2) in enumerate(all_summands):
640 | summands.append([n1, n2])
641 | summands = np.array(summands)
642 | sums = np.sum(summands, axis=1)
643 | sequence_length = int(np.ceil(np.log2(np.max(sums)))) + 1
644 |
645 | summand_binaries = []
646 | sum_binaries = []
647 | for (n1, n2), sum_ in zip(summands, sums):
648 | summand_binaries.append(
649 | [
650 | base10_to_binary_vector(n1, sequence_length),
651 | base10_to_binary_vector(n2, sequence_length),
652 | ]
653 | )
654 | sum_binaries.append(base10_to_binary_vector(sum_, sequence_length))
655 |
656 | summand_inputs = np.array(
657 | [np.stack(summands, axis=1,) for summands in summand_binaries]
658 | )
659 |
660 | sum_outputs = np.expand_dims(np.stack(sum_binaries), axis=-1)
661 |
662 | return dataclasses.replace(
663 | Corpus(
664 | name=f"binary_addition_{min_n}_to_{max_n}",
665 | input_sequence=summand_inputs,
666 | target_sequence=sum_outputs,
667 | ),
668 | optimal_d_given_g=0.0,
669 | )
670 |
671 |
672 | def make_binary_addition(min_n: int, max_n: int) -> Corpus:
673 | training_corpus = _make_binary_addition_corpus(min_n=min_n, max_n=max_n)
674 | test_corpus = _make_binary_addition_corpus(min_n=max_n + 1, max_n=max_n + 251)
675 | training_corpus = dataclasses.replace(training_corpus, test_corpus=test_corpus)
676 | return training_corpus
677 |
678 |
679 | def an_bn_handmade_net(input_seq: np.ndarray, prior: float):
680 | # Optimal network according to Schimdhuber (2001).
681 | outputs = np.zeros_like(input_seq)
682 | for b in range(input_seq.shape[0]):
683 | num_seen_a = 0
684 | for t in range(input_seq.shape[1]):
685 | input_vec = input_seq[b, t]
686 | input_class = input_vec.argmax()
687 | if input_class == 0:
688 | # Start of sequence symbol, always predict "a" (no empty string in current corpus).
689 | outputs[b, t] = [0.0, 1.0, 0.0]
690 | elif input_class == 1:
691 | # "a".
692 | num_seen_a += 1
693 | outputs[b, t] = [0.0, 1 - prior, prior]
694 | elif input_class == 2:
695 | # "b".
696 | num_seen_a -= 1
697 | if num_seen_a > 0:
698 | outputs[b, t] = [0.0, 0.0, 1.0]
699 | else:
700 | outputs[b, t] = [1.0, 0.0, 0.0]
701 |
702 | return outputs
703 |
704 |
705 | def make_english_onset_phonotactics(split_ratio: Optional[float] = None):
706 | PHONOTACTIC_COUNTS = {
707 | "k": 2764,
708 | "r": 2752,
709 | "d": 2526,
710 | "s": 2215,
711 | "m": 1965,
712 | "p": 1881,
713 | "b": 1544,
714 | "l": 1225,
715 | "f": 1222,
716 | "h": 1153,
717 | "t": 1146,
718 | "pr": 1046,
719 | "w": 780,
720 | "n": 716,
721 | "v": 615,
722 | "g": 537,
723 | "dÇ": 524,
724 | "st": 521,
725 | "tr": 515,
726 | "kr": 387,
727 | "+": 379,
728 | "gr": 331,
729 | "t+": 329,
730 | "br": 319,
731 | "sp": 313,
732 | "fl": 290,
733 | "kl": 285,
734 | "sk": 278,
735 | "j": 268,
736 | "fr": 254,
737 | "pl": 238,
738 | "bl": 213,
739 | "sl": 213,
740 | "dr": 211,
741 | "kw": 201,
742 | "str": 183,
743 | "‡": 173,
744 | "sw": 153,
745 | "gl": 131,
746 | "hw": 111,
747 | "sn": 109,
748 | "skr": 93,
749 | "z": 83,
750 | "sm": 82,
751 | "‡r": 73,
752 | "skw": 69,
753 | "tw": 55,
754 | "spr": 51,
755 | "+r": 40,
756 | "spl": 27,
757 | "L": 19,
758 | "dw": 17,
759 | "gw": 11,
760 | "‡w": 4,
761 | "skl": 1,
762 | }
763 | name = "english_onset_phonotactics"
764 | inputs, targets, num_classes, weights = _make_phonotactic_corpus(
765 | phonotactic_counts=PHONOTACTIC_COUNTS
766 | )
767 | if not split_ratio:
768 | return make_one_hot_corpus(
769 | name, inputs, targets, num_classes, num_classes, weights
770 | )
771 | train_inputs, test_inputs = split_train_and_test(inputs, split_ratio)
772 | train_targets, test_targets = split_train_and_test(targets, split_ratio)
773 | return Corpus(
774 | name=f"{name}_{split_ratio}_split",
775 | input_sequence=_make_one_hot_sequence(train_inputs, num_classes),
776 | target_sequence=_make_one_hot_sequence(train_targets, num_classes),
777 | test_input_sequence=_make_one_hot_sequence(test_inputs, num_classes),
778 | test_target_sequence=_make_one_hot_sequence(test_targets, num_classes),
779 | sample_weights=weights,
780 | )
781 |
782 |
783 | def _load_futrell_phonotactic_data(filename: Text) -> Dict[Text, int]:
784 | word_to_count = {}
785 | with open(filename, "r") as f:
786 | reader = csv.DictReader(f)
787 | for row in reader:
788 | word_to_count[row["Phonology"]] = int(row["LemmaFrequency"])
789 | return word_to_count
790 |
791 |
792 | def make_futrell_german_phonotactic_corpus(split_ratio: float):
793 | inputs, targets, num_classes, weights = _make_phonotactic_corpus(
794 | _load_futrell_phonotactic_data("German.Dict.CSV")
795 | )
796 | train_inputs, test_inputs = split_train_and_test(inputs, split_ratio)
797 | train_targets, test_targets = split_train_and_test(targets, split_ratio)
798 | return Corpus(
799 | name=f"german_phonotactics_{split_ratio}_split",
800 | input_sequence=_make_one_hot_sequence(train_inputs, num_classes),
801 | target_sequence=_make_one_hot_sequence(train_targets, num_classes),
802 | test_input_sequence=_make_one_hot_sequence(test_inputs, num_classes),
803 | test_target_sequence=_make_one_hot_sequence(test_targets, num_classes),
804 | sample_weights=weights,
805 | )
806 |
807 |
808 | def _make_phonotactic_corpus(phonotactic_counts: Dict[Text, int]):
809 | segments = sorted(list(set("".join(phonotactic_counts))))
810 | segment_to_index = {segment: i for i, segment in enumerate(segments)}
811 | sequences = list(phonotactic_counts.keys())
812 | max_sequence_len = max([len(x) for x in sequences])
813 | segments = list(segment_to_index)
814 | weights = tuple(phonotactic_counts[seq] for seq in sequences)
815 | inputs = np.empty((len(sequences), max_sequence_len + 1))
816 | targets = np.empty_like(inputs)
817 | inputs.fill(MASK_VALUE)
818 | targets.fill(MASK_VALUE)
819 | for i in range(len(sequences)):
820 | # Start-of-sequence symbol
821 | inputs[i, 0] = len(segments)
822 | for j in range(len(sequences[i])):
823 | inputs[i, j + 1] = segment_to_index[sequences[i][j]]
824 | targets[i, j] = segment_to_index[sequences[i][j]]
825 | # End-of-sequence symbol
826 | targets[i, len(sequences[i])] = len(segments)
827 | return inputs, targets, len(segments) + 1, weights
828 |
829 |
830 | class _DerivationTooLong(Exception):
831 | pass
832 |
833 |
834 | def _generate_string_from_pcfg(
835 | pcfg: Dict, max_length: Optional[int] = None
836 | ) -> Tuple[Text, ...]:
837 | """Stops when all generated characters are terminals.
838 | To stop without adding an epsilon terminal, use the empty string '', i.e. add a rule `S->''`. """
839 | stack = ["S"]
840 | terminals = []
841 | while stack:
842 | node = stack[0]
843 | stack = stack[1:]
844 |
845 | if node not in pcfg:
846 | terminals.append(node)
847 | if max_length is not None and len(terminals) > max_length:
848 | raise _DerivationTooLong
849 | continue
850 |
851 | rules, probabs = list(zip(*pcfg[node]))
852 | rule_idx = np.random.choice(len(rules), p=probabs)
853 | rule = rules[rule_idx]
854 |
855 | stack = list(rule) + stack
856 |
857 | return tuple(terminals)
858 |
859 |
860 | def _make_corpus_from_pcfg(
861 | name: Text,
862 | pcfg: Dict,
863 | batch_size: int,
864 | max_derivation_length: Optional[int] = None,
865 | sort_by_length: bool = False,
866 | ) -> Corpus:
867 | sequences = []
868 | while len(sequences) < batch_size:
869 | try:
870 | sequence = _generate_string_from_pcfg(
871 | pcfg, max_length=max_derivation_length
872 | )
873 | except _DerivationTooLong:
874 | continue
875 | sequences.append(sequence)
876 |
877 | if sort_by_length:
878 | sequences = sorted(sequences, key=len, reverse=True)
879 |
880 | lengths = list(map(len, sequences))
881 |
882 | sequence_counts = collections.Counter(sequences)
883 | unique_sequences, weights = tuple(zip(*sequence_counts.items()))
884 |
885 | logging.info(f"PCFG sum of sequence lengths: {sum(lengths)}")
886 | logging.info(f"PCFG max sequence length: {max(lengths)}")
887 | logging.info(f"PCFG mean sequence length: {np.mean(lengths)}")
888 | logging.info(
889 | f"PCFG unique sequences: {len(unique_sequences)}/{len(sequences)} ({len(unique_sequences)/len(sequences):.2f})"
890 | )
891 |
892 | alphabet = set()
893 | for rules in pcfg.values():
894 | alphabet |= set(itertools.chain(*set(map(lambda x: x[0], rules))))
895 | alphabet -= set(pcfg.keys())
896 | alphabet = ("#",) + tuple(sorted(alphabet))
897 |
898 | symbol_to_idx = {x: i for i, x in enumerate(alphabet)}
899 |
900 | max_seq_length = max(map(len, unique_sequences))
901 | input_classes = np.empty((len(unique_sequences), max_seq_length + 1))
902 | target_classes = np.empty_like(input_classes)
903 | input_classes.fill(MASK_VALUE)
904 | target_classes.fill(MASK_VALUE)
905 |
906 | for i, sequence in enumerate(unique_sequences):
907 | sequence_classes = [symbol_to_idx[symbol] for symbol in sequence]
908 | input_row = [symbol_to_idx["#"]] + sequence_classes
909 | target_row = sequence_classes + [symbol_to_idx["#"]]
910 | input_classes[i, : len(sequence_classes) + 1] = input_row
911 | target_classes[i, : len(sequence_classes) + 1] = target_row
912 |
913 | inputs = _make_one_hot_sequence(input_classes, num_classes=len(alphabet))
914 | targets = _make_one_hot_sequence(target_classes, num_classes=len(alphabet))
915 |
916 | vocabulary = _make_identical_input_output_vocabulary(alphabet)
917 |
918 | return Corpus(
919 | name=name,
920 | input_sequence=inputs,
921 | target_sequence=targets,
922 | sample_weights=weights,
923 | vocabulary=vocabulary,
924 | )
925 |
926 |
927 | def make_center_embedding(
928 | batch_size: int, embedding_depth_probab: float, dependency_distance_probab: float
929 | ) -> Corpus:
930 | pcfg = {
931 | "S": (
932 | (("NP_s", "VP_s"), (1 - embedding_depth_probab) / 2),
933 | (("NP_p", "VP_p"), (1 - embedding_depth_probab) / 2),
934 | (("NP_s", "S", "VP_s"), embedding_depth_probab / 2),
935 | (("NP_p", "S", "VP_p"), embedding_depth_probab / 2),
936 | ),
937 | "NP_s": (
938 | (("N_s",), 1 - dependency_distance_probab),
939 | # (("A", "NP_s"), dependency_distance_probab),
940 | ),
941 | "NP_p": (
942 | (("N_p",), 1 - dependency_distance_probab),
943 | # (("A", "NP_p"), dependency_distance_probab),
944 | ),
945 | "VP_s": (
946 | (("V_s",), 1 - dependency_distance_probab),
947 | # (("A", "VP_s"), dependency_distance_probab),
948 | ),
949 | "VP_p": (
950 | (("V_p",), 1 - dependency_distance_probab),
951 | # (("A", "VP_p"), dependency_distance_probab),
952 | ),
953 | "N_s": (
954 | (("cat",), 1.0),
955 | # (("dog",), 1.0),
956 | # (("horse",), 0.2),
957 | # (("rat",), 0.2),
958 | # (("flower",), 0.2),
959 | ),
960 | "N_p": (
961 | (("cats",), 1.0),
962 | # (("dogs",), 1.0),
963 | # (("horses",), 0.2),
964 | # (("rats",), 0.2),
965 | # (("flowers",), 0.2),
966 | ),
967 | "V_s": (
968 | (("runs",), 1.0),
969 | # (("talks",), 1.0),
970 | # (("dances",), 0.2),
971 | # (("eats",), 0.2),
972 | # (("drinks",), 0.2),
973 | ),
974 | "V_p": (
975 | (("run",), 1.0),
976 | # (("talk",), 1.0),
977 | # (("dance",), 0.2),
978 | # (("eat",), 0.2),
979 | # (("drink",), 0.2),
980 | ),
981 | # "A": (
982 | # (("good",), 0.5),
983 | # (("bad",), 0.5),
984 | # (("nice",), 0.2),
985 | # (("smart",), 0.2),
986 | # (("funny",), 0.2),
987 | # ),
988 | }
989 | corpus = _make_corpus_from_pcfg(
990 | f"center_embedding_pcfg_embedding_{embedding_depth_probab}_distance_{dependency_distance_probab}",
991 | pcfg=pcfg,
992 | batch_size=batch_size,
993 | )
994 | input_classes = np.argmax(corpus.input_sequence, axis=-1)
995 | deterministic_steps_mask = (~np.all(is_masked(corpus.input_sequence), axis=-1)) & (
996 | # "cat/s"
997 | (input_classes == 3)
998 | | (input_classes == 4)
999 | )
1000 | return dataclasses.replace(
1001 | corpus, deterministic_steps_mask=deterministic_steps_mask
1002 | )
1003 |
1004 |
1005 | def make_palindrome_with_middle_marker_distinct(batch_size: int, nesting_probab: float):
1006 | pcfg = {
1007 | "S": (
1008 | (("0", "S", "0"), nesting_probab / 2),
1009 | (("1", "S", "1"), nesting_probab / 2),
1010 | (("@",), 1 - nesting_probab),
1011 | )
1012 | }
1013 | return _make_corpus_from_pcfg(
1014 | name=f"palindrome_middle_marker__batch_{batch_size}__p_{nesting_probab}",
1015 | pcfg=pcfg,
1016 | batch_size=batch_size,
1017 | )
1018 |
1019 |
1020 | def _optimal_d_g_for_fixed_palindrome(corpus) -> float:
1021 | sequence_length = corpus.input_sequence.shape[1]
1022 | batch_size = sum(corpus.sample_weights)
1023 | deterministic_length = sequence_length // 2
1024 | return batch_size * deterministic_length
1025 |
1026 |
1027 | def make_binary_palindrome_fixed_length(
1028 | batch_size: int, sequence_length: int, train_set_ratio: float
1029 | ) -> Corpus:
1030 | assert sequence_length % 2 == 0
1031 | prefixes_non_unique = np.random.randint(
1032 | 2, size=(batch_size, sequence_length // 2)
1033 | ).astype(utils.FLOAT_DTYPE)
1034 |
1035 | sequence_counts = collections.Counter(list(map(tuple, prefixes_non_unique)))
1036 | unique_prefixes, weights = list(zip(*sequence_counts.items()))
1037 |
1038 | prefixes = np.array(unique_prefixes)
1039 | suffixes = np.flip(prefixes, axis=1)
1040 | sequences = np.concatenate([prefixes, suffixes], axis=1)
1041 | targets = _make_prediction_sequence(input_sequence=sequences, lookahead=1)
1042 |
1043 | input_sequences = np.expand_dims(sequences, axis=2)
1044 | target_sequences = np.expand_dims(targets, axis=2)
1045 |
1046 | logging.info(
1047 | f"Fixed palindrome: {len(unique_prefixes)}/{len(prefixes_non_unique)} unique sequences"
1048 | )
1049 |
1050 | full_corpus = optimize_for_feeding(
1051 | Corpus(
1052 | name=f"palindrome_binary_fixed_length_batch_{batch_size}_length_{sequence_length}",
1053 | input_sequence=input_sequences,
1054 | target_sequence=target_sequences,
1055 | sample_weights=weights,
1056 | )
1057 | )
1058 |
1059 | train, test = split_train_test(full_corpus, train_ratio=train_set_ratio)
1060 | logging.info(
1061 | f"Train size: {train.input_sequence.shape[0]}, test size: {test.input_sequence.shape[0]}"
1062 | )
1063 | test = dataclasses.replace(
1064 | test, optimal_d_given_g=_optimal_d_g_for_fixed_palindrome(test)
1065 | )
1066 | return dataclasses.replace(
1067 | train,
1068 | test_corpus=test,
1069 | optimal_d_given_g=_optimal_d_g_for_fixed_palindrome(train),
1070 | )
1071 |
1072 |
1073 | def _make_an_bn_square_corpus(n_values: Tuple[int, ...], prior: float):
1074 | start_end_of_sequence_symbol = 0
1075 | max_n = max(n_values)
1076 | max_sequence_length = max_n + (max_n ** 2) + 1
1077 |
1078 | n_values_counts = collections.Counter(n_values)
1079 | unique_n_values, n_values_weights = tuple(zip(*n_values_counts.items()))
1080 |
1081 | inputs = np.empty((len(unique_n_values), max_sequence_length))
1082 | targets = np.empty_like(inputs)
1083 | inputs.fill(MASK_VALUE)
1084 | targets.fill(MASK_VALUE)
1085 |
1086 | for b, n in enumerate(unique_n_values):
1087 | input_seq = [start_end_of_sequence_symbol] + ([1] * n) + ([2] * n ** 2)
1088 | target_seq = input_seq[1:] + [start_end_of_sequence_symbol]
1089 | inputs[b, : len(input_seq)] = input_seq
1090 | targets[b, : len(input_seq)] = target_seq
1091 |
1092 | corpus = make_one_hot_corpus(
1093 | f"an_bn_square_batch_{len(n_values)}_p_{prior}",
1094 | inputs,
1095 | targets,
1096 | num_input_classes=3,
1097 | num_target_classes=3,
1098 | vocabulary=_make_identical_input_output_vocabulary(alphabet=("#", "a", "b")),
1099 | weights=n_values_weights,
1100 | )
1101 | return dataclasses.replace(
1102 | corpus,
1103 | optimal_d_given_g=_get_ain_bjn_ckn_dtn_optimal_d_given_g(prior, n_values),
1104 | )
1105 |
1106 |
1107 | def make_an_bn_square(batch_size: int, prior: float) -> Corpus:
1108 | training_n_values = tuple(np.random.geometric(p=prior, size=batch_size))
1109 | training_corpus = _make_an_bn_square_corpus(training_n_values, prior)
1110 |
1111 | max_training_n = max(training_n_values)
1112 | test_n_values = tuple(range(max_training_n + 1, max_training_n + 11))
1113 | test_corpus = _make_an_bn_square_corpus(test_n_values, prior)
1114 | return dataclasses.replace(training_corpus, test_corpus=test_corpus)
1115 |
1116 |
1117 | def _make_identical_input_output_vocabulary(alphabet: Tuple[Text, ...]) -> _Vocabulary:
1118 | # Create class index to symbol mapping, assuming inputs and outputs are identical and ordered identically.
1119 | class_to_symbol = {idx: alphabet[idx] for idx in range(len(alphabet))}
1120 | class_to_symbol.update(
1121 | {idx + len(alphabet): symbol for idx, symbol in class_to_symbol.items()}
1122 | )
1123 | return class_to_symbol
1124 |
1125 |
1126 | def _get_ain_bjn_ckn_dtn_optimal_d_given_g(prior, n_values) -> float:
1127 | return -np.sum(
1128 | [(n - 1) * (np.log2(1 - prior)) + np.log2(prior) for n in n_values]
1129 | ).item()
1130 |
1131 |
1132 | def get_num_chars_in_corpus(corpus: Corpus) -> int:
1133 | non_masked = ~np.all(is_masked(corpus.input_sequence), axis=-1)
1134 | num_chars_per_row = np.sum(non_masked, axis=1)
1135 | if corpus.sample_weights:
1136 | total_chars = np.dot(num_chars_per_row, corpus.sample_weights)
1137 | else:
1138 | total_chars = np.sum(num_chars_per_row)
1139 | return total_chars.item()
1140 |
1141 |
1142 | def make_inputs_counter(num_inputs: int, num_ones: int, batch_size: int):
1143 | # From Schmidhuber (1997) -- network's goal is to output the number of ones in the input.
1144 | # The optimal solution is to set all weights to 1.
1145 | inputs = np.zeros((batch_size, 1, num_inputs), dtype=utils.FLOAT_DTYPE)
1146 | for b in range(batch_size):
1147 | idxs = random.choices(range(num_inputs), k=num_ones)
1148 | inputs[b, idxs] = 1.0
1149 |
1150 | targets = np.ones((batch_size, 1, 1)) * num_ones
1151 |
1152 | return Corpus(
1153 | name=f"inputs_counter", input_sequence=inputs, target_sequence=targets,
1154 | )
1155 |
1156 | pass
1157 |
1158 |
1159 | def _make_ain_bjn_ckn_dtn_corpus(
1160 | n_values: Tuple[int, ...],
1161 | multipliers: Tuple[int, ...],
1162 | prior: float,
1163 | sort_by_length: bool,
1164 | ) -> Corpus:
1165 | # Create a corpus of a^in, b^jn, c^kn, d^tn, multipliers = [i,j,k,n].
1166 | max_n = max(n_values)
1167 | max_sequence_length = (max_n * sum(multipliers)) + 1
1168 |
1169 | start_end_of_sequence_symbol = 0 # Using same symbol for start/end of sequence, as in Schmidhuber et al. (2001).
1170 |
1171 | n_values_counts = collections.Counter(n_values)
1172 | n_value_counts_items = tuple(n_values_counts.items())
1173 | if sort_by_length:
1174 | n_value_counts_items = sorted(n_value_counts_items, reverse=True)
1175 |
1176 | unique_n_values, n_values_weights = tuple(zip(*n_value_counts_items))
1177 |
1178 | inputs = np.empty((len(unique_n_values), max_sequence_length))
1179 | targets = np.empty_like(inputs)
1180 | inputs.fill(MASK_VALUE)
1181 | targets.fill(MASK_VALUE)
1182 |
1183 | for b, n in enumerate(unique_n_values):
1184 | input_seq = (
1185 | [start_end_of_sequence_symbol]
1186 | + ([1] * n * multipliers[0])
1187 | + ([2] * n * multipliers[1])
1188 | + ([3] * n * multipliers[2])
1189 | + ([4] * n * multipliers[3])
1190 | )
1191 | target_seq = input_seq[1:] + [start_end_of_sequence_symbol]
1192 | inputs[b, : len(input_seq)] = input_seq
1193 | targets[b, : len(input_seq)] = target_seq
1194 |
1195 | name = f"a{multipliers[0]}n_b{multipliers[1]}n_c{multipliers[2]}n_d{multipliers[3]}n__p_{prior}__batch_{len(n_values)}"
1196 | num_input_classes = sum([1 for x in multipliers if x != 0]) + 1
1197 |
1198 | alphabet = ("#", "a", "b", "c", "d")[:num_input_classes]
1199 | vocabulary = _make_identical_input_output_vocabulary(alphabet)
1200 |
1201 | corpus = make_one_hot_corpus(
1202 | name,
1203 | inputs,
1204 | targets,
1205 | num_input_classes=num_input_classes,
1206 | num_target_classes=num_input_classes,
1207 | weights=n_values_weights,
1208 | vocabulary=vocabulary,
1209 | )
1210 | return dataclasses.replace(
1211 | corpus,
1212 | optimal_d_given_g=_get_ain_bjn_ckn_dtn_optimal_d_given_g(prior, n_values),
1213 | # TODO: this assumes no empty sequences in corpus.
1214 | deterministic_steps_mask=(~is_masked(inputs)) & (inputs != 1),
1215 | )
1216 |
1217 |
1218 | def make_ain_bjn_ckn_dtn(
1219 | batch_size: int,
1220 | prior: float,
1221 | multipliers: Tuple[int, ...],
1222 | sort_by_length: bool = False,
1223 | ) -> Corpus:
1224 | training_n_values = tuple(np.random.geometric(p=prior, size=batch_size))
1225 | training_corpus = _make_ain_bjn_ckn_dtn_corpus(
1226 | training_n_values, multipliers, prior, sort_by_length
1227 | )
1228 |
1229 | max_training_n = max(training_n_values)
1230 | test_n_values = tuple(range(max_training_n + 1, max_training_n + 1001))
1231 | test_corpus = _make_ain_bjn_ckn_dtn_corpus(
1232 | test_n_values, multipliers, prior, sort_by_length
1233 | )
1234 | test_corpus = dataclasses.replace(test_corpus, name=f"{test_corpus.name}_test")
1235 |
1236 | logging.info(f"Created corpus {training_corpus.name}")
1237 | logging.info(f"Max n in training set: {max_training_n}")
1238 | logging.info(f"Optimal training set D:G: {training_corpus.optimal_d_given_g:,.2f}")
1239 | logging.info(f"Optimal test set D:G: {test_corpus.optimal_d_given_g:,.2f}")
1240 |
1241 | return dataclasses.replace(training_corpus, test_corpus=test_corpus)
1242 |
1243 |
1244 | def _get_an_bm_cn_plus_m_corpus_optimal_d_given_g(
1245 | n_plus_m_values: Tuple[int, ...], prior: float
1246 | ) -> float:
1247 | return -np.sum(
1248 | [
1249 | ((n_plus_m - 2) * np.log2(1 - prior)) + (2 * np.log2(prior))
1250 | for n_plus_m in n_plus_m_values
1251 | ]
1252 | ).item()
1253 |
1254 |
1255 | def _make_an_bm_cn_plus_m_corpus(
1256 | n_values: Tuple[int, ...],
1257 | m_values: Tuple[int, ...],
1258 | prior: float,
1259 | sort_by_length: bool,
1260 | ) -> Corpus:
1261 | sum_values = tuple(np.add(n_values, m_values))
1262 | start_end_of_sequence_symbol = 0
1263 | max_sequence_length = 2 * max(sum_values) + 1
1264 |
1265 | n_m_values_counts = collections.Counter(zip(n_values, m_values))
1266 | n_m_values_counts_items = tuple(n_m_values_counts.items())
1267 | if sort_by_length:
1268 | n_m_values_counts_items = sorted(
1269 | n_m_values_counts_items, key=lambda x: sum(x[0]), reverse=True
1270 | )
1271 |
1272 | unique_n_m_values, n_m_values_weights = tuple(zip(*n_m_values_counts_items))
1273 |
1274 | inputs = np.empty((len(unique_n_m_values), max_sequence_length))
1275 | targets = np.empty_like(inputs)
1276 | inputs.fill(MASK_VALUE)
1277 | targets.fill(MASK_VALUE)
1278 |
1279 | for b in range(len(unique_n_m_values)):
1280 | n, m = unique_n_m_values[b]
1281 | input_seq = (
1282 | [start_end_of_sequence_symbol] + ([1] * n) + ([2] * m) + ([3] * (n + m))
1283 | )
1284 | target_seq = input_seq[1:] + [start_end_of_sequence_symbol]
1285 | inputs[b, : len(input_seq)] = input_seq
1286 | targets[b, : len(input_seq)] = target_seq
1287 |
1288 | vocabulary = _make_identical_input_output_vocabulary(alphabet=("#", "a", "b", "c"))
1289 |
1290 | corpus = make_one_hot_corpus(
1291 | f"an_bm_cn_plus_m__batch_{len(n_values)}_p_{prior}",
1292 | inputs,
1293 | targets,
1294 | num_input_classes=4,
1295 | num_target_classes=4,
1296 | weights=n_m_values_weights,
1297 | vocabulary=vocabulary,
1298 | )
1299 |
1300 | return dataclasses.replace(
1301 | corpus,
1302 | optimal_d_given_g=_get_an_bm_cn_plus_m_corpus_optimal_d_given_g(
1303 | sum_values, prior
1304 | ),
1305 | # TODO: this assumes no empty sequences in corpus.
1306 | deterministic_steps_mask=(~is_masked(inputs)) & (inputs != 1) & (inputs != 2),
1307 | )
1308 |
1309 |
1310 | def make_an_bm_cn_plus_m(
1311 | batch_size: int, prior: float, sort_by_length: bool = False,
1312 | ) -> Corpus:
1313 | training_n_values = tuple(np.random.geometric(p=prior, size=batch_size))
1314 | training_m_values = tuple(np.random.geometric(p=prior, size=batch_size))
1315 |
1316 | training_corpus = _make_an_bm_cn_plus_m_corpus(
1317 | training_n_values, training_m_values, prior, sort_by_length
1318 | )
1319 | max_n = max(training_n_values)
1320 | max_m = max(training_m_values)
1321 | max_training_n_or_m = max(max_n, max_m)
1322 |
1323 | test_n_values, test_m_values = zip(
1324 | *itertools.product(
1325 | range(max_training_n_or_m + 1, max_training_n_or_m + 50), repeat=2
1326 | )
1327 | )
1328 |
1329 | test_corpus = _make_an_bm_cn_plus_m_corpus(
1330 | test_n_values, test_m_values, prior, sort_by_length
1331 | )
1332 | test_corpus = dataclasses.replace(test_corpus, name=f"{test_corpus.name}_test")
1333 |
1334 | logging.info(f"Created corpus {training_corpus.name}")
1335 | logging.info(f"Max n in training: {max_n}")
1336 | logging.info(f"Max m in training: {max_m}")
1337 | logging.info(f"Optimal training set D:G: {training_corpus.optimal_d_given_g:,.2f}")
1338 | logging.info(f"Optimal test set D:G: {test_corpus.optimal_d_given_g:,.2f}")
1339 | logging.info(f"Training set dimensions: {training_corpus.input_sequence.shape}")
1340 | logging.info(f"Test set dimensions: {test_corpus.input_sequence.shape}")
1341 |
1342 | return dataclasses.replace(training_corpus, test_corpus=test_corpus)
1343 |
1344 |
1345 | def _calculate_nesting_depths(corpus):
1346 | input_classes = corpus.input_sequence.argmax(axis=-1)
1347 | opening_classes = {1, 3}
1348 | closing_classes = {2, 4}
1349 | depths = []
1350 | for b in range(input_classes.shape[0]):
1351 | depth = 0
1352 | max_depth = 0
1353 | for i in range(input_classes[b].shape[0]):
1354 | if np.all(is_masked(corpus.input_sequence[b, i])):
1355 | break
1356 | if input_classes[b, i] in opening_classes:
1357 | depth += 1
1358 | max_depth = max(max_depth, depth)
1359 | elif input_classes[b, i] in closing_classes:
1360 | depths.append(depth)
1361 | depth -= 1
1362 |
1363 | depth_counts = dict(collections.Counter(depths).most_common())
1364 | max_depth = max(depths)
1365 | logging.info(f"Max depth in corpus: {max_depth}")
1366 | logging.info(f"Depth counts: {depth_counts}")
1367 |
1368 |
1369 | def _get_dyck_1_symbol_counts(corpus) -> Tuple[int, int, int]:
1370 | # Masks (nans) become 0 here after argmax, but we ignore it since we count only 1's and 2's.
1371 | input_classes = corpus.input_sequence.argmax(axis=-1)
1372 |
1373 | num_ends_of_sequence = np.sum(corpus.sample_weights).item()
1374 | num_opening_brackets = np.dot(
1375 | corpus.sample_weights, np.sum(input_classes == 1, axis=1)
1376 | ).item()
1377 | num_closing_brackets = np.dot(
1378 | corpus.sample_weights, np.sum(input_classes == 2, axis=1)
1379 | ).item()
1380 |
1381 | return num_ends_of_sequence, num_opening_brackets, num_closing_brackets
1382 |
1383 |
1384 | def get_dyck_1_target_probabs(corpus, prior) -> np.ndarray:
1385 | input_classes = np.argmax(corpus.input_sequence, axis=-1).astype(np.float64)
1386 | input_classes[~corpus.input_mask] = np.nan
1387 |
1388 | target_probabs = np.zeros_like(corpus.target_sequence)
1389 | target_probabs[~corpus.input_mask] = np.nan
1390 |
1391 | for b in range(input_classes.shape[0]):
1392 | open_brackets = 0
1393 | for i in range(input_classes.shape[1]):
1394 | if np.isnan(input_classes[b, i]):
1395 | break
1396 | if input_classes[b, i] == 0:
1397 | target_probabs[b, i, 0] = 1 - prior
1398 | target_probabs[b, i, 1] = prior
1399 | elif input_classes[b, i] == 1:
1400 | open_brackets += 1
1401 | target_probabs[b, i, 2] = 1 - prior
1402 | target_probabs[b, i, 1] = prior
1403 | elif input_classes[b, i] == 2:
1404 | open_brackets -= 1
1405 | if open_brackets == 0:
1406 | target_probabs[b, i, 0] = 1 - prior
1407 | target_probabs[b, i, 1] = prior
1408 | else:
1409 | target_probabs[b, i, 1] = prior
1410 | target_probabs[b, i, 2] = 1 - prior
1411 |
1412 | return target_probabs
1413 |
1414 |
1415 | def _make_dyck_n_corpus(
1416 | batch_size: int,
1417 | nesting_probab: float,
1418 | n: int,
1419 | max_sequence_length: Optional[int] = None,
1420 | sort_by_length: bool = False,
1421 | ):
1422 | bracket_pairs = (
1423 | ("[", "]"),
1424 | ("(", ")"),
1425 | ("{", "}"),
1426 | ("<", ">"),
1427 | ("⟦", "⟧"),
1428 | ("〔", " 〕"),
1429 | )
1430 | single_nesting_probab = nesting_probab / n
1431 |
1432 | bracket_derivations = []
1433 | for i in range(n):
1434 | bracket_derivations.append(
1435 | # e.g. `S -> ("[", S, "]", S)`.
1436 | (
1437 | (bracket_pairs[i][0], "S", bracket_pairs[i][1], "S",),
1438 | single_nesting_probab,
1439 | )
1440 | )
1441 |
1442 | pcfg = {"S": tuple(bracket_derivations) + (("", 1 - nesting_probab),)}
1443 | corpus = _make_corpus_from_pcfg(
1444 | name=f"dyck_{n}__batch_{batch_size}__p_{nesting_probab}",
1445 | pcfg=pcfg,
1446 | batch_size=batch_size,
1447 | max_derivation_length=max_sequence_length,
1448 | sort_by_length=sort_by_length,
1449 | )
1450 |
1451 | if n == 1:
1452 | (
1453 | num_ends_of_sequence,
1454 | num_opening_brackets,
1455 | num_closing_brackets,
1456 | ) = _get_dyck_1_symbol_counts(corpus)
1457 |
1458 | optimal_d_given_g = (
1459 | -1
1460 | * (
1461 | (num_ends_of_sequence * np.log2(1 - nesting_probab))
1462 | + (num_opening_brackets * np.log2(nesting_probab))
1463 | + (num_closing_brackets * np.log2(1 - nesting_probab))
1464 | ).item()
1465 | )
1466 | corpus = dataclasses.replace(corpus, optimal_d_given_g=optimal_d_given_g)
1467 |
1468 | if n == 2:
1469 | import manual_nets
1470 | import network
1471 |
1472 | stack_net = manual_nets.make_emmanuel_dyck_2_network(
1473 | nesting_probab=nesting_probab
1474 | )
1475 | stack_net = network.calculate_fitness(
1476 | stack_net, optimize_for_feeding(corpus), config=_DEFAULT_CONFIG,
1477 | )
1478 | corpus = dataclasses.replace(
1479 | corpus, optimal_d_given_g=stack_net.fitness.data_encoding_length
1480 | )
1481 | logging.info(f"Optimal |D:G|: {corpus.optimal_d_given_g:,.2f}")
1482 |
1483 | _calculate_nesting_depths(corpus)
1484 | return corpus
1485 |
1486 |
1487 | def _get_sequence_strings(corpus) -> FrozenSet[Text]:
1488 | unique_sequences = set()
1489 | for b in range(corpus.input_sequence.shape[0]):
1490 | seq = corpus.input_sequence[b]
1491 | seq_str = str(np.argmax(seq, axis=-1).tolist())
1492 | unique_sequences.add(seq_str)
1493 | return frozenset(unique_sequences)
1494 |
1495 |
1496 | def make_dyck_n(
1497 | batch_size: int,
1498 | nesting_probab: float,
1499 | n: int,
1500 | max_sequence_length: Optional[int] = None,
1501 | sort_by_length: bool = False,
1502 | ) -> Corpus:
1503 | training_corpus = _make_dyck_n_corpus(
1504 | batch_size=batch_size,
1505 | nesting_probab=nesting_probab,
1506 | n=n,
1507 | max_sequence_length=max_sequence_length,
1508 | sort_by_length=sort_by_length,
1509 | )
1510 | test_corpus = _make_dyck_n_corpus(
1511 | batch_size=50_000,
1512 | nesting_probab=nesting_probab,
1513 | n=n,
1514 | max_sequence_length=max_sequence_length,
1515 | sort_by_length=sort_by_length,
1516 | )
1517 |
1518 | training_sequences = _get_sequence_strings(training_corpus)
1519 | test_sequences = _get_sequence_strings(test_corpus)
1520 | shared = training_sequences & test_sequences
1521 |
1522 | logging.info(
1523 | f"Dyck-{n} Sequences shared between train and test: {len(shared)} ({len(shared)/len(test_sequences):.2f} of test)"
1524 | )
1525 |
1526 | return dataclasses.replace(training_corpus, test_corpus=test_corpus)
1527 |
1528 |
1529 | def split_train_test(corpus: Corpus, train_ratio: float) -> Tuple[Corpus, Corpus]:
1530 | batch_size = corpus.input_sequence.shape[0]
1531 | train_size = math.floor(train_ratio * batch_size)
1532 | shuffled_idxs = np.random.permutation(batch_size)
1533 | train_idxs = sorted(shuffled_idxs[:train_size])
1534 | test_idxs = sorted(shuffled_idxs[train_size:])
1535 | return _split_corpus(corpus, batch_idxs_per_corpus=(train_idxs, test_idxs))
1536 |
1537 |
1538 | def _split_corpus(
1539 | corpus: Corpus, batch_idxs_per_corpus: Tuple[np.ndarray, ...]
1540 | ) -> Tuple[Corpus, ...]:
1541 | new_corpora = []
1542 | for batch_idxs in batch_idxs_per_corpus:
1543 | max_sample_length = max(np.where(corpus.input_mask[batch_idxs])[-1]) + 1
1544 |
1545 | inputs = corpus.input_sequence[batch_idxs, :, :max_sample_length]
1546 | targets = corpus.target_sequence[batch_idxs, :, :max_sample_length]
1547 | target_mask = corpus.targets_mask[batch_idxs, :, :max_sample_length]
1548 | input_mask = corpus.input_mask[batch_idxs, :max_sample_length]
1549 |
1550 | # TODO: recalculating indices for this is hard, need to change the data format.
1551 | inputs_per_time_step = None
1552 |
1553 | if corpus.sample_weights:
1554 | sample_weights = tuple(corpus.sample_weights[i] for i in batch_idxs)
1555 | else:
1556 | sample_weights = None
1557 |
1558 | new_corpora.append(
1559 | Corpus(
1560 | name=corpus.name,
1561 | input_sequence=inputs,
1562 | target_sequence=targets,
1563 | input_values_per_time_step=inputs_per_time_step,
1564 | input_mask=input_mask,
1565 | targets_mask=target_mask,
1566 | sample_weights=sample_weights,
1567 | )
1568 | )
1569 |
1570 | return tuple(new_corpora)
1571 |
1572 |
1573 | def split_to_mini_batches(
1574 | corpus: Corpus, mini_batch_size: Optional[int]
1575 | ) -> Tuple[Corpus, ...]:
1576 | if mini_batch_size is None:
1577 | return (corpus,)
1578 |
1579 | num_samples = corpus.input_sequence.shape[0]
1580 | mini_batch_idxs = tuple(
1581 | np.split(np.arange(num_samples), np.arange(0, num_samples, mini_batch_size)[1:])
1582 | )
1583 |
1584 | return _split_corpus(corpus, mini_batch_idxs)
1585 |
--------------------------------------------------------------------------------
/dfa.py:
--------------------------------------------------------------------------------
1 | import itertools
2 | import pathlib
3 | import random
4 | from typing import Dict, Set, Text
5 |
6 | import numpy as np
7 |
8 | _State = int
9 | _Label = Text
10 | _StateTransitions = Dict[_Label, _State]
11 |
12 | END_OF_SEQUENCE = "#"
13 |
14 |
15 | class DFA:
16 | def __init__(
17 | self,
18 | transitions: Dict[_State, _StateTransitions],
19 | accepting_states=Set[_State],
20 | ):
21 | self._states: Set[_State] = set(transitions.keys()) | set(
22 | itertools.chain(*[tuple(x.values()) for x in transitions.values()])
23 | )
24 | self._transitions: Dict[_State, _StateTransitions] = transitions
25 | self._accepting_states: Set[_State] = accepting_states
26 |
27 | def generate_string(self):
28 | curr_state = 0
29 | string = ""
30 | while curr_state not in self._accepting_states:
31 | char, curr_state = random.choice(
32 | tuple(self._transitions[curr_state].items())
33 | )
34 | string += char
35 | return string
36 |
37 | def visualize(self, name: Text):
38 | dot = "digraph G {\n" "colorscheme=X11\n"
39 | # inputs and outputs
40 | for state in sorted(self._states):
41 | if state in self._accepting_states:
42 | style = "peripheries=2"
43 | else:
44 | style = ""
45 | description = f'[label="q{state}" {style}]'
46 |
47 | dot += f"{state} {description}\n"
48 |
49 | for state, transitions in self._transitions.items():
50 | for label, neighbor in transitions.items():
51 | dot += f'{state} -> {neighbor} [ label="{label}" ];\n'
52 |
53 | dot += "}"
54 |
55 | path = pathlib.Path(f"dot_files/dfa_{name}.dot")
56 | path.parent.mkdir(parents=True, exist_ok=True)
57 | with path.open("w") as f:
58 | f.write(dot)
59 |
60 | def get_optimal_data_given_grammar_for_dfa(
61 | self, input_sequence: np.ndarray
62 | ) -> float:
63 | total_d_g = 0.0
64 | curr_state = 0
65 |
66 | for b in range(input_sequence.shape[0]):
67 | for i in range(input_sequence.shape[1]):
68 | curr_vec = input_sequence[b, i]
69 | if np.all(np.isnan(curr_vec)):
70 | # Sequence is masked until its end.
71 | break
72 | if curr_vec.shape[0] == 1:
73 | curr_val = curr_vec[0]
74 | else:
75 | curr_val = curr_vec.argmax()
76 |
77 | curr_transitions = self._transitions[curr_state]
78 | total_d_g += -np.log2(1 / len(curr_transitions))
79 |
80 | curr_char = {0: "0", 1: "1", 2: END_OF_SEQUENCE}[curr_val]
81 | curr_state = curr_transitions[curr_char]
82 |
83 | if curr_state in self._accepting_states:
84 | curr_state = 0
85 |
86 | return total_d_g
87 |
--------------------------------------------------------------------------------
/genetic_algorithm.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import datetime
3 | import itertools
4 | import logging
5 | import math
6 | import multiprocessing
7 | import operator
8 | import os
9 | import pathlib
10 | import pickle
11 | import queue
12 | import random
13 | import shutil
14 | import uuid
15 | from typing import FrozenSet, Iterator, List, Optional, Text, Tuple
16 |
17 | import cachetools
18 | import numpy as np
19 | from mpi4py import MPI
20 |
21 | import configuration
22 | import corpora
23 | import network
24 | import utils
25 |
26 | _DEBUG_MODE = False
27 |
28 | _FITNESS_CACHE = {}
29 | _ISLAND_MIGRATIONS_PATH = pathlib.Path("/tmp/mdlnn/island_migrations/")
30 |
31 | _GET_NET_MDL = operator.attrgetter("fitness.mdl")
32 |
33 | _Population = List[network.Network]
34 |
35 |
36 | _NETWORKS_CACHE = cachetools.LRUCache(maxsize=1_000_000)
37 |
38 | _MPI_COMMUNICATOR = MPI.COMM_WORLD
39 | _MPI_RANK = _MPI_COMMUNICATOR.Get_rank()
40 | _MPI_MIGRANTS_BUFFER_SIZE = 10_000_000
41 |
42 |
43 | @dataclasses.dataclass(frozen=True)
44 | class _Tournament:
45 | winner_idx: int
46 | loser_idx: int
47 |
48 |
49 | @cachetools.cached(_NETWORKS_CACHE, key=lambda net, corpus, config: hash(net))
50 | def _evaluate_net_cached(
51 | net: network.Network,
52 | corpus: corpora.Corpus,
53 | config: configuration.SimulationConfig,
54 | ) -> network.Fitness:
55 | return network.calculate_fitness(net, corpus, config).fitness
56 |
57 |
58 | def get_migration_path(simulation_id) -> pathlib.Path:
59 | return _ISLAND_MIGRATIONS_PATH.joinpath(simulation_id)
60 |
61 |
62 | def _make_migration_target_island_generator(
63 | island_num: int, total_islands: int,
64 | ):
65 | yield from itertools.cycle(
66 | itertools.chain(range(island_num + 1, total_islands), range(island_num))
67 | )
68 |
69 |
70 | def get_migrants_through_mpi() -> Optional[_Population]:
71 | migrant_batches = []
72 |
73 | while True:
74 | has_awaiting_migrants = _MPI_COMMUNICATOR.iprobe(tag=utils.MPI_MIGRANTS_TAG)
75 | if not has_awaiting_migrants:
76 | break
77 | migrants = _MPI_COMMUNICATOR.recv(bytearray(_MPI_MIGRANTS_BUFFER_SIZE))
78 | migrant_batches.append(migrants)
79 |
80 | if not migrant_batches:
81 | return None
82 |
83 | return min(migrant_batches, key=_mean_population_fitness)
84 |
85 |
86 | def _get_migrants_from_file(
87 | simulation_id: Text, island_num: int, file_lock: multiprocessing.Lock
88 | ) -> Optional[_Population]:
89 | migrants_filename = get_target_island_filename(island_num)
90 | incoming_migrants_path = get_migration_path(simulation_id).joinpath(
91 | migrants_filename
92 | )
93 |
94 | lock_start = datetime.datetime.now()
95 |
96 | with file_lock:
97 | if not incoming_migrants_path.exists():
98 | return None
99 |
100 | with incoming_migrants_path.open("rb") as f:
101 | incoming_migrants = pickle.load(f)
102 | incoming_migrants_path.unlink()
103 |
104 | lock_end = datetime.datetime.now()
105 | lock_delta = lock_end - lock_start
106 | logging.info(
107 | f"Incoming lock took {lock_delta.seconds}.{str(lock_delta.microseconds)[:2]} seconds"
108 | )
109 | return incoming_migrants
110 |
111 |
112 | def get_target_island_filename(target_island: int) -> Text:
113 | return f"island_{target_island}_incoming_migrants"
114 |
115 |
116 | def _make_random_population(config, input_size, output_size) -> _Population:
117 | return [
118 | network.make_random_net(
119 | input_size=input_size,
120 | output_size=output_size,
121 | allowed_activations=config.allowed_activations,
122 | start_smooth=config.start_smooth,
123 | )
124 | for _ in range(config.population_size)
125 | ]
126 |
127 |
128 | def _mean_population_fitness(population: _Population) -> float:
129 | return np.mean([x.fitness.mdl for x in population]).item()
130 |
131 |
132 | def _should_migrate(
133 | outgoing_migrants: _Population, awaiting_migrants_at_target: _Population
134 | ) -> bool:
135 | mean_awaiting_fitness = _mean_population_fitness(awaiting_migrants_at_target)
136 | mean_outgoing_fitness = _mean_population_fitness(outgoing_migrants)
137 | logging.info(f"Awaiting mean fitness: {mean_awaiting_fitness:.2f}")
138 | logging.info(f"Outgoing mean fitness: {mean_outgoing_fitness:.2f}")
139 | return mean_outgoing_fitness < mean_awaiting_fitness
140 |
141 |
142 | def _send_migrants_through_mpi(migrants: _Population, target_island: int) -> bool:
143 | # We can't use _should_migrate() here because in MPI we can't override the target's buffer, so filtering the best migrants is done at the receiving side.
144 | _MPI_COMMUNICATOR.send(migrants, dest=target_island, tag=utils.MPI_MIGRANTS_TAG)
145 | return True
146 |
147 |
148 | def _send_migrants_through_file(
149 | migrants: _Population,
150 | target_island: int,
151 | file_lock: multiprocessing.Lock,
152 | simulation_id: Text,
153 | ) -> bool:
154 | lock_start = datetime.datetime.now()
155 | target_island_filename = get_target_island_filename(target_island)
156 | with file_lock:
157 | outgoing_path = get_migration_path(simulation_id).joinpath(
158 | target_island_filename
159 | )
160 |
161 | if outgoing_path.exists():
162 | with outgoing_path.open("rb") as f:
163 | awaiting_migrants_at_target: _Population = pickle.load(f)
164 | should_migrate = _should_migrate(
165 | outgoing_migrants=migrants,
166 | awaiting_migrants_at_target=awaiting_migrants_at_target,
167 | )
168 | else:
169 | should_migrate = True
170 |
171 | if should_migrate:
172 | with outgoing_path.open("wb") as f:
173 | pickle.dump(migrants, f)
174 |
175 | lock_end = datetime.datetime.now()
176 | lock_delta = lock_end - lock_start
177 | logging.info(
178 | f"Outgoing file lock took {lock_delta.seconds}.{str(lock_delta.microseconds)[:2]} seconds"
179 | )
180 |
181 | return should_migrate
182 |
183 |
184 | def _simulation_exists(simulation_id: Text) -> bool:
185 | return get_migration_path(simulation_id).exists()
186 |
187 |
188 | def remove_simulation_directory(simulation_id: Text):
189 | shutil.rmtree(get_migration_path(simulation_id), ignore_errors=True)
190 | logging.info(f"Removed directory {simulation_id} from local storage.")
191 |
192 |
193 | def verify_existing_simulation_override(simulation_id: Text):
194 | if _simulation_exists(simulation_id):
195 | logging.info(
196 | f"Ddirectory for simulation {simulation_id} already exists. Re-run using `--override` flag to delete the previous run.\n"
197 | )
198 | exit()
199 |
200 |
201 | def _select_best_no_repetition_by_arch_uniqueness(
202 | population: _Population, k: int
203 | ) -> _Population:
204 | return sorted(set(population), key=_GET_NET_MDL)[:k]
205 |
206 |
207 | def _select_best_no_repetition(population: _Population, k: int) -> _Population:
208 | population_fitness = set()
209 | individuals_no_repetition = []
210 | for net in population:
211 | fitness = net.fitness.mdl
212 | if fitness not in population_fitness:
213 | individuals_no_repetition.append(net)
214 | population_fitness.add(fitness)
215 | return sorted(individuals_no_repetition, key=_GET_NET_MDL)[:k]
216 |
217 |
218 | def _get_worst_individuals_idxs(population: _Population, n: int) -> List[int]:
219 | fitness = [x.fitness.mdl for x in population]
220 | return np.argsort(fitness)[-n:].tolist()
221 |
222 |
223 | def _get_elite_idxs(population: _Population, elite_ratio: float) -> FrozenSet[int]:
224 | elite_size = math.ceil(len(population) * elite_ratio)
225 | fitness = [x.fitness.mdl for x in population]
226 | argsort = np.argsort(fitness)
227 | seen = set()
228 | best_idxs = set()
229 | for i in argsort:
230 | if len(best_idxs) == elite_size:
231 | break
232 | net = population[i]
233 | if net.fitness.mdl in seen or np.isinf(net.fitness.mdl):
234 | continue
235 | seen.add(net.fitness.mdl)
236 | best_idxs.add(i)
237 | return frozenset(best_idxs)
238 |
239 |
240 | def _get_elite(elite_ratio: float, population: _Population) -> _Population:
241 | elite_size = math.ceil(len(population) * elite_ratio)
242 | return _select_best_no_repetition(population, elite_size)
243 |
244 |
245 | def _tournament_selection(population: _Population, tournament_size: int) -> _Tournament:
246 | # Returns (winner index, loser index).
247 | tournament_idxs = random.sample(range(len(population)), tournament_size)
248 | tournament_nets = tuple(population[i] for i in tournament_idxs)
249 | nets_and_idxs = tuple(zip(tournament_nets, tournament_idxs))
250 |
251 | if len(set(x.fitness.mdl for x in tournament_nets)) == 1:
252 | # MDL Tie.
253 | if np.isinf(tournament_nets[0].fitness.mdl):
254 | # Break |D:G| infinity ties using |G|.
255 | argsort_by_d_g = tuple(
256 | np.argsort([x.fitness.grammar_encoding_length for x in tournament_nets])
257 | )
258 | return _Tournament(
259 | winner_idx=argsort_by_d_g[0], loser_idx=argsort_by_d_g[-1]
260 | )
261 | return _Tournament(*tuple(random.sample(tournament_idxs, k=2)))
262 |
263 | sorted_tournament = sorted(nets_and_idxs, key=lambda x: x[0].fitness.mdl)
264 | return _Tournament(
265 | winner_idx=sorted_tournament[0][1], loser_idx=sorted_tournament[-1][1]
266 | )
267 |
268 |
269 | def _get_population_incoming_degrees(
270 | population: _Population, edge_type: int
271 | ) -> np.ndarray:
272 | degrees = []
273 | for net in population:
274 | (_, reverse_connections, _,) = network.get_connections_and_weights_by_edge_type(
275 | net, edge_type
276 | )
277 | degrees += list(map(len, reverse_connections.values()))
278 | return np.array(degrees)
279 |
280 |
281 | def _initialize_population_and_generation_from_existing_simulation(
282 | config, island_num
283 | ) -> Tuple[_Population, int]:
284 | if config.migration_channel in {"mpi", "file"}:
285 | with open(
286 | f"./generations/{config.resumed_from_simulation_id}_latest_generation_island_{island_num}.pickle",
287 | "rb",
288 | ) as f:
289 | latest_generation_data = pickle.load(f)
290 | else:
291 | raise ValueError(config.migration_channel)
292 | generation = latest_generation_data["generation"]
293 | population = latest_generation_data["population"]
294 | logging.info(f"Loaded population island {island_num} from generation {generation}")
295 | return population, generation
296 |
297 |
298 | def _initialize_population_and_generation(
299 | config: configuration.SimulationConfig,
300 | island_num: int,
301 | input_size: int,
302 | output_size: int,
303 | ) -> Tuple[_Population, int]:
304 | if config.resumed_from_simulation_id is not None:
305 | return _initialize_population_and_generation_from_existing_simulation(
306 | config, island_num
307 | )
308 |
309 | generation = 1
310 | population = _make_random_population(
311 | config=config, input_size=input_size, output_size=output_size,
312 | )
313 | logging.debug(f"Initialized random population size {config.population_size}")
314 |
315 | if _DEBUG_MODE:
316 | [
317 | network.visualize(
318 | population[i], f"random_initial_net_{i}__island_{island_num}",
319 | )
320 | for i in random.sample(range(len(population)), 10)
321 | ]
322 |
323 | return population, generation
324 |
325 |
326 | def _evaluate_population(
327 | population: _Population,
328 | corpus: corpora.Corpus,
329 | config: configuration.SimulationConfig,
330 | ) -> _Population:
331 | return [
332 | dataclasses.replace(
333 | net, fitness=_evaluate_net_cached(net=net, corpus=corpus, config=config),
334 | )
335 | for net in population
336 | ]
337 |
338 |
339 | def _make_single_reproduction(
340 | population: _Population,
341 | elite_idxs: FrozenSet[int],
342 | corpus: corpora.Corpus,
343 | config: configuration.SimulationConfig,
344 | ) -> _Population:
345 | # Select parent(s) using tournament selection, create an offspring, replace tournament loser with offspring.
346 | p = random.random()
347 | if p < config.mutation_probab:
348 | tournament = _tournament_selection(population, config.tournament_size)
349 | parent_idx = tournament.winner_idx
350 | killed_idx = tournament.loser_idx
351 | offspring = network.mutate(population[parent_idx], config=config)
352 | else:
353 | tournament = _tournament_selection(population, config.tournament_size)
354 | offspring = population[tournament.winner_idx]
355 | killed_idx = tournament.loser_idx
356 |
357 | offspring_fitness = _evaluate_net_cached(offspring, corpus, config)
358 | offspring = dataclasses.replace(offspring, fitness=offspring_fitness)
359 |
360 | if (
361 | killed_idx in elite_idxs
362 | and offspring.fitness.mdl >= population[killed_idx].fitness.mdl
363 | ):
364 | # Only kill a losing elite if the offspring is better.
365 | return population
366 |
367 | population[killed_idx] = offspring
368 | return population
369 |
370 |
371 | def _make_generation(
372 | population: _Population,
373 | corpus: corpora.Corpus,
374 | config: configuration.SimulationConfig,
375 | ) -> _Population:
376 | # Calculate elite once per generation for performance.
377 | elite_idxs = _get_elite_idxs(population, config.elite_ratio)
378 | for _ in range(len(population)):
379 | population = _make_single_reproduction(
380 | population=population, elite_idxs=elite_idxs, corpus=corpus, config=config
381 | )
382 | return population
383 |
384 |
385 | def _save_generation(
386 | generation: int,
387 | population: _Population,
388 | island_num: int,
389 | config: configuration.SimulationConfig,
390 | cloud_upload_queue: queue.Queue,
391 | ):
392 | data = {
393 | "generation": generation,
394 | "population": population,
395 | "island": island_num,
396 | }
397 | if config.migration_channel in {"mpi", "file"}:
398 | path = pathlib.Path(
399 | f"./generations/{config.simulation_id}_latest_generation_island_{island_num}.pickle"
400 | )
401 | path.parent.mkdir(parents=True, exist_ok=True)
402 | with open(path, "wb") as f:
403 | pickle.dump(data, f)
404 |
405 |
406 | def _log_generation_to_logging_process(
407 | island_num: int,
408 | generation: int,
409 | best_net: network.Network,
410 | corpus: corpora.Corpus,
411 | config: configuration.SimulationConfig,
412 | logging_queue: multiprocessing.Queue,
413 | ):
414 | if config.migration_channel == "mpi":
415 | _MPI_COMMUNICATOR.Send(
416 | np.array(
417 | [
418 | island_num,
419 | generation,
420 | best_net.fitness.mdl,
421 | best_net.fitness.grammar_encoding_length,
422 | best_net.fitness.data_encoding_length,
423 | network.get_num_units(best_net),
424 | network.get_total_connections(best_net, include_biases=True),
425 | ]
426 | ),
427 | dest=config.num_islands,
428 | tag=utils.MPI_LOGGING_TAG,
429 | )
430 | return
431 |
432 | stats = {
433 | "island": island_num,
434 | "generation": generation,
435 | "mdl": best_net.fitness.mdl,
436 | "|g|": best_net.fitness.grammar_encoding_length,
437 | "|d:g|": best_net.fitness.data_encoding_length,
438 | "units": network.get_num_units(best_net),
439 | "connections": network.get_total_connections(best_net, include_biases=True),
440 | "accuracy": best_net.fitness.accuracy,
441 | }
442 |
443 | logging_queue.put({"best_net": best_net, "stats": stats})
444 |
445 |
446 | def _log_generation(
447 | population: _Population,
448 | corpus: corpora.Corpus,
449 | config: configuration.SimulationConfig,
450 | generation: int,
451 | island_num: int,
452 | generation_time_delta: datetime.timedelta,
453 | logging_queue: multiprocessing.Queue,
454 | cloud_upload_queue: queue.Queue,
455 | ):
456 | all_fitness = [x.fitness.mdl for x in population]
457 | valid_population = [
458 | population[i] for i in range(len(population)) if not np.isinf(all_fitness[i])
459 | ]
460 |
461 | if valid_population:
462 | best_net_idx = int(np.argmin(all_fitness))
463 | else:
464 | best_net_idx = 0
465 | valid_population = [population[0]]
466 |
467 | best_net = population[best_net_idx]
468 | best_fitness = all_fitness[best_net_idx]
469 |
470 | valid_fitnesses = [x.fitness.mdl for x in valid_population]
471 | mean_fitness = np.mean(valid_fitnesses)
472 | fitness_std = np.std(valid_fitnesses)
473 |
474 | num_connections = [
475 | network.get_total_connections(x, include_biases=True) for x in population
476 | ]
477 | num_connections_mean = np.mean(num_connections)
478 | num_connections_std = np.std(num_connections)
479 | num_connections_max = np.max(num_connections)
480 |
481 | num_units = [network.get_num_units(x) for x in population]
482 | num_units_mean = np.mean(num_units)
483 | num_units_std = np.std(num_units)
484 | num_units_max = np.max(num_units)
485 |
486 | incoming_forward_degrees = _get_population_incoming_degrees(
487 | population=population, edge_type=network.FORWARD_CONNECTION
488 | )
489 | multiple_inputs_forward_degrees = incoming_forward_degrees[
490 | incoming_forward_degrees > 1
491 | ]
492 | incoming_recurrent_degrees = _get_population_incoming_degrees(
493 | population=population, edge_type=network.RECURRENT_CONNECTION
494 | )
495 | multiple_inputs_recurrent_degrees = incoming_recurrent_degrees[
496 | incoming_recurrent_degrees > 1
497 | ]
498 |
499 | g_s = [x.fitness.grammar_encoding_length for x in population]
500 | mean_g = np.mean(g_s)
501 | std_g = np.std(g_s)
502 | max_g = np.max(g_s)
503 | d_g_s = [x.fitness.data_encoding_length for x in valid_population]
504 | mean_d_g = np.mean(d_g_s)
505 | std_d_g = np.std(d_g_s)
506 | max_d_g = np.max(d_g_s)
507 | mean_accuracy = np.mean([x.fitness.accuracy for x in valid_population])
508 |
509 | all_weights = []
510 | for x in population:
511 | all_weights.append(network.get_forward_weights(x))
512 | all_weights.append(network.get_recurrent_weights(x))
513 |
514 | all_weights = np.concatenate(all_weights)
515 | mean_weight = np.mean(all_weights)
516 | max_weight = np.max(all_weights)
517 |
518 | num_invalid = len(list(filter(np.isinf, [x.fitness.mdl for x in population],)))
519 | invalid_ratio = num_invalid / len(population)
520 |
521 | unique_ratio = len(set(population)) / len(population)
522 |
523 | logging.info(
524 | f"\nIsland {island_num} (pid {os.getpid()}) Generation {generation}"
525 | f"\n\tGeneration took {generation_time_delta.seconds}.{str(generation_time_delta.microseconds)[:2]} seconds"
526 | f"\n\tMean fitness: {mean_fitness:.2f} (±{fitness_std:.2f}, worst valid {np.max(valid_fitnesses):.2f}) \tBest fitness: {best_fitness:.2f}"
527 | f"\n\tMean num nodes: {num_units_mean:.2f} (±{num_units_std:.2f}, max {num_units_max}) \tMean num connections: {num_connections_mean:.2f} (±{num_connections_std:.2f}, max {num_connections_max}) \tMean G: {mean_g:.2f} (±{std_g:.2f}, max {max_g:.2f})\tMean D:G: {mean_d_g:.2f} (±{std_d_g:.2f}, max {max_d_g:.2f})"
528 | f"\n\tMean forward in degree: {np.mean(incoming_forward_degrees):.2f} (±{np.std(incoming_forward_degrees):.2f}, max {np.max(incoming_forward_degrees)}) \tMean recurrent in degree: {np.mean(incoming_recurrent_degrees):.2f} (±{np.std(incoming_recurrent_degrees):.2f}, max {np.max(incoming_recurrent_degrees) if incoming_recurrent_degrees.size else '-'})"
529 | f"\n\tMean forward in degree>1: {np.mean(multiple_inputs_forward_degrees):.2f} (±{np.std(multiple_inputs_forward_degrees):.2f}) \tMean recurrent in degree>1: {np.mean(multiple_inputs_recurrent_degrees):.2f} (±{np.std(multiple_inputs_recurrent_degrees):.2f})"
530 | f"\n\tMean weight: {mean_weight:.2f} (max {max_weight})\tMean accuracy: {mean_accuracy:.2f}\tInvalid: {invalid_ratio*100:.1f}%\tUnique: {unique_ratio*100:.1f}%"
531 | f"\n\tBest network:\n\t{network.to_string(best_net)}\n\n"
532 | )
533 |
534 | if generation == 1 or generation % 100 == 0:
535 | network_filename = f"{config.simulation_id}__island_{island_num}__best_network"
536 | network.visualize(best_net, network_filename, class_to_label=corpus.vocabulary)
537 | network.save(best_net, network_filename)
538 |
539 | _log_generation_to_logging_process(
540 | island_num=island_num,
541 | generation=generation,
542 | best_net=best_net,
543 | corpus=corpus,
544 | config=config,
545 | logging_queue=logging_queue,
546 | )
547 |
548 | if generation == 1 or generation % config.generation_dump_interval == 0:
549 | _save_generation(
550 | generation=generation,
551 | population=population,
552 | island_num=island_num,
553 | config=config,
554 | cloud_upload_queue=cloud_upload_queue,
555 | )
556 |
557 | if _DEBUG_MODE and generation > 0 and generation % 5 == 0:
558 | [
559 | network.visualize(
560 | population[x],
561 | f"random_gen_{generation}__island_{island_num}__{str(uuid.uuid1())}",
562 | )
563 | for x in random.sample(range(len(population)), 5)
564 | ]
565 |
566 |
567 | def _send_migrants(
568 | population: _Population,
569 | island_num: int,
570 | config: configuration.SimulationConfig,
571 | target_island: int,
572 | target_process_lock: multiprocessing.Lock,
573 | cloud_upload_queue: queue.Queue,
574 | ):
575 | num_migrants = math.floor(config.migration_ratio * config.population_size)
576 | migrants = list(
577 | set(
578 | population[
579 | _tournament_selection(
580 | population, tournament_size=config.tournament_size
581 | ).winner_idx
582 | ]
583 | for _ in range(num_migrants)
584 | )
585 | )
586 | if config.migration_channel == "file":
587 | did_send = _send_migrants_through_file(
588 | migrants=migrants,
589 | target_island=target_island,
590 | simulation_id=config.simulation_id,
591 | file_lock=target_process_lock,
592 | )
593 | elif config.migration_channel == "mpi":
594 | did_send = _send_migrants_through_mpi(
595 | migrants=migrants, target_island=target_island,
596 | )
597 | else:
598 | raise ValueError(config.migration_channel)
599 |
600 | if did_send:
601 | best_sent = min(migrants, key=_GET_NET_MDL)
602 | logging.info(
603 | f"Island {island_num} sent {len(migrants)} migrants to island {target_island} through {config.migration_channel}. Best sent: {best_sent.fitness.mdl:,.2f}"
604 | )
605 |
606 |
607 | def _integrate_migrants(
608 | incoming_migrants: _Population,
609 | population: _Population,
610 | config: configuration.SimulationConfig,
611 | island_num: int,
612 | ) -> _Population:
613 | losing_idxs = tuple(
614 | _tournament_selection(population, config.tournament_size).loser_idx
615 | for _ in range(len(incoming_migrants))
616 | )
617 | prev_best_fitness = min([x.fitness.mdl for x in population])
618 | for migrant_idx, local_idx in enumerate(losing_idxs):
619 | population[local_idx] = incoming_migrants[migrant_idx]
620 | new_best_fitness = min([x.fitness.mdl for x in population])
621 | logging.info(
622 | f"Island {island_num} got {len(incoming_migrants)} incoming migrants. Previous best fitness: {prev_best_fitness:.2f}, new best: {new_best_fitness:.2f}"
623 | )
624 | return population
625 |
626 |
627 | def _receive_and_integrate_migrants(
628 | population, config, island_num, process_lock
629 | ) -> _Population:
630 | if config.migration_channel == "file":
631 | incoming_migrants = _get_migrants_from_file(
632 | simulation_id=config.simulation_id,
633 | island_num=island_num,
634 | file_lock=process_lock,
635 | )
636 | elif config.migration_channel == "mpi":
637 | incoming_migrants = get_migrants_through_mpi()
638 | else:
639 | raise ValueError(config.migration_channel)
640 |
641 | if incoming_migrants is None:
642 | logging.info(f"Island {island_num} has no incoming migrants waiting.")
643 | return population
644 |
645 | return _integrate_migrants(
646 | incoming_migrants=incoming_migrants,
647 | population=population,
648 | config=config,
649 | island_num=island_num,
650 | )
651 |
652 |
653 | def _make_migration(
654 | population: _Population,
655 | island_num: int,
656 | config: configuration.SimulationConfig,
657 | migration_target_generator: Iterator[int],
658 | process_locks: Tuple[multiprocessing.Lock, ...],
659 | cloud_upload_queue: queue.Queue,
660 | ) -> _Population:
661 | target_island = next(migration_target_generator)
662 | _send_migrants(
663 | population=population,
664 | island_num=island_num,
665 | config=config,
666 | target_island=target_island,
667 | target_process_lock=process_locks[target_island] if process_locks else None,
668 | cloud_upload_queue=cloud_upload_queue,
669 | )
670 | return _receive_and_integrate_migrants(
671 | population,
672 | config,
673 | island_num,
674 | process_locks[island_num] if process_locks else None,
675 | )
676 |
677 |
678 | def run(
679 | island_num: int,
680 | corpus: corpora.Corpus,
681 | config: configuration.SimulationConfig,
682 | process_locks: Tuple[multiprocessing.Lock, ...],
683 | logging_queue: multiprocessing.Queue,
684 | ):
685 | seed = config.seed + island_num
686 | utils.seed(seed)
687 | logging.info(f"Island {island_num}, seed {seed}")
688 |
689 | population, generation = _initialize_population_and_generation(
690 | config,
691 | island_num,
692 | input_size=corpus.input_sequence.shape[-1],
693 | output_size=corpus.target_sequence.shape[-1],
694 | )
695 | population = _evaluate_population(population, corpus, config)
696 | migration_target_generator = _make_migration_target_island_generator(
697 | island_num=island_num, total_islands=config.num_islands
698 | )
699 |
700 | cloud_upload_queue = queue.Queue()
701 |
702 | stopwatch_start = datetime.datetime.now()
703 | while generation <= config.num_generations:
704 | generation_start_time = datetime.datetime.now()
705 | population = _make_generation(population, corpus, config)
706 | generation_time_delta = datetime.datetime.now() - generation_start_time
707 |
708 | time_delta = datetime.datetime.now() - stopwatch_start
709 | if config.num_islands > 1 and (
710 | time_delta.total_seconds() >= config.migration_interval_seconds
711 | or generation % config.migration_interval_generations == 0
712 | ):
713 | logging.info(
714 | f"Island {island_num} performing migration, time passed {time_delta.total_seconds()} seconds."
715 | )
716 | population = _make_migration(
717 | population=population,
718 | island_num=island_num,
719 | config=config,
720 | migration_target_generator=migration_target_generator,
721 | process_locks=process_locks,
722 | cloud_upload_queue=cloud_upload_queue,
723 | )
724 | stopwatch_start = datetime.datetime.now()
725 |
726 | _log_generation(
727 | population=population,
728 | corpus=corpus,
729 | config=config,
730 | generation=generation,
731 | island_num=island_num,
732 | generation_time_delta=generation_time_delta,
733 | logging_queue=logging_queue,
734 | cloud_upload_queue=cloud_upload_queue,
735 | )
736 | generation += 1
737 |
738 | population = _make_migration(
739 | population=population,
740 | island_num=island_num,
741 | config=config,
742 | migration_target_generator=migration_target_generator,
743 | process_locks=process_locks,
744 | cloud_upload_queue=cloud_upload_queue,
745 | )
746 |
747 | best_network = min(population, key=_GET_NET_MDL)
748 | return best_network
749 |
--------------------------------------------------------------------------------
/island.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 | import logging
4 | import multiprocessing
5 | import pickle
6 | import queue
7 | from datetime import datetime
8 | from typing import Dict, Optional, Text, Tuple
9 |
10 | import numpy as np
11 | from mpi4py import MPI
12 |
13 | import configuration
14 | import corpora
15 | import genetic_algorithm
16 | import network
17 | import utils
18 |
19 |
20 | _MPI_COMMUNICATOR = MPI.COMM_WORLD
21 | _MPI_RANK = _MPI_COMMUNICATOR.Get_rank()
22 |
23 | _MPI_LOGGING_BUFFER_SIZE = 5_000_000
24 |
25 |
26 | def _island_process(
27 | island_num: int,
28 | config: configuration.SimulationConfig,
29 | corpus: corpora.Corpus,
30 | result_queue: multiprocessing.Queue,
31 | logging_queue: multiprocessing.Queue,
32 | process_locks: Tuple[multiprocessing.Lock, ...],
33 | ):
34 | try:
35 | result_network = genetic_algorithm.run(
36 | island_num=island_num,
37 | corpus=corpus,
38 | config=config,
39 | logging_queue=logging_queue,
40 | process_locks=process_locks,
41 | )
42 | result_queue.put((island_num, result_network))
43 | except:
44 | logging.exception(f"Exception in island {island_num}:")
45 |
46 |
47 | class _DummyProcess:
48 | # Used for profiling without multiprocessing.
49 | def __init__(self, target, kwargs):
50 | self._target = target
51 | self._kwargs = kwargs
52 | self.__setattr__ = lambda x: x
53 |
54 | def start(self):
55 | self._target(**self._kwargs)
56 |
57 |
58 | def _log_net_stats(
59 | stats: Dict,
60 | config: configuration.SimulationConfig,
61 | train_corpus: corpora.Corpus,
62 | test_corpus: Optional[corpora.Corpus],
63 | best_net: network.Network,
64 | status: Text,
65 | ):
66 | stats.update(
67 | {
68 | "simulation": config.simulation_id,
69 | "simulation id": config.simulation_id,
70 | "configuration": json.dumps(config.__dict__),
71 | "params": config.comment,
72 | "status": status,
73 | "best island": stats["island"],
74 | "last update": datetime.now().isoformat(),
75 | }
76 | )
77 |
78 | log_text = (
79 | f"Current best net, island {stats['island']}, generation {stats['generation']}: "
80 | f"MDL = {stats['mdl']:,.2f}, |G| = {stats['|g|']}, |D:G| = {stats['|d:g|']:,.2f}, units: {stats['units']}, connections: {stats['connections']}."
81 | )
82 |
83 | train_num_chars = corpora.get_num_chars_in_corpus(train_corpus)
84 | average_train_dg_per_char = best_net.fitness.data_encoding_length / train_num_chars
85 | stats.update(
86 | {
87 | "average train d:g per character": f"{average_train_dg_per_char:.2f}",
88 | "training set num. chars": f"{train_num_chars}",
89 | }
90 | )
91 |
92 | if train_corpus.deterministic_steps_mask is not None:
93 | train_deterministic_accuracy = network.calculate_deterministic_accuracy(
94 | best_net, train_corpus, config
95 | )
96 | stats.update({"deterministic accuracy": train_deterministic_accuracy})
97 | log_text += (
98 | f" Training set deterministic accuracy: {train_deterministic_accuracy:,.2f}"
99 | )
100 |
101 | if test_corpus is not None:
102 | test_net = network.invalidate_fitness(best_net)
103 | test_net = network.calculate_fitness(
104 | test_net, corpus=test_corpus, config=config
105 | )
106 |
107 | test_num_chars = corpora.get_num_chars_in_corpus(test_corpus)
108 | average_test_dg_per_char = (
109 | test_net.fitness.data_encoding_length / test_num_chars
110 | )
111 |
112 | stats.update(
113 | {
114 | "test set d:g": f"{test_net.fitness.data_encoding_length:.2f}",
115 | "test set accuracy": f"{test_net.fitness.accuracy}",
116 | "average test d:g per character": f"{average_test_dg_per_char:.2f}",
117 | }
118 | )
119 | log_text += f" Test set |D:G|: {test_net.fitness.data_encoding_length:,.2f}."
120 |
121 | if test_corpus.deterministic_steps_mask is not None:
122 | test_deterministic_accuracy = network.calculate_deterministic_accuracy(
123 | test_net, test_corpus, config
124 | )
125 | stats.update(
126 | {"test set deterministic accuracy": test_deterministic_accuracy}
127 | )
128 | log_text += (
129 | f" Test set deterministic accuracy: {test_deterministic_accuracy:,.2f}"
130 | )
131 |
132 | logging.info(log_text)
133 |
134 | current_best_net_filename = f"{config.simulation_id}__current_best"
135 | network.save(best_net, current_best_net_filename)
136 | network.visualize(
137 | best_net,
138 | current_best_net_filename,
139 | class_to_label=test_corpus.vocabulary if test_corpus else None,
140 | )
141 |
142 |
143 | def _mpi_logging_worker(
144 | config: configuration.SimulationConfig,
145 | train_corpus: corpora.Corpus,
146 | test_corpus: Optional[corpora.Corpus],
147 | ):
148 | best_mdl = float("inf")
149 | buffer = np.empty(7)
150 | logging.info(f"Started MPI logging worker, rank {_MPI_RANK}")
151 | while True:
152 | _MPI_COMMUNICATOR.Recv(buffer, tag=utils.MPI_LOGGING_TAG)
153 | (
154 | island_num,
155 | generation,
156 | mdl,
157 | grammar_encoding_length,
158 | data_encoding_length,
159 | num_units,
160 | connections,
161 | ) = buffer
162 | island_num = int(island_num)
163 | if mdl < best_mdl:
164 | best_mdl = mdl
165 |
166 | with open(
167 | f"./networks/{config.simulation_id}__island_{island_num}__best_network.pickle",
168 | "rb",
169 | ) as f:
170 | best_net = pickle.load(f)
171 |
172 | best_island_stats = {
173 | "island": island_num,
174 | "generation": int(generation),
175 | "mdl": mdl,
176 | "|g|": grammar_encoding_length,
177 | "|d:g|": data_encoding_length,
178 | "units": int(num_units),
179 | "connections": int(connections),
180 | "accuracy": best_net.fitness.accuracy,
181 | }
182 |
183 | _log_net_stats(
184 | stats=best_island_stats,
185 | config=config,
186 | train_corpus=train_corpus,
187 | test_corpus=test_corpus,
188 | best_net=best_net,
189 | status="Running",
190 | )
191 |
192 |
193 | def _queue_logging_worker(
194 | logging_queue: multiprocessing.Queue,
195 | config: configuration.SimulationConfig,
196 | train_corpus: corpora.Corpus,
197 | test_corpus: Optional[corpora.Corpus],
198 | ):
199 | best_net_fitness = float("inf")
200 |
201 | while True:
202 | island_num_to_data = {}
203 | start_time = datetime.now()
204 | while (datetime.now() - start_time).total_seconds() < 60:
205 | try:
206 | data = logging_queue.get(timeout=5)
207 | island_num = data["stats"]["island"]
208 | island_num_to_data[island_num] = data
209 | except queue.Empty:
210 | pass
211 |
212 | best_island_data = None
213 | for data in island_num_to_data.values():
214 | stats = data["stats"]
215 | if stats["mdl"] <= best_net_fitness:
216 | best_net_fitness = stats["mdl"]
217 | best_island_data = data
218 |
219 | if best_island_data is None:
220 | continue
221 |
222 | _log_net_stats(
223 | stats=best_island_data["stats"],
224 | config=config,
225 | train_corpus=train_corpus,
226 | test_corpus=test_corpus,
227 | best_net=best_island_data["best_net"],
228 | status="Running",
229 | )
230 |
231 |
232 | def _create_csv_and_spreadsheet_entry(
233 | config: configuration.SimulationConfig, train_corpus, test_corpus
234 | ):
235 | data = {
236 | "simulation": config.simulation_id,
237 | "simulation id": config.simulation_id,
238 | "params": config.comment,
239 | "configuration": json.dumps(config.__dict__),
240 | "status": "Running",
241 | "started": datetime.now().isoformat(),
242 | "seed": config.seed,
243 | }
244 |
245 | train_num_chars = corpora.get_num_chars_in_corpus(train_corpus)
246 | data.update(
247 | {"training set num. chars": f"{train_num_chars}",}
248 | )
249 | if train_corpus.optimal_d_given_g is not None:
250 | data.update(
251 | {
252 | "training set optimal d:g": f"{train_corpus.optimal_d_given_g:.2f}",
253 | "optimal average train d:g per character": f"{train_corpus.optimal_d_given_g / train_num_chars:.2f}",
254 | }
255 | )
256 |
257 | if test_corpus is not None:
258 | test_num_chars = corpora.get_num_chars_in_corpus(test_corpus)
259 | data.update(
260 | {
261 | "test set params": f"Input shape: {test_corpus.input_sequence.shape}. Output shape: {test_corpus.target_sequence.shape}",
262 | "test set num. chars": f"{test_num_chars}",
263 | }
264 | )
265 | if test_corpus.optimal_d_given_g is not None:
266 | data.update(
267 | {
268 | "test set optimal d:g": f"{test_corpus.optimal_d_given_g:.2f}",
269 | "optimal average test d:g per character": f"{test_corpus.optimal_d_given_g / test_num_chars:.2f}",
270 | }
271 | )
272 |
273 |
274 | def _init_islands(
275 | first_island: int,
276 | last_island: int,
277 | config: configuration.SimulationConfig,
278 | train_corpus,
279 | result_queue: multiprocessing.Queue,
280 | logging_queue: multiprocessing.Queue,
281 | process_locks: Tuple[multiprocessing.Lock, ...],
282 | ) -> Tuple[multiprocessing.Process, ...]:
283 | processes = []
284 | for i in range(first_island, last_island + 1):
285 | if config.parallelize:
286 | process_class = multiprocessing.Process
287 | else:
288 | process_class = _DummyProcess
289 |
290 | p = process_class(
291 | target=_island_process,
292 | kwargs={
293 | "island_num": i,
294 | "config": config,
295 | "corpus": train_corpus,
296 | "result_queue": result_queue,
297 | "logging_queue": logging_queue,
298 | "process_locks": process_locks,
299 | },
300 | )
301 | p.daemon = True
302 | processes.append(p)
303 |
304 | return tuple(processes)
305 |
306 |
307 | def _get_island_results_from_queue_or_mpi(result_queue: multiprocessing.Queue):
308 | if result_queue is not None:
309 | return result_queue.get()
310 | return _MPI_COMMUNICATOR.recv(tag=utils.MPI_RESULTS_TAG)
311 |
312 |
313 | def _collect_results(
314 | num_islands_to_collect: int,
315 | result_queue: multiprocessing.Queue,
316 | config: configuration.SimulationConfig,
317 | train_corpus: corpora.Corpus,
318 | test_corpus: Optional[corpora.Corpus],
319 | ):
320 | result_network_per_island = {}
321 | while len(result_network_per_island) < num_islands_to_collect:
322 | island_num, result_net = _get_island_results_from_queue_or_mpi(
323 | result_queue=result_queue
324 | )
325 | result_network_per_island[island_num] = result_net
326 | logging.info(f"Island {island_num} best network:\n{str(result_net)}")
327 | logging.info(
328 | f"{len(result_network_per_island)}/{num_islands_to_collect} (global {config.num_islands}) islands done."
329 | )
330 |
331 | best_island, best_net = min(
332 | result_network_per_island.items(), key=lambda x: x[1].fitness.mdl
333 | )
334 | network_filename = f"{config.simulation_id}__best_network"
335 | network.visualize(
336 | best_net, network_filename, class_to_label=train_corpus.vocabulary
337 | )
338 | network.save(best_net, network_filename)
339 |
340 | logging.info(f"Best network of all islands:\n{str(best_net)}")
341 |
342 | csv_data = {
343 | "simulation id": config.simulation_id,
344 | "island": best_island,
345 | "mdl": best_net.fitness.mdl,
346 | "|g|": best_net.fitness.grammar_encoding_length,
347 | "|d:g|": best_net.fitness.data_encoding_length,
348 | "units": network.get_num_units(best_net),
349 | "connections": network.get_total_connections(best_net, include_biases=True),
350 | "accuracy": best_net.fitness.accuracy,
351 | "generation": config.num_generations,
352 | "finished": datetime.now().isoformat(),
353 | }
354 | _log_net_stats(
355 | stats=csv_data,
356 | config=config,
357 | train_corpus=train_corpus,
358 | test_corpus=test_corpus,
359 | best_net=best_net,
360 | status="Done",
361 | )
362 |
363 |
364 | def _run_mpi_island(island_num, corpus, config):
365 | result_network = genetic_algorithm.run(
366 | island_num=island_num,
367 | corpus=corpus,
368 | config=config,
369 | logging_queue=None,
370 | process_locks=None,
371 | )
372 | _MPI_COMMUNICATOR.send(
373 | (island_num, result_network),
374 | dest=config.num_islands + 1,
375 | tag=utils.MPI_RESULTS_TAG,
376 | )
377 | while True:
378 | # TODO: need to keep running this so buffer won't fill. need to find a better solution.
379 | genetic_algorithm.get_migrants_through_mpi()
380 |
381 |
382 | def run(
383 | corpus: corpora.Corpus,
384 | config: configuration.SimulationConfig,
385 | first_island: int,
386 | last_island: int,
387 | ):
388 | train_corpus = corpora.optimize_for_feeding(
389 | dataclasses.replace(corpus, test_corpus=None)
390 | )
391 | test_corpus = (
392 | corpora.optimize_for_feeding(corpus.test_corpus) if corpus.test_corpus else None
393 | )
394 |
395 | genetic_algorithm.verify_existing_simulation_override(
396 | simulation_id=config.simulation_id,
397 | )
398 |
399 | logging.info(f"Starting simulation {config.simulation_id}\n")
400 | logging.info(f"Config: {config}\n")
401 | logging.info(
402 | f"Running islands {first_island}-{last_island} ({last_island-first_island+1}/{config.num_islands})"
403 | )
404 |
405 | if config.migration_channel == "file":
406 | genetic_algorithm.get_migration_path(config.simulation_id).mkdir(
407 | parents=True, exist_ok=True
408 | )
409 |
410 | if first_island == 0:
411 | _create_csv_and_spreadsheet_entry(config, train_corpus, test_corpus)
412 |
413 | result_queue = multiprocessing.Queue()
414 | logging_queue = multiprocessing.Queue()
415 | process_locks = tuple(multiprocessing.Lock() for _ in range(config.num_islands))
416 |
417 | island_processes = _init_islands(
418 | first_island=first_island,
419 | last_island=last_island,
420 | config=config,
421 | train_corpus=train_corpus,
422 | result_queue=result_queue,
423 | logging_queue=logging_queue,
424 | process_locks=process_locks,
425 | )
426 |
427 | for p in island_processes:
428 | p.start()
429 |
430 | logging_process = multiprocessing.Process(
431 | target=_queue_logging_worker,
432 | args=(logging_queue, config, train_corpus, test_corpus),
433 | )
434 | logging_process.daemon = True
435 | logging_process.start()
436 |
437 | _collect_results(
438 | num_islands_to_collect=len(island_processes),
439 | result_queue=result_queue,
440 | config=config,
441 | train_corpus=train_corpus,
442 | test_corpus=test_corpus,
443 | )
444 |
445 | elif config.migration_channel == "mpi":
446 | if _MPI_RANK < config.num_islands:
447 | # Ranks [0, num_islands - 1] = islands
448 | _run_mpi_island(
449 | island_num=_MPI_RANK, corpus=train_corpus, config=config,
450 | )
451 | elif _MPI_RANK == config.num_islands:
452 | # Rank num_islands = logger
453 | _create_csv_and_spreadsheet_entry(config, train_corpus, test_corpus)
454 | _mpi_logging_worker(config, train_corpus, test_corpus)
455 | elif _MPI_RANK == config.num_islands + 1:
456 | # Rank num_island + 1 = results collector
457 | logging.info(f"Starting results collection, rank {_MPI_RANK}")
458 | _collect_results(
459 | num_islands_to_collect=config.num_islands,
460 | result_queue=None,
461 | config=config,
462 | train_corpus=train_corpus,
463 | test_corpus=test_corpus,
464 | )
465 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import json
3 |
4 | import configuration
5 | import genetic_algorithm
6 | import island
7 | import simulations
8 | import utils
9 |
10 | utils.setup_logging()
11 |
12 |
13 | _NUM_REPRODUCTIONS = 1
14 |
15 |
16 | def _make_corpus(factory, args, corpus_seed):
17 | utils.seed(corpus_seed)
18 | return factory(**args)
19 |
20 |
21 | def run():
22 | arg_parser = utils.make_cli_arguments()
23 | arguments = arg_parser.parse_args()
24 |
25 | simulation = simulations.SIMULATIONS[arguments.simulation_name]
26 |
27 | for i in range(_NUM_REPRODUCTIONS):
28 | simulation_seed = arguments.base_seed + i
29 |
30 | base_config = configuration.SimulationConfig(
31 | seed=simulation_seed,
32 | simulation_id=arguments.simulation_name,
33 | num_islands=arguments.total_islands,
34 | **{**simulations.DEFAULT_CONFIG, **simulation.get("config", {})},
35 | )
36 |
37 | corpus_args = simulation["corpus"]["args"]
38 | if arguments.corpus_args is not None:
39 | corpus_args.update(json.loads(arguments.corpus_args))
40 |
41 | corpus = _make_corpus(
42 | factory=simulation["corpus"]["factory"],
43 | args=corpus_args,
44 | corpus_seed=base_config.corpus_seed,
45 | )
46 |
47 | simulation_config = dataclasses.replace(
48 | base_config,
49 | comment=f"Corpus params: {json.dumps(simulation['corpus']['args'])}, input shape: {corpus.input_sequence.shape}. Output shape: {corpus.target_sequence.shape}",
50 | simulation_id=corpus.name,
51 | resumed_from_simulation_id=arguments.resumed_simulation_id,
52 | )
53 | simulation_config = utils.add_hash_to_simulation_id(simulation_config, corpus)
54 |
55 | if arguments.override_existing:
56 | genetic_algorithm.remove_simulation_directory(
57 | simulation_id=simulation_config.simulation_id,
58 | )
59 |
60 | utils.seed(simulation_seed)
61 |
62 | island.run(
63 | corpus=corpus,
64 | config=simulation_config,
65 | first_island=arguments.first_island
66 | if arguments.first_island is not None
67 | else 0,
68 | last_island=(
69 | arguments.last_island
70 | if arguments.last_island is not None
71 | else simulation_config.num_islands - 1
72 | ),
73 | )
74 |
75 |
76 | if __name__ == "__main__":
77 | run()
78 |
--------------------------------------------------------------------------------
/manual_nets.py:
--------------------------------------------------------------------------------
1 | import network
2 |
3 |
4 | def make_emmanuel_dyck_2_network(nesting_probab: float):
5 | opening_bracket_output_bias = nesting_probab / (2 * (1 - nesting_probab))
6 | return network.make_custom_network(
7 | input_size=5,
8 | output_size=5,
9 | num_units=17,
10 | forward_weights={
11 | 1: ((3, 1, 2, 1),),
12 | 2: ((12, -1, 1, 1), (13, 1, 1, 1)),
13 | 3: ((10, 1, 1, 1),),
14 | 4: ((2, 1, 1, 1),),
15 | 5: ((9, -1, 1, 1),),
16 | 6: ((8, 1, 1, 1),),
17 | 7: ((9, -1, 1, 1),),
18 | 10: ((11, 1, 1, 1),),
19 | 11: ((15, 1, 1, 1),),
20 | 12: ((11, 1, 1, 1),),
21 | 13: ((15, 1, 1, 1),),
22 | 14: ((13, 1, 1, 1),),
23 | 15: ((16, 1, 1, 1),),
24 | 16: ((5, -1, 1, 1), (7, 1, 1, 1)),
25 | },
26 | recurrent_weights={15: ((10, 1, 3, 1), (14, 1, 1, 3))},
27 | unit_types={11: network.MULTIPLICATION_UNIT, 13: network.MULTIPLICATION_UNIT,},
28 | biases={5: 1, 6: opening_bracket_output_bias, 7: -1, 9: 1, 12: 1},
29 | activations={
30 | 5: network.UNSIGNED_STEP,
31 | 7: network.UNSIGNED_STEP,
32 | 14: network.FLOOR,
33 | 16: network.MODULO_3,
34 | },
35 | )
36 |
37 |
38 | def make_emmanuel_dyck_2_network_io_protection(nesting_probab: float):
39 | opening_bracket_output_bias = nesting_probab / (2 * (1 - nesting_probab))
40 | return network.make_custom_network(
41 | input_size=5,
42 | output_size=5,
43 | num_units=23,
44 | forward_weights={
45 | 1: ((18, 1, 2, 1),),
46 | 2: ((17, 1, 1, 1),),
47 | 3: ((18, 1, 1, 1),),
48 | 4: ((17, 1, 1, 1),),
49 | 10: ((11, 1, 1, 1),),
50 | 11: ((15, 1, 1, 1),),
51 | 12: ((11, 1, 1, 1),),
52 | 13: ((15, 1, 1, 1),),
53 | 14: ((13, 1, 1, 1),),
54 | 15: ((16, 1, 1, 1),),
55 | 16: ((19, -1, 1, 1), (20, 1, 1, 1)),
56 | 17: ((12, -1, 1, 1), (13, 1, 1, 1)),
57 | 18: ((10, 1, 1, 1),),
58 | 19: ((21, -1, 1, 1), (20, 1, 1, 1), (5, 1, 1, 1)),
59 | 20: ((7, 1, 1, 1), (21, -1, 1, 1)),
60 | 21: ((9, 1, 1, 1),),
61 | 22: ((6, 1, 1, 1), (8, 1, 1, 1)),
62 | },
63 | recurrent_weights={15: ((10, 1, 3, 1), (14, 1, 1, 3))},
64 | unit_types={11: network.MULTIPLICATION_UNIT, 13: network.MULTIPLICATION_UNIT,},
65 | biases={12: 1, 19: 1, 20: -1, 21: 1, 22: opening_bracket_output_bias},
66 | activations={
67 | 14: network.FLOOR,
68 | 16: network.MODULO_3,
69 | 19: network.UNSIGNED_STEP,
70 | 20: network.UNSIGNED_STEP,
71 | },
72 | )
73 |
74 |
75 | def make_emmanuel_triplet_xor_network():
76 | net = network.make_custom_network(
77 | input_size=1,
78 | output_size=1,
79 | num_units=7,
80 | forward_weights={
81 | 0: ((2, 1, 1, 1),),
82 | 2: ((5, -1, 1, 1), (6, 1, 1, 1),),
83 | 3: ((6, 1, 1, 1), (5, 1, 1, 1)),
84 | 4: ((3, 1, 1, 1),),
85 | 5: ((1, -1, 1, 2),),
86 | 6: ((1, 1, 1, 2),),
87 | },
88 | recurrent_weights={
89 | 0: ((2, -1, 1, 1),),
90 | 3: ((4, -1, 3, 1),),
91 | 4: ((4, 1, 1, 1),),
92 | },
93 | biases={1: 0.5, 3: -1, 4: 1, 6: -1},
94 | activations={
95 | 2: network.SQUARE,
96 | 3: network.RELU,
97 | 5: network.RELU,
98 | 6: network.RELU,
99 | },
100 | )
101 | return net
102 |
--------------------------------------------------------------------------------
/mdlrnn_torch.py:
--------------------------------------------------------------------------------
1 | import itertools
2 |
3 | import torch
4 | from torch import nn
5 |
6 |
7 | class MDLRNN(nn.Module):
8 | def __init__(
9 | self,
10 | computation_graph: dict[int, tuple[int, nn.Linear]],
11 | layer_to_memory_weights: dict[int, nn.Linear],
12 | memory_to_layer_weights: dict[int, nn.Linear],
13 | layer_to_activation_to_units: dict[int, dict[int, frozenset[int]]],
14 | ):
15 | super(MDLRNN, self).__init__()
16 | self._computation_graph = computation_graph
17 | self._layer_to_activation_to_units = layer_to_activation_to_units
18 | self._memory_to_layer_weights = memory_to_layer_weights
19 | self.layer_to_memory_weights = layer_to_memory_weights
20 |
21 | self._memory_size = memory_to_layer_weights[
22 | min(memory_to_layer_weights)
23 | ].in_features
24 |
25 | self.module_list = nn.ModuleList(
26 | [x[1] for x in list(itertools.chain(*self._computation_graph.values()))]
27 | + list(self.layer_to_memory_weights.values())
28 | + list(self._memory_to_layer_weights.values())
29 | )
30 |
31 | def forward(self, inputs, memory=None, output_layer=None):
32 | """
33 | :param inputs: batched tensor of shape `(batch_size, sequence_length, num_input_classes)`.
34 | :param memory: batch tensor of shape `(batch_size, memory_size)`.
35 | :param output_layer: function to apply to outputs: `None` for raw logits, `"softmax"`, or `"normalize"` for simple normalization.
36 | :return: tensor of shape `(batch_size, sequence_length, num_output_classes)`.
37 | """
38 |
39 | def recurrence(inputs_inner, memory_inner):
40 | input_layer_num = min(self._computation_graph)
41 | layer_to_vals = {input_layer_num: inputs_inner}
42 |
43 | memory_out = torch.zeros(
44 | (
45 | inputs_inner.shape[0],
46 | self._memory_size,
47 | )
48 | )
49 |
50 | for source_layer in sorted(self._computation_graph):
51 | # Add memory.
52 | memory_weights = self._memory_to_layer_weights[source_layer]
53 | incoming_memory = memory_weights(memory_inner)
54 | layer_to_vals[source_layer] = (
55 | layer_to_vals[source_layer] + incoming_memory
56 | )
57 |
58 | # Apply activations.
59 | source_layer_activations_to_unit = self._layer_to_activation_to_units[
60 | source_layer
61 | ]
62 | activation_vals = self._apply_activations(
63 | source_layer_activations_to_unit, layer_to_vals[source_layer]
64 | )
65 | layer_to_vals[source_layer] = activation_vals
66 |
67 | # Feed-forward.
68 | for target_layer, current_weights in self._computation_graph[
69 | source_layer
70 | ]:
71 | source_layer_val = layer_to_vals[source_layer]
72 | target_layer_val = current_weights(source_layer_val)
73 |
74 | if target_layer in layer_to_vals:
75 | layer_to_vals[target_layer] = (
76 | layer_to_vals[target_layer] + target_layer_val
77 | )
78 | else:
79 | layer_to_vals[target_layer] = target_layer_val
80 |
81 | # Write to memory.
82 | memory_out = memory_out + self.layer_to_memory_weights[source_layer](
83 | layer_to_vals[source_layer]
84 | )
85 |
86 | y_out = layer_to_vals[max(layer_to_vals)]
87 | return y_out, memory_out
88 |
89 | if memory is None:
90 | memory = torch.zeros(
91 | (inputs.shape[0], self._memory_size),
92 | )
93 |
94 | outputs = []
95 | for step in range(inputs.shape[1]):
96 | y_t, memory = recurrence(inputs[:, step], memory)
97 | outputs.append(y_t)
98 |
99 | outputs = torch.stack(outputs, dim=1)
100 |
101 | if output_layer is not None:
102 | if output_layer == "softmax":
103 | outputs = torch.softmax(outputs, dim=-1)
104 | elif output_layer == "normalize":
105 | outputs = torch.clamp(outputs, min=0, max=None)
106 | outputs = nn.functional.normalize(outputs, p=1, dim=-1)
107 | else:
108 | raise ValueError(output_layer)
109 |
110 | return outputs, memory
111 |
112 | @staticmethod
113 | def _apply_activations(activation_to_unit, layer_vals) -> torch.Tensor:
114 | for activation_id in activation_to_unit:
115 | activation_unit_idxs = activation_to_unit[activation_id]
116 | if activation_id == 0: # Identity.
117 | continue
118 | activation_func = {
119 | 1: torch.relu,
120 | 3: torch.tanh,
121 | 4: torch.square,
122 | 2: torch.sigmoid,
123 | }[activation_id]
124 | layer_vals[:, activation_unit_idxs] = activation_func(
125 | layer_vals[:, activation_unit_idxs]
126 | )
127 | return layer_vals
128 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | cachetools~=4.1.0
2 | mpi4py==3.0.3
3 | numba==0.53.1
4 | numpy==1.18.2
5 | pytest==6.2.2
6 | scikit-learn~=0.22.2.post1
7 | scipy~=1.4.1
8 | torch~=1.5.0
--------------------------------------------------------------------------------
/simulations.py:
--------------------------------------------------------------------------------
1 | import corpora
2 | import network
3 |
4 | DEFAULT_ACTIVATIONS = (
5 | network.SIGMOID,
6 | network.LINEAR,
7 | network.RELU,
8 | network.SQUARE,
9 | network.UNSIGNED_STEP,
10 | network.FLOOR,
11 | )
12 |
13 | EXTENDED_ACTIVATIONS = DEFAULT_ACTIVATIONS + (network.MODULO_2, network.MODULO_3,)
14 |
15 | DEFAULT_UNIT_TYPES = (network.SUMMATION_UNIT,)
16 |
17 | DEFAULT_CONFIG = {
18 | "migration_ratio": 0.01,
19 | "migration_interval_seconds": 1800,
20 | "migration_interval_generations": 1000,
21 | "num_generations": 25_000,
22 | "population_size": 500,
23 | "elite_ratio": 0.001,
24 | "allowed_activations": DEFAULT_ACTIVATIONS,
25 | "allowed_unit_types": DEFAULT_UNIT_TYPES,
26 | "start_smooth": False,
27 | "compress_grammar_encoding": False,
28 | "tournament_size": 2,
29 | "mutation_probab": 1.0,
30 | "mini_batch_size": None,
31 | "grammar_multiplier": 1,
32 | "data_given_grammar_multiplier": 1,
33 | "max_network_units": 1024,
34 | "softmax_outputs": False,
35 | "truncate_large_values": True,
36 | "bias_connections": True,
37 | "recurrent_connections": True,
38 | "corpus_seed": 100,
39 | "parallelize": True,
40 | "migration_channel": "file",
41 | "generation_dump_interval": 250,
42 | }
43 |
44 | SIMULATIONS = {
45 | "identity": {
46 | "corpus": {
47 | "factory": corpora.make_identity_binary,
48 | "args": {"sequence_length": 100, "batch_size": 10},
49 | }
50 | },
51 | "repeat_last_char": {
52 | "corpus": {
53 | "factory": corpora.make_prev_char_repetition_binary,
54 | "args": {"sequence_length": 100, "batch_size": 10, "repetition_offset": 1,},
55 | }
56 | },
57 | "binary_addition": {
58 | "corpus": {
59 | "factory": corpora.make_binary_addition,
60 | "args": {"min_n": 0, "max_n": 20},
61 | },
62 | },
63 | "dyck_1": {
64 | "corpus": {
65 | "factory": corpora.make_dyck_n,
66 | "args": {
67 | "n": 1,
68 | "batch_size": 100,
69 | "nesting_probab": 0.3,
70 | "max_sequence_length": 200,
71 | },
72 | },
73 | },
74 | "dyck_2": {
75 | "corpus": {
76 | "factory": corpora.make_dyck_n,
77 | "args": {
78 | "batch_size": 20_000,
79 | "nesting_probab": 0.3,
80 | "n": 2,
81 | "max_sequence_length": 200,
82 | },
83 | },
84 | "config": {
85 | "allowed_activations": EXTENDED_ACTIVATIONS,
86 | "allowed_unit_types": (network.SUMMATION_UNIT, network.MULTIPLICATION_UNIT),
87 | },
88 | },
89 | "an_bn": {
90 | "corpus": {
91 | "factory": corpora.make_ain_bjn_ckn_dtn,
92 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 1, 0, 0)},
93 | }
94 | },
95 | "an_bn_cn": {
96 | "corpus": {
97 | "factory": corpora.make_ain_bjn_ckn_dtn,
98 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 1, 1, 0)},
99 | }
100 | },
101 | "an_bn_cn_dn": {
102 | "corpus": {
103 | "factory": corpora.make_ain_bjn_ckn_dtn,
104 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 1, 1, 1)},
105 | }
106 | },
107 | "an_b2n": {
108 | "corpus": {
109 | "factory": corpora.make_ain_bjn_ckn_dtn,
110 | "args": {"batch_size": 100, "prior": 0.3, "multipliers": (1, 2, 0, 0)},
111 | }
112 | },
113 | "an_bn_square": {
114 | "corpus": {
115 | "factory": corpora.make_an_bn_square,
116 | "args": {"batch_size": 1000, "prior": 0.5},
117 | }
118 | },
119 | "palindrome_fixed_length": {
120 | "corpus": {
121 | "factory": corpora.make_binary_palindrome_fixed_length,
122 | "args": {
123 | "batch_size": 1000,
124 | "sequence_length": 50,
125 | "train_set_ratio": 0.7,
126 | },
127 | }
128 | },
129 | "an_bm_cn_plus_m": {
130 | "corpus": {
131 | "factory": corpora.make_an_bm_cn_plus_m,
132 | "args": {"batch_size": 100, "prior": 0.3},
133 | }
134 | },
135 | "center_embedding": {
136 | "corpus": {
137 | "factory": corpora.make_center_embedding,
138 | "args": {
139 | "batch_size": 20_000,
140 | "embedding_depth_probab": 0.3,
141 | "dependency_distance_probab": 0.0,
142 | },
143 | },
144 | "config": {
145 | "allowed_activations": DEFAULT_ACTIVATIONS + (network.MODULO_2,),
146 | "allowed_unit_types": (network.SUMMATION_UNIT, network.MULTIPLICATION_UNIT),
147 | },
148 | },
149 | "0_1_pattern_binary": {
150 | "corpus": {
151 | "factory": corpora.make_0_1_pattern_binary,
152 | "args": {"sequence_length": 20, "batch_size": 1},
153 | }
154 | },
155 | "0_1_pattern_one_hot_no_eos": {
156 | "corpus": {
157 | "factory": corpora.make_0_1_pattern_one_hot,
158 | "args": {
159 | "add_end_of_sequence": False,
160 | "sequence_length": 50,
161 | "batch_size": 1,
162 | },
163 | }
164 | },
165 | "0_1_pattern_one_hot_with_eos": {
166 | "corpus": {
167 | "factory": corpora.make_0_1_pattern_one_hot,
168 | "args": {
169 | "add_end_of_sequence": True,
170 | "sequence_length": 50,
171 | "batch_size": 1,
172 | },
173 | }
174 | },
175 | }
176 |
--------------------------------------------------------------------------------
/stats/tacl_stats.py:
--------------------------------------------------------------------------------
1 | import json
2 | import operator
3 | import re
4 | from typing import Dict
5 |
6 | import numpy as np
7 |
8 |
9 | def _params_key(params_dict):
10 | return str(tuple(sorted(params_dict.items())))
11 |
12 |
13 | def _parse_mdlnn_params(params_string) -> Dict:
14 | return json.loads(re.findall(r"\{.+\}", params_string)[0])
15 |
16 |
17 | def _parse_rnn_params(params_string) -> Dict:
18 | return json.loads(params_string)
19 |
20 |
21 | def _dict_subset(parent_dict, child_dict):
22 | t = all(child_key in parent_dict for child_key in child_dict.keys())
23 | return t
24 |
25 |
26 | def _float_or_none(val):
27 | return float(val) if val is not None else -1
28 |
29 |
30 | def _get_value_from_sim_stats(sims_list, key):
31 | return np.array([_float_or_none(x[key]) for x in sims_list])
32 |
33 |
34 | def _percent_format(x):
35 | return f"{x*100:.1f}%"
36 |
37 |
38 | def _ln_to_log2(ln_loss):
39 | return ln_loss / np.log(2)
40 |
41 |
42 | def _compare_mdl_and_rnn_winners_by_stat(mdl_stats, rnn_stats, key):
43 | mdl_best_val = _float_or_none(mdl_stats[key])
44 | rnn_best_val = _float_or_none(rnn_stats[key])
45 |
46 | if "accuracy" in key.lower():
47 | # Accuracy, higher is better.
48 | op = operator.gt
49 | mdl_best_val_str = _percent_format(mdl_best_val)
50 | rnn_best_val_str = _percent_format(rnn_best_val)
51 | optimal_str = ""
52 | else:
53 | # Cross-entropy, lower is better.
54 | op = operator.lt
55 | rnn_best_val = _ln_to_log2(rnn_best_val)
56 | mdl_best_val_str = f"{mdl_best_val*100:.1f}"
57 | rnn_best_val_str = f"{rnn_best_val*100:.1f}"
58 | optimal_str = f"& ({_float_or_none(rnn_stats['Optimal average test CE per character'])*100:.1f})"
59 |
60 | print(f"\t{key}")
61 | print(f"\tMDLRNN & RNN")
62 | print(f"\t{mdl_best_val_str} & {rnn_best_val_str} {optimal_str}")
63 |
64 | if mdl_best_val == rnn_best_val:
65 | print("\t\tTie")
66 | elif op(mdl_best_val, rnn_best_val):
67 | print("\t\tMDL wins")
68 | else:
69 | print(f"\t\tRNN wins")
70 |
71 |
72 | def compare_mdlnn_rnn_stats(select_by: str):
73 | print(f"Comparing based on {select_by}\n")
74 |
75 | with open("./tacl_stats.json", "r") as f:
76 | data = json.load(f)
77 |
78 | mdlnn_rnn_comparisons = (
79 | "Average test CE per character",
80 | "Test accuracy",
81 | "Test deterministic accuracy",
82 | )
83 |
84 | test_best_rnn_ids = []
85 | test_best_mdlnn_ids = []
86 |
87 | for task_dict in data.values():
88 | corpus_name = task_dict["corpus_name"]
89 | try:
90 | mdl_sims = task_dict["mdl"]
91 | except KeyError:
92 | print(f"No MDL model for {corpus_name}\n\n")
93 | continue
94 | rnn_sims = task_dict["rnn"]
95 |
96 | print(f"* {corpus_name}")
97 |
98 | train_winning_mdl_sim_idx = np.argmin(
99 | _get_value_from_sim_stats(mdl_sims, "MDL")
100 | )
101 | train_winning_mdl_sim_stats = mdl_sims[train_winning_mdl_sim_idx]
102 |
103 | train_mdl_g_scores = _get_value_from_sim_stats(mdl_sims, "|G|")
104 | test_mdl_scores = train_mdl_g_scores + _get_value_from_sim_stats(
105 | mdl_sims, "Test set D:G"
106 | )
107 |
108 | test_winning_mdl_sim_idx = np.argmin(test_mdl_scores)
109 | test_winning_mdl_sim_stats = mdl_sims[test_winning_mdl_sim_idx]
110 |
111 | train_winning_rnn_sim_idx = np.argmin(
112 | _get_value_from_sim_stats(rnn_sims, "Average train CE per character")
113 | )
114 | train_winning_rnn_sim_stats = rnn_sims[train_winning_rnn_sim_idx]
115 |
116 | print(f"\tTrain-best MDL:\t{train_winning_mdl_sim_stats['Simulation id']}")
117 | print(f"\tTrain-best RNN:\t{train_winning_rnn_sim_stats['Simulation id']}\n")
118 |
119 | test_winning_rnn_sim_idx = np.argmin(
120 | # Take winner based on actual CE performance, without regularization ('loss' stat includes regularization term).
121 | _get_value_from_sim_stats(rnn_sims, "Average test CE per character")
122 | )
123 | test_winning_rnn_sim_stats = rnn_sims[test_winning_rnn_sim_idx]
124 | test_rnn_args = json.loads(test_winning_rnn_sim_stats["Params"])
125 |
126 | print(
127 | f"\tTest-best MDL:\t{test_winning_mdl_sim_stats['Simulation id']}.\n{test_winning_mdl_sim_stats['Units']} units, {test_winning_mdl_sim_stats['Connections']} connections"
128 | )
129 | print(
130 | f"\tTest-best RNN:\t{test_winning_rnn_sim_stats['Simulation id']}\n{test_rnn_args['network_type']}, {test_rnn_args['num_hidden_units']} units, {test_rnn_args.get('regularization', 'no')} regularization, {test_rnn_args.get('regularization_lambda', '-')} lambda.\n"
131 | )
132 |
133 | test_best_mdlnn_ids.append(test_winning_mdl_sim_stats["Simulation id"])
134 | test_best_rnn_ids.append(test_winning_rnn_sim_stats["Simulation id"])
135 |
136 | if select_by == "train":
137 | for comparison_key in mdlnn_rnn_comparisons:
138 | _compare_mdl_and_rnn_winners_by_stat(
139 | mdl_stats=train_winning_mdl_sim_stats,
140 | rnn_stats=train_winning_rnn_sim_stats,
141 | key=comparison_key,
142 | )
143 | else:
144 | for comparison_key in mdlnn_rnn_comparisons:
145 | _compare_mdl_and_rnn_winners_by_stat(
146 | mdl_stats=test_winning_mdl_sim_stats,
147 | rnn_stats=test_winning_rnn_sim_stats,
148 | key=comparison_key,
149 | )
150 |
151 |
152 | if __name__ == "__main__":
153 | compare_mdlnn_rnn_stats(select_by="train")
154 | compare_mdlnn_rnn_stats(select_by="test")
155 |
--------------------------------------------------------------------------------
/tests/test_corpus.py:
--------------------------------------------------------------------------------
1 | import unittest
2 |
3 | import numpy as np
4 |
5 | import configuration
6 | import corpora
7 | import manual_nets
8 | import network
9 | import simulations
10 | import utils
11 |
12 | utils.setup_logging()
13 |
14 |
15 | def _test_xor_correctness(input_sequence, target_sequence):
16 | memory = {"step_t_minus_1": 0, "step_t_minus_2": 0}
17 |
18 | def predictor(x):
19 | memory["step_t_minus_2"] = memory["step_t_minus_1"]
20 | memory["step_t_minus_1"] = x
21 | return memory["step_t_minus_1"] ^ memory["step_t_minus_2"]
22 |
23 | correct = 0
24 | for in_, target in zip(input_sequence[0], target_sequence[0]):
25 | prediction = predictor(in_)
26 | if np.all(prediction == target):
27 | correct += 1
28 |
29 | accuracy = correct / input_sequence.shape[1]
30 | print(f"Accuracy: {accuracy:.2f}")
31 |
32 | # A perfect network predicts 33% of the values deterministically, and another 66% x 0.5 by chance.
33 | assert accuracy >= 0.66
34 |
35 |
36 | class TestCorpus(unittest.TestCase):
37 | def test_train_test_shuffle(self):
38 | train_ratio = 0.7
39 | batch_size = 1000
40 | train_corpus = corpora.make_binary_palindrome_fixed_length(
41 | batch_size=batch_size, sequence_length=10, train_set_ratio=train_ratio
42 | )
43 | test_corpus = train_corpus.test_corpus
44 | assert (
45 | sum(train_corpus.sample_weights) + sum(test_corpus.sample_weights)
46 | == batch_size
47 | )
48 |
49 | assert train_corpus.input_sequence.shape[0] == 22
50 | assert test_corpus.input_sequence.shape[0] == 10
51 |
52 | def test_mini_batches(self):
53 | mini_batch_size = 100
54 |
55 | net = manual_nets.make_emmanuel_triplet_xor_network()
56 | full_corpus = corpora.optimize_for_feeding(
57 | corpora.make_elman_xor_binary(sequence_length=300, batch_size=1000)
58 | )
59 |
60 | mini_batches = corpora.split_to_mini_batches(
61 | full_corpus, mini_batch_size=mini_batch_size
62 | )
63 | assert len(mini_batches) == 10
64 |
65 | config = configuration.SimulationConfig(
66 | **{**simulations.DEFAULT_CONFIG, "mini_batch_size": mini_batch_size},
67 | simulation_id="",
68 | num_islands=1,
69 | seed=1,
70 | )
71 | net = network.calculate_fitness(net, full_corpus, config)
72 | print(net)
73 | assert net.fitness.data_encoding_length == 200000
74 |
75 | def test_0_1_pattern_binary(self):
76 | seq_length = 10
77 | batch_size = 10
78 | corpus = corpora.make_0_1_pattern_binary(
79 | sequence_length=seq_length, batch_size=batch_size
80 | )
81 | assert corpus.sample_weights == (batch_size,)
82 | assert corpus.test_corpus.input_sequence.shape == (1, seq_length * 50_000, 1)
83 |
84 | seq_length = 10
85 | batch_size = 1
86 | corpus = corpora.make_0_1_pattern_binary(
87 | sequence_length=seq_length, batch_size=batch_size
88 | )
89 | assert corpus.sample_weights is None
90 |
91 | def test_xor_corpus_correctness(self):
92 | sequence_length = 9999
93 | xor_corpus_binary = corpora.make_elman_xor_binary(sequence_length, batch_size=1)
94 | xor_corpus_one_hot = corpora.make_elman_xor_one_hot(
95 | sequence_length, batch_size=1
96 | )
97 | _test_xor_correctness(
98 | xor_corpus_binary.input_sequence.astype(np.int),
99 | xor_corpus_binary.target_sequence.astype(np.int),
100 | )
101 | _test_xor_correctness(
102 | xor_corpus_one_hot.input_sequence.argmax(axis=-1),
103 | xor_corpus_one_hot.target_sequence.argmax(axis=-1),
104 | )
105 |
106 | def test_binary_addition_corpus_correctness(self):
107 | addition_corpus = corpora.make_binary_addition(min_n=0, max_n=100)
108 |
109 | for input_sequence, target_sequence in zip(
110 | addition_corpus.input_sequence, addition_corpus.target_sequence
111 | ):
112 | n1 = 0
113 | n2 = 0
114 | sum_ = 0
115 | for i, bin_digit in enumerate(input_sequence):
116 | n1_binary_digit = bin_digit[0]
117 | n2_binary_digit = bin_digit[1]
118 | current_exp = 2 ** i
119 | if n1_binary_digit == 1:
120 | n1 += current_exp
121 | if n2_binary_digit == 1:
122 | n2 += current_exp
123 | target_binary_digit = target_sequence[i]
124 | if target_binary_digit == 1:
125 | sum_ += current_exp
126 |
127 | assert n1 + n2 == sum_, (n1, n2, sum_)
128 |
129 | def test_an_bn_corpus(self):
130 | n_values = tuple(range(50))
131 | an_bn_corpus = corpora.optimize_for_feeding(
132 | corpora._make_ain_bjn_ckn_dtn_corpus(
133 | n_values, multipliers=(1, 1, 0, 0), prior=0.1, sort_by_length=True
134 | )
135 | )
136 |
137 | for n in n_values:
138 | row = 49 - n # Sequences are sorted by decreasing length.
139 | input_seq = an_bn_corpus.input_sequence[row]
140 | target_seq = an_bn_corpus.target_sequence[row]
141 | seq_len = 1 + (2 * n)
142 |
143 | zeros_start = 1
144 | ones_start = n + 1
145 |
146 | assert not np.all(corpora.is_masked(input_seq[:seq_len]))
147 | assert np.all(corpora.is_masked(input_seq[seq_len:]))
148 |
149 | input_classes = np.argmax(input_seq, axis=-1)[:seq_len]
150 | target_classes = np.argmax(target_seq, axis=-1)[:seq_len]
151 |
152 | assert np.sum(input_classes == 1) == n
153 | assert np.sum(input_classes == 2) == n
154 |
155 | assert input_classes[0] == 0 # Start of sequence.
156 | assert np.all(input_classes[zeros_start:ones_start] == 1)
157 | assert np.all(input_classes[ones_start:seq_len] == 2)
158 |
159 | assert target_classes[seq_len - 1] == 0 # End of sequence.
160 | assert np.all(target_classes[zeros_start - 1 : ones_start - 1] == 1)
161 | assert np.all(target_classes[ones_start - 1 : seq_len - 1] == 2)
162 |
163 | def test_dfa_baseline_d_g(self):
164 | dfa_ = corpora.make_between_dfa(start=4, end=4)
165 |
166 | corpus = corpora.make_exactly_n_quantifier(4, sequence_length=50, batch_size=10)
167 |
168 | optimal_d_g = dfa_.get_optimal_data_given_grammar_for_dfa(corpus.input_sequence)
169 |
170 | num_non_masked_steps = np.sum(np.all(~np.isnan(corpus.input_sequence), axis=-1))
171 | assert optimal_d_g == num_non_masked_steps
172 |
--------------------------------------------------------------------------------
/tests/test_genetic_algorithm.py:
--------------------------------------------------------------------------------
1 | import dataclasses
2 | import random
3 | import unittest
4 |
5 | import numpy as np
6 |
7 | from mdlrnn import configuration, corpora, genetic_algorithm, network, utils
8 |
9 | utils.setup_logging()
10 | _TEST_CONFIG = configuration.SimulationConfig(
11 | simulation_id="test",
12 | num_islands=1,
13 | migration_ratio=0.1,
14 | migration_interval_seconds=20,
15 | migration_interval_generations=1000,
16 | num_generations=1000,
17 | population_size=20,
18 | elite_ratio=0.05,
19 | allowed_activations=(
20 | network.SIGMOID,
21 | network.LINEAR,
22 | network.RELU,
23 | network.SQUARE,
24 | ),
25 | start_smooth=False,
26 | allowed_unit_types=(network.SUMMATION_UNIT, network.MULTIPLICATION_UNIT),
27 | tournament_size=4,
28 | mutation_probab=0.9,
29 | grammar_multiplier=1,
30 | data_given_grammar_multiplier=1,
31 | compress_grammar_encoding=False,
32 | max_network_units=1024,
33 | softmax_outputs=False,
34 | truncate_large_values=False,
35 | bias_connections=True,
36 | recurrent_connections=True,
37 | seed=1,
38 | corpus_seed=100,
39 | generation_dump_interval=1,
40 | )
41 |
42 |
43 | def _num_non_masked_inputs(x):
44 | return np.sum(~np.all(corpora.is_masked(x), axis=-1))
45 |
46 |
47 | def _make_random_population(size):
48 | population = []
49 | for _ in range(size):
50 | net = network.make_random_net(
51 | input_size=3,
52 | output_size=3,
53 | allowed_activations=_TEST_CONFIG.allowed_activations,
54 | start_smooth=False,
55 | )
56 | p = random.random()
57 | if p < 0.3:
58 | grammar_encoding_length = np.inf
59 | else:
60 | grammar_encoding_length = np.random.randint(100)
61 | data_encoding_length = np.random.randint(100)
62 | net = dataclasses.replace(
63 | net,
64 | fitness=network.Fitness(
65 | mdl=grammar_encoding_length + data_encoding_length,
66 | grammar_encoding_length=grammar_encoding_length,
67 | data_encoding_length=data_encoding_length,
68 | accuracy=1.0,
69 | ),
70 | )
71 | population.append(net)
72 | return population
73 |
74 |
75 | def _test_elite(population, elite):
76 | population_fitnesses = tuple(map(genetic_algorithm._GET_NET_MDL, population))
77 | elite_fitnesses = tuple(map(genetic_algorithm._GET_NET_MDL, elite))
78 |
79 | population_fitness_without_elite = frozenset(population_fitnesses) - frozenset(
80 | elite_fitnesses
81 | )
82 |
83 | worst_elite_fitness = max(elite_fitnesses)
84 | best_non_elite_fitness = min(population_fitness_without_elite)
85 |
86 | assert worst_elite_fitness <= best_non_elite_fitness
87 |
88 |
89 | class TestGeneticAlgorithm(unittest.TestCase):
90 | def test_get_elite_idxs(self):
91 | population = _make_random_population(size=1000)
92 | best_idxs = genetic_algorithm._get_elite_idxs(population, elite_ratio=0.01)
93 | assert len(best_idxs) == 10, len(best_idxs)
94 |
95 | elite = [population[i] for i in best_idxs]
96 | _test_elite(population, elite)
97 |
98 | def test_get_elite(self):
99 | population = _make_random_population(size=1000)
100 | elite = genetic_algorithm._get_elite(elite_ratio=0.01, population=population)
101 | assert len(elite) == 10
102 | _test_elite(population, elite)
103 |
104 | def test_get_migration_target_island(self):
105 | island_num = 3
106 | migration_interval = 20
107 | total_islands = 16
108 | targets = []
109 | generator = genetic_algorithm._make_migration_target_island_generator(
110 | island_num, total_islands
111 | )
112 |
113 | for generation in range(
114 | 1, (total_islands + 1) * migration_interval, migration_interval
115 | ):
116 | targets.append(next(generator))
117 |
118 | assert island_num not in targets
119 | assert tuple(targets) == (
120 | 4,
121 | 5,
122 | 6,
123 | 7,
124 | 8,
125 | 9,
126 | 10,
127 | 11,
128 | 12,
129 | 13,
130 | 14,
131 | 15,
132 | 0,
133 | 1,
134 | 2,
135 | 4,
136 | 5,
137 | )
138 |
139 | def test_tournament_selection_break_infinity_ties(self):
140 | population = _make_random_population(size=10000)
141 | inf_population = [x for x in population if np.isinf(x.fitness.mdl)]
142 | for _ in range(100):
143 | tournament = genetic_algorithm._tournament_selection(
144 | population=inf_population, tournament_size=2
145 | )
146 | winner_idx = tournament.winner_idx
147 | loser_idx = tournament.loser_idx
148 | assert (
149 | inf_population[winner_idx].fitness.grammar_encoding_length
150 | <= inf_population[loser_idx].fitness.grammar_encoding_length
151 | )
152 |
153 | def test_tournament_selection(self):
154 | population = _make_random_population(size=1000)
155 | tournament = genetic_algorithm._tournament_selection(
156 | population=population, tournament_size=2
157 | )
158 | winner_idx = tournament.winner_idx
159 | loser_idx = tournament.loser_idx
160 | assert population[winner_idx].fitness.mdl < population[loser_idx].fitness.mdl
161 |
162 | tournament = genetic_algorithm._tournament_selection(
163 | population=population, tournament_size=len(population),
164 | )
165 | absolute_best_offspring_idx = tournament.winner_idx
166 | absolute_worst_offspring_idx = tournament.loser_idx
167 |
168 | assert (
169 | population[absolute_best_offspring_idx].fitness.mdl
170 | == min(population, key=genetic_algorithm._GET_NET_MDL).fitness.mdl
171 | )
172 |
173 | assert (
174 | population[absolute_worst_offspring_idx].fitness.mdl
175 | == max(population, key=genetic_algorithm._GET_NET_MDL).fitness.mdl
176 | )
177 |
178 | identical_fitness_population = [population[0], population[0]]
179 | tournament = genetic_algorithm._tournament_selection(
180 | identical_fitness_population, tournament_size=2
181 | )
182 | winner_idx = tournament.winner_idx
183 | loser_idx = tournament.loser_idx
184 | assert winner_idx != loser_idx
185 |
--------------------------------------------------------------------------------
/torch_conversion.py:
--------------------------------------------------------------------------------
1 | import collections
2 | import itertools
3 | from typing import Optional
4 |
5 | import mdlrnn_torch
6 | import numpy as np
7 | import torch
8 | from torch import nn
9 |
10 | import corpora
11 | import network
12 | import utils
13 |
14 | logger = utils.setup_logging()
15 |
16 |
17 | def _get_forward_mesh_layers(net) -> dict[int, frozenset[int]]:
18 | forward_mesh_layers = network.bfs_layers(
19 | forward_connections=net.forward_connections,
20 | reverse_forward_connections=net.reverse_forward_connections,
21 | units=network.get_units(net),
22 | input_units_range=net.input_units_range,
23 | output_units_range=net.output_units_range,
24 | )
25 |
26 | # Move inputs to own layer.
27 | input_units = frozenset(net.input_units_range)
28 | for depth in range(2): # Inputs are at most depth 1.
29 | forward_mesh_layers[depth] = forward_mesh_layers[depth] - input_units
30 | if len(forward_mesh_layers[0]) == 0:
31 | input_level = 0
32 | else:
33 | input_level = -1
34 | forward_mesh_layers[input_level] = input_units
35 |
36 | # Move outputs to own layer.
37 | output_units = frozenset(net.output_units_range)
38 | for depth in forward_mesh_layers:
39 | forward_mesh_layers[depth] = forward_mesh_layers[depth] - output_units
40 | forward_mesh_layers = {k: v for k, v in forward_mesh_layers.items() if len(v)}
41 | max_depth = max(forward_mesh_layers)
42 | forward_mesh_layers[max_depth + 1] = output_units
43 |
44 | return forward_mesh_layers
45 |
46 |
47 | def _make_linear_weights(weights: torch.Tensor, bias: Optional[torch.Tensor] = None):
48 | linear = nn.Linear(
49 | in_features=weights.shape[1],
50 | out_features=weights.shape[0],
51 | bias=bias is not None,
52 | )
53 | with torch.no_grad():
54 | linear.weight.copy_(weights)
55 |
56 | if bias is not None:
57 | linear.bias.copy_(bias)
58 |
59 | return linear
60 |
61 |
62 | def _get_unit_to_bfs_layer(
63 | bfs_layer_to_units: dict[int, frozenset[int]]
64 | ) -> dict[int, int]:
65 | unit_to_layer = {}
66 | for source_layer, source_layer_units in bfs_layer_to_units.items():
67 | for unit in source_layer_units:
68 | unit_to_layer[unit] = source_layer
69 | return unit_to_layer
70 |
71 |
72 | def _build_memory_layers(
73 | bfs_layer_to_units: dict[int, frozenset[int]],
74 | net: network.Network,
75 | ):
76 | memory_units = frozenset(
77 | unit
78 | for unit in net.recurrent_connections
79 | if len(net.recurrent_connections[unit])
80 | )
81 | memory_size = len(memory_units)
82 |
83 | unit_to_memory_idx = {unit: i for i, unit in enumerate(sorted(memory_units))}
84 |
85 | layer_to_memory_weights = {}
86 | memory_to_layer_weights = {}
87 |
88 | for layer, layer_units in bfs_layer_to_units.items():
89 | to_memory_weights = torch.zeros((memory_size, len(layer_units)))
90 | from_memory_weight = torch.zeros((len(layer_units), memory_size))
91 |
92 | for i, unit in enumerate(sorted(layer_units)):
93 | if unit in memory_units:
94 | to_memory_weights[unit_to_memory_idx[unit], i] = 1
95 |
96 | if unit in net.reverse_recurrent_connections:
97 | for source_memory_unit in net.reverse_recurrent_connections[unit]:
98 | incoming_weight = network.weight_to_float(
99 | net.recurrent_weights[(source_memory_unit, unit)]
100 | )
101 | incoming_memory_idx = unit_to_memory_idx[source_memory_unit]
102 | from_memory_weight[i, incoming_memory_idx] = incoming_weight
103 |
104 | layer_to_memory_weights[layer] = _make_linear_weights(
105 | to_memory_weights, bias=torch.zeros((memory_size,))
106 | )
107 | memory_to_layer_weights[layer] = _make_linear_weights(
108 | from_memory_weight, bias=torch.zeros((from_memory_weight.shape[0],))
109 | )
110 |
111 | return layer_to_memory_weights, memory_to_layer_weights
112 |
113 |
114 | def _freeze_defaultdict(dd) -> dict:
115 | if "lambda" in str(dd.default_factory):
116 | return {key: _freeze_defaultdict(val) for key, val in dd.items()}
117 | frozen_dict = {}
118 | for key, val in dd.items():
119 | if type(val) == set:
120 | val = frozenset(val)
121 | elif type(val) == list:
122 | val = tuple(val)
123 | frozen_dict[key] = val
124 | return frozen_dict
125 |
126 |
127 | def _build_computation_graph(net: network.Network, bfs_layer_to_units):
128 | unit_to_layer = _get_unit_to_bfs_layer(bfs_layer_to_units)
129 |
130 | layer_to_outgoing_layers = collections.defaultdict(set)
131 | for layer in bfs_layer_to_units:
132 | for unit in bfs_layer_to_units[layer]:
133 | for target_unit in net.forward_connections.get(unit, set()):
134 | layer_to_outgoing_layers[layer].add(unit_to_layer[target_unit])
135 |
136 | layer_to_outgoing_layers = _freeze_defaultdict(layer_to_outgoing_layers)
137 |
138 | units_with_biases = set()
139 |
140 | computation_graph = collections.defaultdict(list)
141 |
142 | for source_layer in sorted(bfs_layer_to_units):
143 | source_layer_units = bfs_layer_to_units[source_layer]
144 | source_layer_size = len(source_layer_units)
145 | source_to_idx = dict((x, i) for i, x in enumerate(sorted(source_layer_units)))
146 |
147 | for target_layer in layer_to_outgoing_layers.get(source_layer, frozenset()):
148 | target_layer_units = bfs_layer_to_units[target_layer]
149 | target_to_idx = dict(
150 | (x, i) for i, x in enumerate(sorted(target_layer_units))
151 | )
152 | target_layer_size = len(target_layer_units)
153 |
154 | weights = torch.zeros((target_layer_size, source_layer_size))
155 |
156 | for source in source_layer_units:
157 | source_idx = source_to_idx[source]
158 | source_unit_targets = (
159 | net.forward_connections.get(source, frozenset())
160 | & target_layer_units
161 | )
162 | for target in source_unit_targets:
163 | target_idx = target_to_idx[target]
164 |
165 | weight = network.weight_to_float(
166 | net.forward_weights[(source, target)]
167 | )
168 | weights[target_idx, source_idx] = weight
169 |
170 | bias = torch.zeros((target_layer_size,))
171 |
172 | for target in target_layer_units:
173 | if target in net.biases and target not in units_with_biases:
174 | units_with_biases.add(target)
175 | bias[target_to_idx[target]] = network.weight_to_float(
176 | net.biases[target]
177 | )
178 |
179 | weights = _make_linear_weights(weights=weights, bias=bias)
180 |
181 | computation_graph[source_layer].append((target_layer, weights))
182 |
183 | # Connect inputs to layers that have no inputs with weights = 0.
184 | floating_layers = set(computation_graph) - set(
185 | [x[0] for x in list(itertools.chain(*computation_graph.values()))]
186 | )
187 | input_layer = min(bfs_layer_to_units)
188 | input_size = len(bfs_layer_to_units[input_layer])
189 | for floating_layer in floating_layers:
190 | layer_size = len(bfs_layer_to_units[floating_layer])
191 | floating_layer_weights = _make_linear_weights(
192 | weights=torch.zeros((layer_size, input_size)), bias=torch.zeros((1,))
193 | )
194 | computation_graph[input_layer].append((floating_layer, floating_layer_weights))
195 | computation_graph = _freeze_defaultdict(computation_graph)
196 |
197 | return computation_graph
198 |
199 |
200 | def _build_activation_function_vectors(net: network.Network, bfs_layer_to_units):
201 | layer_to_activation_to_units = collections.defaultdict(
202 | lambda: collections.defaultdict(list)
203 | )
204 | for source_layer, source_layer_units in bfs_layer_to_units.items():
205 | for unit_idx, unit in enumerate(sorted(source_layer_units)):
206 | activation = net.activations[unit]
207 | layer_to_activation_to_units[source_layer][activation].append(unit_idx)
208 |
209 | layer_to_activation_to_units = _freeze_defaultdict(layer_to_activation_to_units)
210 | return layer_to_activation_to_units
211 |
212 |
213 | def mdlnn_to_torch(net: network.Network) -> mdlrnn_torch.MDLRNN:
214 | bfs_layer_to_units = _get_forward_mesh_layers(net)
215 |
216 | computation_graph = _build_computation_graph(net, bfs_layer_to_units)
217 | layer_to_activation_to_units = _build_activation_function_vectors(
218 | net, bfs_layer_to_units
219 | )
220 | layer_to_memory_weights, memory_to_layer_weights = _build_memory_layers(
221 | bfs_layer_to_units=bfs_layer_to_units, net=net
222 | )
223 |
224 | return mdlrnn_torch.MDLRNN(
225 | computation_graph=computation_graph,
226 | layer_to_memory_weights=layer_to_memory_weights,
227 | memory_to_layer_weights=memory_to_layer_weights,
228 | layer_to_activation_to_units=layer_to_activation_to_units,
229 | )
230 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import dataclasses
3 | import hashlib
4 | import itertools
5 | import logging
6 | import os
7 | import pathlib
8 | import pickle
9 | import random
10 | from typing import Any, Dict, Iterable, Optional, Text, Tuple
11 |
12 | import numpy as np
13 | from numba import types
14 |
15 | import corpora
16 |
17 | BASE_SEED = 100
18 |
19 | MPI_LOGGING_TAG = 1
20 | MPI_RESULTS_TAG = 2
21 | MPI_MIGRANTS_TAG = 3
22 |
23 | FLOAT_DTYPE = np.float64
24 | NUMBA_FLOAT_DTYPE = types.float64
25 |
26 |
27 | def kwargs_from_param_grid(param_grid: Dict[Text, Iterable[Any]]) -> Iterable[Dict]:
28 | arg_names = list(param_grid.keys())
29 | arg_products = list(itertools.product(*param_grid.values()))
30 | for arg_product in arg_products:
31 | yield {arg: val for arg, val in zip(arg_names, arg_product)}
32 |
33 |
34 | def setup_logging():
35 | logging.basicConfig(
36 | format="%(asctime)s.%(msecs)03d %(levelname)-4s [%(filename)s:%(lineno)d] %(message)s",
37 | datefmt="%d-%m-%Y:%H:%M:%S",
38 | level=logging.INFO,
39 | )
40 |
41 |
42 | def load_network_from_zoo(name, subdir=None):
43 | print(name)
44 | path = pathlib.Path(f"./network_zoo/")
45 | if subdir:
46 | path = path.joinpath(subdir.strip("/"))
47 | path = path.joinpath(f"{name}.pickle")
48 | print(path)
49 | with path.open("rb") as f:
50 | return pickle.load(f)
51 |
52 |
53 | def seed(n):
54 | random.seed(n)
55 | np.random.seed(n)
56 |
57 |
58 | def dict_and_corpus_hash(dict_, corpus) -> Text:
59 | s = corpus.name
60 | for key in sorted(dict_.keys()):
61 | s += f"{key} {dict_[key]}"
62 | # TODO: ugly but works.
63 | s += corpus.name + " "
64 | s += str(corpus.input_sequence) + " "
65 | s += str(corpus.target_sequence)
66 | hash = hashlib.sha1()
67 | hash.update(s.encode())
68 | return hash.hexdigest()
69 |
70 |
71 | def add_hash_to_simulation_id(simulation_config, corpus):
72 | config_hash = dict_and_corpus_hash(simulation_config.__dict__, corpus)
73 | simulation_id = f"{corpus.name}_{config_hash}"
74 | return dataclasses.replace(simulation_config, simulation_id=simulation_id)
75 |
76 |
77 | def make_cli_arguments():
78 | arg_parser = argparse.ArgumentParser()
79 | arg_parser.add_argument(
80 | "-s",
81 | "--simulation",
82 | dest="simulation_name",
83 | required=True,
84 | help=f"Simulation name.",
85 | )
86 |
87 | arg_parser.add_argument(
88 | "-n",
89 | "--total-islands",
90 | type=int,
91 | dest="total_islands",
92 | default=os.cpu_count(),
93 | help=f"Total number of islands in entire simulation (including other machines). Default: number of local cores ({os.cpu_count()}).",
94 | )
95 |
96 | arg_parser.add_argument(
97 | "--first-island",
98 | type=int,
99 | default=None,
100 | dest="first_island",
101 | help="First island index on this machine. Default: 0.",
102 | )
103 |
104 | arg_parser.add_argument(
105 | "--last-island",
106 | type=int,
107 | default=None,
108 | dest="last_island",
109 | help="Last island index on this machine. Default: number of islands minus 1.",
110 | )
111 |
112 | arg_parser.add_argument(
113 | "--seed",
114 | type=int,
115 | default=BASE_SEED,
116 | dest="base_seed",
117 | help=f"Base seed value. Default: {BASE_SEED}. For the i-th reproduction (0-based), the seed will be {BASE_SEED} + i.",
118 | )
119 |
120 | arg_parser.add_argument(
121 | "--override",
122 | action="store_true",
123 | dest="override_existing",
124 | help="Override an existing simulation that has the same hash.",
125 | )
126 |
127 | arg_parser.add_argument(
128 | "--resume",
129 | dest="resumed_simulation_id",
130 | default=None,
131 | help="Resume simulation from latest generations.",
132 | )
133 |
134 | arg_parser.add_argument(
135 | "--corpus-args",
136 | default=None,
137 | dest="corpus_args",
138 | help="json to override default corpus arguments.",
139 | )
140 |
141 | return arg_parser
142 |
143 |
144 | def calculate_symbolic_accuracy(
145 | predicted_probabs: np.ndarray,
146 | target_probabs: np.ndarray,
147 | input_mask: Optional[np.ndarray],
148 | sample_weights: Tuple[int],
149 | plots: bool,
150 | epsilon: float = 0.0,
151 | ) -> Tuple[float, Tuple[int, ...]]:
152 | zero_target_probabs = target_probabs == 0.0
153 |
154 | zero_predicted_probabs = predicted_probabs <= epsilon
155 |
156 | prediction_matches = np.all(
157 | np.equal(zero_predicted_probabs, zero_target_probabs), axis=-1
158 | )
159 |
160 | prediction_matches[~input_mask] = True
161 |
162 | sequence_idxs_with_errors = tuple(np.where(np.any(~prediction_matches, axis=1))[0])
163 | logging.info(f"Sequence idxs with mismatches: {sequence_idxs_with_errors}")
164 |
165 | incorrect_predictions_per_time_step = np.sum(~prediction_matches, axis=0)
166 |
167 | if plots:
168 | from matplotlib import pyplot as plt
169 |
170 | fig, ax = plt.subplots()
171 | ax.set_title("Num prediction mismatches by time step")
172 | ax.bar(
173 | np.arange(len(incorrect_predictions_per_time_step)),
174 | incorrect_predictions_per_time_step,
175 | )
176 | plt.show()
177 |
178 | prediction_matches_without_masked = prediction_matches[input_mask]
179 |
180 | w = np.array(sample_weights).reshape((-1, 1))
181 | weights_repeated = np.matmul(w, np.ones((1, predicted_probabs.shape[1])))
182 | weights_masked = weights_repeated[input_mask]
183 |
184 | prediction_matches_weighted = np.multiply(
185 | prediction_matches_without_masked, weights_masked
186 | )
187 |
188 | symbolic_accuracy = np.sum(prediction_matches_weighted) / np.sum(weights_masked)
189 | return symbolic_accuracy, sequence_idxs_with_errors
190 |
191 |
192 | def plot_probabs(probabs: np.ndarray, input_classes, class_to_label=None):
193 | from matplotlib import _color_data as matploit_color_data
194 | from matplotlib import pyplot as plt
195 |
196 | if probabs.shape[-1] == 1:
197 | # Binary outputs, output is P(1).
198 | probabs_ = np.zeros((probabs.shape[0], 2))
199 | probabs_[:, 0] = (1 - probabs).squeeze()
200 | probabs_[:, 1] = probabs.squeeze()
201 | probabs = probabs_
202 |
203 | masked_timesteps = np.where(corpora.is_masked(probabs))[0]
204 | if len(masked_timesteps):
205 | first_mask_step = masked_timesteps[0]
206 | probabs = probabs[:first_mask_step]
207 |
208 | plt.rc("grid", color="w", linestyle="solid")
209 | if class_to_label is None:
210 | class_to_label = {i: str(i) for i in range(len(input_classes))}
211 | fig, ax = plt.subplots(figsize=(9, 5), dpi=150, facecolor="white")
212 | x = np.arange(probabs.shape[0])
213 | num_classes = probabs.shape[1]
214 | width = 0.8
215 | colors = (
216 | list(matploit_color_data.TABLEAU_COLORS) + list(matploit_color_data.XKCD_COLORS)
217 | )[:num_classes]
218 | for c in range(num_classes):
219 | ax.bar(
220 | x,
221 | probabs[:, c],
222 | label=f"P({class_to_label[c]})" if num_classes > 1 else "P(1)",
223 | color=colors[c],
224 | width=width,
225 | bottom=np.sum(probabs[:, :c], axis=-1),
226 | )
227 | ax.set_facecolor("white")
228 | ax.set_xticks(x)
229 | ax.set_xticklabels([class_to_label[x] for x in input_classes], fontsize=13)
230 |
231 | ax.set_xlabel("Input characters", fontsize=15)
232 | ax.set_ylabel("Next character probability", fontsize=15)
233 |
234 | ax.grid(b=True, color="#bcbcbc")
235 | # plt.title("Next step prediction probabilities", fontsize=22)
236 | plt.legend(loc="upper left", fontsize=15)
237 |
238 | # fig.savefig("test.png")
239 | fig.subplots_adjust(bottom=0.1)
240 |
241 | plt.show()
242 | fig.savefig(
243 | f"./figures/net_probabs_{random.randint(0,10_000)}.pdf",
244 | dpi=300,
245 | facecolor="white",
246 | )
247 |
--------------------------------------------------------------------------------
/vanilla_rnn.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from typing import List, Optional, Text, Tuple
3 |
4 | import numpy as np
5 | import torch
6 | from torch import nn, optim
7 | from torch.nn.utils import rnn
8 |
9 | import corpora
10 | import utils
11 |
12 | utils.setup_logging()
13 |
14 |
15 | def _ensure_torch(x) -> torch.Tensor:
16 | if isinstance(x, np.ndarray):
17 | x = torch.from_numpy(x).float()
18 | return x
19 |
20 |
21 | def _get_masked_sequence_lengths(x) -> List[int]:
22 | # Ugly and slow but works.
23 | sequence_lengths = []
24 | for b in range(x.shape[0]):
25 | seq = x[b]
26 | any_masked = False
27 | for i in range(seq.shape[0]):
28 | if np.all(corpora.is_masked(seq[i])):
29 | any_masked = True
30 | sequence_lengths.append(i)
31 | break
32 | if not any_masked:
33 | sequence_lengths.append(seq.shape[0])
34 | return sequence_lengths
35 |
36 |
37 | def _get_packed_sample_weights(
38 | sample_weights, max_sequence_length, masked_sequence_lengths
39 | ):
40 | if sample_weights is None:
41 | return None
42 | w = np.array(sample_weights).reshape((-1, 1))
43 | weights_repeated = np.matmul(w, torch.ones((1, max_sequence_length)))
44 | return _pack_masked(weights_repeated, masked_sequence_lengths).data
45 |
46 |
47 | def _pack_masked(sequences: np.ndarray, sequence_lengths):
48 | return rnn.pack_padded_sequence(
49 | _ensure_torch(sequences),
50 | lengths=sequence_lengths,
51 | batch_first=True,
52 | enforce_sorted=True,
53 | )
54 |
55 |
56 | def _unpack(
57 | packed_tensor,
58 | packed_sequence_batch_sizes,
59 | batch_size,
60 | max_sequence_length,
61 | num_classes,
62 | ):
63 | # Assumes sorted indices.
64 | unpacked = np.empty((batch_size, max_sequence_length, num_classes))
65 | unpacked.fill(corpora.MASK_VALUE)
66 | i = 0
67 | for t, batch_size in enumerate(packed_sequence_batch_sizes):
68 | unpacked[:batch_size, t] = packed_tensor[i : i + batch_size]
69 | i += batch_size
70 | return unpacked
71 |
72 |
73 | class VanillaRNN(nn.Module):
74 | def __init__(
75 | self,
76 | num_hidden_units: int,
77 | corpus: corpora.Corpus,
78 | network_type: Text,
79 | num_epochs: int,
80 | regularization: Optional[Text],
81 | regularization_lambda: Optional[float],
82 | ):
83 | super(VanillaRNN, self).__init__()
84 |
85 | self._num_hidden_units = num_hidden_units
86 | self._num_epochs = num_epochs
87 | self._train_corpus = corpus
88 | self._regularization = regularization
89 | self._regularization_lambda = regularization_lambda
90 |
91 | self._input_size = self._train_corpus.input_sequence.shape[-1]
92 | self._output_size = self._train_corpus.target_sequence.shape[-1]
93 |
94 | rnn_kwargs = {
95 | "input_size": self._input_size,
96 | "hidden_size": self._num_hidden_units,
97 | "batch_first": True,
98 | }
99 | rnn_type_to_layer = {"elman": nn.RNN, "lstm": nn.LSTM, "gru": nn.GRU}
100 | rnn_layer = rnn_type_to_layer[network_type]
101 | self._rnn_layer = rnn_layer(**rnn_kwargs)
102 |
103 | layers = [
104 | nn.Linear(
105 | in_features=self._num_hidden_units, out_features=self._output_size,
106 | ),
107 | ]
108 | if self._output_size == 1:
109 | layers.append(nn.Sigmoid())
110 | else:
111 | layers.append(nn.LogSoftmax(dim=-1))
112 | self._layers = nn.Sequential(*layers)
113 |
114 | def _forward(self, x):
115 | rnn_packed_outputs, _ = self._rnn_layer(x)
116 | rnn_outputs = rnn_packed_outputs.data
117 | return self._layers(rnn_outputs)
118 |
119 | def fit(self):
120 | train_max_sequence_length = self._train_corpus.input_sequence.shape[1]
121 | train_sequence_lengths = _get_masked_sequence_lengths(
122 | self._train_corpus.input_sequence
123 | )
124 | train_inputs_packed = _pack_masked(
125 | self._train_corpus.input_sequence, sequence_lengths=train_sequence_lengths
126 | )
127 | train_targets_packed = _pack_masked(
128 | self._train_corpus.target_sequence, sequence_lengths=train_sequence_lengths
129 | ).data
130 |
131 | train_sample_weights_packed = _get_packed_sample_weights(
132 | self._train_corpus.sample_weights,
133 | max_sequence_length=train_max_sequence_length,
134 | masked_sequence_lengths=train_sequence_lengths,
135 | )
136 |
137 | optimizer = optim.Adam(self.parameters(), lr=0.001)
138 |
139 | for epoch in range(self._num_epochs):
140 | optimizer.zero_grad()
141 | output = self._forward(train_inputs_packed)
142 | cross_entropy_loss, _ = _calculate_loss(
143 | net=self,
144 | outputs_packed=output,
145 | targets_packed=train_targets_packed,
146 | sample_weights=train_sample_weights_packed,
147 | regularization=self._regularization,
148 | regularization_lambda=self._regularization_lambda,
149 | )
150 | cross_entropy_loss.backward()
151 | optimizer.step()
152 |
153 | if epoch % 10 == 0:
154 | logging.info(
155 | f"Epoch {epoch} training loss: {cross_entropy_loss.item():.3e}"
156 | )
157 |
158 | def feed_sequence(self, input_sequence):
159 | with torch.no_grad():
160 | return self._forward(input_sequence)
161 |
162 |
163 | def _calculate_loss(
164 | net: VanillaRNN,
165 | outputs_packed,
166 | targets_packed,
167 | sample_weights,
168 | regularization,
169 | regularization_lambda,
170 | ):
171 | if targets_packed.shape[-1] == 1:
172 | loss_func = nn.BCELoss
173 | target_classes = targets_packed
174 | else:
175 | # Not using `CrossEntropyLoss` because network outputs are already log-softmaxed.
176 | loss_func = nn.NLLLoss
177 | target_classes = targets_packed.argmax(axis=-1)
178 |
179 | non_reduced_loss = loss_func(reduction="none")(outputs_packed, target_classes)
180 |
181 | if sample_weights is not None:
182 | weighted_losses = torch.mul(non_reduced_loss, sample_weights)
183 | weighted_losses_sum = weighted_losses.sum()
184 | total_chars_in_input = sample_weights.sum()
185 | average_loss = weighted_losses_sum / total_chars_in_input
186 | else:
187 | average_loss = non_reduced_loss.mean()
188 | weighted_losses_sum = non_reduced_loss.sum()
189 |
190 | regularized_loss = 0
191 | if regularization == "L1":
192 | for p in net._rnn_layer.parameters():
193 | regularized_loss += torch.sum(torch.abs(p))
194 | for p in net._layers.parameters():
195 | regularized_loss += torch.sum(torch.abs(p))
196 |
197 | elif regularization == "L2":
198 | for p in net._rnn_layer.parameters():
199 | regularized_loss += torch.sum(torch.square(p))
200 | for p in net._layers.parameters():
201 | regularized_loss += torch.sum(torch.square(p))
202 |
203 | average_loss = average_loss + (regularization_lambda * regularized_loss)
204 |
205 | return average_loss, weighted_losses_sum
206 |
207 |
208 | def calculate_symbolic_accuracy(
209 | found_net: VanillaRNN,
210 | inputs: np.ndarray,
211 | target_probabs: np.ndarray,
212 | sample_weights: Tuple[int, ...],
213 | input_mask: np.ndarray,
214 | epsilon: float,
215 | plots: bool = False,
216 | ):
217 | sequence_lengths = _get_masked_sequence_lengths(inputs)
218 |
219 | inputs_packed = _pack_masked(inputs, sequence_lengths)
220 | predicted_probabs_packed = found_net.feed_sequence(inputs_packed)
221 | predicted_probabs = _unpack(
222 | packed_tensor=predicted_probabs_packed,
223 | packed_sequence_batch_sizes=inputs_packed.batch_sizes,
224 | batch_size=inputs.shape[0],
225 | max_sequence_length=inputs.shape[1],
226 | num_classes=target_probabs.shape[-1],
227 | )
228 | predicted_probabs = np.exp(predicted_probabs)
229 |
230 | return utils.calculate_symbolic_accuracy(
231 | predicted_probabs=predicted_probabs,
232 | target_probabs=target_probabs,
233 | input_mask=input_mask,
234 | plots=plots,
235 | sample_weights=sample_weights,
236 | epsilon=epsilon,
237 | )
238 |
239 |
240 | def evaluate(
241 | net,
242 | inputs,
243 | targets,
244 | sample_weights,
245 | deterministic_steps_mask,
246 | regularization,
247 | regularization_lambda,
248 | ):
249 | sequence_lengths = _get_masked_sequence_lengths(inputs)
250 |
251 | inputs_packed = _pack_masked(inputs, sequence_lengths)
252 | targets_packed = _pack_masked(targets, sequence_lengths).data
253 |
254 | sample_weights_packed = _get_packed_sample_weights(
255 | sample_weights,
256 | max_sequence_length=inputs.shape[1],
257 | masked_sequence_lengths=sequence_lengths,
258 | )
259 |
260 | y_pred = net.feed_sequence(inputs_packed)
261 |
262 | cross_entropy_loss, cross_entropy_sum = _calculate_loss(
263 | net,
264 | y_pred,
265 | targets_packed,
266 | sample_weights_packed,
267 | regularization,
268 | regularization_lambda,
269 | )
270 |
271 | if targets.shape[-1] == 1:
272 | target_classes = targets_packed.flatten()
273 | predicted_classes = (y_pred > 0.5).flatten().long()
274 | else:
275 | target_classes = targets_packed.argmax(dim=-1).flatten()
276 | predicted_classes = y_pred.argmax(dim=-1).flatten()
277 |
278 | correct = torch.sum(torch.eq(predicted_classes, target_classes)).item()
279 | accuracy = correct / len(target_classes)
280 |
281 | if deterministic_steps_mask is not None:
282 | deterministic_mask_packed = _pack_masked(
283 | deterministic_steps_mask, sequence_lengths
284 | ).data.bool()
285 | det_target_classes = target_classes[deterministic_mask_packed]
286 | det_correct = torch.eq(
287 | predicted_classes[deterministic_mask_packed], det_target_classes
288 | )
289 | det_flat_sample_weights = sample_weights_packed[deterministic_mask_packed]
290 | det_correct_weighted = torch.mul(det_correct.int(), det_flat_sample_weights)
291 | det_accuracy = (
292 | f"{det_correct_weighted.sum() / det_flat_sample_weights.sum().item():.5f}"
293 | )
294 | else:
295 | det_accuracy = None
296 |
297 | logging.info(
298 | f"Accuracy: {accuracy:.5f} (Correct: {correct} / {len(target_classes)})\n"
299 | f"Deterministic accuracy: {det_accuracy}\n"
300 | f"Cross-entropy loss: {cross_entropy_loss:.2f}\n"
301 | f"Cross-entropy sum: {cross_entropy_sum:.2f}"
302 | )
303 | return (
304 | accuracy,
305 | cross_entropy_loss.item(),
306 | cross_entropy_sum.item(),
307 | det_accuracy,
308 | )
309 |
--------------------------------------------------------------------------------