├── shots └── .keep ├── .gitignore ├── tile.png ├── tiles ├── t.png ├── bridge.png ├── corner.png ├── dskew.png ├── skew.png ├── track.png ├── turn.png ├── viad.png ├── vias.png ├── wire.png ├── component.png ├── connection.png ├── substrate.png ├── transition.png └── data.xml ├── main.py ├── gpWFC ├── runners.py ├── models.py ├── observers.py ├── previews.py └── propagators.py ├── circuit.py └── README.md /shots/.keep: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | __pycache__ 2 | shots/*.png 3 | shots/*.mp4 4 | -------------------------------------------------------------------------------- /tile.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tile.png -------------------------------------------------------------------------------- /tiles/t.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/t.png -------------------------------------------------------------------------------- /tiles/bridge.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/bridge.png -------------------------------------------------------------------------------- /tiles/corner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/corner.png -------------------------------------------------------------------------------- /tiles/dskew.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/dskew.png -------------------------------------------------------------------------------- /tiles/skew.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/skew.png -------------------------------------------------------------------------------- /tiles/track.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/track.png -------------------------------------------------------------------------------- /tiles/turn.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/turn.png -------------------------------------------------------------------------------- /tiles/viad.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/viad.png -------------------------------------------------------------------------------- /tiles/vias.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/vias.png -------------------------------------------------------------------------------- /tiles/wire.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/wire.png -------------------------------------------------------------------------------- /tiles/component.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/component.png -------------------------------------------------------------------------------- /tiles/connection.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/connection.png -------------------------------------------------------------------------------- /tiles/substrate.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/substrate.png -------------------------------------------------------------------------------- /tiles/transition.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/s-ol/gpWFC/HEAD/tiles/transition.png -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from gpWFC.models import Model2d, Model3d, Tile 4 | from gpWFC.observers import CLObserver 5 | from gpWFC.propagators import CPUPropagator, CL1Propagator 6 | from gpWFC.previews import PreviewWindow, PreviewWindow3d 7 | from gpWFC.runners import BacktrackingRunner 8 | 9 | if __name__ == '__main__': 10 | import sys 11 | 12 | if '3d' in sys.argv[1:]: 13 | model = Model3d((4, 4, 2)) 14 | model.add(Tile((0, 1, 1, 0, 1, 0))) # all green 15 | model.add(Tile((2, 0, 0, 2, 0, 1))) # all green 16 | else: 17 | model = Model2d((8, 8)) 18 | adjs = [0, 1, 2] 19 | for adj in np.stack(np.meshgrid(adjs, adjs, adjs, adjs), -1).reshape(-1, 4): 20 | bins = np.bincount(adj, minlength=3) 21 | if bins[0] % 2 == 1: 22 | continue 23 | if bins[1] % 2 == 1: 24 | continue 25 | # if bins[2] % 2 == 1: 26 | # continue 27 | model.add(Tile(adj)) 28 | 29 | print('{} tiles:'.format(len(model.tiles))) 30 | 31 | Propagator = CL1Propagator 32 | if 'cpu' in sys.argv[1:]: 33 | Propagator = CPUPropagator 34 | 35 | runner = BacktrackingRunner(model, Observer=CLObserver, Propagator=Propagator) 36 | 37 | if 'silent' in sys.argv[1:]: 38 | from timeit import default_timer 39 | 40 | start = default_timer() 41 | status = runner.finish() 42 | print('{} after {}s'.format(status, default_timer() - start)) 43 | else: 44 | Preview = PreviewWindow 45 | if '3d' in sys.argv[1:]: 46 | Preview = PreviewWindow3d 47 | 48 | preview = Preview(runner) 49 | if 'render' in sys.argv[1:]: 50 | preview.render() 51 | else: 52 | preview.launch() 53 | -------------------------------------------------------------------------------- /gpWFC/runners.py: -------------------------------------------------------------------------------- 1 | from pyopencl import create_some_context, CommandQueue 2 | from pyopencl.array import to_device 3 | from .observers import CLObserver 4 | from .propagators import CL1Propagator 5 | 6 | class Runner(object): 7 | def __init__(self, model, Observer=CLObserver, Propagator=CL1Propagator, ctx=None): 8 | if not ctx: 9 | ctx = create_some_context() 10 | self.model = model 11 | 12 | self.grid_array = self.model.build_grid() 13 | with CommandQueue(ctx) as queue: 14 | self.grid = to_device(queue, self.grid_array) 15 | self.observer = Observer(model, ctx=ctx) 16 | self.propagator = Propagator(model, ctx=ctx) 17 | 18 | self.candidate = self.observer.observe(self.grid)[1:] 19 | self.done = False 20 | 21 | def step(self): 22 | index, collapsed = self.candidate 23 | self.propagator.propagate(self.grid, index, collapsed) 24 | status = self.observer.observe(self.grid) 25 | 26 | if status[0] == 'continue': 27 | self.candidate = status[1:] 28 | elif status[0] == 'error': 29 | self.done = True 30 | else: 31 | self.grid.get(ary=self.grid_array) 32 | 33 | return status[0] 34 | 35 | def run(self): 36 | while not self.done: 37 | yield self.step() 38 | 39 | def finish(self): 40 | status = None 41 | for status in self.run(): 42 | pass 43 | return status 44 | 45 | class BacktrackingRunner(Runner): 46 | def __init__(self, *args, snapshot_every=4, **kwargs): 47 | super().__init__(*args, **kwargs) 48 | self.snapshot = self.grid_array.copy() 49 | self.snapshot_every = snapshot_every 50 | self.snapshot_age = 0 51 | 52 | def step(self): 53 | index, collapsed = self.candidate 54 | self.propagator.propagate(self.grid, index, collapsed) 55 | status = self.observer.observe(self.grid) 56 | self.grid.get(ary=self.grid_array) 57 | 58 | if status[0] == 'continue': 59 | self.candidate = status[1:] 60 | self.snapshot_age += 1 61 | if self.snapshot_age >= self.snapshot_every: 62 | self.snapshot = self.grid_array.copy() 63 | self.snapshot_age = 0 64 | elif status[0] == 'error': 65 | if not self.snapshot is None: 66 | print('backtracking {} rounds'.format(self.snapshot_age)) 67 | self.grid.set(self.snapshot) 68 | self.candidate = self.observer.observe(self.grid)[1:] 69 | self.snapshot = None 70 | self.snapshot_age = self.snapshot_every 71 | return self.step() 72 | else: 73 | print('cannot backtrack anymore') 74 | self.done = True 75 | else: 76 | self.done = True 77 | 78 | return status[0] 79 | -------------------------------------------------------------------------------- /circuit.py: -------------------------------------------------------------------------------- 1 | from gpWFC.models import Model2d, SpriteTile 2 | from gpWFC.previews import SpritePreviewWindow 3 | from gpWFC.runners import BacktrackingRunner 4 | from pyglet import app, image, clock 5 | import sys 6 | 7 | model = Model2d((16, 16)) 8 | # 0: empty pcb 9 | # 1: masked track 10 | # 2: bridge 11 | # 3/4: component_edge 12 | # 7: component_center 13 | 14 | # component tiles 15 | model.add(SpriteTile('tiles/component.png', (7, 7, 7, 7), weight=20)) 16 | # model.add_rotations(SpriteTile('tiles/corner.png', (3, 0, 0, 3), weight=10), [0, 1, 2, 3]) 17 | # model.add_rotations(SpriteTile('tiles/connection.png', (3, 1, 3, 4), weight=10), [0, 1, 2, 3]) 18 | model.add(SpriteTile('tiles/corner.png', (3, 0, 0, 4), weight=10, rotation=0)) 19 | model.add(SpriteTile('tiles/corner.png', (3, 4, 0, 0), weight=10, rotation=1)) 20 | model.add(SpriteTile('tiles/corner.png', (0, 4, 3, 0), weight=10, rotation=2)) 21 | model.add(SpriteTile('tiles/corner.png', (0, 0, 3, 4), weight=10, rotation=3)) 22 | model.add(SpriteTile('tiles/connection.png', (3, 1, 3, 7), weight=10, rotation=0)) 23 | model.add(SpriteTile('tiles/connection.png', (7, 4, 1, 4), weight=10, rotation=1)) 24 | model.add(SpriteTile('tiles/connection.png', (3, 7, 3, 1), weight=10, rotation=2)) 25 | model.add(SpriteTile('tiles/connection.png', (1, 4, 7, 4), weight=10, rotation=3)) 26 | 27 | # bridge tiles 28 | # model.add_rotations(SpriteTile('tiles/bridge.png', (2, 1, 2, 1), weight=1), [0, 1]) 29 | # model.add_rotations(SpriteTile('tiles/wire.png', (2, 0, 2, 0), weight=0.5), [0, 1]) 30 | # model.add_rotations(SpriteTile('tiles/transition.png', (0, 2, 0, 1), weight=0.4), [0, 1, 2, 3]) 31 | 32 | # track tiles 33 | # model.add_rotations(SpriteTile('tiles/t.png', (1, 0, 1, 1), weight=1.3), [0, 1, 2, 3]) 34 | # model.add_rotations(SpriteTile('tiles/viad.png', (1, 0, 1, 0), weight=0.1), [0, 1]) 35 | model.add_rotations(SpriteTile('tiles/track.png', (0, 1, 0, 1), weight=10.0), [0, 1]) 36 | # model.add_rotations(SpriteTile('tiles/turn.png', (0, 1, 1, 0), weight=1), [0, 1, 2, 3]) 37 | model.add_rotations(SpriteTile('tiles/skew.png', (0, 1, 1, 0), weight=2), [0, 1, 2, 3]) 38 | model.add_rotations(SpriteTile('tiles/dskew.png', (1, 1, 1, 1), weight=2), [0, 1]) 39 | 40 | # model.add_rotations(SpriteTile('tiles/vias.png', (0, 1, 0, 0), weight=0.3), [0, 1, 2, 3]) 41 | 42 | model.add(SpriteTile('tiles/substrate.png', (0, 0, 0, 0), weight=2)) 43 | 44 | runner = BacktrackingRunner(model) 45 | preview = SpritePreviewWindow(runner, 14) 46 | 47 | if 'render' in sys.argv[1:]: 48 | preview.render() 49 | else: 50 | preview.launch() 51 | -------------------------------------------------------------------------------- /gpWFC/models.py: -------------------------------------------------------------------------------- 1 | import pyopencl as cl 2 | import pyopencl.cltypes 3 | import numpy as np 4 | import pyglet 5 | 6 | class Tile(object): 7 | def __init__(self, adj, weight=1): 8 | self.adj = adj 9 | self.weight = weight 10 | 11 | def rotated(self, rot): 12 | adj = self.adj[-rot:] + self.adj[:-rot] 13 | return Tile(adj, self.weight) 14 | 15 | def register(self, index): 16 | self.index = index 17 | self.flag = np.uint64(1 << self.index) 18 | 19 | def compatible(self, other, direction): 20 | l = len(self.adj) 21 | return self.adj[direction] == other.adj[(direction+l//2) % l] 22 | 23 | class SpriteTile(Tile): 24 | def __init__(self, image, adj, weight=1, rotation=0): 25 | super().__init__(adj, weight) 26 | if not isinstance(image, pyglet.image.AbstractImage): 27 | image = pyglet.resource.image(image) 28 | self.image = image 29 | self.rotation = rotation 30 | 31 | def rotated(self, rotation): 32 | adj = self.adj[-rotation:] + self.adj[:-rotation] 33 | return SpriteTile(self.image, adj, weight=self.weight, rotation=rotation) 34 | 35 | class Model(object): 36 | def __init__(self, world_shape): 37 | self.tiles = [] 38 | self.world_shape = world_shape 39 | 40 | def add(self, tile): 41 | tile.register(len(self.tiles)) 42 | self.tiles.append(tile) 43 | 44 | def add_rotations(self, orig, rotations): 45 | for rot in rotations: 46 | tile = orig.rotated(rot) 47 | tile.register(len(self.tiles)) 48 | self.tiles.append(tile) 49 | 50 | def build_grid(self): 51 | all_tiles = sum(tile.flag for tile in self.tiles) 52 | print('filling grid with {}'.format(all_tiles)) 53 | return np.full(self.world_shape, all_tiles, dtype=cl.cltypes.ulong) 54 | 55 | def get_allowed_tiles(self, bits): 56 | return [tile for tile in self.tiles if tile.flag & bits] 57 | 58 | class Model2d(Model): 59 | adjacent = 4 60 | 61 | def __init__(self, world_shape): 62 | assert len(world_shape) == 2 63 | super().__init__(world_shape) 64 | 65 | def get_neighbours(self, pos): 66 | w, h = self.world_shape 67 | x, y = pos 68 | yield (x-1)%w, y 69 | yield x, (y-1)%h 70 | yield (x+1)%w, y 71 | yield x, (y+1)%h 72 | 73 | class Model3d(Model): 74 | adjacent = 6 75 | 76 | def __init__(self, world_shape): 77 | assert len(world_shape) == 3 78 | super().__init__(world_shape) 79 | 80 | def get_neighbours(self, pos): 81 | w, h, d = self.world_shape 82 | x, y, z = pos 83 | yield (x-1)%w, y, z 84 | yield x, (y-1)%h, z 85 | yield x, y, (z-1)%d 86 | yield (x+1)%w, y, z 87 | yield x, (y+1)%h, z 88 | yield x, y, (z+1)%d 89 | -------------------------------------------------------------------------------- /gpWFC/observers.py: -------------------------------------------------------------------------------- 1 | import pyopencl as cl 2 | import pyopencl.array 3 | import pyopencl.clrandom 4 | import pyopencl.tools 5 | import pyopencl.reduction 6 | import numpy as np 7 | import numpy.random 8 | 9 | class CLObserver(object): 10 | def __init__(self, model, ctx=None): 11 | self.model = model 12 | with cl.CommandQueue(ctx) as queue: 13 | self.rnd = pyopencl.clrandom.PhiloxGenerator(ctx) 14 | self.bias = cl.array.to_device(queue, np.zeros(self.model.world_shape, dtype=cl.cltypes.float)) 15 | 16 | alloc = cl.tools.ImmediateAllocator(queue, cl.mem_flags.READ_ONLY) 17 | self.weights_array = np.array(list(tile.weight for tile in self.model.tiles), dtype=cl.cltypes.float) 18 | self.weights = cl.array.to_device(queue, self.weights_array, alloc) 19 | 20 | min_collector = np.dtype([ 21 | ('entropy', cl.cltypes.float), 22 | ('index', cl.cltypes.uint), 23 | ]) 24 | min_collector, min_collector_def = cl.tools.match_dtype_to_c_struct(ctx.devices[0], 'min_collector', min_collector) 25 | min_collector = cl.tools.get_or_register_dtype('min_collector', min_collector) 26 | 27 | preamble = ''' 28 | #define STATES {} 29 | '''.format(len(self.model.tiles)) 30 | 31 | self.find_lowest_entropy = cl.reduction.ReductionKernel(ctx, 32 | arguments='__global ulong* grid, __global float* bias, __global float* weights', 33 | neutral='neutral()', 34 | dtype_out=min_collector, 35 | map_expr='get_entropy(i, grid[i], bias[i], weights)', 36 | reduce_expr='reduce(a, b)', 37 | preamble=min_collector_def + preamble + r'''//CL// 38 | 39 | /* start with an imaginary solved tile */ 40 | min_collector neutral() { 41 | min_collector res; 42 | res.entropy = -1.0; 43 | res.index = 0; 44 | return res; 45 | } 46 | 47 | /* get entropy of tile (> 0) 48 | * -1: solved 49 | * 0: overconstrained */ 50 | min_collector get_entropy(uint i, ulong bitfield, float bias, __global float* weights) { 51 | min_collector res; 52 | res.entropy = 0.0f; 53 | res.index = i; 54 | 55 | uint remaining_states = 0; 56 | for (uint state = 0; state < STATES; state++) { 57 | if (bitfield & ((ulong)1 << state)) { 58 | remaining_states++; 59 | res.entropy += weights[state]; 60 | } 61 | } 62 | 63 | if (remaining_states == 1) res.entropy = -1.0; 64 | else if (remaining_states > 1) res.entropy += bias * 0.5; 65 | return res; 66 | } 67 | 68 | /* if one is solved try the other 69 | * otherwise reduce to minimum entropy */ 70 | min_collector reduce(min_collector a, min_collector b) { 71 | if (a.entropy < 0.0) return b; 72 | if (b.entropy < 0.0) return a; 73 | return b.entropy < a.entropy ? b : a; 74 | } 75 | ''' 76 | ) 77 | 78 | def collapse(self, bits): 79 | p = self.weights_array.copy() 80 | bits = int(bits.get()) 81 | for i in range(len(p)): 82 | p[i] *= not not bits & (1 << i) 83 | p = p / np.sum(p) 84 | tile = np.random.choice(self.model.tiles, p=p) 85 | print('collapsing from {} to {}'.format(bits, tile.flag)) 86 | return tile.flag 87 | 88 | def observe(self, grid): 89 | # random tie-breaking bias for each tile 90 | self.rnd.fill_uniform(self.bias) 91 | 92 | tile = self.find_lowest_entropy(grid, self.bias, self.weights).get() 93 | entropy, index = tile['entropy'].item(), tile['index'].item() 94 | 95 | t_index = np.unravel_index(index, self.model.world_shape) 96 | if entropy < 0: 97 | print('solved!') 98 | return ('done',) 99 | elif entropy == 0: 100 | print('tile {} overconstrained!'.format(t_index)) 101 | return ('error',) 102 | 103 | print('selected tile {} with entropy {}'.format(t_index, entropy)) 104 | return ('continue', index, self.collapse(grid[t_index])) 105 | -------------------------------------------------------------------------------- /gpWFC/previews.py: -------------------------------------------------------------------------------- 1 | from numpy import ndenumerate 2 | from pyglet.app import run 3 | from pyglet.window import Window, key 4 | from pyglet.resource import image 5 | from pyglet.image import get_buffer_manager 6 | from pyglet.text import Label 7 | from pyglet.sprite import Sprite 8 | 9 | class BasePreview(Window): 10 | def __init__(self, runner, width=512, height=512): 11 | super().__init__(width=width, height=height) 12 | self.runner = runner 13 | self.debug = False 14 | 15 | def on_draw(self): 16 | self.clear() 17 | 18 | for pos, bits in ndenumerate(self.runner.grid_array): 19 | self.draw_tiles(pos, bits) 20 | 21 | def on_key_press(self, symbol, modifiers): 22 | if symbol == key.ESCAPE: 23 | self.close() 24 | elif symbol == key.SPACE: 25 | self.runner.step() 26 | elif symbol == key.R: 27 | self.runner.finish() 28 | elif symbol == key.D: 29 | self.debug = not self.debug 30 | 31 | def screenshot(self, name='shots/snapshot.png'): 32 | get_buffer_manager().get_color_buffer().save(name) 33 | 34 | def render(self): 35 | iteration = 0 36 | for i in range(2): 37 | self.dispatch_events() 38 | self.dispatch_event('on_draw') 39 | self.flip() 40 | self.screenshot('shots/{:04}.png'.format(iteration)) 41 | for i in self.runner.run(): 42 | self.dispatch_events() 43 | self.dispatch_event('on_draw') 44 | self.flip() 45 | self.screenshot('shots/{:04}.png'.format(iteration)) 46 | iteration += 1 47 | 48 | def launch(self): 49 | run() 50 | 51 | class PreviewWindow(BasePreview): 52 | colors = ( (0, 0, 255), (255, 0, 0), (0, 255, 0), (0, 0, 255) ) 53 | rotations = [0, 90, 180, 270] 54 | def __init__(self, runner): 55 | super().__init__(runner, width=512, height=512) 56 | 57 | tile = image('tile.png') 58 | tile.anchor_x = 32 59 | tile.anchor_y = 32 60 | self.sprite = Sprite(img=tile, x=0, y=0) 61 | 62 | def draw_tiles(self, pos, bits): 63 | if bits == 0: 64 | return 65 | 66 | x, y = pos[-2:] 67 | self.sprite.x = x * 64 + 32 68 | self.sprite.y = self.height - y * 64 - 32 69 | 70 | tiles = self.runner.model.get_allowed_tiles(bits) 71 | self.sprite.opacity = 255 / len(tiles) 72 | 73 | for tile in tiles: 74 | for direction, adj in enumerate(tile.adj): 75 | if adj < 1 or self.rotations[direction] == None: 76 | continue 77 | self.sprite.color = self.colors[adj] 78 | self.sprite.rotation = self.rotations[direction] 79 | self.sprite.draw() 80 | 81 | if self.debug: 82 | Label(str(bits), x=self.sprite.x, y=self.sprite.y).draw() 83 | 84 | class PreviewWindow3d(PreviewWindow): 85 | rotations = [0, 90, None, 180, 270, None] 86 | 87 | def __init__(self, *args): 88 | super().__init__(*args) 89 | self.slice = 0 90 | 91 | def on_draw(self): 92 | self.clear() 93 | 94 | for pos, bits in ndenumerate(self.runner.grid_array[...,self.slice]): 95 | self.draw_tiles(pos, bits) 96 | 97 | def on_key_press(self, symbol, modifiers): 98 | if symbol == key.UP: 99 | self.slice += 1 100 | elif symbol == key.DOWN: 101 | self.slice += 1 102 | else: 103 | super().on_key_press(symbol, modifiers) 104 | return 105 | self.slice = self.slice % self.runner.model.world_shape[-1] 106 | print(self.slice) 107 | 108 | class SpritePreviewWindow(BasePreview): 109 | def __init__(self, runner, tile_size): 110 | width = runner.model.world_shape[0] * tile_size 111 | height = runner.model.world_shape[1] * tile_size 112 | super().__init__(runner, width=width, height=height) 113 | 114 | self.sprite = Sprite(img=runner.model.tiles[0].image) 115 | self.tile_size = tile_size 116 | 117 | def draw_tiles(self, pos, bits): 118 | if bits == 0: 119 | return 120 | 121 | x, y = pos[-2:] 122 | self.sprite.x = x * self.tile_size + self.tile_size/2 123 | self.sprite.y = self.height - y * self.tile_size - self.tile_size/2 124 | 125 | tiles = self.runner.model.get_allowed_tiles(bits) 126 | self.sprite.opacity = 255 / len(tiles) 127 | 128 | for tile in tiles: 129 | tile.image.anchor_x = self.tile_size/2 130 | tile.image.anchor_y = self.tile_size/2 131 | self.sprite.image = tile.image 132 | self.sprite.rotation = tile.rotation * 90 133 | self.sprite.draw() 134 | 135 | if self.debug: 136 | Label(str(bits), x=self.sprite.x, y=self.sprite.y).draw() 137 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | gpWFC 2 | ===== 3 | 4 | Implementation of the [Wave Function Collapse][WFC] procedural content generation algorithm, 5 | using [(py)OpenCL][pyopencl] for GPU acceleration. 6 | 7 | ![circuit example](https://thumbs.gfycat.com/MinorFewBlackmamba-max-1mb.gif) ![simple example](https://thumbs.gfycat.com/FinishedFlamboyantHylaeosaurus-max-1mb.gif) 8 | 9 | Getting Started 10 | --------------- 11 | 12 | make sure you have the python packages pyopencl, numpy and pyglet installed. 13 | 14 | You can then run a basic example using 15 | 16 | python main.py 17 | 18 | in the preview window the following keybindings are set: 19 | 20 | - `escape`: close 21 | - `space`: do one oberservation/propagation cycle and render 22 | - `r`: cycle until stable, then render again 23 | - `d`: debug view (overlay decimal display of bitmask for each tile) 24 | 25 | There is also a more interesting sprite-based example that you can run using 26 | 27 | python circuit.py [render] 28 | 29 | but as you can see I didn't set up the model constraints properly. Maybe you want to fix that? 30 | 31 | `main.py` can take a few options that are just passed as strings on the command line, in any order. 32 | They might not all be compatible with each other, in any case `main.py` is only a 33 | starting point to write your own set up code with a more serious model. 34 | 35 | ### `cpu` 36 | 37 | propagate using a simplistic CPU algorithm. 38 | 39 | ### `3d` 40 | 41 | work in a 3d space (4x4x2 by default), with a *very* rudimentary preview. 42 | more of a proof of concept, but totally workable. 43 | 44 | In the 3d preview, the up and down keys can be used to cycle through slices of the Z axis. 45 | 46 | ### `silent` 47 | 48 | don't open a preview or render, just measure the execution time. 49 | 50 | ### `render` 51 | 52 | automatically step execution forward and take save a screenshot to `shots/0001.png` etc. 53 | You can use e.g. ffmpeg to turn the png frames into an animation. 54 | 55 | Programatic Usage 56 | ----------------- 57 | 58 | `gpWFC` is set up to follow a 'mix and match' modular architecture as best as possible. 59 | It is therefore divided into a couple of components that need to be used to run a simulation: 60 | 61 | - the Tiles (`Tile` and `SpriteTile` from `models.py`): 62 | - `tile.weight` (float): the relative probability of occurence 63 | - `tile.compatible(other, direction_id)` (bool): constraint information 64 | - additional information for the Preview, e.g. `tile.image` and `tile.rotation` for `SpriteTile` 65 | - the Model (`Model2d` and `Model3d` from `models.py`): 66 | - information about the *world*: 67 | - `model.world_shape` (tuple): dimensions of the world (any nr of axes) 68 | - `model.get_neighbours(pos)` (generator): tile adjacency information 69 | - information about the *tiles*: 70 | - `model.tiles` (list): the tiles to be used 71 | - `model.get_allowed_tiles(bitmask)` (list): a way to resolve the opaque bitmask 72 | - the Runner (`Runner` and `BacktrackingRunner` from `runners.py`): 73 | - `runner.step()` (string): execute a single observartion/propagation cycle 74 | - `runner.finish()` (string): run the simulation until it either fails or stabilizes 75 | - `runner.run()` (generator): iterate over `runner.step()` 76 | - all of these return/yield status strings, which are one of: 77 | - `'done'` - fully collapsed 78 | - `'error'` - overconstrained / stuck 79 | - `'continue'` - step successful but uncollapsed tiles remain 80 | - the Preview (`PreviewWindow*` from `previews.py`): 81 | - `preview.draw_tiles(pos, bits)`: draw the tiles at `pos` (tuple) 82 | - `preview.launch()`: enter interactive preview mode 83 | - `preview.render()`: enter non-interactive render loop 84 | - the Observer and Propagator (`observers.py` and `propagators.py`): 85 | - you probably don't need to touch these 86 | 87 | You can find a straightforward example of the basic setup steps in `circuit.py`, it should follow this flow: 88 | 89 | - instantiate a Model 90 | - instantiate Tiles and register them with the Model 91 | - instantiate a Runner and pass it the Model 92 | - instantiate a Preview and pass it the Runner 93 | - launch the Preview 94 | 95 | GPU-only rendering 96 | ------------------ 97 | 98 | There is a terribly broken `glsl-render` branch that tries to not ever get the buffer back to CPU memory during propagation, 99 | while still rendering the world in a GLSL shader. 100 | Unfortunately I could never get it to work properly with pyOpenCL to date, and due to some other constraints 101 | I also cannot test or bring the current version back to the best state it was in, 102 | so it will remain in a messy test state for now. 103 | 104 | If anyone is brave enough to touch it though, when working, it should give some incredible performance gains as the rendering 105 | and memory transfer / gpu blocking are by far the biggest slow-downs at the moment. 106 | There is also some hope since a new version of pyOpenCL is [apparently on the way][opencl-fix]. 107 | 108 | [WFC]: https://github.com/mxgmn/WaveFunctionCollapse 109 | [pyopencl]: https://documen.tician.de/pyopencl 110 | [opencl-fix]: https://github.com/inducer/pyopencl/issues/235#issuecomment-431644685 111 | -------------------------------------------------------------------------------- /gpWFC/propagators.py: -------------------------------------------------------------------------------- 1 | import pyopencl as cl 2 | import pyopencl.array 3 | import pyopencl.cltypes 4 | import pyopencl.tools 5 | import pyopencl.reduction 6 | import numpy as np 7 | 8 | class BasePropagator(object): 9 | def __init__(self, model): 10 | self.model = model 11 | 12 | def get_neighbours(self, pad_to=None): 13 | if not pad_to: 14 | pad_to = self.model.adjacent 15 | diff = pad_to - self.model.adjacent 16 | 17 | neighbours = np.zeros(self.model.world_shape + (pad_to,), dtype=cl.cltypes.uint) 18 | for pos, _ in np.ndenumerate(neighbours[...,0]): 19 | neighbours[pos] = [np.ravel_multi_index(neighbour, self.model.world_shape) for neighbour in self.model.get_neighbours(pos)] + [0] * diff 20 | return neighbours 21 | 22 | def get_allows(self, pad_to=None, flipped=False): 23 | if not pad_to: 24 | pad_to = self.model.adjacent 25 | 26 | def add_allows(i, direction): 27 | if direction >= self.model.adjacent: 28 | return 0 29 | ret = np.uint64(0) 30 | tile = self.model.tiles[i] 31 | for other in self.model.tiles: 32 | if tile.compatible(other, direction): 33 | ret |= other.flag 34 | return ret 35 | 36 | if flipped: 37 | def add_allows(i, direction): 38 | if direction >= self.model.adjacent: 39 | return 0 40 | 41 | ret = np.uint64(0) 42 | tile = self.model.tiles[i] 43 | for other in self.model.tiles: 44 | if other.compatible(tile, direction): 45 | ret |= other.flag 46 | return ret 47 | 48 | return np.fromfunction( 49 | np.vectorize(add_allows), 50 | (len(self.model.tiles), pad_to), 51 | dtype=int 52 | ).astype(cl.cltypes.ulong) 53 | 54 | def get_config(self): 55 | adj = self.model.adjacent 56 | adjacent_bits = (adj - 1).bit_length() 57 | adjacent_pow = 1 << adjacent_bits 58 | 59 | config = { 60 | 'adj': adj, 61 | 'adj_pow': adjacent_pow, 62 | 'adj_uint': 'uint' + str(adjacent_pow), 63 | 'adj_ulong': 'ulong' + str(adjacent_pow), 64 | 'states': len(self.model.tiles), 65 | 'forNeighbour': lambda tpl, join='\n': join.join([tpl.format(i=i) for i in range(adj)]), 66 | } 67 | 68 | config['preamble'] = ''' 69 | #define ADJ {adj} 70 | #define ADJUINT {adj_uint} 71 | #define ADJULONG {adj_ulong} 72 | #define STATES {states} 73 | '''.format(**config) 74 | 75 | return config 76 | 77 | class CPUPropagator(BasePropagator): 78 | def __init__(self, model, ctx=None): 79 | super().__init__(model) 80 | self.allows = self.get_allows() 81 | 82 | def reduce_to_allowed(self, i, allowmap, grid): 83 | old = grid[i] 84 | new = old & allowmap 85 | # print('tile {}: {} & {} = {}, delta: {}'.format(i, old, allowmap, new, diff)) 86 | if old == new or not new: 87 | return 88 | grid[i] = new 89 | 90 | allowmaps = np.zeros((4,), dtype=np.uint64) 91 | for tile in self.model.tiles: 92 | if new & tile.flag: 93 | # print('delta bit {}, propagate allows {}'.format(tile.index, self.allows[tile.index])) 94 | allowmaps |= self.allows[tile.index] 95 | # print('neighbour allows: {}'.format(allowmaps)) 96 | 97 | for n, neighbour in enumerate(self.model.get_neighbours(i)): 98 | self.reduce_to_allowed( 99 | neighbour, 100 | allowmaps[n], 101 | grid 102 | ) 103 | 104 | def propagate(self, grid, index, collapsed): 105 | self.reduce_to_allowed(np.unravel_index(index, self.model.world_shape), np.uint64(collapsed), grid) 106 | 107 | class CL2Propagator(BasePropagator): 108 | def __init__(self, model, ctx=None): 109 | super().__init__(model) 110 | 111 | self.ctx = ctx 112 | 113 | with cl.CommandQueue(ctx) as queue: 114 | alloc = cl.tools.ImmediateAllocator(queue, cl.mem_flags.READ_ONLY) 115 | self.allows_buf = cl.array.to_device(queue, self.get_allows(), alloc) 116 | self.neighbours_buf = cl.array.to_device(queue, self.get_neighbours(), alloc) 117 | 118 | config = self.get_config() 119 | self.program = cl.Program(ctx, config['preamble'] + ''' 120 | __kernel void reduce_to_allowed( 121 | const uint i, const ulong allowmap, 122 | __global ulong* grid, __global ADJULONG* allows, __global ADJUINT* neighbours 123 | ) { 124 | ulong old_bits = grid[i]; 125 | ulong new_bits = old_bits & allowmap; 126 | grid[i] = new_bits; 127 | ulong diff = old_bits ^ new_bits; 128 | if (!diff) return; 129 | 130 | ADJULONG allowmaps; 131 | for (int bit = 0; bit < STATES; bit++) { 132 | if (new & (1 << i)) 133 | allowmaps |= allows[bit]; 134 | } 135 | 136 | // change in bit, trigger neighbours 137 | enqueue_kernel( 138 | get_default_queue(), 139 | CLK_ENQUEUE_FLAGS_WAIT_KERNEL, 140 | ndrange_1D(ADJ), 141 | ^{ 142 | uint neighbour = get_global_id(0); 143 | ulong allow = allowmaps[neighbour]; 144 | reduce_to_allowed( 145 | neighbours[i][neighbour], allow, 146 | grid, allows, neighbours 147 | ); 148 | } 149 | ); 150 | } 151 | ''').build() 152 | 153 | def propagate(self, grid, index, collapsed): 154 | with cl.CommandQueue(self.ctx) as queue: 155 | self.program.reduce_to_allowed( 156 | queue, (1,), None, 157 | index, collapsed, 158 | grid, self.allows_buf, self.neighbours_buf 159 | ) 160 | 161 | class CL1Propagator(BasePropagator): 162 | def __init__(self, model, ctx=None): 163 | super().__init__(model) 164 | 165 | config = self.get_config() 166 | 167 | with cl.CommandQueue(ctx) as queue: 168 | alloc = cl.tools.ImmediateAllocator(queue, cl.mem_flags.READ_ONLY) 169 | self.allows_buf = cl.array.to_device(queue, self.get_allows(pad_to=config['adj_pow'], flipped=True), alloc) 170 | self.neighbours_buf = cl.array.to_device(queue, self.get_neighbours(pad_to=config['adj_pow']), alloc) 171 | 172 | fN = config['forNeighbour'] 173 | 174 | self.update_grid = cl.reduction.ReductionKernel(ctx, 175 | arguments='__global ulong* grid, __global {adj_ulong}* allows, __global {adj_uint}* neighbours'.format(**config), 176 | neutral='0', 177 | dtype_out=cl.cltypes.uint, 178 | map_expr='update_tile(i, grid, allows, neighbours)', 179 | reduce_expr='a + b', 180 | preamble=config['preamble'] + ''' 181 | uint update_tile(uint i, __global ulong* grid, __global ADJULONG* allows, __global ADJUINT* neighbours) { 182 | ulong old_bits = grid[i]; 183 | 184 | ADJUINT next = neighbours[i]; 185 | ''' + 186 | fN(''' 187 | ulong grid_{i} = grid[next.s{i}]; 188 | ulong mask_{i} = 0; 189 | ''') + ''' 190 | 191 | for (uint tile = 0; tile < STATES; tile++) { 192 | ulong flag = 1 << tile; 193 | ADJULONG tile_allows = allows[tile]; 194 | ''' + fN(''' 195 | if (flag & grid_{i}) mask_{i} |= tile_allows.s{i}; 196 | ''') + ''' 197 | } 198 | 199 | ulong new_bits = old_bits ''' + fN('& mask_{i}', '') + '''; 200 | grid[i] = new_bits; 201 | return old_bits != new_bits; 202 | } 203 | ''' 204 | ) 205 | 206 | def propagate(self, grid, index, collapsed): 207 | grid[np.unravel_index(index, self.model.world_shape)] = collapsed 208 | turn, changes = 0, 1 209 | while changes > 0: 210 | changes = self.update_grid(grid, self.allows_buf, self.neighbours_buf).get() 211 | turn += 1 212 | print('propagated in {} turns'.format(turn)) 213 | -------------------------------------------------------------------------------- /tiles/data.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | 38 | 39 | 40 | 41 | 42 | 43 | 44 | 45 | 46 | 47 | 48 | 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | 76 | 77 | 78 | 79 | 80 | 81 | 82 | 83 | 84 | 85 | 86 | 87 | 88 | 89 | 90 | 91 | 92 | 93 | 94 | 95 | 96 | 97 | 98 | 99 | 100 | 101 | 102 | 103 | 104 | 105 | 106 | 107 | 108 | 109 | 110 | 111 | 112 | 113 | 114 | 115 | 116 | 117 | 118 | 119 | 120 | 121 | 122 | 123 | 124 | 125 | 126 | 127 | 128 | 129 | 130 | 131 | 132 | 133 | 134 | 135 | 136 | 137 | 138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | 146 | 147 | 148 | 149 | 150 | 151 | 152 | 153 | 154 | 155 | 156 | 157 | 158 | 159 | 160 | 161 | 162 | 163 | 164 | 165 | 166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | 174 | 175 | 176 | 177 | 178 | 179 | 180 | 181 | 182 | --------------------------------------------------------------------------------