├── .gitignore ├── README.md ├── cell.py ├── docs ├── wfc.gif └── wfc_voxel.gif ├── examples ├── midi.py └── voxel.py ├── grid.py ├── main.py ├── pattern.py ├── propagator.py ├── requirements.txt ├── samples ├── blue.png ├── line.png ├── maze.png ├── red_maze.png └── twinkle_twinkle.mid └── wfc.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # wave-function-collapse 2 | 3 | [Wave function collapse](https://github.com/mxgmn/WaveFunctionCollapse) python implementation. 4 | It supports 1D, 2D, 3D samples. 5 | 6 | ![wfc_example](./docs/wfc.gif) 7 | 8 | ## Installation 9 | ``` 10 | pip install -r requirements.txt 11 | ``` 12 | [py-vox-io](https://github.com/gromgull/py-vox-io) is used to load magica voxel file to numpy array. 13 | 14 | ## Usage 15 | See `main.py` for usage example. 16 | 17 | ## Examples 18 | 19 | ### Midi file 20 | See `examples/midi.py` 21 | 22 | ### Voxel 23 | See `examples/voxel.py` 24 | 25 | ![wfc_example](./docs/wfc_voxel.gif) 26 | -------------------------------------------------------------------------------- /cell.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from pattern import Pattern 4 | 5 | 6 | class Cell: 7 | """ 8 | Cell is a pixel or tile (in 2d) that stores the possible patterns 9 | """ 10 | 11 | def __init__(self, num_pattern, position, grid): 12 | self.num_pattern = num_pattern 13 | self.allowed_patterns = [i for i in range(self.num_pattern)] 14 | 15 | self.position = position 16 | self.grid = grid 17 | self.offsets = [(z, y, x) for x in range(-1, 2) for y in range(-1, 2) for z in range(-1, 2)] 18 | 19 | def entropy(self): 20 | return len(self.allowed_patterns) 21 | 22 | def choose_rnd_pattern(self): 23 | chosen_index = np.random.randint(len(self.allowed_patterns)) 24 | self.allowed_patterns = [self.allowed_patterns[chosen_index]] 25 | 26 | def is_stable(self): 27 | return len(self.allowed_patterns) == 1 28 | 29 | def get_value(self): 30 | if self.is_stable(): 31 | pattern = Pattern.from_index(self.allowed_patterns[0]) 32 | return pattern.get() 33 | return -1 34 | 35 | def get_neighbors(self): 36 | neighbors = [] 37 | for offset in self.offsets: 38 | neighbor_pos = tuple(np.array(self.position) + np.array(offset)) 39 | out = False 40 | for i, d in enumerate(neighbor_pos): 41 | if not 0 <= d < self.grid.size[i]: 42 | out = True 43 | if out: 44 | continue 45 | 46 | neighbors.append((self.grid.get_cell(neighbor_pos), offset)) 47 | 48 | return neighbors 49 | -------------------------------------------------------------------------------- /docs/wfc.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/docs/wfc.gif -------------------------------------------------------------------------------- /docs/wfc_voxel.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/docs/wfc_voxel.gif -------------------------------------------------------------------------------- /examples/midi.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example of using the wave function collapse with midi file. 3 | It uses the mido library to load and export midi files 4 | You should install mido using `pip install mido` 5 | """ 6 | 7 | import numpy as np 8 | from mido import MidiFile, MidiTrack, Message, second2tick, MetaMessage 9 | from mido.midifiles.midifiles import DEFAULT_TEMPO 10 | 11 | from wfc import WaveFunctionCollapse 12 | 13 | TEMPO = DEFAULT_TEMPO 14 | 15 | 16 | def load_midi_sample(path): 17 | midi_file = MidiFile(path) 18 | notes = [] 19 | time = 0.0 20 | prev_time = 0.0 21 | for msg in midi_file: 22 | time += msg.time 23 | if msg.is_meta: 24 | if msg.type == 'set_tempo': 25 | global TEMPO 26 | TEMPO = msg.tempo 27 | else: 28 | if msg.channel == 0: 29 | # TODO note_off 30 | if msg.type == 'note_on': 31 | note = msg.bytes() 32 | note.append(time - prev_time) 33 | prev_time = time 34 | notes.append(note) 35 | 36 | notes = np.array(notes) 37 | 38 | notes = np.expand_dims(notes, axis=0) 39 | notes = np.expand_dims(notes, axis=0) 40 | 41 | return notes, midi_file.ticks_per_beat 42 | 43 | 44 | def export_midi(notes, path, ticks_per_beat): 45 | notes = np.squeeze(notes, axis=0) 46 | notes = np.squeeze(notes, axis=0) 47 | 48 | midi_file = MidiFile() 49 | midi_file.ticks_per_beat = ticks_per_beat 50 | track = MidiTrack() 51 | track.append(MetaMessage('set_tempo', tempo=TEMPO)) 52 | midi_file.tracks.append(track) 53 | for note in notes: 54 | bytes = note.astype(int) 55 | msg = Message.from_bytes(bytes[0:3]) 56 | time = int(second2tick(note[3], ticks_per_beat, TEMPO)) 57 | msg.time = time 58 | track.append(msg) 59 | 60 | print(midi_file) 61 | midi_file.save(path) 62 | 63 | 64 | if __name__ == '__main__': 65 | np.random.seed(42) 66 | 67 | grid_size = (1, 1, 100) 68 | pattern_size = (1, 1, 2) 69 | 70 | sample, ticks_per_beat = load_midi_sample('../samples/twinkle_twinkle.mid') 71 | 72 | print('sample shape:', sample.shape) 73 | 74 | wfc = WaveFunctionCollapse(grid_size, sample, pattern_size) 75 | wfc.run() 76 | 77 | notes = wfc.get_image() 78 | 79 | export_midi(notes, '../samples/output.mid', ticks_per_beat) 80 | -------------------------------------------------------------------------------- /examples/voxel.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example of using the wave function collapse with 3D voxel file. 3 | It loads a magica voxel file using py-vox-io 4 | You should install py-vox-io using `pip install py-vox-io` 5 | """ 6 | 7 | import numpy as np 8 | from pyvox.models import Vox 9 | from pyvox.parser import VoxParser 10 | from pyvox.writer import VoxWriter 11 | 12 | from wfc import WaveFunctionCollapse 13 | 14 | 15 | def load_voxel_sample(path): 16 | vox_parser = VoxParser(path).parse() 17 | sample = vox_parser.to_dense() 18 | sample = np.expand_dims(sample, axis=3) 19 | return sample 20 | 21 | 22 | def export_voxel(path, image): 23 | print("image shape:", image.shape) 24 | image = wfc.get_image() 25 | image = np.squeeze(image, axis=3) 26 | image = image.astype(int) 27 | vox = Vox.from_dense(image) 28 | VoxWriter(path, vox).write() 29 | 30 | 31 | if __name__ == '__main__': 32 | np.random.seed(42) 33 | 34 | grid_size = (6, 6, 6) 35 | pattern_size = (2, 2, 2) 36 | 37 | sample = load_voxel_sample('../samples/test.vox') 38 | 39 | wfc = WaveFunctionCollapse(grid_size, sample, pattern_size) 40 | 41 | wfc.run() 42 | 43 | image = wfc.get_image() 44 | 45 | export_voxel('../samples/output.vox', image) 46 | -------------------------------------------------------------------------------- /grid.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from cell import Cell 4 | from pattern import Pattern 5 | 6 | 7 | class Grid: 8 | """ 9 | Grid is made of Cells 10 | """ 11 | 12 | def __init__(self, size, num_pattern): 13 | self.size = size 14 | self.grid = np.empty(self.size, dtype=object) 15 | for position in np.ndindex(self.size): 16 | self.grid[position] = Cell(num_pattern, position, self) 17 | 18 | # self.grid = np.array([[Cell(num_pattern, (x, y), self) for x in range(self.size)] for y in range(self.size)]) 19 | # self.grid = np.array([Cell(num_pattern, (x,), self) for x in range(self.size)]) 20 | 21 | def find_lowest_entropy(self): 22 | min_entropy = 999999 23 | lowest_entropy_cells = [] 24 | for cell in self.grid.flat: 25 | if cell.is_stable(): 26 | continue 27 | 28 | entropy = cell.entropy() 29 | 30 | if entropy == min_entropy: 31 | lowest_entropy_cells.append(cell) 32 | elif entropy < min_entropy: 33 | min_entropy = entropy 34 | lowest_entropy_cells = [cell] 35 | 36 | if len(lowest_entropy_cells) == 0: 37 | return None 38 | cell = lowest_entropy_cells[np.random.randint(len(lowest_entropy_cells))] 39 | return cell 40 | 41 | def get_cell(self, index): 42 | """ 43 | Returns the cell contained in the grid at the provided index 44 | :param index: (...z, y, x) 45 | :return: cell 46 | """ 47 | return self.grid[index] 48 | 49 | def get_image(self): 50 | """ 51 | Returns the grid converted from index to back to color 52 | :return: 53 | """ 54 | image = np.vectorize(lambda c: c.get_value())(self.grid) 55 | image = Pattern.index_to_img(image) 56 | return image 57 | 58 | def check_contradiction(self): 59 | for cell in self.grid.flat: 60 | if len(cell.allowed_patterns) == 0: 61 | return True 62 | return False 63 | 64 | def print_allowed_pattern_count(self): 65 | grid_allowed_patterns = np.vectorize(lambda c: len(c.allowed_patterns))(self.grid) 66 | print(grid_allowed_patterns) 67 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | """ 2 | An example of using the wave function collapse with 2D image. 3 | 4 | """ 5 | 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | 9 | from wfc import WaveFunctionCollapse 10 | 11 | 12 | def plot_patterns(patterns, title=''): 13 | fig = plt.figure(figsize=(8, 8)) 14 | fig.suptitle(title, fontsize=16) 15 | columns = 4 16 | rows = 5 17 | for i in range(1, columns * rows + 1): 18 | if i > len(patterns): 19 | break 20 | fig.add_subplot(rows, columns, i) 21 | show(patterns[i - 1]) 22 | 23 | plt.show() 24 | 25 | 26 | def load_sample(path): 27 | sample = plt.imread(path) 28 | # Expand dim to 3D 29 | sample = np.expand_dims(sample, axis=0) 30 | sample = sample[:, :, :, :3] 31 | 32 | return sample 33 | 34 | 35 | def show(image): 36 | if image.shape[0] == 1: 37 | return plt.imshow(np.squeeze(image, axis=0)) 38 | 39 | 40 | if __name__ == '__main__': 41 | 42 | grid_size = (1, 30, 30) 43 | pattern_size = (1, 2, 2) 44 | 45 | sample = load_sample('samples/blue.png') 46 | show(sample) 47 | plt.show() 48 | 49 | wfc = WaveFunctionCollapse(grid_size, sample, pattern_size) 50 | plot_patterns(wfc.get_patterns(), 'patterns') 51 | 52 | # _, _, legal_patterns = wfc.propagator.legal_patterns(wfc.patterns[2], (0, 0, 1)) 53 | # show(Pattern.from_index(2).to_image()) 54 | # plt.show() 55 | # plot_patterns([Pattern.from_index(i).to_image() for i in legal_patterns]) 56 | 57 | fig, ax = plt.subplots() 58 | image = wfc.get_image() 59 | im = show(image) 60 | while True: 61 | done = wfc.step() 62 | if done: 63 | break 64 | image = wfc.get_image() 65 | 66 | if image.shape[0] == 1: 67 | image = np.squeeze(image, axis=0) 68 | im.set_array(image) 69 | 70 | fig.canvas.draw() 71 | plt.pause(0.001) 72 | 73 | plt.show() 74 | -------------------------------------------------------------------------------- /pattern.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | class Pattern: 5 | """ 6 | Pattern is a configuration of tiles from the input image. 7 | """ 8 | index_to_pattern = {} 9 | color_to_index = {} 10 | index_to_color = {} 11 | 12 | def __init__(self, data, index): 13 | self.index = index 14 | self.data = np.array(data) 15 | self.legal_patterns_index = {} # offset -> [pattern_index] 16 | 17 | def get(self, index=None): 18 | if index is None: 19 | return self.data.item(0) 20 | return self.data[index] 21 | 22 | def set_legal_patterns(self, offset, legal_patterns): 23 | self.legal_patterns_index[offset] = legal_patterns 24 | 25 | @property 26 | def shape(self): 27 | return self.data.shape 28 | 29 | def is_compatible(self, candidate_pattern, offset): 30 | """ 31 | Check if pattern is compatible with a candidate pattern for a given offset 32 | :param candidate_pattern: 33 | :param offset: 34 | :return: True if compatible 35 | """ 36 | assert (self.shape == candidate_pattern.shape) 37 | 38 | # Precomputed compatibility 39 | if offset in self.legal_patterns_index: 40 | return candidate_pattern.index in self.legal_patterns_index[offset] 41 | 42 | # Computing compatibility 43 | ok_constraint = True 44 | start = tuple([max(offset[i], 0) for i, _ in enumerate(offset)]) 45 | end = tuple([min(self.shape[i] + offset[i], self.shape[i]) for i, _ in enumerate(offset)]) 46 | for index in np.ndindex(end): # index = (x, y, z...) 47 | start_constraint = True 48 | for i, d in enumerate(index): 49 | if d < start[i]: 50 | start_constraint = False 51 | break 52 | if not start_constraint: 53 | continue 54 | 55 | if candidate_pattern.get(tuple(np.array(index) - np.array(offset))) != self.get(index): 56 | ok_constraint = False 57 | break 58 | 59 | return ok_constraint 60 | 61 | def to_image(self): 62 | return Pattern.index_to_img(self.data) 63 | 64 | @staticmethod 65 | def from_sample(sample, pattern_size): 66 | """ 67 | Compute patterns from sample 68 | :param pattern_size: 69 | :param sample: 70 | :return: list of patterns 71 | """ 72 | 73 | sample = Pattern.sample_img_to_indexes(sample) 74 | 75 | shape = sample.shape 76 | patterns = [] 77 | pattern_index = 0 78 | 79 | for index, _ in np.ndenumerate(sample): 80 | # Checking if index is out of bounds 81 | out = False 82 | for i, d in enumerate(index): # d is a dimension, e.g.: x, y, z 83 | if d > shape[i] - pattern_size[i]: 84 | out = True 85 | break 86 | if out: 87 | continue 88 | 89 | pattern_location = [range(d, pattern_size[i] + d) for i, d in enumerate(index)] 90 | pattern_data = sample[np.ix_(*pattern_location)] 91 | 92 | datas = [pattern_data, np.fliplr(pattern_data)] 93 | if shape[1] > 1: # is 2D 94 | datas.append(np.flipud(pattern_data)) 95 | datas.append(np.rot90(pattern_data, axes=(1, 2))) 96 | datas.append(np.rot90(pattern_data, 2, axes=(1, 2))) 97 | datas.append(np.rot90(pattern_data, 3, axes=(1, 2))) 98 | 99 | if shape[0] > 1: # is 3D 100 | datas.append(np.flipud(pattern_data)) 101 | datas.append(np.rot90(pattern_data, axes=(0, 2))) 102 | datas.append(np.rot90(pattern_data, 2, axes=(0, 2))) 103 | datas.append(np.rot90(pattern_data, 3, axes=(0, 2))) 104 | 105 | # Checking existence 106 | # TODO: more probability to multiple occurrences when observe phase 107 | for data in datas: 108 | exist = False 109 | for p in patterns: 110 | if (p.data == data).all(): 111 | exist = True 112 | break 113 | if exist: 114 | continue 115 | 116 | pattern = Pattern(data, pattern_index) 117 | patterns.append(pattern) 118 | Pattern.index_to_pattern[pattern_index] = pattern 119 | pattern_index += 1 120 | 121 | # Pattern.plot_patterns(patterns) 122 | return patterns 123 | 124 | @staticmethod 125 | def sample_img_to_indexes(sample): 126 | """ 127 | Convert a rgb image to a 2D array with pixel index 128 | :param sample: 129 | :return: pixel index sample 130 | """ 131 | Pattern.color_to_index = {} 132 | Pattern.index_to_color = {} 133 | sample_index = np.zeros(sample.shape[:-1]) # without last rgb dim 134 | color_number = 0 135 | for index in np.ndindex(sample.shape[:-1]): 136 | color = tuple(sample[index]) 137 | if color not in Pattern.color_to_index: 138 | Pattern.color_to_index[color] = color_number 139 | Pattern.index_to_color[color_number] = color 140 | color_number += 1 141 | 142 | sample_index[index] = Pattern.color_to_index[color] 143 | 144 | print('Unique color count = ', color_number) 145 | return sample_index 146 | 147 | @staticmethod 148 | def index_to_img(sample): 149 | color = next(iter(Pattern.index_to_color.values())) 150 | 151 | image = np.zeros(sample.shape + (len(color),)) 152 | for index in np.ndindex(sample.shape): 153 | pattern_index = sample[index] 154 | if pattern_index == -1: 155 | image[index] = [0.5 for _ in range(len(color))] # Grey 156 | else: 157 | image[index] = Pattern.index_to_color[pattern_index] 158 | return image 159 | 160 | @staticmethod 161 | def from_index(pattern_index): 162 | return Pattern.index_to_pattern[pattern_index] 163 | -------------------------------------------------------------------------------- /propagator.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | from multiprocessing import Pool 4 | 5 | from pattern import Pattern 6 | 7 | 8 | class Propagator: 9 | """ 10 | Propagator that computes and stores the legal patterns relative to another 11 | """ 12 | 13 | def __init__(self, patterns): 14 | self.patterns = patterns 15 | self.offsets = [(z, y, x) for x in range(-1, 2) for y in range(-1, 2) for z in range(-1, 2)] 16 | 17 | start_time = time.time() 18 | self.precompute_legal_patterns() 19 | print("Patterns constraints generation took %s seconds" % (time.time() - start_time)) 20 | 21 | def precompute_legal_patterns(self): 22 | pool = Pool(os.cpu_count()) 23 | 24 | patterns_offsets = [] 25 | for pattern in self.patterns: 26 | for offset in self.offsets: 27 | patterns_offsets.append((pattern, offset)) 28 | 29 | patterns_compatibility = pool.starmap(self.legal_patterns, patterns_offsets) 30 | pool.close() 31 | pool.join() 32 | 33 | for pattern_index, offset, legal_patterns in patterns_compatibility: 34 | self.patterns[pattern_index].set_legal_patterns(offset, legal_patterns) 35 | 36 | def legal_patterns(self, pattern, offset): 37 | legal_patt = [] 38 | for candidate_pattern in self.patterns: 39 | if pattern.is_compatible(candidate_pattern, offset): 40 | legal_patt.append(candidate_pattern.index) 41 | pattern.set_legal_patterns(offset, legal_patt) 42 | 43 | return pattern.index, offset, legal_patt 44 | 45 | @staticmethod 46 | def propagate(cell): 47 | to_update = [neighbour for neighbour, _ in cell.get_neighbors()] 48 | while len(to_update) > 0: 49 | cell = to_update.pop(0) 50 | for neighbour, offset in cell.get_neighbors(): 51 | for pattern_index in cell.allowed_patterns: 52 | pattern = Pattern.from_index(pattern_index) 53 | pattern_still_compatible = False 54 | for neighbour_pattern_index in neighbour.allowed_patterns: 55 | neighbour_pattern = Pattern.from_index(neighbour_pattern_index) 56 | 57 | if pattern.is_compatible(neighbour_pattern, offset): 58 | pattern_still_compatible = True 59 | break 60 | 61 | if not pattern_still_compatible: 62 | cell.allowed_patterns.remove(pattern_index) 63 | 64 | for neigh, _ in cell.get_neighbors(): 65 | if neigh not in to_update: 66 | to_update.append(neigh) 67 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | matplotlib 2 | numpy 3 | -------------------------------------------------------------------------------- /samples/blue.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/samples/blue.png -------------------------------------------------------------------------------- /samples/line.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/samples/line.png -------------------------------------------------------------------------------- /samples/maze.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/samples/maze.png -------------------------------------------------------------------------------- /samples/red_maze.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/samples/red_maze.png -------------------------------------------------------------------------------- /samples/twinkle_twinkle.mid: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Coac/wave-function-collapse/46c8ef49bf5a38399ee3308f8f5d8d301993a21b/samples/twinkle_twinkle.mid -------------------------------------------------------------------------------- /wfc.py: -------------------------------------------------------------------------------- 1 | import time 2 | 3 | from grid import Grid 4 | from pattern import Pattern 5 | from propagator import Propagator 6 | 7 | """ 8 | Implementation of WaveFunctionCollapse 9 | Following the "WaveFunctionCollapse is Constraint Solving in the Wild" terminology 10 | """ 11 | 12 | 13 | class WaveFunctionCollapse: 14 | """ 15 | WaveFunctionCollapse encapsulates the wfc algorithm 16 | """ 17 | 18 | def __init__(self, grid_size, sample, pattern_size): 19 | self.patterns = Pattern.from_sample(sample, pattern_size) 20 | self.grid = self._create_grid(grid_size) 21 | self.propagator = Propagator(self.patterns) 22 | 23 | def run(self): 24 | start_time = time.time() 25 | 26 | done = False 27 | while not done: 28 | done = self.step() 29 | 30 | print("WFC run took %s seconds" % (time.time() - start_time)) 31 | 32 | def step(self): 33 | self.grid.print_allowed_pattern_count() 34 | cell = self.observe() 35 | if cell is None: 36 | return True 37 | self.propagate(cell) 38 | return False 39 | 40 | def get_image(self): 41 | return self.grid.get_image() 42 | 43 | def get_patterns(self): 44 | return [pattern.to_image() for pattern in self.patterns] 45 | 46 | def observe(self): 47 | if self.grid.check_contradiction(): 48 | return None 49 | cell = self.grid.find_lowest_entropy() 50 | 51 | if cell is None: 52 | return None 53 | 54 | cell.choose_rnd_pattern() 55 | 56 | return cell 57 | 58 | def propagate(self, cell): 59 | self.propagator.propagate(cell) 60 | 61 | def _create_grid(self, grid_size): 62 | num_pattern = len(self.patterns) 63 | return Grid(grid_size, num_pattern) 64 | --------------------------------------------------------------------------------