├── .gitignore ├── LICENSE.txt ├── QuickDiffusionModel └── QuickDiffusionModel.ipynb ├── Make Language Model from scratch like MNIST └── MakeLanguageModelFromScratch.ipynb └── TransferLearningGAN └── TransferLearningCycleGAN.ipynb /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | 3 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Seachaos 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /QuickDiffusionModel/QuickDiffusionModel.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "f59c9742-5a7f-4317-bcab-8660acc70016", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "\"\"\"\n", 11 | "For more detail:\n", 12 | "https://tree.rocks/make-diffusion-model-from-scratch-easy-way-to-implement-quick-diffusion-model-e60d18fd0f2e\n", 13 | "\"\"\"\n", 14 | "import os\n", 15 | "os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'\n", 16 | "\n", 17 | "import tensorflow as tf\n", 18 | "for gpu in tf.config.list_physical_devices('GPU'):\n", 19 | " tf.config.experimental.set_memory_growth(gpu, True)" 20 | ] 21 | }, 22 | { 23 | "cell_type": "code", 24 | "execution_count": null, 25 | "id": "379efde3-f91e-4cb5-a404-42ea52e3a4e1", 26 | "metadata": {}, 27 | "outputs": [], 28 | "source": [ 29 | "import numpy as np\n", 30 | "\n", 31 | "from tqdm.auto import trange, tqdm\n", 32 | "import matplotlib.pyplot as plt\n", 33 | "\n", 34 | "import tensorflow as tf\n", 35 | "from tensorflow.keras import layers" 36 | ] 37 | }, 38 | { 39 | "cell_type": "code", 40 | "execution_count": null, 41 | "id": "63ecf33b-edf2-41cf-adcd-9475d492522e", 42 | "metadata": {}, 43 | "outputs": [], 44 | "source": [ 45 | "(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()\n", 46 | "X_train = X_train[y_train.squeeze() == 1]\n", 47 | "X_train = (X_train / 127.5) - 1.0" 48 | ] 49 | }, 50 | { 51 | "cell_type": "code", 52 | "execution_count": null, 53 | "id": "9a0d2bcf-e69d-4147-96ca-e6bed75643de", 54 | "metadata": {}, 55 | "outputs": [], 56 | "source": [ 57 | "IMG_SIZE = 32 # input image size, CIFAR-10 is 32x32\n", 58 | "BATCH_SIZE = 128 # for training batch size\n", 59 | "timesteps = 16 # how many steps for a noisy image into clear\n", 60 | "time_bar = 1 - np.linspace(0, 1.0, timesteps + 1) # linspace for timesteps" 61 | ] 62 | }, 63 | { 64 | "cell_type": "code", 65 | "execution_count": null, 66 | "id": "9aaa1fcf-a588-484b-99a2-a288a1a58a56", 67 | "metadata": {}, 68 | "outputs": [], 69 | "source": [ 70 | "plt.plot(time_bar, label='Noise')\n", 71 | "plt.plot(1 - time_bar, label='Clarity')\n", 72 | "plt.legend()" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "b7611759-27d1-418a-bd56-0436246f5863", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def cvtImg(img):\n", 83 | " img = img - img.min()\n", 84 | " img = (img / img.max())\n", 85 | " return img.astype(np.float32)\n", 86 | "\n", 87 | "def show_examples(x):\n", 88 | " plt.figure(figsize=(10, 10))\n", 89 | " for i in range(25):\n", 90 | " plt.subplot(5, 5, i+1)\n", 91 | " img = cvtImg(x[i])\n", 92 | " plt.imshow(img)\n", 93 | " plt.axis('off')\n", 94 | "\n", 95 | "show_examples(X_train)" 96 | ] 97 | }, 98 | { 99 | "cell_type": "code", 100 | "execution_count": null, 101 | "id": "bd8c84b3-bd06-4f55-8d06-44b2e5bcf817", 102 | "metadata": {}, 103 | "outputs": [], 104 | "source": [ 105 | "def forward_noise(x, t):\n", 106 | " a = time_bar[t] # base on t\n", 107 | " b = time_bar[t + 1] # image for t + 1\n", 108 | " \n", 109 | " noise = np.random.normal(size=x.shape) # noise mask\n", 110 | " a = a.reshape((-1, 1, 1, 1))\n", 111 | " b = b.reshape((-1, 1, 1, 1))\n", 112 | " img_a = x * (1 - a) + noise * a\n", 113 | " img_b = x * (1 - b) + noise * b\n", 114 | " return img_a, img_b\n", 115 | " \n", 116 | "def generate_ts(num):\n", 117 | " return np.random.randint(0, timesteps, size=num)\n", 118 | "\n", 119 | "# t = np.full((25,), timesteps - 1) # if you want see clarity\n", 120 | "# t = np.full((25,), 0) # if you want see noisy\n", 121 | "t = generate_ts(25) # random for training data\n", 122 | "a, b = forward_noise(X_train[:25], t)\n", 123 | "show_examples(a)" 124 | ] 125 | }, 126 | { 127 | "cell_type": "code", 128 | "execution_count": null, 129 | "id": "b16f4531-2815-4345-921f-442bddebf150", 130 | "metadata": {}, 131 | "outputs": [], 132 | "source": [ 133 | "def block(x_img, x_ts):\n", 134 | " x_parameter = layers.Conv2D(128, kernel_size=3, padding='same')(x_img)\n", 135 | " x_parameter = layers.Activation('relu')(x_parameter)\n", 136 | "\n", 137 | " time_parameter = layers.Dense(128)(x_ts)\n", 138 | " time_parameter = layers.Activation('relu')(time_parameter)\n", 139 | " time_parameter = layers.Reshape((1, 1, 128))(time_parameter)\n", 140 | " x_parameter = x_parameter * time_parameter\n", 141 | " \n", 142 | " # -----\n", 143 | " x_out = layers.Conv2D(128, kernel_size=3, padding='same')(x_img)\n", 144 | " x_out = x_out + x_parameter\n", 145 | " x_out = layers.LayerNormalization()(x_out)\n", 146 | " x_out = layers.Activation('relu')(x_out)\n", 147 | " \n", 148 | " return x_out" 149 | ] 150 | }, 151 | { 152 | "cell_type": "code", 153 | "execution_count": null, 154 | "id": "d6ad5f86-936a-4229-9902-98dd8366c803", 155 | "metadata": { 156 | "scrolled": true, 157 | "tags": [] 158 | }, 159 | "outputs": [], 160 | "source": [ 161 | "def make_model():\n", 162 | " x = x_input = layers.Input(shape=(IMG_SIZE, IMG_SIZE, 3), name='x_input')\n", 163 | " \n", 164 | " x_ts = x_ts_input = layers.Input(shape=(1,), name='x_ts_input')\n", 165 | " x_ts = layers.Dense(192)(x_ts)\n", 166 | " x_ts = layers.LayerNormalization()(x_ts)\n", 167 | " x_ts = layers.Activation('relu')(x_ts)\n", 168 | " \n", 169 | " # ----- left ( down ) -----\n", 170 | " x = x32 = block(x, x_ts)\n", 171 | " x = layers.MaxPool2D(2)(x)\n", 172 | " \n", 173 | " x = x16 = block(x, x_ts)\n", 174 | " x = layers.MaxPool2D(2)(x)\n", 175 | " \n", 176 | " x = x8 = block(x, x_ts)\n", 177 | " x = layers.MaxPool2D(2)(x)\n", 178 | " \n", 179 | " x = x4 = block(x, x_ts)\n", 180 | " \n", 181 | " # ----- MLP -----\n", 182 | " x = layers.Flatten()(x)\n", 183 | " x = layers.Concatenate()([x, x_ts])\n", 184 | " x = layers.Dense(128)(x)\n", 185 | " x = layers.LayerNormalization()(x)\n", 186 | " x = layers.Activation('relu')(x)\n", 187 | "\n", 188 | " x = layers.Dense(4 * 4 * 32)(x)\n", 189 | " x = layers.LayerNormalization()(x)\n", 190 | " x = layers.Activation('relu')(x)\n", 191 | " x = layers.Reshape((4, 4, 32))(x)\n", 192 | " \n", 193 | " # ----- right ( up ) -----\n", 194 | " x = layers.Concatenate()([x, x4])\n", 195 | " x = block(x, x_ts)\n", 196 | " x = layers.UpSampling2D(2)(x)\n", 197 | " \n", 198 | " x = layers.Concatenate()([x, x8])\n", 199 | " x = block(x, x_ts)\n", 200 | " x = layers.UpSampling2D(2)(x)\n", 201 | " \n", 202 | " x = layers.Concatenate()([x, x16])\n", 203 | " x = block(x, x_ts)\n", 204 | " x = layers.UpSampling2D(2)(x)\n", 205 | " \n", 206 | " x = layers.Concatenate()([x, x32])\n", 207 | " x = block(x, x_ts)\n", 208 | " \n", 209 | " # ----- output -----\n", 210 | " x = layers.Conv2D(3, kernel_size=1, padding='same')(x)\n", 211 | " model = tf.keras.models.Model([x_input, x_ts_input], x)\n", 212 | " return model\n", 213 | " \n", 214 | "\n", 215 | "model = make_model()\n", 216 | "# model.summary()" 217 | ] 218 | }, 219 | { 220 | "cell_type": "code", 221 | "execution_count": null, 222 | "id": "4abb6b35-a520-4fd6-8550-dd84603e3a7d", 223 | "metadata": {}, 224 | "outputs": [], 225 | "source": [ 226 | "optimizer = tf.keras.optimizers.Adam(learning_rate=0.0008)\n", 227 | "loss_func = tf.keras.losses.MeanAbsoluteError()\n", 228 | "model.compile(loss=loss_func, optimizer=optimizer)" 229 | ] 230 | }, 231 | { 232 | "cell_type": "code", 233 | "execution_count": null, 234 | "id": "d0040599-3ba4-4b84-ac36-d88b7456dd81", 235 | "metadata": {}, 236 | "outputs": [], 237 | "source": [ 238 | "def predict(x_idx=None):\n", 239 | " x = np.random.normal(size=(32, IMG_SIZE, IMG_SIZE, 3))\n", 240 | "\n", 241 | " for i in trange(timesteps):\n", 242 | " t = i\n", 243 | " x = model.predict([x, np.full((32), t)], verbose=0)\n", 244 | " show_examples(x)\n", 245 | "\n", 246 | "predict()" 247 | ] 248 | }, 249 | { 250 | "cell_type": "code", 251 | "execution_count": null, 252 | "id": "ee0efa2b-72ab-4efa-8e89-8dac57c238b8", 253 | "metadata": {}, 254 | "outputs": [], 255 | "source": [ 256 | "def predict_step():\n", 257 | " xs = []\n", 258 | " x = np.random.normal(size=(8, IMG_SIZE, IMG_SIZE, 3))\n", 259 | "\n", 260 | " for i in trange(timesteps):\n", 261 | " t = i\n", 262 | " x = model.predict([x, np.full((8), t)], verbose=0)\n", 263 | " if i % 2 == 0:\n", 264 | " xs.append(x[0])\n", 265 | "\n", 266 | " plt.figure(figsize=(20, 2))\n", 267 | " for i in range(len(xs)):\n", 268 | " plt.subplot(1, len(xs), i+1)\n", 269 | " plt.imshow(cvtImg(xs[i]))\n", 270 | " plt.title(f'{i}')\n", 271 | " plt.axis('off')\n", 272 | "\n", 273 | "predict_step()" 274 | ] 275 | }, 276 | { 277 | "cell_type": "code", 278 | "execution_count": null, 279 | "id": "ede500db-da1d-4b88-ab1d-07131e2a1823", 280 | "metadata": {}, 281 | "outputs": [], 282 | "source": [ 283 | "def train_one(x_img):\n", 284 | " x_ts = generate_ts(len(x_img))\n", 285 | " x_a, x_b = forward_noise(x_img, x_ts)\n", 286 | " loss = model.train_on_batch([x_a, x_ts], x_b)\n", 287 | " return loss" 288 | ] 289 | }, 290 | { 291 | "cell_type": "code", 292 | "execution_count": null, 293 | "id": "e9ffea54-b7cb-43c7-8f6b-54ceccfbbe99", 294 | "metadata": {}, 295 | "outputs": [], 296 | "source": [ 297 | "def train(R=50):\n", 298 | " bar = trange(R)\n", 299 | " total = 100\n", 300 | " for i in bar:\n", 301 | " for j in range(total):\n", 302 | " x_img = X_train[np.random.randint(len(X_train), size=BATCH_SIZE)]\n", 303 | " loss = train_one(x_img)\n", 304 | " pg = (j / total) * 100\n", 305 | " if j % 5 == 0:\n", 306 | " bar.set_description(f'loss: {loss:.5f}, p: {pg:.2f}%')" 307 | ] 308 | }, 309 | { 310 | "cell_type": "code", 311 | "execution_count": null, 312 | "id": "0556441d-1498-4f32-bfae-56038503037d", 313 | "metadata": {}, 314 | "outputs": [], 315 | "source": [] 316 | }, 317 | { 318 | "cell_type": "code", 319 | "execution_count": null, 320 | "id": "07b3b394-8312-47fe-a414-b2283ac752ab", 321 | "metadata": {}, 322 | "outputs": [], 323 | "source": [ 324 | "for _ in range(10):\n", 325 | " train()\n", 326 | " # reduce learning rate for next training\n", 327 | " model.optimizer.learning_rate = max(0.000001, model.optimizer.learning_rate * 0.9)\n", 328 | "\n", 329 | " # show result \n", 330 | " predict()\n", 331 | " predict_step()\n", 332 | " plt.show()" 333 | ] 334 | }, 335 | { 336 | "cell_type": "code", 337 | "execution_count": null, 338 | "id": "84e55243-68dc-4066-8ccc-bdf51609900b", 339 | "metadata": {}, 340 | "outputs": [], 341 | "source": [] 342 | }, 343 | { 344 | "cell_type": "code", 345 | "execution_count": null, 346 | "id": "7376635a-846c-46a6-bbf3-f5a522c3f0b3", 347 | "metadata": {}, 348 | "outputs": [], 349 | "source": [] 350 | } 351 | ], 352 | "metadata": { 353 | "kernelspec": { 354 | "display_name": "Python 3 (ipykernel)", 355 | "language": "python", 356 | "name": "python3" 357 | }, 358 | "language_info": { 359 | "codemirror_mode": { 360 | "name": "ipython", 361 | "version": 3 362 | }, 363 | "file_extension": ".py", 364 | "mimetype": "text/x-python", 365 | "name": "python", 366 | "nbconvert_exporter": "python", 367 | "pygments_lexer": "ipython3", 368 | "version": "3.8.11" 369 | } 370 | }, 371 | "nbformat": 4, 372 | "nbformat_minor": 5 373 | } 374 | -------------------------------------------------------------------------------- /Make Language Model from scratch like MNIST/MakeLanguageModelFromScratch.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": 24, 6 | "id": "52f6b107-1532-4fdf-ad0a-833205c09902", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "\"\"\"\n", 11 | "For more detail:\n", 12 | "https://tree.rocks/make-language-model-from-scratch-like-mnist-5ed59aeb538d\n", 13 | "\"\"\"\n", 14 | "# !pip install torch numpy einops tqdm matplotlib scikit-learn\n", 15 | "\n", 16 | "import torch\n", 17 | "import torch.nn as nn\n", 18 | "import numpy as np\n", 19 | "import einops\n", 20 | "import string\n", 21 | "import re\n", 22 | "from tqdm.auto import trange\n", 23 | "from torch.utils.data import Dataset, DataLoader\n", 24 | "from matplotlib import pyplot as plt\n", 25 | "from sklearn.decomposition import PCA" 26 | ] 27 | }, 28 | { 29 | "cell_type": "code", 30 | "execution_count": null, 31 | "id": "5452f67e-9862-431c-9fc2-73a768b6271d", 32 | "metadata": {}, 33 | "outputs": [], 34 | "source": [ 35 | "torch.set_printoptions(sci_mode=False)\n", 36 | "\n", 37 | "if torch.cuda.is_available():\n", 38 | " device = torch.device(\"cuda\")\n", 39 | "elif torch.backends.mps.is_available():\n", 40 | " device = torch.device(\"mps\")\n", 41 | "else:\n", 42 | " device = torch.device(\"cpu\")\n", 43 | "\n", 44 | "print('device:', device)" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "b7d7dea7-8e26-4bec-bec3-6e5aed26f8d3", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "word2id = {}\n", 55 | "id2word = {}\n", 56 | "\n", 57 | "def format_number(num):\n", 58 | " return f\"{num:,}\"" 59 | ] 60 | }, 61 | { 62 | "cell_type": "code", 63 | "execution_count": null, 64 | "id": "8b46996e-7edd-4bf8-9ab6-2585df804bf8", 65 | "metadata": {}, 66 | "outputs": [], 67 | "source": [ 68 | "with open('./data/short_animal_texts.txt', 'r') as f:\n", 69 | " text_data = f.read()\n", 70 | "\n", 71 | "total_characters = len(text_data)\n", 72 | "print('total_characters:', format_number(total_characters))" 73 | ] 74 | }, 75 | { 76 | "cell_type": "code", 77 | "execution_count": null, 78 | "id": "f90bcbdb-afc6-4d35-a9e4-b0f97f142a1a", 79 | "metadata": {}, 80 | "outputs": [], 81 | "source": [ 82 | "def regex_tokenizer(text):\n", 83 | " return re.findall(r'\\w+|[^\\w\\s]|[\\s]+', text, re.UNICODE)\n", 84 | "\n", 85 | "print(regex_tokenizer(\"Hi, It's sunny day!\"))" 86 | ] 87 | }, 88 | { 89 | "cell_type": "code", 90 | "execution_count": null, 91 | "id": "04ac7ea9-7be1-45ab-82da-96af949f97b3", 92 | "metadata": {}, 93 | "outputs": [], 94 | "source": [ 95 | "cleaned_words = regex_tokenizer(text_data)\n", 96 | "unique_words = set(cleaned_words)\n", 97 | "print('unique_words:', len(unique_words))" 98 | ] 99 | }, 100 | { 101 | "cell_type": "code", 102 | "execution_count": null, 103 | "id": "75fe70de-dc05-4281-8469-c0a02adb250a", 104 | "metadata": {}, 105 | "outputs": [], 106 | "source": [ 107 | "sorted_unique_words = sorted(unique_words)\n", 108 | "for i, w in enumerate(sorted_unique_words):\n", 109 | " word2id[w] = i\n", 110 | " id2word[i] = w" 111 | ] 112 | }, 113 | { 114 | "cell_type": "code", 115 | "execution_count": null, 116 | "id": "fb0134dd-c3e9-415e-b2bb-8510e801bb60", 117 | "metadata": {}, 118 | "outputs": [], 119 | "source": [ 120 | "def encode(text):\n", 121 | " tokens = regex_tokenizer(text)\n", 122 | " return [word2id[w] for w in tokens]\n", 123 | "\n", 124 | "def decode(token_ids):\n", 125 | " return ''.join([id2word[i] for i in token_ids])\n", 126 | "\n", 127 | "print(encode(\"Hi, It's sunny day!\"))\n", 128 | "print(decode(encode(\"Hi, It's sunny day!\")))" 129 | ] 130 | }, 131 | { 132 | "cell_type": "code", 133 | "execution_count": null, 134 | "id": "72ca8551-2a0a-4aa7-91e1-26d5dd6b4c39", 135 | "metadata": {}, 136 | "outputs": [], 137 | "source": [ 138 | "CFG = {\n", 139 | " \"num_unique_words\": len(unique_words),\n", 140 | " \"context_length\": 384,\n", 141 | "\n", 142 | " \"emb_dim\": 128,\n", 143 | " \"head_dim\": 384,\n", 144 | "\n", 145 | " \"drop_rate\": 0.15,\n", 146 | "\n", 147 | " \"stride\": 8,\n", 148 | " \"batch_size\": 32,\n", 149 | " \"LR\": 0.0009,\n", 150 | "}" 151 | ] 152 | }, 153 | { 154 | "cell_type": "code", 155 | "execution_count": null, 156 | "id": "ae55384e-6cd4-4cc2-96b2-56015a398b7b", 157 | "metadata": {}, 158 | "outputs": [], 159 | "source": [ 160 | "class TextDataset(Dataset):\n", 161 | " def __init__(self, txt, cfg):\n", 162 | " self.x = []\n", 163 | " self.y = []\n", 164 | "\n", 165 | " token_ids = encode(txt)\n", 166 | " c = cfg['context_length']\n", 167 | " for i in range(0, len(token_ids) - c + 1, c // cfg['stride']):\n", 168 | " self.x.append(torch.tensor(token_ids[i:i + c]))\n", 169 | " self.y.append(torch.tensor(token_ids[i + 1:i + c + 1]))\n", 170 | "\n", 171 | " def __len__(self):\n", 172 | " return len(self.x)\n", 173 | "\n", 174 | " def __getitem__(self, idx):\n", 175 | " return self.x[idx], self.y[idx]\n", 176 | "\n", 177 | "def create_dataloader(text):\n", 178 | " ds = TextDataset(text, CFG)\n", 179 | " loader = DataLoader(\n", 180 | " ds,\n", 181 | " batch_size=CFG['batch_size'],\n", 182 | " shuffle=True,\n", 183 | " drop_last=True,\n", 184 | " )\n", 185 | " return loader\n", 186 | "\n", 187 | "train_loader = create_dataloader(text_data)\n", 188 | "\n", 189 | "x, y = next(iter(train_loader))\n", 190 | "print(x.shape, y.shape)" 191 | ] 192 | }, 193 | { 194 | "cell_type": "code", 195 | "execution_count": null, 196 | "id": "4ecc64e0-4d3b-41df-a912-8c097a817ba0", 197 | "metadata": {}, 198 | "outputs": [], 199 | "source": [ 200 | "x, y = next(iter(train_loader))\n", 201 | "print(decode(x[0].tolist()[:20]))\n", 202 | "print(decode(y[0].tolist()[:20]))" 203 | ] 204 | }, 205 | { 206 | "cell_type": "code", 207 | "execution_count": null, 208 | "id": "09ac2c05-b177-4e83-b181-129e33b7f0d6", 209 | "metadata": {}, 210 | "outputs": [], 211 | "source": [ 212 | "class Model(nn.Module):\n", 213 | " def __init__(self, cfg):\n", 214 | " super().__init__()\n", 215 | "\n", 216 | " self.cfg = cfg\n", 217 | "\n", 218 | " self.embedding = nn.Embedding(cfg['num_unique_words'], cfg['emb_dim'])\n", 219 | " self.pos_emb = nn.Embedding(cfg['context_length'], cfg['emb_dim'])\n", 220 | "\n", 221 | " self.w_q = nn.Linear(cfg['emb_dim'], cfg['head_dim'], bias=False)\n", 222 | " self.w_k = nn.Linear(cfg['emb_dim'], cfg['head_dim'], bias=False)\n", 223 | " self.w_v = nn.Linear(cfg['emb_dim'], cfg['head_dim'], bias=False)\n", 224 | "\n", 225 | " self.dropout_input = nn.Dropout(cfg['drop_rate'])\n", 226 | " self.dropout_attention = nn.Dropout(cfg['drop_rate'])\n", 227 | "\n", 228 | " self.norm = nn.LayerNorm(cfg['head_dim'])\n", 229 | " self.output = nn.Linear(cfg['head_dim'], cfg['emb_dim'], bias=False)\n", 230 | "\n", 231 | " self.register_buffer('mask', torch.triu(torch.ones(cfg['context_length'], cfg['context_length']), diagonal=1 ).bool())\n", 232 | "\n", 233 | " def forward(self, x_input):\n", 234 | " b, n = x_input.shape\n", 235 | " x_emb = self.embedding(x_input)\n", 236 | " x_pos = self.pos_emb(torch.arange(n, device=x_input.device))\n", 237 | " \n", 238 | "\n", 239 | " x = self.dropout_input(x_emb + x_pos)\n", 240 | " head_dim = self.cfg['head_dim']\n", 241 | "\n", 242 | " \n", 243 | " w_q = self.w_q(x)\n", 244 | " w_k = self.w_k(x)\n", 245 | " w_v = self.w_v(x)\n", 246 | "\n", 247 | " attention_score = (w_q @ w_k.transpose(-1, -2)) / (head_dim ** 0.5)\n", 248 | "\n", 249 | " mask = self.mask[:n,:n]\n", 250 | " attention_score = attention_score.masked_fill(mask, -torch.inf)\n", 251 | "\n", 252 | " attention_weight = torch.softmax(attention_score, dim=-1)\n", 253 | " attention_weight = self.dropout_attention(attention_weight)\n", 254 | " \n", 255 | " x = attention_weight @ w_v\n", 256 | " x = self.norm(x)\n", 257 | " x = nn.functional.gelu(x)\n", 258 | "\n", 259 | " x = self.output(x)\n", 260 | " x = x @ self.embedding.weight.T\n", 261 | " return x" 262 | ] 263 | }, 264 | { 265 | "cell_type": "code", 266 | "execution_count": null, 267 | "id": "5d8dbf36-eba2-407e-9d66-4a1507d1d14b", 268 | "metadata": {}, 269 | "outputs": [], 270 | "source": [ 271 | "model = Model(CFG)\n", 272 | "print(model(torch.randint(0, len(unique_words), size=(5, 8))).shape)" 273 | ] 274 | }, 275 | { 276 | "cell_type": "code", 277 | "execution_count": null, 278 | "id": "73dfec83-72f4-4837-8d62-aaef4ddb33c4", 279 | "metadata": {}, 280 | "outputs": [], 281 | "source": [ 282 | "def count_trainable_parameters(model):\n", 283 | " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", 284 | "\n", 285 | "print('Model paramters:', format_number(count_trainable_parameters(model)))" 286 | ] 287 | }, 288 | { 289 | "cell_type": "code", 290 | "execution_count": null, 291 | "id": "1d6ef7c6-0d77-4eda-b341-441c15dbe80a", 292 | "metadata": {}, 293 | "outputs": [], 294 | "source": [ 295 | "optimizer = torch.optim.AdamW(model.parameters(), lr=CFG['LR'], weight_decay=0.1)\n", 296 | "model.to(device)" 297 | ] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "92cd3767-f1a3-44f3-8351-f0f6257e948b", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "def show_embedding(sample_count=7):\n", 307 | " tags = [\n", 308 | " 'cat',\n", 309 | " 'tree',\n", 310 | " 'blue',\n", 311 | " 'Bob',\n", 312 | " 'jump',\n", 313 | " 'friendly',\n", 314 | " ]\n", 315 | "\n", 316 | " tags_i = [word2id[t] for t in tags]\n", 317 | "\n", 318 | " weights = model.embedding.weight.detach().cpu().numpy()\n", 319 | "\n", 320 | " def query(idx):\n", 321 | " sel = weights[idx].reshape(1, weights.shape[1])\n", 322 | " score = (sel @ weights.T).squeeze()\n", 323 | "\n", 324 | " score = [(s, id2word[i], i) for i, s in enumerate(score)]\n", 325 | " score = sorted(score, reverse=True)\n", 326 | " score = score[:sample_count]\n", 327 | "\n", 328 | " result = [f'{n}: {s:.3f}' for s, n, _ in score]\n", 329 | " print(id2word[idx], '->')\n", 330 | " print(', '.join(result))\n", 331 | " print('\\n')\n", 332 | " return [i for _, _, i in score]\n", 333 | "\n", 334 | "\n", 335 | " arr = []\n", 336 | " for i in tags_i:\n", 337 | " arr += query(i)\n", 338 | "\n", 339 | " pca = PCA(n_components=2)\n", 340 | " reduced = pca.fit_transform(weights[arr])\n", 341 | "\n", 342 | " plt.figure(figsize=(8, 8))\n", 343 | " plt.scatter(reduced[:, 0], reduced[:, 1], s=20, alpha=0.7)\n", 344 | " for i in range(len(reduced)):\n", 345 | " label = id2word[arr[i]]\n", 346 | " attr = {\n", 347 | " 'fontsize': 8,\n", 348 | " }\n", 349 | " if arr[i] in tags_i:\n", 350 | " attr['fontsize'] = 10\n", 351 | " attr['fontweight'] = 'bold'\n", 352 | " else:\n", 353 | " attr['alpha'] = 0.6\n", 354 | " plt.text(reduced[i, 0], reduced[i, 1], label, **attr)\n", 355 | " \n", 356 | " plt.grid(True)\n", 357 | " plt.show()\n", 358 | "\n", 359 | "\n", 360 | "show_embedding()" 361 | ] 362 | }, 363 | { 364 | "cell_type": "code", 365 | "execution_count": null, 366 | "id": "c93988a2-2e2c-4a65-a029-b727150ca337", 367 | "metadata": {}, 368 | "outputs": [], 369 | "source": [ 370 | "def predict(text, max_len=50):\n", 371 | " model.eval()\n", 372 | "\n", 373 | " token_ids = encode(text)\n", 374 | " token_ids = torch.tensor(token_ids).to(device)\n", 375 | " token_ids = token_ids.unsqueeze(0)\n", 376 | " \n", 377 | " with torch.no_grad():\n", 378 | " for _ in range(max_len):\n", 379 | " token_ids = token_ids[:, -CFG['context_length']:]\n", 380 | " y = model(token_ids)\n", 381 | " y = y[:, -1, :]\n", 382 | " y_probs = torch.softmax(y, dim=-1)\n", 383 | " y_next = torch.argmax(y_probs, dim=-1, keepdim=True)\n", 384 | " token_ids = torch.cat([token_ids, y_next], dim=-1)\n", 385 | "\n", 386 | " token_ids = token_ids.squeeze().tolist()\n", 387 | " output_text = decode(token_ids)\n", 388 | " print(output_text)\n", 389 | " model.train()\n", 390 | "\n", 391 | "predict('In a sunny day')" 392 | ] 393 | }, 394 | { 395 | "cell_type": "code", 396 | "execution_count": null, 397 | "id": "f5ca8d07-2625-4130-adb8-6557bd523f35", 398 | "metadata": {}, 399 | "outputs": [], 400 | "source": [ 401 | "def calc_loss(x, y_true):\n", 402 | " y = model(x)\n", 403 | " return torch.nn.functional.cross_entropy(y.flatten(0, 1), y_true.flatten())\n", 404 | " \n", 405 | "def evaluate(loader):\n", 406 | " model.eval()\n", 407 | " t = min(len(loader), 30)\n", 408 | " total_loss, total_count = 0.0, 0\n", 409 | " iloader = iter(loader)\n", 410 | " with torch.no_grad():\n", 411 | " for _ in range(t):\n", 412 | " x, y_true = next(iloader)\n", 413 | " loss = calc_loss(x.to(device), y_true.to(device))\n", 414 | " total_loss += loss\n", 415 | "\n", 416 | " model.train()\n", 417 | " return total_loss / t\n", 418 | "\n", 419 | "evaluate(train_loader).item()" 420 | ] 421 | }, 422 | { 423 | "cell_type": "code", 424 | "execution_count": null, 425 | "id": "67db927e-32b3-4ae8-888f-4e36f36c30c5", 426 | "metadata": {}, 427 | "outputs": [], 428 | "source": [] 429 | }, 430 | { 431 | "cell_type": "code", 432 | "execution_count": null, 433 | "id": "b301a5e1-9d3d-4bee-8320-5cc4391b5193", 434 | "metadata": {}, 435 | "outputs": [], 436 | "source": [ 437 | "pred_text = 'In a sunny day'\n", 438 | "\n", 439 | "def train(epochs=50):\n", 440 | " bar = trange(epochs)\n", 441 | " tlen = len(train_loader)\n", 442 | " for i in bar:\n", 443 | " model.train()\n", 444 | " for j, (x, y_true) in enumerate(train_loader):\n", 445 | " optimizer.zero_grad()\n", 446 | " loss = calc_loss(x.to(device), y_true.to(device))\n", 447 | " loss.backward()\n", 448 | " optimizer.step()\n", 449 | "\n", 450 | " bar.set_description(f'Epochs: {i+1}/{epochs}, Batch: {j+1}/{tlen}, loss: {loss.item():.5f}')\n", 451 | "\n", 452 | " val_loss = evaluate(train_loader).item()\n", 453 | " print(f'val loss: {val_loss:.5f}')\n", 454 | "\n", 455 | " if i % 5 == 0:\n", 456 | " print(f'predict {i+1} >>')\n", 457 | " predict(pred_text, max_len=50)\n", 458 | " print('\\n')\n", 459 | "\n", 460 | "train()" 461 | ] 462 | }, 463 | { 464 | "cell_type": "code", 465 | "execution_count": null, 466 | "id": "2f71596f-ef79-4fea-9fad-a47a11175ce0", 467 | "metadata": {}, 468 | "outputs": [], 469 | "source": [ 470 | "show_embedding()" 471 | ] 472 | }, 473 | { 474 | "cell_type": "code", 475 | "execution_count": null, 476 | "id": "daff768c-4086-463e-9ee5-dc1290d60c56", 477 | "metadata": {}, 478 | "outputs": [], 479 | "source": [ 480 | "predict('Once upon a time', max_len=100)" 481 | ] 482 | }, 483 | { 484 | "cell_type": "code", 485 | "execution_count": null, 486 | "id": "fe8d2a83-e02c-4394-a909-a8adafc07b0b", 487 | "metadata": {}, 488 | "outputs": [], 489 | "source": [] 490 | } 491 | ], 492 | "metadata": { 493 | "kernelspec": { 494 | "display_name": "Python 3 (ipykernel)", 495 | "language": "python", 496 | "name": "python3" 497 | }, 498 | "language_info": { 499 | "codemirror_mode": { 500 | "name": "ipython", 501 | "version": 3 502 | }, 503 | "file_extension": ".py", 504 | "mimetype": "text/x-python", 505 | "name": "python", 506 | "nbconvert_exporter": "python", 507 | "pygments_lexer": "ipython3", 508 | "version": "3.10.13" 509 | } 510 | }, 511 | "nbformat": 4, 512 | "nbformat_minor": 5 513 | } 514 | -------------------------------------------------------------------------------- /TransferLearningGAN/TransferLearningCycleGAN.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "code", 5 | "execution_count": null, 6 | "id": "819164d2-1de5-43d7-b782-4f3cb029c6f0", 7 | "metadata": {}, 8 | "outputs": [], 9 | "source": [ 10 | "\"\"\"\n", 11 | "For more detail:\n", 12 | "https://seachaos.com/transfer-learning-with-gan-cyclegan-from-scratch-1afc9ab7c7d1\n", 13 | "\"\"\"" 14 | ] 15 | }, 16 | { 17 | "cell_type": "code", 18 | "execution_count": null, 19 | "id": "0753e7ea-c484-42e7-9454-c5e7889eb813", 20 | "metadata": {}, 21 | "outputs": [], 22 | "source": [ 23 | "import tensorflow as tf\n", 24 | "from tensorflow.keras import layers\n", 25 | "import tensorflow_datasets as tfds\n", 26 | "\n", 27 | "import numpy as np\n", 28 | "import matplotlib.pyplot as plt\n", 29 | "from tqdm.auto import trange, tqdm\n", 30 | "\n", 31 | "import random" 32 | ] 33 | }, 34 | { 35 | "cell_type": "code", 36 | "execution_count": null, 37 | "id": "94ab0077-af4e-4744-837b-84e0106a28f7", 38 | "metadata": {}, 39 | "outputs": [], 40 | "source": [ 41 | "dataset, dataset_info = tfds.load('cycle_gan/horse2zebra', with_info=True, as_supervised=True)\n", 42 | "\n", 43 | "train_a, train_b = dataset['trainA'], dataset['trainB']\n", 44 | "test_a, test_b = dataset['testA'], dataset['testB']" 45 | ] 46 | }, 47 | { 48 | "cell_type": "code", 49 | "execution_count": null, 50 | "id": "cdb1ea8d-e254-4c46-8267-ae9a2a061871", 51 | "metadata": {}, 52 | "outputs": [], 53 | "source": [ 54 | "batch_size = 32 # set to 16 or less, if you don't have enough VRAM.\n", 55 | "\n", 56 | "img_size = 128\n", 57 | "big_img_size = 192\n", 58 | "\n", 59 | "LR = 0.00012" 60 | ] 61 | }, 62 | { 63 | "cell_type": "code", 64 | "execution_count": null, 65 | "id": "7199f5ed-3c1d-41e1-93e9-6d1245b2fbf5", 66 | "metadata": {}, 67 | "outputs": [], 68 | "source": [ 69 | "def _process_img(image, label):\n", 70 | " image = tf.image.resize(image, (big_img_size, big_img_size))\n", 71 | " image = (image / 127.5) - 1.0\n", 72 | " return image, label\n", 73 | "\n", 74 | "def prepare_data(data, b=batch_size):\n", 75 | " return data \\\n", 76 | " .cache() \\\n", 77 | " .map(_process_img, num_parallel_calls=tf.data.AUTOTUNE) \\\n", 78 | " .shuffle(b) \\\n", 79 | " .batch(b)\n", 80 | "\n", 81 | "ds_train_a, ds_train_b = prepare_data(train_a), prepare_data(train_b)\n", 82 | "ds_test_a, ds_test_b = prepare_data(test_a), prepare_data(test_b)\n", 83 | "\n", 84 | "\n", 85 | "x_train_sets = [\n", 86 | " tf.concat([a[0] for a in ds_train_a], axis=0),\n", 87 | " tf.concat([b[0] for b in ds_train_b], axis=0),\n", 88 | "]\n", 89 | "\n", 90 | "x_test_sets = [\n", 91 | " tf.concat([a[0] for a in ds_test_a], axis=0),\n", 92 | " tf.concat([b[0] for b in ds_test_b], axis=0),\n", 93 | "]\n", 94 | "\n", 95 | "print('x_train_all: ', sum([s.shape[0] for s in x_train_sets]), x_train_sets[0].numpy().min(), x_train_sets[0].numpy().max())\n", 96 | "print('x_test_all: ', sum([s.shape[0] for s in x_test_sets]), x_test_sets[0].numpy().min(), x_test_sets[0].numpy().max())\n" 97 | ] 98 | }, 99 | { 100 | "cell_type": "code", 101 | "execution_count": null, 102 | "id": "3510e0e9-881c-4100-9e32-2a3d756334bf", 103 | "metadata": {}, 104 | "outputs": [], 105 | "source": [ 106 | "def _rand_pick(data, augment=True):\n", 107 | " idx = np.random.choice(range(len(data)), size=batch_size, replace=False)\n", 108 | " x = tf.gather(data, idx, axis=0)\n", 109 | " if augment:\n", 110 | " cx = random.uniform(1.0, 1.5)\n", 111 | " cy = random.uniform(1.0, 1.5)\n", 112 | " x = tf.image.random_crop(x, size=(batch_size, int(img_size * cx), int(img_size * cy), 3))\n", 113 | " x = tf.image.random_flip_left_right(x)\n", 114 | " x = tf.image.resize(x, (img_size, img_size))\n", 115 | " return x\n", 116 | "\n", 117 | "def get_x_train():\n", 118 | " xa = _rand_pick(x_train_sets[0])\n", 119 | " xb = _rand_pick(x_train_sets[1])\n", 120 | " return xa, xb\n", 121 | "\n", 122 | "def get_x_test():\n", 123 | " xa = _rand_pick(x_test_sets[0], augment=False)\n", 124 | " xb = _rand_pick(x_test_sets[1], augment=False)\n", 125 | " return xa, xb" 126 | ] 127 | }, 128 | { 129 | "cell_type": "code", 130 | "execution_count": null, 131 | "id": "523f1340-a64c-4b29-b38a-1adebd6d8034", 132 | "metadata": {}, 133 | "outputs": [], 134 | "source": [ 135 | "# Verify \"get_x_train\" output\n", 136 | "def cvtImg(x):\n", 137 | " return (x + 1.0) / 2.0\n", 138 | "\n", 139 | "def show(x, S=12):\n", 140 | " x = cvtImg(x)\n", 141 | " plt.figure(figsize=(15, 3))\n", 142 | " for i in range(min(len(x), S)):\n", 143 | " plt.subplot(1, S, i + 1)\n", 144 | " plt.imshow(x[i])\n", 145 | " plt.axis('off')\n", 146 | " plt.show()\n", 147 | "\n", 148 | "for _ in range(1):\n", 149 | " xa, xb = get_x_train()\n", 150 | " xa = xa.numpy()\n", 151 | " print(xa.min(), xa.max(), xa.shape)\n", 152 | " show(xa)\n", 153 | " show(xb.numpy())" 154 | ] 155 | }, 156 | { 157 | "cell_type": "code", 158 | "execution_count": null, 159 | "id": "afeefba2-eeb5-4664-b0d0-da7b71236951", 160 | "metadata": {}, 161 | "outputs": [], 162 | "source": [] 163 | }, 164 | { 165 | "cell_type": "code", 166 | "execution_count": null, 167 | "id": "759efadc-7519-4686-bc37-faddc22e373d", 168 | "metadata": {}, 169 | "outputs": [], 170 | "source": [ 171 | "base_model = tf.keras.applications.VGG16(input_shape=(img_size, img_size, 3), include_top=False)\n", 172 | "\n", 173 | "x = x_input = base_model.input\n", 174 | "\n", 175 | "outputs = [\n", 176 | " 'block2_conv2',\n", 177 | " 'block3_conv3',\n", 178 | " 'block4_conv3',\n", 179 | " 'block5_conv1',\n", 180 | " 'block5_pool',\n", 181 | "]\n", 182 | "\n", 183 | "x_output = [base_model.get_layer(n).output for n in outputs]\n", 184 | "base_model = tf.keras.models.Model(x_input, x_output)\n", 185 | "\n", 186 | "base_model.trainable = False\n", 187 | "\n", 188 | "# base_model.summary() # if you want see more detail about VGG16" 189 | ] 190 | }, 191 | { 192 | "cell_type": "code", 193 | "execution_count": null, 194 | "id": "fa0f165c-8035-4c91-9a8a-1cf211974687", 195 | "metadata": {}, 196 | "outputs": [], 197 | "source": [ 198 | "act_name = 'gelu'\n", 199 | "\n", 200 | "def act(x):\n", 201 | " x = layers.LayerNormalization()(x)\n", 202 | " x = layers.Activation(act_name)(x)\n", 203 | " return x" 204 | ] 205 | }, 206 | { 207 | "cell_type": "code", 208 | "execution_count": null, 209 | "id": "431e3e76-21ca-4367-bd64-331f5f795d8b", 210 | "metadata": {}, 211 | "outputs": [], 212 | "source": [ 213 | "def conv_with_cmd(x_img_input, x_cmd, f=64, sp=4):\n", 214 | " x = layers.Dense(128)(x_cmd)\n", 215 | " x = layers.BatchNormalization()(x)\n", 216 | " x = layers.Activation(act_name)(x)\n", 217 | " \n", 218 | " x = layers.Dense(f)(x)\n", 219 | " x = layers.BatchNormalization()(x)\n", 220 | " x = layers.Activation('sigmoid')(x)\n", 221 | "\n", 222 | " x_g = layers.Reshape((1, 1, f))(x)\n", 223 | "\n", 224 | " # ---\n", 225 | "\n", 226 | " x = layers.Conv2D(f, kernel_size=3, padding='same')(x_img_input)\n", 227 | " x = layers.BatchNormalization()(x)\n", 228 | " x = layers.Activation(act_name)(x)\n", 229 | " x = x * x_g\n", 230 | "\n", 231 | "\n", 232 | " return x" 233 | ] 234 | }, 235 | { 236 | "cell_type": "code", 237 | "execution_count": null, 238 | "id": "1265c2df-f766-4fa6-a13b-ef5a5405da73", 239 | "metadata": {}, 240 | "outputs": [], 241 | "source": [ 242 | "def create_gen_model():\n", 243 | " # img input\n", 244 | " x_input = layers.Input(shape=(img_size, img_size, 3))\n", 245 | "\n", 246 | " # load base model\n", 247 | " x_base_out = base_model(x_input)\n", 248 | " [x64, x32, x16, x8, x4] = x_base_out\n", 249 | "\n", 250 | "\n", 251 | " # x_cmd\n", 252 | " x = x4\n", 253 | " x = layers.Conv2D(256, kernel_size=3, padding='same')(x)\n", 254 | " x = act(x)\n", 255 | "\n", 256 | " x = layers.GlobalMaxPool2D()(x)\n", 257 | "\n", 258 | "\n", 259 | " x = layers.Dense(128)(x)\n", 260 | " x = layers.BatchNormalization()(x)\n", 261 | " x = layers.Activation(act_name)(x)\n", 262 | " x_cmd = x\n", 263 | "\n", 264 | " \n", 265 | " # GAN up\n", 266 | " x = conv_with_cmd(x4, x_cmd, f=512)\n", 267 | "\n", 268 | " # if you don't have enought VRAM, try reduce filters\n", 269 | " for i, (x_cat, f) in enumerate([\n", 270 | " (x8, 512),\n", 271 | " (x16, 384),\n", 272 | " (x32, 256),\n", 273 | " (x64, 256),\n", 274 | " (x_input, 256),\n", 275 | " ]):\n", 276 | " x = layers.UpSampling2D(2)(x)\n", 277 | " x = layers.Concatenate()([x, x_cat])\n", 278 | " x = conv_with_cmd(x, x_cmd, f=f)\n", 279 | " \n", 280 | " # final output\n", 281 | " x = layers.Conv2D(3, kernel_size=3, padding='same')(x)\n", 282 | " x = layers.BatchNormalization()(x)\n", 283 | " x = layers.Activation('tanh')(x)\n", 284 | "\n", 285 | " return tf.keras.models.Model(x_input, x)\n", 286 | "\n", 287 | "gen = create_gen_model()\n", 288 | "# gen.summary() # if you want see more detail about model" 289 | ] 290 | }, 291 | { 292 | "cell_type": "code", 293 | "execution_count": null, 294 | "id": "a77e77e0-025e-4e82-8007-1d374e039c2c", 295 | "metadata": {}, 296 | "outputs": [], 297 | "source": [] 298 | }, 299 | { 300 | "cell_type": "code", 301 | "execution_count": null, 302 | "id": "b50751b1-97d4-4dd6-8a22-b3ddb04b02ef", 303 | "metadata": {}, 304 | "outputs": [], 305 | "source": [ 306 | "def create_dis_model():\n", 307 | " x = x_input = layers.Input(shape=(img_size, img_size, 3))\n", 308 | "\n", 309 | " [x64, x32, x16, x8, x4] = base_model(x_input)\n", 310 | "\n", 311 | " x = x8\n", 312 | " x = layers.Conv2D(512, kernel_size=3, padding='same')(x)\n", 313 | " x = act(x)\n", 314 | " x = layers.MaxPool2D()(x)\n", 315 | " \n", 316 | " x = layers.Concatenate()([x, x4])\n", 317 | " x = layers.Conv2D(512, kernel_size=3, padding='same')(x)\n", 318 | " x = act(x)\n", 319 | " \n", 320 | " x = layers.GlobalMaxPool2D()(x)\n", 321 | "\n", 322 | " x = layers.Dense(384)(x)\n", 323 | " x = layers.BatchNormalization()(x)\n", 324 | " x = layers.Activation(act_name)(x)\n", 325 | " \n", 326 | " x = layers.Dense(128)(x)\n", 327 | " x = layers.BatchNormalization()(x)\n", 328 | " x = layers.Activation(act_name)(x)\n", 329 | " \n", 330 | " x = layers.Dense(4)(x)\n", 331 | " x = layers.BatchNormalization()(x)\n", 332 | " x = layers.Activation('softmax')(x)\n", 333 | " \n", 334 | " return tf.keras.models.Model(x_input, x)\n", 335 | "\n", 336 | "dis = create_dis_model()\n", 337 | "# dis.summary() # if you want see more detail about model" 338 | ] 339 | }, 340 | { 341 | "cell_type": "code", 342 | "execution_count": null, 343 | "id": "70df3a9e-c304-49e7-aab9-b9ce9556a559", 344 | "metadata": {}, 345 | "outputs": [], 346 | "source": [ 347 | "y_false_a = np.zeros(batch_size)\n", 348 | "y_false_b = np.full_like(y_false_a, 1)\n", 349 | "y_true_a = np.full_like(y_false_a, 2)\n", 350 | "y_true_b = np.full_like(y_false_a, 3)" 351 | ] 352 | }, 353 | { 354 | "cell_type": "code", 355 | "execution_count": null, 356 | "id": "c835790e-3e10-47a9-9b33-94b5e72d3332", 357 | "metadata": {}, 358 | "outputs": [], 359 | "source": [ 360 | "opt_gen = tf.keras.optimizers.AdamW(learning_rate=LR)\n", 361 | "opt_dis = tf.keras.optimizers.AdamW(learning_rate=LR)" 362 | ] 363 | }, 364 | { 365 | "cell_type": "code", 366 | "execution_count": null, 367 | "id": "1228f961-db1d-47d6-b4f6-d7c74da58cdb", 368 | "metadata": {}, 369 | "outputs": [], 370 | "source": [] 371 | }, 372 | { 373 | "cell_type": "code", 374 | "execution_count": null, 375 | "id": "c9af8d3a-0ba7-4e2e-a1eb-9089972fdb2a", 376 | "metadata": {}, 377 | "outputs": [], 378 | "source": [ 379 | "@tf.function\n", 380 | "def _train_dis(x, y_t):\n", 381 | " with tf.GradientTape(persistent=True) as tape:\n", 382 | " y_p = dis(x)\n", 383 | " loss = tf.losses.sparse_categorical_crossentropy(y_t, y_p)\n", 384 | " loss = tf.reduce_mean(loss)\n", 385 | "\n", 386 | " g = tape.gradient(loss, dis.trainable_variables)\n", 387 | " g = zip(g, dis.trainable_variables)\n", 388 | " opt_dis.apply_gradients(g)\n", 389 | " \n", 390 | " return float(loss)\n", 391 | "\n", 392 | "def train_dis():\n", 393 | " dis.trainable = True\n", 394 | " gen.trainable = False\n", 395 | " base_model.trainable = False\n", 396 | "\n", 397 | " xa, xb = get_x_train()\n", 398 | "\n", 399 | " # train dis A\n", 400 | " xa_fake = gen.predict(xb, verbose=False)\n", 401 | " loss_a = \\\n", 402 | " _train_dis(xa, y_true_a) + \\\n", 403 | " _train_dis(xa_fake, y_false_a)\n", 404 | "\n", 405 | " # train dis B\n", 406 | " xb_fake = gen.predict(xa, verbose=False)\n", 407 | " loss_b = \\\n", 408 | " _train_dis(xb, y_true_b) + \\\n", 409 | " _train_dis(xb_fake, y_false_b)\n", 410 | " \n", 411 | " return float(loss_a), float(loss_b)\n", 412 | "\n", 413 | "train_dis()" 414 | ] 415 | }, 416 | { 417 | "cell_type": "code", 418 | "execution_count": null, 419 | "id": "cd45ade5-a3f5-4023-b7b0-cb655a8bcaf8", 420 | "metadata": {}, 421 | "outputs": [], 422 | "source": [] 423 | }, 424 | { 425 | "cell_type": "code", 426 | "execution_count": null, 427 | "id": "b4febb4c-0207-4019-8655-2912f1657243", 428 | "metadata": {}, 429 | "outputs": [], 430 | "source": [ 431 | "@tf.function\n", 432 | "def _train_gen_cycle(x_real, y_t, y_f):\n", 433 | " with tf.GradientTape(persistent=True) as tape:\n", 434 | " x_fake = gen(x_real) # forward\n", 435 | " \n", 436 | " # discriminator\n", 437 | " y_p = dis(x_fake)\n", 438 | " loss_dis = tf.losses.sparse_categorical_crossentropy(y_t, y_p)\n", 439 | "\n", 440 | " # revert\n", 441 | " x_revert = gen(x_fake)\n", 442 | " loss_revert = tf.losses.mse(x_real, x_revert)\n", 443 | "\n", 444 | " loss = tf.reduce_mean(loss_dis) + tf.reduce_mean(loss_revert)\n", 445 | "\n", 446 | "\n", 447 | " g = tape.gradient(loss, gen.trainable_variables)\n", 448 | " g = zip(g, gen.trainable_variables)\n", 449 | " opt_gen.apply_gradients(g)\n", 450 | "\n", 451 | " return float(loss)\n", 452 | "\n", 453 | "def train_gen():\n", 454 | " gen.trainable = True\n", 455 | " dis.trainable = False\n", 456 | " base_model.trainable = False\n", 457 | "\n", 458 | " xa, xb = get_x_train()\n", 459 | "\n", 460 | " loss_a = \\\n", 461 | " _train_gen_cycle(xa, y_true_b, y_true_a)\n", 462 | " \n", 463 | " loss_b = \\\n", 464 | " _train_gen_cycle(xb, y_true_a, y_true_b)\n", 465 | "\n", 466 | " return float(loss_a), float(loss_b)\n", 467 | "\n", 468 | "train_gen()" 469 | ] 470 | }, 471 | { 472 | "cell_type": "code", 473 | "execution_count": null, 474 | "id": "be130f5c-e6e0-409f-9843-8c9ba4428056", 475 | "metadata": {}, 476 | "outputs": [], 477 | "source": [] 478 | }, 479 | { 480 | "cell_type": "code", 481 | "execution_count": null, 482 | "id": "e8a5cb7f-9a1b-4af1-90f8-04ffa75be7a6", 483 | "metadata": {}, 484 | "outputs": [], 485 | "source": [ 486 | "def _preview(x_real, title=None):\n", 487 | " x_fake = gen.predict(x_real, verbose=0)\n", 488 | " x_real = cvtImg(x_real.numpy())\n", 489 | " x_fake = cvtImg(x_fake)\n", 490 | "\n", 491 | "\n", 492 | " plt.figure(figsize=(25, 5))\n", 493 | " if title:\n", 494 | " plt.suptitle(title)\n", 495 | " s = min(batch_size, 9)\n", 496 | " for i in range(s):\n", 497 | " plt.subplot(2, s, i + 1)\n", 498 | " plt.axis('off')\n", 499 | " plt.imshow(x_real[i])\n", 500 | " plt.subplot(2, s, i + 1 + s)\n", 501 | " plt.axis('off')\n", 502 | " plt.imshow(x_fake[i])\n", 503 | " plt.show()\n", 504 | "\n", 505 | "def preview(useTest=True):\n", 506 | " if useTest:\n", 507 | " xa, xb = get_x_test()\n", 508 | " else:\n", 509 | " xa, xb = get_x_train()\n", 510 | " _preview(xa[:9], 'A -> B')\n", 511 | " _preview(xb[:9], 'B -> A')\n", 512 | "\n", 513 | "preview()" 514 | ] 515 | }, 516 | { 517 | "cell_type": "code", 518 | "execution_count": null, 519 | "id": "11d7a0bb-9d42-4342-86e0-6a12d2f4e734", 520 | "metadata": {}, 521 | "outputs": [], 522 | "source": [] 523 | }, 524 | { 525 | "cell_type": "code", 526 | "execution_count": null, 527 | "id": "30476330-5921-47f0-8a80-0a604c42c8d2", 528 | "metadata": {}, 529 | "outputs": [], 530 | "source": [ 531 | "def train():\n", 532 | " bar = trange(200)\n", 533 | " for _ in bar:\n", 534 | " lda, ldb = train_dis()\n", 535 | " lga, lgb = train_gen()\n", 536 | " msg = f'gen: {lga:.5f}, {lgb:.5f} | dis: {lda:.5f}, {ldb:.5f}'\n", 537 | " bar.set_description(msg)\n", 538 | "\n", 539 | "def go():\n", 540 | " for i in trange(50):\n", 541 | " train()\n", 542 | " if i % 5 == 0:\n", 543 | " preview()\n", 544 | " \n", 545 | " opt_dis.learning_rate = opt_dis.learning_rate * 0.98\n", 546 | " opt_gen.learning_rate = opt_gen.learning_rate * 0.98\n", 547 | " lg = opt_gen.learning_rate.numpy()\n", 548 | " ld = opt_dis.learning_rate.numpy()\n", 549 | " print(f'run: {i}')\n", 550 | " print(f'LR gen: {lg:.7f}')\n", 551 | " print(f'LR dis: {ld:.7f}')\n", 552 | "\n", 553 | "\n", 554 | "go()\n", 555 | "preview()" 556 | ] 557 | }, 558 | { 559 | "cell_type": "code", 560 | "execution_count": null, 561 | "id": "d4c1c063-4cde-4241-bc28-3133bc821029", 562 | "metadata": {}, 563 | "outputs": [], 564 | "source": [] 565 | }, 566 | { 567 | "cell_type": "code", 568 | "execution_count": null, 569 | "id": "36eadd51-6ea7-4a4b-9128-3084d903f852", 570 | "metadata": {}, 571 | "outputs": [], 572 | "source": [] 573 | }, 574 | { 575 | "cell_type": "code", 576 | "execution_count": null, 577 | "id": "96e8406e-3255-4884-8231-d20b16d09d84", 578 | "metadata": {}, 579 | "outputs": [], 580 | "source": [] 581 | } 582 | ], 583 | "metadata": { 584 | "kernelspec": { 585 | "display_name": "Python 3", 586 | "language": "python", 587 | "name": "python3" 588 | }, 589 | "language_info": { 590 | "codemirror_mode": { 591 | "name": "ipython", 592 | "version": 3 593 | }, 594 | "file_extension": ".py", 595 | "mimetype": "text/x-python", 596 | "name": "python", 597 | "nbconvert_exporter": "python", 598 | "pygments_lexer": "ipython3", 599 | "version": "3.10.10" 600 | } 601 | }, 602 | "nbformat": 4, 603 | "nbformat_minor": 5 604 | } 605 | --------------------------------------------------------------------------------