├── .idea └── .gitignore ├── LICENSE ├── README.md ├── common ├── __init__.py ├── dataset.py ├── models.py └── train_utils.py ├── diffusion └── train.py ├── diffusion_coords ├── models.py └── train.py ├── notebook.ipynb ├── requirements.txt └── transformer_move ├── models.py └── train.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /shelf/ 3 | /workspace.xml 4 | # GitHub Copilot persisted chat sessions 5 | /copilot/chatSessions 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # MazeSolver 2 | 3 | A couple of fun trainers using diffusion and autoregressive models to solve mazes. Explainer blog post [here](https://sweet-hall-e72.notion.site/Diffusion-and-Autoregressive-Models-for-Learning-to-Solve-Mazes-c3bc4bcdfa304ecd9531ee5445a4da66?pvs=4) 4 | 5 | 6 | ## Requirements 7 | torch, diffusers, transformers, accelerate should do the trick 8 | 9 | 10 | ## How To Use 11 | There are 3 different trainers provided, namely diffusion, diffusion_coords, and transformer_move. Each can be run by editing the arguments within the script and calling python train.py -------------------------------------------------------------------------------- /common/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/MazeSolver/43cacf58bf50e9216d67c714156dc7add7d7907b/common/__init__.py -------------------------------------------------------------------------------- /common/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import DataLoader, Dataset 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | from queue import PriorityQueue 6 | from PIL import Image 7 | import torch 8 | 9 | def generate_maze(size): 10 | # 0 = empty, 1 = walls 11 | maze = np.ones((size * 2 + 1, size * 2 + 1), dtype=np.int8) 12 | 13 | def remove_wall(pos1, pos2): 14 | maze[pos1[0] + (pos2[0] - pos1[0]) // 2, pos1[1] + (pos2[1] - pos1[1]) // 2] = 0 15 | 16 | 17 | stack = [((1, 1), None)] 18 | while stack: 19 | (cx, cy), prev = stack.pop() 20 | if maze[cy, cx] == 1: 21 | maze[cy, cx] = 0 22 | if prev: 23 | remove_wall(prev, (cx, cy)) 24 | 25 | neighbors = [(cx - 2, cy), (cx + 2, cy), (cx, cy - 2), (cx, cy + 2)] 26 | np.random.shuffle(neighbors) 27 | for nx, ny in neighbors: 28 | if 1 <= nx < size * 2 and 1 <= ny < size * 2: 29 | stack.append(((nx, ny), (cx, cy))) 30 | 31 | # have to create entry and exit points, random for variation 32 | walls_for_entry_exit = [(i, 0) for i in range(1, size * 2, 2)] + \ 33 | [(i, size * 2) for i in range(1, size * 2, 2)] + \ 34 | [(0, i) for i in range(1, size * 2, 2)] + \ 35 | [(size * 2, i) for i in range(1, size * 2, 2)] 36 | entry_exit = np.random.choice(len(walls_for_entry_exit), 2, replace=False) 37 | entry, exit_pt = walls_for_entry_exit[entry_exit[0]], walls_for_entry_exit[entry_exit[1]] 38 | 39 | maze[entry] = 0 40 | maze[exit_pt] = 0 41 | 42 | return maze, entry, exit_pt 43 | 44 | 45 | 46 | def solve_maze_dfs(maze): 47 | # find entry and exit points, it can be in either direcction 48 | h, w = maze.shape 49 | temp_maze = maze.copy() 50 | temp_maze[1:-1, 1:-1] = 1 51 | 52 | locs = np.where(temp_maze == 0) 53 | entry_idx = np.random.randint(0, len(locs[0])) 54 | exit_idx = (entry_idx + 1) % len(locs[0]) 55 | 56 | start = (locs[1][entry_idx], locs[0][entry_idx]) 57 | goal = (locs[1][exit_idx], locs[0][exit_idx]) 58 | path = [] 59 | visited = set() 60 | 61 | def dfs(position): 62 | if position == goal: 63 | path.append(position) 64 | return True 65 | x, y = position 66 | 67 | visited.add(position) 68 | 69 | for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]: 70 | nx, ny = x + dx, y + dy 71 | if not (0 <= nx < w and 0 <= ny < h): # Check boundaries 72 | continue 73 | if maze[ny, nx] == 0 and (nx, ny) not in visited: 74 | if dfs((nx, ny)): 75 | path.append(position) 76 | return True 77 | return False 78 | 79 | dfs(start) 80 | return path[::-1] 81 | 82 | 83 | def interpolate_colormap(value, colors): 84 | # attempt to recreate the matplotlib colormap interpolation 85 | index = value * (len(colors) - 1) 86 | lower_index = int(index) 87 | upper_index = min(lower_index + 1, len(colors) - 1) 88 | interpolation = index - lower_index 89 | 90 | lower_color = np.array(colors[lower_index]) 91 | upper_color = np.array(colors[upper_index]) 92 | color = lower_color + (upper_color - lower_color) * interpolation 93 | return color.astype(np.uint8) 94 | 95 | def apply_complex_heatmap_effect(data, colors= [(102, 0, 181), (0, 255, 0), (255, 255, 0)]): 96 | h, w = data.shape 97 | 98 | normalized_data = (data - np.min(data)) / (np.max(data) - np.min(data)) 99 | 100 | heatmap_data = np.zeros((*data.shape, 3), dtype=np.uint8) 101 | for i in range(data.shape[0]): 102 | for j in range(data.shape[1]): 103 | heatmap_data[i, j] = interpolate_colormap(normalized_data[i, j], colors) 104 | 105 | heatmap_image = Image.fromarray(heatmap_data).resize((w*10, h*10), Image.NEAREST) 106 | return heatmap_image 107 | 108 | def visualize_maze(maze, path=None): 109 | 110 | path_maze = np.copy(maze) 111 | if path is not None: 112 | if isinstance(path, list): 113 | try: 114 | for x, y in path: 115 | path_maze[y, x] = 2 116 | except: 117 | pass 118 | else: 119 | path_maze = path_maze + path * 2 120 | 121 | # cmap = plt.cm.jet 122 | # cmap.set_under('black') 123 | # cmap.set_over('gold') 124 | 125 | # plt.figure(figsize=(10, 10)) 126 | # plt.imshow(path_maze, cmap=cmap, vmin=0.5, vmax=2.5) 127 | # plt.xticks([]), plt.yticks([]) 128 | # plt.show() 129 | 130 | # save in a temp file 131 | # plt.imsave('maze.png', path_maze)# cmap=cmap, vmin=0.5, vmax=2.5) 132 | # img = Image.open('maze.png') 133 | 134 | img = apply_complex_heatmap_effect(path_maze) 135 | 136 | return img 137 | 138 | 139 | def heuristic(a, b): 140 | """Calculate the Manhattan distance between two points a and b""" 141 | return abs(a[0] - b[0]) + abs(a[1] - b[1]) 142 | 143 | 144 | def solve_maze_a_star(maze): 145 | start = (1, 1) 146 | goal = (maze.shape[0] - 2, maze.shape[1] - 2) 147 | 148 | # Priority queue for nodes to explore stores tuples of (cost, position) 149 | frontier = PriorityQueue() 150 | frontier.put((0, start)) 151 | 152 | came_from = {} 153 | cost_so_far = {} 154 | came_from[start] = None 155 | cost_so_far[start] = 0 156 | 157 | while not frontier.empty(): 158 | current = frontier.get()[1] 159 | 160 | if current == goal: 161 | break 162 | 163 | for dx, dy in [(0, 1), (0, -1), (1, 0), (-1, 0)]: 164 | next = (current[0] + dx, current[1] + dy) 165 | if 0 <= next[0] < maze.shape[1] and 0 <= next[1] < maze.shape[0] and maze[next[1], next[0]] == 0: 166 | new_cost = cost_so_far[current] + 1 # each step costs 1 167 | if next not in cost_so_far or new_cost < cost_so_far[next]: 168 | cost_so_far[next] = new_cost 169 | priority = new_cost + heuristic(next, goal) 170 | frontier.put((priority, next)) 171 | came_from[next] = current 172 | 173 | # Reconstruct path 174 | current = goal 175 | path = [] 176 | while current != start: 177 | path.append(current) 178 | current = came_from[current] 179 | path.append(start) # optional 180 | path.reverse() # optional 181 | 182 | return path 183 | 184 | 185 | def get_movements_from_path(path): 186 | 187 | mapping = { 188 | (0, 1): [0,'Down'], 189 | (0, -1): [1,'Up'], 190 | (1, 0): [2,'Right'], 191 | (-1, 0): [3,'Left'] 192 | } 193 | 194 | movements = [] 195 | movement_ids = [] 196 | for i in range(1, len(path)): 197 | dx = path[i][0] - path[i - 1][0] 198 | dy = path[i][1] - path[i - 1][1] 199 | movements.append((dx, dy)) 200 | 201 | movement_ids = [mapping[m][0] for m in movements] 202 | movements = [mapping[m][1] for m in movements] 203 | 204 | return movement_ids, movements 205 | 206 | 207 | def get_path_from_movements(movements, maze): 208 | #find start pos 209 | start_y = torch.where(maze == 2)[0].item() 210 | start_x = torch.where(maze == 2)[1].item() 211 | start = (start_x, start_y) 212 | # end_y = torch.where(maze == 3)[0].item() 213 | # end_x = torch.where(maze == 3)[1].item() 214 | # end = (end_x, end_y) 215 | 216 | mapping = { 217 | 0: (0, 1), 218 | 1: (0, -1), 219 | 2: (1, 0), 220 | 3: (-1, 0) 221 | } 222 | 223 | path = [start] 224 | for m in movements: 225 | try: 226 | dx, dy = mapping[m.item()] 227 | except: 228 | dx, dy = 0, 0 229 | path.append((path[-1][0] + dx, path[-1][1] + dy)) 230 | 231 | return path 232 | 233 | 234 | def get_maze_path_grid(maze, path): 235 | path_maze = np.zeros_like(maze) 236 | for x, y in path: 237 | path_maze[y, x] = 1 # Mark the path with a distinct value 238 | return path_maze 239 | 240 | 241 | 242 | def default_collate_fn(examples): 243 | batch = {} 244 | for k in examples[0].keys(): 245 | if isinstance(examples[0][k], torch.Tensor): 246 | batch[k] = torch.stack([example[k] for example in examples]) 247 | else: 248 | batch[k] = [example[k] for example in examples] 249 | 250 | return batch 251 | 252 | 253 | class MazeDataset(Dataset): 254 | 255 | def __init__(self, maze_size): 256 | self.maze_size = maze_size 257 | 258 | def __len__(self): 259 | return 100_000 260 | 261 | def __getitem__(self, idx): 262 | maze, entry, exit_pt = generate_maze(self.maze_size) 263 | path = solve_maze_dfs(maze) 264 | path_grid = get_maze_path_grid(maze, path) 265 | maze_labeled = maze.copy() 266 | maze_labeled[entry] = 2 267 | maze_labeled[exit_pt] = 3 268 | 269 | maze = torch.from_numpy(maze).float().unsqueeze(0) 270 | maze_labeled = torch.from_numpy(maze_labeled).float().unsqueeze(0) 271 | # path = torch.from_numpy(path).float() 272 | path_grid = torch.from_numpy(path_grid).float().unsqueeze(0) 273 | 274 | maze = maze * 2 - 1 275 | path_grid = path_grid * 2 - 1 276 | 277 | example = { 278 | "maze": maze, 279 | "path_grid": path_grid, 280 | 281 | "path": path, 282 | "maze_labeled": maze_labeled, 283 | # "movements": get_movements_from_path(path) 284 | } 285 | 286 | return example 287 | 288 | 289 | 290 | 291 | 292 | 293 | 294 | 295 | 296 | -------------------------------------------------------------------------------- /common/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import math 4 | import torch.nn.functional as F 5 | 6 | 7 | def modulate(x, shift, scale): 8 | return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) 9 | 10 | 11 | class DiTBlock(nn.Module): 12 | """ 13 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 14 | """ 15 | def __init__(self, hidden_size, time_dim, num_heads, mlp_ratio=4.0): 16 | super().__init__() 17 | self.norm_attn1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 18 | self.attn1 = Attention(hidden_size, heads=num_heads) 19 | self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) 20 | self.mlp = FeedForward(hidden_size, mult=mlp_ratio) # approx_gelu = lambda: nn.GELU(approximate="tanh") 21 | self.adaLN_modulation = nn.Sequential( 22 | nn.SiLU(), 23 | nn.Linear(time_dim, 6 * hidden_size, bias=True) 24 | ) 25 | 26 | def forward(self, x, c): 27 | shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1) 28 | x = x + gate_msa.unsqueeze(1) * self.attn1(modulate(self.norm_attn1(x), shift_msa, scale_msa)) 29 | x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm_mlp(x), shift_mlp, scale_mlp)) 30 | return x 31 | 32 | 33 | class DiTBlockNoAda(nn.Module): 34 | """ 35 | A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. 36 | """ 37 | def __init__(self, hidden_size, num_heads, mlp_ratio=4.0): 38 | super().__init__() 39 | self.norm_attn1 = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 40 | self.attn1 = Attention(hidden_size, heads=num_heads) 41 | self.norm_mlp = nn.LayerNorm(hidden_size, elementwise_affine=True, eps=1e-6) 42 | self.mlp = FeedForward(hidden_size, mult=mlp_ratio) # approx_gelu = lambda: nn.GELU(approximate="tanh") 43 | 44 | def forward(self, x): 45 | x = x + self.attn1(self.norm_attn1(x)) 46 | x = x + self.mlp(self.norm_mlp(x)) 47 | return x 48 | 49 | 50 | class AttentionResampler(nn.Module): 51 | 52 | def __init__(self, dim, num_queries=1, heads=1): 53 | super().__init__() 54 | self.q = nn.Parameter(torch.randn(1, num_queries, dim)) 55 | self.kv = nn.Linear(dim, dim * 2, bias=False) 56 | self.h = heads 57 | self.dh = dim // heads 58 | self.norm = nn.LayerNorm(dim, elementwise_affine=True) 59 | 60 | def forward(self, x): 61 | b, s, dim = x.shape 62 | norm_x = self.norm(x) 63 | q = self.q.expand(b, -1, -1) 64 | k, v = self.kv(norm_x).chunk(2, dim=-1) 65 | q, k, v = map(lambda t: t.view(b, -1, self.h, self.dh).transpose(1, 2), (q, k, v)) 66 | 67 | attn_output = F.scaled_dot_product_attention(q, k, v).transpose(1, 2).reshape(b, -1, dim).to(q.dtype) 68 | 69 | return attn_output 70 | 71 | 72 | class FeedForward(nn.Module): 73 | def __init__(self, dim, mult=4, dropout=0.0, bias=True): 74 | super().__init__() 75 | self.net = nn.Sequential( 76 | nn.Linear(dim, int(dim * mult), bias=bias), 77 | nn.GELU(), 78 | nn.Dropout(dropout), 79 | nn.Linear(int(dim * mult), dim, bias=bias), 80 | nn.Dropout(dropout), 81 | ) 82 | 83 | def forward(self, x): 84 | return self.net(x) 85 | 86 | class Attention(nn.Module): 87 | def __init__(self, 88 | dim=768, 89 | heads=8, 90 | ): 91 | super().__init__() 92 | self.qkv = nn.Linear(dim, dim * 3, bias=False) 93 | self.h = heads 94 | self.dh = dim // heads 95 | 96 | def forward(self, x): 97 | b, s, dim = x.shape 98 | q, k, v = map(lambda t: t.view(b, -1, self.h, self.dh).transpose(1, 2), self.qkv(x).chunk(3, dim=-1)) 99 | attn_output = F.scaled_dot_product_attention(q, k, v) 100 | 101 | return attn_output.transpose(1, 2).reshape(b, -1, dim).to(q.dtype) 102 | 103 | 104 | def get_timestep_embedding( 105 | timesteps: torch.Tensor, 106 | embedding_dim: int, 107 | downscale_freq_shift: float = 1, 108 | max_period: int = 10000, 109 | ): 110 | half_dim = embedding_dim // 2 111 | exponent = -math.log(max_period) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) 112 | exponent = exponent / (half_dim - downscale_freq_shift) 113 | 114 | emb = timesteps[:, None].float() * torch.exp(exponent)[None, :] 115 | emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) 116 | 117 | # zero pad 118 | if embedding_dim % 2 == 1: 119 | emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) 120 | return emb 121 | 122 | 123 | class TimestepEmbedding(nn.Module): 124 | def __init__( 125 | self, 126 | in_channels: int, 127 | time_embed_dim: int, 128 | bias=True, 129 | ): 130 | super().__init__() 131 | self.linear_1 = nn.Linear(in_channels, time_embed_dim, bias=bias) 132 | self.act = nn.SiLU() 133 | self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim, bias=bias) 134 | 135 | 136 | def forward(self, timestep): 137 | timestep = get_timestep_embedding(timestep, self.linear_1.in_features).to(self.linear_1.weight.device).to(self.linear_1.weight.dtype) 138 | timestep = self.linear_1(timestep) 139 | timestep = self.act(timestep) 140 | timestep = self.linear_2(timestep) 141 | return timestep 142 | 143 | 144 | class AdaNorm(nn.Module): 145 | 146 | def __init__(self, in_dim, ada_dim): 147 | super().__init__() 148 | self.ada_proj = nn.Linear(ada_dim, 2 * in_dim) 149 | self.norm = nn.LayerNorm(in_dim, elementwise_affine=False) 150 | 151 | def forward(self, hidden_states, ada_embed): 152 | hidden_states = self.norm(hidden_states) 153 | ada_embed = self.ada_proj(ada_embed) 154 | scale, shift = ada_embed.chunk(2, dim=1) 155 | hidden_states = hidden_states * (1 + scale) + shift 156 | return hidden_states 157 | 158 | 159 | class ChunkFanIn(torch.nn.Module): 160 | 161 | def __init__(self, in_dim, out_dim, chunks=1): 162 | super().__init__() 163 | assert in_dim % chunks == 0 164 | self.projs = nn.ModuleList([nn.Linear(in_dim // chunks, out_dim) for _ in range(chunks)]) 165 | self.in_dim = in_dim 166 | self.chunk_dim = in_dim // chunks 167 | 168 | def forward(self, x): 169 | return torch.stack([proj(x[..., (i * self.chunk_dim) : ((i+1) * self.chunk_dim)]) for i, proj in enumerate(self.projs)], dim=1).sum(dim=1) 170 | 171 | 172 | class ChunkFanOut(torch.nn.Module): 173 | 174 | def __init__(self, in_dim, out_dim, chunks=1): 175 | super().__init__() 176 | assert out_dim % chunks == 0 177 | self.projs = nn.ModuleList([nn.Linear(in_dim, out_dim // chunks) for _ in range(chunks)]) 178 | 179 | def forward(self, x): 180 | return torch.cat([proj(x) for proj in self.projs], dim=1) -------------------------------------------------------------------------------- /common/train_utils.py: -------------------------------------------------------------------------------- 1 | 2 | from pathlib import Path 3 | import os 4 | import numpy as np 5 | import torch 6 | import torch.utils.checkpoint 7 | import transformers 8 | from accelerate import Accelerator 9 | from PIL import Image 10 | from torch.utils.data import Dataset 11 | from torchvision import transforms 12 | from accelerate.utils import ProjectConfiguration, set_seed 13 | import diffusers 14 | from diffusers.utils.torch_utils import is_compiled_module 15 | import wandb 16 | import logging 17 | import math 18 | import random 19 | from diffusers import DPMSolverMultistepScheduler 20 | from diffusers.optimization import get_scheduler 21 | from transformers import AutoTokenizer, PretrainedConfig 22 | 23 | import diffusers 24 | from diffusers import ( 25 | AutoencoderKL, 26 | DDPMScheduler, 27 | DiffusionPipeline, 28 | UNet2DConditionModel, 29 | ) 30 | import copy 31 | from tqdm import tqdm 32 | from .dataset import MazeDataset, default_collate_fn 33 | 34 | 35 | def save_model(model, save_path, logger): 36 | torch.save(model.state_dict(), save_path) 37 | logger.info(f"Saved state to {save_path}") 38 | 39 | 40 | def unwrap_model(accelerator, model): 41 | model = accelerator.unwrap_model(model) 42 | model = model._orig_mod if is_compiled_module(model) else model 43 | return model 44 | 45 | 46 | def init_train_basics(args, logger): 47 | logging_dir = Path(args.output_dir, args.logging_dir) 48 | accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) 49 | accelerator = Accelerator( 50 | gradient_accumulation_steps=args.gradient_accumulation_steps, 51 | mixed_precision=args.mixed_precision, 52 | log_with=args.report_to, 53 | project_config=accelerator_project_config, 54 | ) 55 | 56 | # Make one log on every process with the configuration for debugging. 57 | logging.basicConfig( 58 | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", 59 | datefmt="%m/%d/%Y %H:%M:%S", 60 | level=logging.INFO, 61 | ) 62 | logger.info(accelerator.state, main_process_only=False) 63 | if accelerator.is_local_main_process: 64 | transformers.utils.logging.set_verbosity_warning() 65 | diffusers.utils.logging.set_verbosity_info() 66 | else: 67 | transformers.utils.logging.set_verbosity_error() 68 | diffusers.utils.logging.set_verbosity_error() 69 | 70 | # If passed along, set the training seed now. 71 | if args.seed is not None: 72 | set_seed(args.seed) 73 | 74 | # Handle the repository creation 75 | if accelerator.is_main_process: 76 | if args.output_dir is not None: 77 | os.makedirs(args.output_dir, exist_ok=True) 78 | 79 | weight_dtype = torch.float32 80 | if accelerator.mixed_precision == "fp16": 81 | weight_dtype = torch.float16 82 | elif accelerator.mixed_precision == "bf16": 83 | weight_dtype = torch.bfloat16 84 | 85 | # Enable TF32 for faster training on Ampere GPUs, 86 | if args.allow_tf32: 87 | torch.backends.cuda.matmul.allow_tf32 = True 88 | 89 | return accelerator, weight_dtype 90 | 91 | 92 | def get_optimizer(args, params_to_optimize, accelerator): 93 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs 94 | optimizer_class = torch.optim.AdamW 95 | if args.use_8bit_adam: 96 | import bitsandbytes as bnb 97 | optimizer_class = bnb.optim.AdamW8bit 98 | 99 | optimizer = optimizer_class( 100 | params_to_optimize, 101 | lr=args.learning_rate, 102 | betas=(args.adam_beta1, args.adam_beta2), 103 | weight_decay=args.adam_weight_decay, 104 | eps=args.adam_epsilon, 105 | ) 106 | 107 | lr_scheduler = get_scheduler( 108 | args.lr_scheduler, 109 | optimizer=optimizer, 110 | num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, 111 | num_training_steps=args.max_train_steps * accelerator.num_processes, 112 | num_cycles=args.lr_num_cycles, 113 | power=args.lr_power, 114 | ) 115 | 116 | return optimizer, lr_scheduler 117 | 118 | 119 | def get_dataset(args): 120 | train_dataset = MazeDataset(maze_size=args.maze_size) 121 | train_dataloader = torch.utils.data.DataLoader( 122 | train_dataset, 123 | batch_size=args.train_batch_size, 124 | shuffle=True, 125 | collate_fn=default_collate_fn, 126 | num_workers=args.dataloader_num_workers, 127 | ) 128 | 129 | overrode_max_train_steps = False 130 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 131 | 132 | return train_dataset, train_dataloader, num_update_steps_per_epoch 133 | 134 | 135 | def resume_model(model, args, accelerator, num_update_steps_per_epoch): 136 | accelerator.print(f"Resuming from checkpoint {args.resume_from_checkpoint}") 137 | global_step = int(args.resume_from_checkpoint.split("-")[-1]) 138 | state_dict = torch.load(args.resume_from_checkpoint, map_location="cpu") 139 | 140 | if not isinstance(model, list): 141 | model = [model] 142 | for m in model: 143 | missing, unexpected = m.load_state_dict(state_dict, strict=False) 144 | 145 | initial_global_step = global_step 146 | first_epoch = global_step // num_update_steps_per_epoch 147 | 148 | return global_step, initial_global_step, first_epoch 149 | 150 | 151 | def more_init(model, accelerator, args, train_dataloader, train_dataset, logger, num_update_steps_per_epoch, wandb_name="diffusion_lora"): 152 | if accelerator.is_main_process: 153 | tracker_config = vars(copy.deepcopy(args)) 154 | accelerator.init_trackers(wandb_name, config=tracker_config) 155 | 156 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. 157 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 158 | args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) 159 | 160 | # Train! 161 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 162 | 163 | logger.info("***** Running training *****") 164 | logger.info(f" Num examples = {len(train_dataset)}") 165 | logger.info(f" Num batches each epoch = {len(train_dataloader)}") 166 | logger.info(f" Num Epochs = {args.num_train_epochs}") 167 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 168 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 169 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 170 | logger.info(f" Total optimization steps = {args.max_train_steps}") 171 | global_step = 0 172 | first_epoch = 0 173 | 174 | # Potentially load in the weights and states from a previous save 175 | initial_global_step = 0 176 | if args.resume_from_checkpoint: 177 | global_step, initial_global_step, first_epoch = resume_model(model, args, accelerator, num_update_steps_per_epoch) 178 | 179 | progress_bar = tqdm( 180 | range(0, args.max_train_steps), 181 | initial=initial_global_step, 182 | desc="Steps", 183 | disable=not accelerator.is_local_main_process, 184 | ) 185 | 186 | return global_step, first_epoch, progress_bar -------------------------------------------------------------------------------- /diffusion/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import copy 17 | import math 18 | import os 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.utils.checkpoint 22 | from accelerate.logging import get_logger 23 | from tqdm.auto import tqdm 24 | 25 | import sys 26 | sys.path.append('..') 27 | 28 | import common.train_utils 29 | from common.train_utils import ( 30 | init_train_basics, 31 | save_model, 32 | get_optimizer, 33 | get_dataset, 34 | more_init 35 | ) 36 | from common.dataset import visualize_maze 37 | from types import SimpleNamespace 38 | import diffusers 39 | import wandb 40 | from pathlib import Path 41 | from PIL import Image 42 | 43 | default_arguments = dict( 44 | model_path="runwayml/stable-diffusion-v1-5", 45 | output_dir="maze-output", 46 | seed=None, 47 | maze_size=13, 48 | train_batch_size=64, 49 | max_train_steps=50_000, 50 | validation_steps=1000, 51 | checkpointing_steps=1000, 52 | resume_from_checkpoint=None, 53 | gradient_accumulation_steps=1, 54 | gradient_checkpointing=True, 55 | learning_rate=1.0e-4, 56 | lr_scheduler="linear", 57 | lr_warmup_steps=50, 58 | lr_num_cycles=1, 59 | lr_power=1.0, 60 | dataloader_num_workers=4, 61 | use_8bit_adam=False, 62 | adam_beta1=0.9, 63 | adam_beta2=0.98, 64 | adam_weight_decay=1e-2, 65 | adam_epsilon=1e-08, 66 | max_grad_norm=1.0, 67 | report_to="wandb", 68 | mixed_precision="bf16", 69 | allow_tf32=True, 70 | logging_dir="logs", 71 | local_rank=-1, 72 | num_processes=1, 73 | guidance_scale=1.2, 74 | num_timesteps=35 75 | ) 76 | 77 | 78 | class Attn2Patch(torch.nn.Module): 79 | 80 | def __init__(self): 81 | super().__init__() 82 | 83 | def __call__(self, *args, **kwargs): 84 | return torch.zeros_like(args[0]) 85 | 86 | 87 | @torch.no_grad() 88 | def gen_samples(model, noise_scheduler, dataloader, out_dir, guidance_scale=1.0, num_timesteps=25): 89 | if not os.path.exists(out_dir): 90 | os.makedirs(out_dir) 91 | batch = next(iter(dataloader)) 92 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 93 | 94 | latents = torch.randn_like(batch["maze"], device=device, dtype=dtype) 95 | noise_scheduler.set_timesteps(num_timesteps, device=device) 96 | all_mazes = torch.cat([torch.zeros_like(batch["maze"]), batch["maze"]], dim=0).to(device).to(dtype) 97 | 98 | for i, t in tqdm(enumerate(noise_scheduler.timesteps), total=num_timesteps): 99 | # expand the latents if we are doing classifier free guidance 100 | latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents 101 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 102 | # channel wise concat condition 103 | latent_model_input = torch.cat([latent_model_input, all_mazes], dim=1) 104 | 105 | noise_pred = model( 106 | latent_model_input, 107 | t, 108 | encoder_hidden_states=torch.zeros(latent_model_input.shape[0],1,1).to(latent_model_input.device), 109 | return_dict=False, 110 | )[0] 111 | 112 | if guidance_scale > 1.0: 113 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 114 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 115 | latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] 116 | 117 | latents = (latents.clamp(-1, 1) / 2 + 0.5).float().cpu().numpy() 118 | images = [] 119 | for i in range(latents.shape[0]): 120 | maze = batch["maze"][i,0]/2+0.5 121 | solved_maze = visualize_maze(maze.cpu().numpy(), latents[i,0]) 122 | filename = f"{f'{i}'.zfill(3)}.png" 123 | solved_maze.save(Path(out_dir) / filename) 124 | images.append(wandb.Image(solved_maze)) 125 | 126 | wandb.log({"validation_images": images}) 127 | 128 | 129 | def train(args): 130 | logger = get_logger(__name__) 131 | args = SimpleNamespace(**args) 132 | accelerator, weight_dtype = init_train_basics(args, logger) 133 | noise_scheduler = diffusers.DDIMScheduler.from_config(args.model_path, subfolder="scheduler") 134 | model = diffusers.UNet2DConditionModel.from_config(args.model_path, 135 | up_block_types=["CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "UpBlock2D"], 136 | down_block_types=["DownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D","CrossAttnDownBlock2D"], 137 | attention_head_dim= [5,10,20,20], 138 | in_channels=2, 139 | out_channels=1, 140 | subfolder="unet",).to(accelerator.device, dtype=weight_dtype) 141 | # we dont need cross attention 142 | for name, module in model.named_modules(): 143 | if hasattr(module, "attn2"): 144 | module.attn2 = Attn2Patch() 145 | 146 | if args.gradient_checkpointing: 147 | model.enable_gradient_checkpointing() 148 | 149 | optimizer, lr_scheduler = get_optimizer(args, model.parameters(), accelerator) 150 | train_dataset, train_dataloader, num_update_steps_per_epoch = get_dataset(args) 151 | 152 | # Prepare everything with our `accelerator`. 153 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 154 | model, optimizer, train_dataloader, lr_scheduler 155 | ) 156 | 157 | global_step, first_epoch, progress_bar = more_init(model, accelerator, args, train_dataloader, 158 | train_dataset, logger, num_update_steps_per_epoch, wandb_name="diffusion_maze") 159 | 160 | for epoch in range(first_epoch, args.num_train_epochs): 161 | model.train() 162 | for step, batch in enumerate(train_dataloader): 163 | with accelerator.accumulate(model): 164 | maze = batch["maze"].to(accelerator.device).to(weight_dtype) 165 | path_grid = batch["path_grid"].to(accelerator.device).to(weight_dtype) 166 | noise = torch.randn_like(path_grid) 167 | timesteps = torch.randint( 168 | 0, noise_scheduler.config.num_train_timesteps, (noise.shape[0],), device=accelerator.device 169 | ).long() 170 | noisy_model_input = noise_scheduler.add_noise(path_grid, noise, timesteps) 171 | noisy_model_input = torch.cat([noisy_model_input, maze], dim=1) 172 | 173 | # Predict the noise residual 174 | model_pred = model( 175 | noisy_model_input, 176 | timesteps, 177 | encoder_hidden_states=torch.zeros(noisy_model_input.shape[0],1,1).to(accelerator.device), 178 | return_dict=False, 179 | )[0] 180 | 181 | loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean") 182 | 183 | accelerator.backward(loss) 184 | if accelerator.sync_gradients: 185 | grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) 186 | optimizer.step() 187 | lr_scheduler.step() 188 | optimizer.zero_grad(set_to_none=True) 189 | 190 | # Checks if the accelerator has performed an optimization step behind the scenes 191 | if accelerator.sync_gradients: 192 | progress_bar.update(1) 193 | global_step += 1 194 | 195 | if accelerator.is_main_process: 196 | if global_step % args.checkpointing_steps == 0: 197 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 198 | save_model(model, save_path, logger) 199 | 200 | 201 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 202 | progress_bar.set_postfix(**logs) 203 | accelerator.log(logs, step=global_step) 204 | 205 | if global_step >= args.max_train_steps: 206 | break 207 | 208 | if accelerator.is_main_process: 209 | if global_step % args.validation_steps == 0 and global_step > 0: 210 | save_path = os.path.join(args.output_dir, f"samples/checkpoint-{global_step}") 211 | gen_samples(model, noise_scheduler, train_dataloader, save_path, guidance_scale=args.guidance_scale, num_timesteps=args.num_timesteps) 212 | 213 | # Save the lora layers 214 | accelerator.wait_for_everyone() 215 | if accelerator.is_main_process: 216 | save_path = os.path.join(args.output_dir, f"checkpoint-final-{global_step}") 217 | save_model(model, save_path, logger) 218 | 219 | accelerator.end_training() 220 | 221 | 222 | if __name__ == "__main__": 223 | train(default_arguments) -------------------------------------------------------------------------------- /diffusion_coords/models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | from diffusers.models.embeddings import ( 7 | TimestepEmbedding, 8 | Timesteps, 9 | ) 10 | 11 | 12 | class SelfAttention(nn.Module): 13 | def __init__(self, dim, heads=8, dropout=0.0, bias=False): 14 | super().__init__() 15 | self.to_qkv = nn.Linear(dim, dim * 3, bias=bias) 16 | self.h = heads 17 | self.dh = dim // heads 18 | 19 | def forward(self, x, mask=None): 20 | q, k, v = map(lambda t: t.reshape(*t.shape[:-1], self.h, self.dh).transpose(1, 2), (self.to_qkv(x).chunk(3, dim=-1))) 21 | attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask).transpose(1, 2).reshape(q.shape[0], -1, self.h * self.dh) 22 | return attn_output 23 | 24 | 25 | class CrossAttention(nn.Module): 26 | def __init__(self, dim, context_dim, heads=8, dropout=0.0, bias=False): 27 | super().__init__() 28 | self.to_q = nn.Linear(dim, dim, bias=bias) 29 | self.to_kv = nn.Linear(context_dim, dim * 2, bias=bias) 30 | self.h = heads 31 | self.dh = dim // heads 32 | 33 | def forward(self, x, context, mask=None): 34 | q = self.to_q(x).reshape(x.shape[0], -1, self.h, self.dh).transpose(1, 2) 35 | k, v = map(lambda t: t.reshape(t.shape[0], -1, self.h, self.dh).transpose(1, 2), (self.to_kv(context).chunk(2, dim=-1))) 36 | attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask).transpose(1, 2).reshape(q.shape[0], -1, self.h * self.dh) 37 | return attn_output 38 | 39 | 40 | class FeedForward(nn.Module): 41 | def __init__(self, dim, mult=4, dropout=0.0, bias=True, act_fn=nn.GELU): 42 | super().__init__() 43 | self.net = nn.Sequential( 44 | nn.Linear(dim, dim * mult, bias=bias), 45 | act_fn(), 46 | nn.Dropout(dropout), 47 | nn.Linear(dim * mult, dim, bias=bias), 48 | nn.Dropout(dropout), 49 | ) 50 | 51 | def forward(self, x): 52 | return self.net(x) 53 | 54 | 55 | class AdaNorm(torch.nn.Module): 56 | def __init__(self, dim, bias=True): 57 | super().__init__() 58 | self.act_fn = nn.SiLU() 59 | self.linear = nn.Linear(dim, 2 * dim, bias=bias) 60 | self.norm = nn.LayerNorm(dim) 61 | 62 | def forward(self, x, time_emb): 63 | time_emb = self.act_fn(time_emb) 64 | scale, shift = self.linear(time_emb).chunk(2, dim=-1) 65 | return self.norm(x) * scale[:, None, :] + shift[:, None, :] 66 | 67 | def identity(x, *args, **kwargs): 68 | return x 69 | 70 | class TransformerLayer(nn.Module): 71 | 72 | def __init__(self, 73 | query_dim=768, 74 | context_dim=1024, 75 | heads=8, 76 | dropout=0.0, 77 | ff_mult=4, 78 | use_cross_attn=True, 79 | ): 80 | 81 | super().__init__() 82 | self.self_attn = SelfAttention(query_dim, heads=heads, dropout=dropout) 83 | self.self_norm = AdaNorm(query_dim) 84 | 85 | self.cross_attn = CrossAttention(query_dim, context_dim, heads=heads, dropout=dropout) if use_cross_attn else identity 86 | self.cross_norm = AdaNorm(query_dim) if use_cross_attn else identity 87 | 88 | self.ff = FeedForward(query_dim, mult=ff_mult, dropout=dropout) 89 | self.ff_norm = AdaNorm(query_dim) 90 | 91 | self.gradient_checkpointing = False 92 | 93 | def forward(self, x, context, ada_emb=None, attn_mask=None, cross_attn_mask=None): 94 | if self.gradient_checkpointing: 95 | x = torch.utils.checkpoint.checkpoint(self.self_attn, self.self_norm(x, ada_emb), attn_mask) + x 96 | x = torch.utils.checkpoint.checkpoint(self.cross_attn, self.cross_norm(x, ada_emb), context, cross_attn_mask) + x 97 | x = torch.utils.checkpoint.checkpoint(self.ff, self.ff_norm(x, ada_emb)) + x 98 | else: 99 | x = self.self_attn(self.self_norm(x, ada_emb), attn_mask) + x 100 | x = self.cross_attn(self.cross_norm(x, ada_emb), context, cross_attn_mask) + x 101 | x = self.ff(self.ff_norm(x, ada_emb)) + x 102 | 103 | return x 104 | 105 | 106 | class FourierEmbedder: 107 | def __init__(self, num_freqs, temperature): 108 | self.num_freqs = num_freqs 109 | self.temperature = temperature 110 | self.freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs) 111 | 112 | @torch.no_grad() 113 | def __call__(self, x, cat_dim=-1): 114 | out = [] 115 | for freq in self.freq_bands: 116 | out.append(torch.sin(freq * x)) 117 | out.append(torch.cos(freq * x)) 118 | return torch.cat(out, cat_dim) 119 | 120 | 121 | 122 | class DiffusionTransformer(nn.Module): 123 | 124 | def __init__(self, in_channels=3, 125 | out_channels=3, 126 | num_layers_encoder=6, 127 | num_layers_decoder=8, 128 | dim=512, 129 | heads=8, 130 | ff_mult=4, 131 | maze_size=20, 132 | act_fn="silu", 133 | num_freqs=64 134 | ): 135 | 136 | super().__init__() 137 | # all coords normalized -1, 1 138 | self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, temperature=60) 139 | self.embed_path = nn.Linear(num_freqs * 2 * in_channels, dim) 140 | 141 | # 0=empty, 1=wall, 2=start, 3=end 142 | self.embed_maze = nn.Embedding(4, dim) 143 | self.maze_one_side = maze_size * 2 + 1 144 | self.total_pixs = self.maze_one_side * self.maze_one_side 145 | 146 | self.pos_embs_encoder = nn.Parameter(torch.randn(1, self.total_pixs, dim) * 0.01) 147 | self.pos_embs_decoder = nn.Parameter(torch.randn(1, self.total_pixs, dim) * 0.01) 148 | 149 | self.time_proj = Timesteps(dim//2, flip_sin_to_cos=False, downscale_freq_shift=0.0) 150 | self.time_embedding = TimestepEmbedding( 151 | dim//2, 152 | dim, 153 | act_fn=act_fn, 154 | post_act_fn=None, 155 | cond_proj_dim=None, 156 | ) 157 | 158 | self.encoder = nn.ModuleList([TransformerLayer(query_dim=dim, 159 | context_dim=dim, 160 | heads=heads, 161 | dropout=0.0, 162 | ff_mult=ff_mult, 163 | use_cross_attn=False, 164 | ) for _ in range(num_layers_encoder)]) 165 | 166 | 167 | self.decoder = nn.ModuleList([TransformerLayer(query_dim=dim, 168 | context_dim=dim, 169 | heads=heads, 170 | dropout=0.0, 171 | ff_mult=ff_mult, 172 | use_cross_attn=True, 173 | ) for _ in range(num_layers_decoder)]) 174 | 175 | self.final_layer_norm = nn.LayerNorm(dim) 176 | self.out_proj = nn.Linear(dim, out_channels) 177 | 178 | 179 | def enable_gradient_checkpointing(self): 180 | for layer in self.encoder: 181 | layer.gradient_checkpointing = True 182 | for layer in self.decoder: 183 | layer.gradient_checkpointing = True 184 | 185 | 186 | def forward(self, path, maze, timesteps, attn_mask=None, dropout_mask=None): 187 | #timestep 188 | if len(timesteps.shape) == 0: 189 | timesteps = timesteps[None].to(path.device) 190 | timesteps = timesteps.expand(path.shape[0]) 191 | 192 | t_emb = self.time_proj(timesteps).to(dtype=path.dtype) 193 | emb = self.time_embedding(t_emb, None) 194 | 195 | # maze goes through encoder 196 | b, _, h, w = maze.shape 197 | maze = maze.squeeze(1).reshape(b, -1) 198 | maze_embs = self.embed_maze(maze) 199 | maze_embs = maze_embs + self.pos_embs_encoder.repeat(b, 1, 1) 200 | 201 | for layer in self.encoder: 202 | maze_embs = layer(maze_embs, None, emb) 203 | 204 | if dropout_mask is not None: 205 | maze_embs = maze_embs * dropout_mask[:, None, None] 206 | 207 | # path should already be flattened and padded to fit maze 208 | b, s, xy = path.shape #b, s, 3 209 | path = self.fourier_embedder(path) 210 | path = self.embed_path(path) 211 | path = path + self.pos_embs_decoder.repeat(b, 1, 1) 212 | for layer in self.decoder: 213 | path = layer(path, maze_embs, emb, attn_mask=attn_mask) 214 | path = self.final_layer_norm(path) 215 | path = self.out_proj(path) 216 | 217 | return path -------------------------------------------------------------------------------- /diffusion_coords/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import copy 17 | import math 18 | import os 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.utils.checkpoint 22 | from accelerate.logging import get_logger 23 | from tqdm.auto import tqdm 24 | 25 | import sys 26 | sys.path.append('..') 27 | 28 | import common.train_utils 29 | from common.train_utils import ( 30 | init_train_basics, 31 | save_model, 32 | get_optimizer, 33 | get_dataset, 34 | more_init 35 | ) 36 | from common.dataset import visualize_maze 37 | from types import SimpleNamespace 38 | import diffusers 39 | import wandb 40 | from pathlib import Path 41 | from PIL import Image 42 | from diffusion_coords.models import DiffusionTransformer 43 | 44 | default_arguments = dict( 45 | model_path="runwayml/stable-diffusion-v1-5", 46 | output_dir="maze-output", 47 | seed=None, 48 | maze_size=13, 49 | train_batch_size=64, 50 | max_train_steps=60_000, 51 | validation_steps=1000, 52 | checkpointing_steps=1000, 53 | resume_from_checkpoint="/home/ubuntu/MazeSolver/diffusion_coords/maze-output/checkpoint-4000", 54 | gradient_accumulation_steps=1, 55 | gradient_checkpointing=True, 56 | learning_rate=1.0e-4, 57 | lr_scheduler="linear", 58 | lr_warmup_steps=50, 59 | lr_num_cycles=1, 60 | lr_power=1.0, 61 | dataloader_num_workers=4, 62 | use_8bit_adam=False, 63 | adam_beta1=0.9, 64 | adam_beta2=0.98, 65 | adam_weight_decay=1e-2, 66 | adam_epsilon=1e-08, 67 | max_grad_norm=1.0, 68 | report_to="wandb", 69 | mixed_precision="bf16", 70 | allow_tf32=True, 71 | logging_dir="logs", 72 | local_rank=-1, 73 | num_processes=1, 74 | guidance_scale=1.2, 75 | num_timesteps=35, 76 | 77 | dropout=0.05, 78 | 79 | encoder_layers=6, 80 | decoder_layers=8, 81 | dim=512, 82 | heads=8, 83 | ff_mult=3, 84 | ) 85 | 86 | 87 | @torch.no_grad() 88 | def gen_samples(model, noise_scheduler, dataloader, out_dir, guidance_scale=1.0, num_timesteps=25): 89 | if not os.path.exists(out_dir): 90 | os.makedirs(out_dir) 91 | batch = next(iter(dataloader)) 92 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 93 | 94 | mazes = batch["maze_labeled"].to(device).long() 95 | 96 | latents = torch.randn(mazes.shape[0], mazes.shape[2] * mazes.shape[3], 3, device=device).to(dtype) 97 | noise_scheduler.set_timesteps(num_timesteps, device=device) 98 | all_mazes = torch.cat([mazes, mazes], dim=0).to(device) # will use mask to dropout neg 99 | mask = torch.ones(all_mazes.shape[0], device=device).to(dtype) 100 | mask[:mask.shape[0] // 2] = 0 101 | 102 | for i, t in tqdm(enumerate(noise_scheduler.timesteps), total=num_timesteps): 103 | # expand the latents if we are doing classifier free guidance 104 | latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1.0 else latents 105 | latent_model_input = noise_scheduler.scale_model_input(latent_model_input, t) 106 | 107 | noise_pred = model( 108 | latent_model_input, 109 | all_mazes, 110 | t, 111 | dropout_mask=mask, 112 | ) 113 | 114 | if guidance_scale > 1.0: 115 | noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 116 | noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 117 | latents = noise_scheduler.step(noise_pred, t, latents, return_dict=False)[0] 118 | 119 | # unnormalize coordinates 120 | masks = latents[:, :, 2:3] 121 | coords = latents[:, :, :2] 122 | coords = coords / 3 123 | coords = coords / 2 + 0.5 124 | coords = coords * mazes.shape[-1] # scale relative back to actual maze dim 125 | # round to nearest int 126 | coords = coords.round().long() 127 | 128 | images = [] 129 | for i in range(coords.shape[0]): 130 | maze = mazes[i].squeeze() 131 | for j in range(coords.shape[1]): 132 | try: 133 | if masks[i, j, 0] > 0.5: 134 | maze[coords[i, j, 1], coords[i, j, 0]] = 5 # mark path 135 | except: 136 | pass 137 | 138 | solved_maze = visualize_maze(maze.float().cpu().numpy()) 139 | images.append(solved_maze) 140 | 141 | wandb_images = [] 142 | for i in range(len(images)): 143 | filename = f"{f'{i}'.zfill(3)}.png" 144 | images[i].save(Path(out_dir) / filename) 145 | wandb_images.append(wandb.Image(images[i])) 146 | 147 | wandb.log({"validation_images": wandb_images}) 148 | 149 | 150 | def train(args): 151 | logger = get_logger(__name__) 152 | args = SimpleNamespace(**args) 153 | accelerator, weight_dtype = init_train_basics(args, logger) 154 | noise_scheduler = diffusers.DDIMScheduler.from_config(args.model_path, subfolder="scheduler") 155 | 156 | model = DiffusionTransformer(num_layers_encoder=args.encoder_layers, 157 | num_layers_decoder=args.decoder_layers, 158 | dim=args.dim, 159 | heads=args.heads, 160 | ff_mult=args.ff_mult, 161 | maze_size=args.maze_size, 162 | ) 163 | 164 | if args.gradient_checkpointing: 165 | model.enable_gradient_checkpointing() 166 | 167 | optimizer, lr_scheduler = get_optimizer(args, model.parameters(), accelerator) 168 | train_dataset, train_dataloader, num_update_steps_per_epoch = get_dataset(args) 169 | 170 | # Prepare everything with our `accelerator`. 171 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 172 | model, optimizer, train_dataloader, lr_scheduler 173 | ) 174 | 175 | global_step, first_epoch, progress_bar = more_init(model, accelerator, args, train_dataloader, 176 | train_dataset, logger, num_update_steps_per_epoch, wandb_name="diffusion_coords_maze") 177 | 178 | for epoch in range(first_epoch, args.num_train_epochs): 179 | model.train() 180 | for step, batch in enumerate(train_dataloader): 181 | with accelerator.accumulate(model): 182 | maze = batch["maze_labeled"].to(accelerator.device).long() 183 | path = batch["path"] 184 | 185 | def pad_to_size(x): 186 | return F.pad(torch.tensor(x), (0, 0, 0, model.total_pixs - len(x)), value=-10) # placeholder number 187 | path = [pad_to_size(x) for x in path] # pad to maximum possible path length 188 | path = torch.stack(path, dim=0).to(accelerator.device).to(weight_dtype) 189 | mask = (path[:,:,0:1] != -10).float() # mask out areas after path has finished 190 | path = torch.where(path == -10, torch.zeros_like(path), path) # replace placeholder with 0s 191 | path = (path / model.maze_one_side) * 2 - 1 # [-1, 1] 192 | path = path * 3 # path std naturally tends to be around 0.36 193 | path = torch.cat([path, mask], dim=-1) # b, s, 3, -> [x, y, mask] 194 | 195 | noise = torch.randn_like(path) 196 | timesteps = torch.randint( 197 | 0, noise_scheduler.config.num_train_timesteps, (path.shape[0],), device=accelerator.device 198 | ).long() 199 | noisy_model_input = noise_scheduler.add_noise(path, noise, timesteps) 200 | 201 | # Predict actual coordinates and mask will tell us where to stop 202 | mask = torch.rand(noise.shape[0], device=accelerator.device) > args.dropout 203 | noise_pred = model(noisy_model_input, maze, timesteps, attn_mask=None, dropout_mask=mask) 204 | noise_pred_coords = noise_pred[:, :, :2] 205 | noise_pred_masks = noise_pred[:, :, 2:3] 206 | noise_coords = noise[:, :, :2] 207 | noise_masks = noise[:, :, 2:3] 208 | 209 | loss_coords = F.mse_loss(noise_pred_coords, noise_coords, reduction="mean") 210 | loss_masks = F.mse_loss(noise_pred_masks, noise_masks, reduction="mean") 211 | loss = loss_coords + loss_masks 212 | 213 | accelerator.backward(loss) 214 | if accelerator.sync_gradients: 215 | grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) 216 | optimizer.step() 217 | lr_scheduler.step() 218 | optimizer.zero_grad(set_to_none=True) 219 | 220 | # Checks if the accelerator has performed an optimization step behind the scenes 221 | if accelerator.sync_gradients: 222 | progress_bar.update(1) 223 | global_step += 1 224 | 225 | if accelerator.is_main_process: 226 | if global_step % args.checkpointing_steps == 0: 227 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 228 | save_model(model, save_path, logger) 229 | 230 | 231 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 232 | progress_bar.set_postfix(**logs) 233 | accelerator.log(logs, step=global_step) 234 | 235 | if global_step >= args.max_train_steps: 236 | break 237 | 238 | if accelerator.is_main_process: 239 | if global_step % args.validation_steps == 0 and global_step > 0: 240 | save_path = os.path.join(args.output_dir, f"samples/checkpoint-{global_step}") 241 | gen_samples(model, noise_scheduler, train_dataloader, save_path, guidance_scale=args.guidance_scale, num_timesteps=args.num_timesteps) 242 | 243 | # Save the lora layers 244 | accelerator.wait_for_everyone() 245 | if accelerator.is_main_process: 246 | save_path = os.path.join(args.output_dir, f"checkpoint-final-{global_step}") 247 | save_model(model, save_path, logger) 248 | 249 | accelerator.end_training() 250 | 251 | 252 | if __name__ == "__main__": 253 | train(default_arguments) -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ethansmith2000/MazeSolver/43cacf58bf50e9216d67c714156dc7add7d7907b/requirements.txt -------------------------------------------------------------------------------- /transformer_move/models.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | 6 | 7 | class SelfAttention(nn.Module): 8 | def __init__(self, dim, heads=8, dropout=0.0, bias=False): 9 | super().__init__() 10 | self.to_qkv = nn.Linear(dim, dim * 3, bias=bias) 11 | self.h = heads 12 | self.dh = dim // heads 13 | 14 | def forward(self, x, mask=None): 15 | q, k, v = map(lambda t: t.reshape(*t.shape[:-1], self.h, self.dh).transpose(1, 2), (self.to_qkv(x).chunk(3, dim=-1))) 16 | attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, is_causal=True).transpose(1, 2).reshape(q.shape[0], -1, self.h * self.dh) 17 | return attn_output 18 | 19 | 20 | class CrossAttention(nn.Module): 21 | def __init__(self, dim, context_dim, heads=8, dropout=0.0, bias=False): 22 | super().__init__() 23 | self.to_q = nn.Linear(dim, dim, bias=bias) 24 | self.to_kv = nn.Linear(context_dim, dim * 2, bias=bias) 25 | self.h = heads 26 | self.dh = dim // heads 27 | 28 | def forward(self, x, context, mask=None): 29 | q = self.to_q(x).reshape(x.shape[0], -1, self.h, self.dh).transpose(1, 2) 30 | k, v = map(lambda t: t.reshape(t.shape[0], -1, self.h, self.dh).transpose(1, 2), (self.to_kv(context).chunk(2, dim=-1))) 31 | attn_output = F.scaled_dot_product_attention(q, k, v, attn_mask=mask).transpose(1, 2).reshape(q.shape[0], -1, self.h * self.dh) 32 | return attn_output 33 | 34 | 35 | class FeedForward(nn.Module): 36 | def __init__(self, dim, mult=4, dropout=0.0, bias=True, act_fn=nn.GELU): 37 | super().__init__() 38 | self.net = nn.Sequential( 39 | nn.Linear(dim, dim * mult, bias=bias), 40 | act_fn(), 41 | nn.Dropout(dropout), 42 | nn.Linear(dim * mult, dim, bias=bias), 43 | nn.Dropout(dropout), 44 | ) 45 | 46 | def forward(self, x): 47 | return self.net(x) 48 | 49 | 50 | class AdaNorm(torch.nn.Module): 51 | def __init__(self, dim, bias=True): 52 | super().__init__() 53 | self.act_fn = nn.SiLU() 54 | self.linear = nn.Linear(dim, 2 * dim, bias=bias) 55 | self.norm = nn.LayerNorm(dim) 56 | 57 | def forward(self, x, time_emb): 58 | time_emb = self.act_fn(time_emb) 59 | scale, shift = self.linear(time_emb).chunk(2, dim=-1) 60 | return self.norm(x) * scale[:, None, :] + shift[:, None, :] 61 | 62 | def identity(x, *args, **kwargs): 63 | return x 64 | 65 | class TransformerLayer(nn.Module): 66 | 67 | def __init__(self, 68 | query_dim=768, 69 | context_dim=1024, 70 | heads=8, 71 | dropout=0.0, 72 | ff_mult=4, 73 | use_cross_attn=True, 74 | ada_norm=False 75 | ): 76 | 77 | super().__init__() 78 | norm_class = AdaNorm if ada_norm else nn.LayerNorm 79 | self.self_attn = SelfAttention(query_dim, heads=heads, dropout=dropout) 80 | self.self_norm = norm_class(query_dim) 81 | 82 | self.cross_attn = CrossAttention(query_dim, context_dim, heads=heads, dropout=dropout) if use_cross_attn else identity 83 | self.cross_norm = norm_class(query_dim) if use_cross_attn else identity 84 | 85 | self.ff = FeedForward(query_dim, mult=ff_mult, dropout=dropout) 86 | self.ff_norm = norm_class(query_dim) 87 | 88 | self.gradient_checkpointing = False 89 | 90 | def forward(self, x, context, ada_emb=None, attn_mask=None, cross_attn_mask=None): 91 | if self.gradient_checkpointing: 92 | x = torch.utils.checkpoint.checkpoint(self.self_attn, self.self_norm(x), attn_mask) + x 93 | x = torch.utils.checkpoint.checkpoint(self.cross_attn, self.cross_norm(x), context, cross_attn_mask) + x 94 | x = torch.utils.checkpoint.checkpoint(self.ff, self.ff_norm(x)) + x 95 | else: 96 | x = self.self_attn(self.self_norm(x), attn_mask) + x 97 | x = self.cross_attn(self.cross_norm(x), context, cross_attn_mask) + x 98 | x = self.ff(self.ff_norm(x)) + x 99 | 100 | return x 101 | 102 | 103 | class Transformer(nn.Module): 104 | 105 | def __init__(self, 106 | num_layers_encoder=6, 107 | num_layers_decoder=8, 108 | dim=512, 109 | heads=8, 110 | ff_mult=4, 111 | maze_size=20, 112 | movements=True 113 | ): 114 | 115 | super().__init__() 116 | self.maze_one_side = maze_size * 2 + 1 117 | self.total_pixs = self.maze_one_side * self.maze_one_side 118 | 119 | self.num_options = 6 if movements else self.total_pixs 120 | 121 | # 0=down, 1=up, 2=right, 3=left 122 | self.embed_path = nn.Embedding(self.num_options, dim) 123 | # 0=empty, 1=wall, 2=start, 3=end 124 | self.embed_maze = nn.Embedding(4, dim) 125 | 126 | 127 | self.pos_embs_encoder = nn.Parameter(torch.randn(1, self.total_pixs, dim) * 0.01) 128 | self.pos_embs_decoder = nn.Parameter(torch.randn(1, self.total_pixs, dim) * 0.01) 129 | 130 | self.encoder = nn.ModuleList([TransformerLayer(query_dim=dim, 131 | context_dim=dim, 132 | heads=heads, 133 | dropout=0.0, 134 | ff_mult=ff_mult, 135 | use_cross_attn=False, 136 | ) for _ in range(num_layers_encoder)]) 137 | 138 | 139 | self.decoder = nn.ModuleList([TransformerLayer(query_dim=dim, 140 | context_dim=dim, 141 | heads=heads, 142 | dropout=0.0, 143 | ff_mult=ff_mult, 144 | use_cross_attn=True, 145 | ) for _ in range(num_layers_decoder)]) 146 | 147 | self.final_layer_norm = nn.LayerNorm(dim) 148 | self.out_proj = nn.Linear(dim, self.num_options) 149 | 150 | 151 | def enable_gradient_checkpointing(self): 152 | for layer in self.encoder: 153 | layer.gradient_checkpointing = True 154 | for layer in self.decoder: 155 | layer.gradient_checkpointing = True 156 | 157 | 158 | def forward(self, path, maze, attn_mask=None): 159 | # maze goes through encoder 160 | b, _, h, w = maze.shape 161 | maze = maze.squeeze(1).reshape(b, -1) 162 | maze_embs = self.embed_maze(maze) 163 | maze_embs = maze_embs + self.pos_embs_encoder.repeat(b, 1, 1) 164 | 165 | for layer in self.encoder: 166 | maze_embs = layer(maze_embs, None) 167 | 168 | # path should already be flattened and padded to fit maze 169 | b, s = path.shape #b, s 170 | path = self.embed_path(path) 171 | path = path + self.pos_embs_decoder[:, :s, :].repeat(b, 1, 1) 172 | for layer in self.decoder: 173 | path = layer(path, maze_embs, attn_mask=attn_mask) 174 | path = self.final_layer_norm(path) 175 | path = self.out_proj(path) 176 | 177 | return path -------------------------------------------------------------------------------- /transformer_move/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding=utf-8 3 | # Copyright 2024 The HuggingFace Inc. team. All rights reserved. 4 | # 5 | # Licensed under the Apache License, Version 2.0 (the "License"); 6 | # you may not use this file except in compliance with the License. 7 | # You may obtain a copy of the License at 8 | # 9 | # http://www.apache.org/licenses/LICENSE-2.0 10 | # 11 | # Unless required by applicable law or agreed to in writing, software 12 | # distributed under the License is distributed on an "AS IS" BASIS, 13 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 | # See the License for the specific language governing permissions and 15 | 16 | import copy 17 | import math 18 | import os 19 | import torch 20 | import torch.nn.functional as F 21 | import torch.utils.checkpoint 22 | from accelerate.logging import get_logger 23 | from tqdm.auto import tqdm 24 | 25 | import sys 26 | sys.path.append('..') 27 | 28 | import common.train_utils 29 | from common.train_utils import ( 30 | init_train_basics, 31 | save_model, 32 | get_optimizer, 33 | get_dataset, 34 | more_init 35 | ) 36 | from common.dataset import visualize_maze, get_movements_from_path, get_path_from_movements 37 | from types import SimpleNamespace 38 | import diffusers 39 | import wandb 40 | from pathlib import Path 41 | from PIL import Image 42 | from transformer_move.models import Transformer 43 | 44 | default_arguments = dict( 45 | model_path="runwayml/stable-diffusion-v1-5", 46 | output_dir="maze-output", 47 | seed=None, 48 | maze_size=13, 49 | train_batch_size=64, 50 | max_train_steps=40_000, 51 | validation_steps=1000, 52 | checkpointing_steps=1000, 53 | resume_from_checkpoint="/home/ubuntu/MazeSolver/transformer_move/maze-output/checkpoint-15000", 54 | gradient_accumulation_steps=1, 55 | gradient_checkpointing=True, 56 | learning_rate=5.0e-5, 57 | lr_scheduler="linear", 58 | lr_warmup_steps=50, 59 | lr_num_cycles=1, 60 | lr_power=1.0, 61 | dataloader_num_workers=4, 62 | use_8bit_adam=False, 63 | adam_beta1=0.9, 64 | adam_beta2=0.98, 65 | adam_weight_decay=1e-2, 66 | adam_epsilon=1e-08, 67 | max_grad_norm=1.0, 68 | report_to="wandb", 69 | mixed_precision="bf16", 70 | allow_tf32=True, 71 | logging_dir="logs", 72 | local_rank=-1, 73 | num_processes=1, 74 | 75 | encoder_layers=6, 76 | decoder_layers=8, 77 | dim=512, 78 | heads=8, 79 | ff_mult=3, 80 | ) 81 | 82 | 83 | 84 | @torch.no_grad() 85 | def gen_samples(model, dataloader, out_dir): 86 | if not os.path.exists(out_dir): 87 | os.makedirs(out_dir) 88 | batch = next(iter(dataloader)) 89 | device, dtype = next(model.parameters()).device, next(model.parameters()).dtype 90 | 91 | mazes = batch["maze_labeled"].to(device).long()[:32] 92 | start_token = torch.tensor([4], device=device).long()[None,:].repeat(mazes.shape[0], 1) 93 | sequence = start_token.clone() 94 | 95 | for i in tqdm(range(model.total_pixs//2)): 96 | preds = model(sequence, mazes, attn_mask=None,) 97 | if i == 20: 98 | with torch.no_grad(): 99 | print(torch.nn.functional.softmax(preds[0, :20], dim=-1)) 100 | print(sequence[0, :20]) 101 | 102 | # keep going until all paths have a stop token in them or we reach max_len 103 | preds = preds.argmax(dim=-1) 104 | sequence = torch.cat([sequence, preds[:, -1:]], dim=1) 105 | 106 | if len(torch.where(preds == 5)[0]) == mazes.shape[0]: 107 | break 108 | 109 | paths = [] 110 | # truncate each path to the first stop token 111 | for i in range(sequence.shape[0]): 112 | moves = sequence[i] 113 | end_pos = torch.where(moves == 5)[0] 114 | if len(end_pos) > 0: 115 | moves = moves[:end_pos[0] + 1] 116 | path = get_path_from_movements(moves, mazes[i,0]) 117 | paths.append(path) 118 | 119 | images = [] 120 | for i in range(len(paths)): 121 | maze = mazes[i,0] 122 | solved_maze = visualize_maze(maze.float().cpu().numpy(), paths[i]) 123 | filename = f"{f'{i}'.zfill(3)}.png" 124 | solved_maze.save(Path(out_dir) / filename) 125 | images.append(wandb.Image(solved_maze)) 126 | 127 | wandb.log({"validation_images": images}) 128 | 129 | 130 | def train(args): 131 | logger = get_logger(__name__) 132 | args = SimpleNamespace(**args) 133 | accelerator, weight_dtype = init_train_basics(args, logger) 134 | 135 | model = Transformer(num_layers_encoder=args.encoder_layers, 136 | num_layers_decoder=args.decoder_layers, 137 | dim=args.dim, 138 | heads=args.heads, 139 | ff_mult=args.ff_mult, 140 | maze_size=args.maze_size, 141 | ) 142 | 143 | if args.gradient_checkpointing: 144 | model.enable_gradient_checkpointing() 145 | 146 | optimizer, lr_scheduler = get_optimizer(args, model.parameters(), accelerator) 147 | train_dataset, train_dataloader, num_update_steps_per_epoch = get_dataset(args) 148 | 149 | # Prepare everything with our `accelerator`. 150 | model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( 151 | model, optimizer, train_dataloader, lr_scheduler 152 | ) 153 | 154 | global_step, first_epoch, progress_bar = more_init(model, accelerator, args, train_dataloader, 155 | train_dataset, logger, num_update_steps_per_epoch, wandb_name="transformer_maze") 156 | 157 | for epoch in range(first_epoch, args.num_train_epochs): 158 | model.train() 159 | for step, batch in enumerate(train_dataloader): 160 | with accelerator.accumulate(model): 161 | maze = batch["maze_labeled"].to(accelerator.device).long() 162 | path = batch["path"] 163 | 164 | # turn path into series of movements 165 | new_paths = [] 166 | for p in path: 167 | movement_ids, _ = get_movements_from_path(p) 168 | # add 4 to the start and 5 to the end to indicate start and end 169 | movement_ids = [4] + movement_ids + [5] 170 | movement_ids = torch.tensor(movement_ids, device=accelerator.device).long() 171 | new_paths.append(movement_ids) 172 | 173 | def pad_to_size(x): 174 | return F.pad(torch.tensor(x), (0, model.total_pixs // 2 - len(x)), value=5) 175 | 176 | path = [pad_to_size(x) for x in new_paths] 177 | path = torch.stack(path, dim=0).to(accelerator.device).long() 178 | 179 | # trim to longest path 180 | longest_pos = path.argmax(dim=-1).max() 181 | path = path[:, :longest_pos+1] 182 | 183 | input_path = path[:, :-1] 184 | target_path = path[:, 1:] 185 | 186 | preds = model(input_path, maze, attn_mask=None) 187 | loss = F.cross_entropy(preds.reshape(-1, preds.shape[-1]).float(), target_path.reshape(-1), reduction="none") 188 | loss = loss.mean() 189 | 190 | accelerator.backward(loss) 191 | if accelerator.sync_gradients: 192 | grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) 193 | optimizer.step() 194 | lr_scheduler.step() 195 | optimizer.zero_grad(set_to_none=True) 196 | 197 | # Checks if the accelerator has performed an optimization step behind the scenes 198 | if accelerator.sync_gradients: 199 | progress_bar.update(1) 200 | global_step += 1 201 | 202 | if accelerator.is_main_process: 203 | if global_step % args.checkpointing_steps == 0: 204 | save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") 205 | save_model(model, save_path, logger) 206 | 207 | 208 | logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} 209 | progress_bar.set_postfix(**logs) 210 | accelerator.log(logs, step=global_step) 211 | 212 | if global_step >= args.max_train_steps: 213 | break 214 | 215 | if accelerator.is_main_process: 216 | if global_step % args.validation_steps == 0 and global_step > 0: 217 | save_path = os.path.join(args.output_dir, f"samples/checkpoint-{global_step}") 218 | gen_samples(model, train_dataloader, save_path) 219 | 220 | # Save the lora layers 221 | accelerator.wait_for_everyone() 222 | if accelerator.is_main_process: 223 | save_path = os.path.join(args.output_dir, f"checkpoint-final-{global_step}") 224 | save_model(model, save_path, logger) 225 | 226 | accelerator.end_training() 227 | 228 | 229 | if __name__ == "__main__": 230 | train(default_arguments) --------------------------------------------------------------------------------