├── .gitignore ├── LICENSE ├── MAML ├── __pycache__ │ ├── model.cpython-36.pyc │ ├── sine_task_generator.cpython-36.pyc │ └── test.cpython-36.pyc ├── maml.py ├── model.py ├── reptile.py ├── sine_task_generator.py ├── test.py └── training │ ├── checkpoint │ ├── maml.ckpt.data-00000-of-00001 │ ├── maml.ckpt.index │ ├── normal_maml.ckpt.data-00000-of-00001 │ ├── normal_maml.ckpt.index │ ├── normal_reptile.ckpt.data-00000-of-00001 │ ├── normal_reptile.ckpt.index │ ├── reptile.ckpt.data-00000-of-00001 │ └── reptile.ckpt.index ├── Metric-Based ├── prototypical-net.py └── siamese.py └── t.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | data/ 3 | MAML/training 4 | Metric-Based/pnet -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2017 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /MAML/__pycache__/model.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/__pycache__/model.cpython-36.pyc -------------------------------------------------------------------------------- /MAML/__pycache__/sine_task_generator.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/__pycache__/sine_task_generator.cpython-36.pyc -------------------------------------------------------------------------------- /MAML/__pycache__/test.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/__pycache__/test.cpython-36.pyc -------------------------------------------------------------------------------- /MAML/maml.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow import keras 3 | import numpy as np 4 | from MAML.sine_task_generator import SineData, sample_shot 5 | from MAML.model import Model 6 | from MAML.test import eval 7 | import os 8 | 9 | 10 | EPOCH_IN = 8 11 | EPOCH_OUT = 10000 12 | TEST_EPOCH = 32 13 | N_CURVE = 1000 14 | N_POINT = 50 15 | N_TASK = 4 16 | K_SHOT = 10 17 | LR_INNER = 0.01 18 | LR_META = 0.01 19 | 20 | loss_func = keras.losses.MeanSquaredError() 21 | data = SineData(N_CURVE, N_POINT) 22 | test_tasks = data.get_test_tasks() 23 | meta_path = "training/maml.ckpt" 24 | normal_path = "training/normal_maml.ckpt" 25 | 26 | 27 | def train(): 28 | model = Model() 29 | model.build((None, 1)) 30 | meta_w = model.get_weights() 31 | model.compile(optimizer=keras.optimizers.SGD(LR_INNER), loss=loss_func) 32 | 33 | for ep in range(EPOCH_OUT): 34 | tasks = data.sample_tasks(N_TASK) 35 | weights = np.copy(meta_w) 36 | 37 | for task in tasks: 38 | # inner update 39 | model.set_weights(weights) 40 | k_shot = sample_shot(task, K_SHOT) 41 | model.fit(k_shot.x, k_shot.y, batch_size=10, epochs=EPOCH_IN, verbose=0) 42 | 43 | # accumulate meta gradients 44 | k_shot = sample_shot(task, K_SHOT) 45 | with tf.GradientTape() as tape: 46 | y_ = model(k_shot.x) 47 | loss = loss_func(k_shot.y, y_) 48 | grads = tape.gradient(loss, model.trainable_variables) 49 | for g, w in zip(grads, meta_w): 50 | w -= LR_META * g.numpy() / N_TASK 51 | 52 | # update meta weights 53 | model.set_weights(meta_w) 54 | 55 | # test loss 56 | if ep % 100 == 0: 57 | losses = [] 58 | for task in test_tasks: 59 | y_ = model(task.x) 60 | losses.append(loss_func(task.y, y_)) 61 | print("ep={} | test loss={:.4f}".format(ep, np.mean(losses))) 62 | 63 | os.makedirs("training", exist_ok=True) 64 | model.save_weights(meta_path) 65 | 66 | 67 | def normal_train(): 68 | model = Model() 69 | model.compile(optimizer=keras.optimizers.SGD(LR_INNER), loss=loss_func) 70 | for ep in range(EPOCH_OUT): 71 | tasks = data.sample_tasks(N_TASK) 72 | for task in tasks: 73 | k_shot = sample_shot(task, K_SHOT) 74 | model.fit(k_shot.x, k_shot.y, batch_size=10, verbose=0) 75 | model.save_weights(normal_path) 76 | 77 | 78 | # train() 79 | # normal_train() 80 | eval( 81 | meta_path=meta_path, 82 | normal_path=normal_path, 83 | k_shot=K_SHOT, lr_meta=LR_META, test_tasks=test_tasks, 84 | test_epoch=TEST_EPOCH, loss_func=loss_func 85 | ) -------------------------------------------------------------------------------- /MAML/model.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | 3 | 4 | class Model(keras.Model): 5 | def __init__(self): 6 | super(Model, self).__init__() 7 | self.l1 = keras.layers.Dense(64, activation=keras.activations.tanh, input_shape=(1, )) 8 | self.l2 = keras.layers.Dense(64, activation=keras.activations.tanh) 9 | self.out = keras.layers.Dense(1) 10 | 11 | def call(self, x, training=None, mask=None): 12 | x = self.l1(x) 13 | x = self.l2(x) 14 | y = self.out(x) 15 | return y 16 | -------------------------------------------------------------------------------- /MAML/reptile.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | import numpy as np 3 | from MAML.sine_task_generator import SineData, sample_shot 4 | from MAML.model import Model 5 | from MAML.test import eval 6 | import os 7 | 8 | EPOCH_IN = 8 # not too small 9 | EPOCH_OUT = 10000 10 | TEST_EPOCH = 32 11 | N_CURVE = 1000 12 | N_POINT = 50 13 | K_SHOT = 10 14 | LR_INNER = 0.01 15 | LR_META = 0.01 16 | 17 | loss_func = keras.losses.MeanSquaredError() 18 | data = SineData(N_CURVE, N_POINT) 19 | test_tasks = data.get_test_tasks() 20 | meta_path = "training/reptile.ckpt" 21 | normal_path = "training/normal_reptile.ckpt" 22 | 23 | 24 | def train(): 25 | model = Model() 26 | model.build((None, 1)) 27 | model.compile( 28 | optimizer=keras.optimizers.SGD(LR_INNER), 29 | loss=loss_func, 30 | ) 31 | 32 | meta_w = model.get_weights() 33 | for ep in range(EPOCH_OUT): 34 | task = data.sample_tasks(1)[0] 35 | k_shot = sample_shot(task, K_SHOT) 36 | # inner update 37 | model.fit(k_shot.x, k_shot.y, batch_size=10, epochs=EPOCH_IN, verbose=0) 38 | 39 | # meta update 40 | for tw, mw in zip(model.get_weights(), meta_w): 41 | mw += LR_META * (tw - mw) 42 | 43 | model.set_weights(meta_w) 44 | 45 | # test loss 46 | if ep % 100 == 0: 47 | losses = [] 48 | for task in test_tasks: 49 | y_ = model(task.x) 50 | losses.append(loss_func(task.y, y_)) 51 | print("ep={} | test loss={:.4f}".format(ep, np.mean(losses))) 52 | 53 | os.makedirs("training", exist_ok=True) 54 | model.save_weights(meta_path) 55 | 56 | 57 | def normal_train(): 58 | model = Model() 59 | model.compile(optimizer=keras.optimizers.SGD(LR_INNER), loss=loss_func) 60 | for ep in range(EPOCH_OUT): 61 | task = data.sample_tasks(1)[0] 62 | k_shot = sample_shot(task, K_SHOT) 63 | model.fit(k_shot.x, k_shot.y, batch_size=10, verbose=0) 64 | model.save_weights(normal_path) 65 | 66 | 67 | train() 68 | normal_train() 69 | eval( 70 | meta_path=meta_path, 71 | normal_path=normal_path, 72 | k_shot=K_SHOT, lr_meta=LR_META, test_tasks=test_tasks, 73 | test_epoch=TEST_EPOCH, loss_func=loss_func 74 | ) -------------------------------------------------------------------------------- /MAML/sine_task_generator.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import matplotlib.pyplot as plt 3 | 4 | 5 | class TaskData: 6 | def __init__(self, x, y): 7 | self.x = x 8 | self.y = y 9 | 10 | 11 | class SineData: 12 | def __init__(self, n_curves, n_point, test_ratio=0.2): 13 | x = np.random.rand(n_curves, n_point) * 5 * 2 - 5 14 | amplitude, phase = np.random.rand(n_curves, 1)*5+0.1, np.random.rand(n_curves, 1) * np.pi * 2 15 | y = amplitude * np.sin(x + phase) 16 | 17 | p = int(x.shape[0]*test_ratio) 18 | self.test_x, self.train_x = x[:p], x[p:] 19 | self.test_y, self.train_y = y[:p], y[p:] 20 | 21 | def sample_tasks(self, n): 22 | assert n <= self.train_x.shape[0] 23 | tasks = [] 24 | for i in np.random.randint(0, self.train_x.shape[0], n): 25 | tasks.append(TaskData(self.train_x[i][:, None], self.train_y[i][:, None])) 26 | return tasks 27 | 28 | def get_test_tasks(self): 29 | tasks = [] 30 | for i in range(len(self.test_x)): 31 | tasks.append(TaskData(self.test_x[i][:, None], self.test_y[i][:, None])) 32 | return tasks 33 | 34 | 35 | def sample_shot(task, k): 36 | assert isinstance(task, TaskData) 37 | indices = np.random.randint(0, task.x.shape[0], k) 38 | return TaskData(task.x[indices], task.y[indices]) 39 | 40 | 41 | if __name__ == "__main__": 42 | gen = SineData(20, 100) 43 | tasks = gen.sample_tasks(3) 44 | for t in tasks: 45 | index = np.argsort(t.x.ravel()) 46 | x_, y_ = t.x.ravel()[index], t.y.ravel()[index] 47 | plt.plot(x_, y_) 48 | plt.scatter(x_, y_) 49 | plt.show() 50 | 51 | -------------------------------------------------------------------------------- /MAML/test.py: -------------------------------------------------------------------------------- 1 | from MAML.model import Model 2 | from tensorflow import keras 3 | import numpy as np 4 | import matplotlib.pyplot as plt 5 | 6 | 7 | def eval(meta_path, normal_path, k_shot, lr_meta, test_tasks, test_epoch, loss_func): 8 | model = Model() 9 | model.build((None, 1)) 10 | normal = Model() 11 | normal.build((None, 1)) 12 | 13 | # new task 14 | for i in range(4): 15 | model.load_weights(meta_path) 16 | normal.load_weights(normal_path) 17 | new_task = test_tasks[i] 18 | index = np.argsort(new_task.x.ravel()) 19 | initial_y = model(new_task.x).numpy().ravel()[index] 20 | normal_initial_y = normal(new_task.x).numpy().ravel()[index] 21 | 22 | new_task_x, new_task_y = new_task.x[:k_shot], new_task.y[:k_shot] 23 | # train new task 24 | model.compile( 25 | optimizer=keras.optimizers.SGD(lr_meta), 26 | loss=loss_func, 27 | ) 28 | 29 | model.fit(new_task_x, new_task_y, epochs=test_epoch, verbose=0) 30 | 31 | # train from finetune 32 | normal.compile(optimizer=keras.optimizers.SGD(lr_meta), loss=loss_func) 33 | normal.fit(new_task_x, new_task_y, epochs=test_epoch, verbose=0) 34 | 35 | # train from scratch 36 | model_scratch = Model() 37 | model_scratch.compile( 38 | optimizer=keras.optimizers.SGD(lr_meta), 39 | loss=loss_func, 40 | ) 41 | model_scratch.fit(new_task_x, new_task_y, epochs=test_epoch, verbose=0) 42 | 43 | x_, y_ = new_task.x.ravel()[index], model(new_task.x).numpy().ravel()[index] 44 | y_scratch = model_scratch(new_task.x).numpy().ravel()[index] 45 | y_normal = normal(new_task.x).numpy().ravel()[index] 46 | plt.subplot(2, 2, i + 1) 47 | plt.plot(x_, new_task.y.ravel()[index], label="target", c="k", alpha=0.3) 48 | plt.scatter(new_task_x.ravel(), new_task_y.ravel(), c="k", s=20, alpha=0.4) 49 | plt.plot(x_, y_, label="meta {} epoch".format(test_epoch), c="r", alpha=0.5) 50 | plt.plot(x_, initial_y, label="meta start point", ls="--", c="r", alpha=0.5) 51 | plt.plot(x_, y_scratch, label="train from scratch", c="y", alpha=0.5) 52 | plt.plot(x_, y_normal, label="normal {} epoch".format(test_epoch), c="b", alpha=0.5) 53 | plt.plot(x_, normal_initial_y, label="normal start point", c="b", ls="--", alpha=0.5) 54 | plt.ylim(-5.5, 5.5) 55 | plt.legend(prop={'size': 6}) 56 | plt.show() 57 | -------------------------------------------------------------------------------- /MAML/training/checkpoint: -------------------------------------------------------------------------------- 1 | model_checkpoint_path: "normal_maml.ckpt" 2 | all_model_checkpoint_paths: "normal_maml.ckpt" 3 | -------------------------------------------------------------------------------- /MAML/training/maml.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/maml.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /MAML/training/maml.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/maml.ckpt.index -------------------------------------------------------------------------------- /MAML/training/normal_maml.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/normal_maml.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /MAML/training/normal_maml.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/normal_maml.ckpt.index -------------------------------------------------------------------------------- /MAML/training/normal_reptile.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/normal_reptile.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /MAML/training/normal_reptile.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/normal_reptile.ckpt.index -------------------------------------------------------------------------------- /MAML/training/reptile.ckpt.data-00000-of-00001: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/reptile.ckpt.data-00000-of-00001 -------------------------------------------------------------------------------- /MAML/training/reptile.ckpt.index: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MorvanZhou/Meta-Learning/c716cdea8676c9e18a90d4373503bbc1071a2629/MAML/training/reptile.ckpt.index -------------------------------------------------------------------------------- /Metric-Based/prototypical-net.py: -------------------------------------------------------------------------------- 1 | # omniglot data: https://github.com/brendenlake/omniglot/tree/master/python 2 | 3 | from PIL import Image 4 | import numpy as np 5 | import tensorflow as tf 6 | from tensorflow import keras 7 | import os 8 | import matplotlib.pyplot as plt 9 | 10 | tf.random.set_seed(2) 11 | np.random.seed(1) 12 | 13 | EPOCH = 20 14 | STEP = 1000 15 | N_WAY = 60 16 | N_SUPPORT = 5 17 | N_QUERY = 5 18 | N_EXAMPLE = 20 19 | IMG_WIDTH, IMG_HEIGHT, IMG_CHANNEL = 28, 28, 1 20 | 21 | DATA_DIR = "../data/omniglot/images_background" 22 | 23 | 24 | def get_train_data(): 25 | # load train dataset 26 | data = [] 27 | for root, dirs, files in os.walk(DATA_DIR, topdown=True): 28 | class_data = np.zeros([1, N_EXAMPLE, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL], dtype=np.float32) 29 | for i, file in enumerate(files): 30 | if not file.endswith(".png"): 31 | continue 32 | if i >= N_EXAMPLE: 33 | break 34 | img_path = os.path.join(root, file) 35 | img = 1. - np.array( 36 | Image.open(img_path).resize((IMG_HEIGHT, IMG_WIDTH)), 37 | np.float32, copy=False) 38 | class_data[0, i, :, :, 0] = img 39 | data.append(class_data) 40 | if len(data) == 100: 41 | break 42 | return np.concatenate(data, axis=0) 43 | 44 | 45 | class ConvBlock(keras.Model): 46 | def __init__(self, filters, kernel_size): 47 | super().__init__() 48 | self.c = keras.layers.Conv2D(filters, kernel_size, padding="same") 49 | self.bn = keras.layers.BatchNormalization() 50 | self.p = keras.layers.MaxPool2D(2) 51 | 52 | def call(self, _x, training=None, mask=None): 53 | _x = self.c(_x) 54 | _x = self.bn(_x, training=training) 55 | _x = keras.activations.relu(_x) 56 | return self.p(_x) 57 | 58 | 59 | c1 = ConvBlock(64, 3) 60 | c2 = ConvBlock(64, 3) 61 | c3 = ConvBlock(64, 3) 62 | o = ConvBlock(64, 3) 63 | 64 | 65 | class ProtoNet(keras.Model): 66 | def __init__(self): 67 | super().__init__() 68 | self.c1 = c1 69 | self.c2 = c2 70 | self.c3 = c3 71 | self.o = o 72 | self.flat = keras.layers.Flatten() 73 | 74 | def call(self, _x, training=None, mask=None): 75 | _x = self.c1(_x, training=training) 76 | _x = self.c2(_x, training=training) 77 | _x = self.c3(_x, training=training) 78 | _x = self.o(_x, training=training) 79 | return self.flat(_x) 80 | 81 | 82 | def euclidean_distance(qy, sy): 83 | # this calculation is inspired by: 84 | # https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/5e18a5e5b369903092f683d434efb12c7c40a83c/src/prototypical_loss.py 85 | # and https://github.com/abdulfatir/prototypical-networks-tensorflow/blob/master/ProtoNet-Omniglot.ipynb 86 | 87 | # qy [n, d] 88 | # sy [m, d] 89 | n, d = tf.shape(qy)[0], tf.shape(qy)[1] 90 | m = tf.shape(sy)[0] 91 | qy = tf.tile(tf.expand_dims(qy, axis=1), (1, m, 1)) # -> [n, m, d] 92 | sy = tf.tile(tf.expand_dims(sy, axis=0), (n, 1, 1)) # -> [n, m, d] 93 | return tf.reduce_mean(tf.square(qy - sy), axis=2) # -> [n, m] 94 | 95 | 96 | def train(): 97 | @tf.function 98 | def train_step(sx, qx): 99 | with tf.GradientTape() as tape: 100 | # [way * support_shot, height, width, 1] support set, provides prototype 101 | support_y = pnet(sx) # -> [way * support_shot, d] 102 | 103 | # [way * query_shot, height, width, 1] query set 104 | query_y = pnet_reuse(qx) # -> [way * query_shot, d] 105 | 106 | # find c from support set -> [way, d] 107 | support_c = tf.reduce_mean(tf.reshape(support_y, [N_WAY, N_SUPPORT, -1]), axis=1) 108 | _loss, _acc = loss_func(query_y, support_c) 109 | 110 | grads = tape.gradient(_loss, pnet.trainable_variables) 111 | opt.apply_gradients(zip(grads, pnet.trainable_variables)) 112 | return _loss, _acc 113 | 114 | def loss_func(qy, sy): 115 | dists = euclidean_distance(qy, sy) # -> [way * query_shot, way] 116 | log_p_y = tf.reshape( 117 | tf.nn.log_softmax(-dists), 118 | [N_WAY, N_QUERY, -1] 119 | ) # -> [way, query_shot, way] 120 | cross_entropy = -tf.reduce_mean( 121 | tf.reshape( 122 | tf.reduce_sum(tf.multiply(y_one_hot, log_p_y), axis=-1), 123 | [-1] 124 | ) 125 | ) 126 | _acc = tf.reduce_mean(tf.dtypes.cast((tf.equal(tf.argmax(log_p_y, axis=-1), labels)), tf.float32)) 127 | return cross_entropy, _acc 128 | 129 | train_data = get_train_data() # [class, n_example, img_height, img_width, img_channel] 130 | 131 | pnet = ProtoNet() 132 | pnet_reuse = ProtoNet() 133 | opt = keras.optimizers.Adam(lr=0.001) 134 | 135 | fixed_range = np.arange(N_EXAMPLE) 136 | labels = np.tile(np.arange(N_WAY)[:, None], (1, N_QUERY)).astype(np.uint8) 137 | y_one_hot = tf.one_hot(labels, depth=N_WAY) 138 | 139 | for ep in range(EPOCH): 140 | for step in range(STEP): 141 | class_idx = np.random.randint(len(train_data), size=N_WAY) 142 | perm = np.random.permutation(fixed_range) 143 | support_idx = perm[:N_SUPPORT] 144 | query_idx = perm[N_SUPPORT: N_SUPPORT + N_QUERY] 145 | 146 | train_class = train_data[class_idx] 147 | support_x = train_class[:, support_idx] # [way, support_shot, height, width, 1] 148 | query_x = train_class[:, query_idx] # [way, query_shot, height, width, 1] 149 | support_x_reshape = support_x.reshape([-1, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL]) 150 | query_x_reshape = query_x.reshape([-1, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL]) 151 | 152 | loss, acc = train_step(support_x_reshape, query_x_reshape) 153 | 154 | if step % 10 == 0: 155 | print("ep {} | step {} | loss {:.2f} | acc {:.2f}".format(ep, step, loss.numpy(), acc.numpy())) 156 | 157 | os.makedirs("./pnet", exist_ok=True) 158 | pnet.save_weights("./pnet/model.ckpt") 159 | 160 | 161 | def eval_compare(src, tgts): 162 | src_data = np.zeros([1, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL], dtype=np.float32) 163 | tgt_data = np.zeros([3, IMG_HEIGHT, IMG_WIDTH, IMG_CHANNEL], dtype=np.float32) 164 | for i, file in enumerate(tgts): 165 | img = 1. - np.array( 166 | Image.open(file).resize((IMG_HEIGHT, IMG_WIDTH)), 167 | np.float32, copy=False) 168 | tgt_data[i, :, :, 0] = img 169 | src_data[0, :, :, 0] = 1. - np.array( 170 | Image.open(src).resize((IMG_HEIGHT, IMG_WIDTH)), 171 | np.float32, copy=False) 172 | 173 | latest_model = tf.train.latest_checkpoint("./pnet") 174 | model = ProtoNet() 175 | model.load_weights(latest_model) 176 | model.compile(optimizer=keras.optimizers.SGD(), loss=keras.losses.MeanSquaredError()) 177 | src_y = model.predict(src_data) 178 | tgt_y = model.predict(tgt_data) 179 | idx = np.mean(np.power(src_y - tgt_y, 2), axis=1).argsort() 180 | 181 | tgt_data = tgt_data[idx] 182 | tgts = [tgts[i] for i in idx] 183 | plt.subplot(221) 184 | plt.imshow(src_data[0, :, :, 0]) 185 | plt.title(os.path.basename(src)) 186 | plt.subplot(222) 187 | plt.imshow(tgt_data[0, :, :, 0]) 188 | plt.title(os.path.basename(tgts[0])) 189 | plt.subplot(223) 190 | plt.imshow(tgt_data[1, :, :, 0]) 191 | plt.title(os.path.basename(tgts[1])) 192 | plt.subplot(224) 193 | plt.imshow(tgt_data[2, :, :, 0]) 194 | plt.title(os.path.basename(tgts[2])) 195 | plt.show() 196 | 197 | 198 | # train() 199 | eval_compare( 200 | "../data/omniglot/images_evaluation/Atemayar_Qelisayer/character07/0991_06.png", 201 | [ 202 | "../data/omniglot/images_evaluation/Atemayar_Qelisayer/character01/0985_04.png", 203 | "../data/omniglot/images_evaluation/Atemayar_Qelisayer/character03/0987_10.png", 204 | "../data/omniglot/images_evaluation/Atemayar_Qelisayer/character07/0991_03.png", 205 | ] 206 | ) -------------------------------------------------------------------------------- /Metric-Based/siamese.py: -------------------------------------------------------------------------------- 1 | # omniglot data: https://github.com/brendenlake/omniglot/tree/master/python 2 | 3 | import os 4 | from PIL import Image 5 | import numpy as np 6 | 7 | DATA_DIR = "../data/omniglot/images_background" 8 | im_height, im_width = 28, 28 9 | n_examples = 20 10 | 11 | def get_train_data(): 12 | # load train dataset 13 | root_old = "" 14 | train_data = [] 15 | for root, dirs, files in os.walk(DATA_DIR, topdown=True): 16 | class_data = np.empty([1, n_examples, im_height, im_width], dtype=np.float32) 17 | i = -1 18 | for i, file in enumerate(files): 19 | if i >= n_examples: 20 | break 21 | img_path = os.path.join(root, file) 22 | img = 1. - np.array( 23 | Image.open(img_path).resize((im_width, im_height)), 24 | np.float32, copy=False) 25 | class_data[0, i] = img 26 | train_data.append(class_data) 27 | return np.concatenate(train_data, axis=0) 28 | 29 | get_train_data() -------------------------------------------------------------------------------- /t.py: -------------------------------------------------------------------------------- 1 | from tensorflow import keras 2 | import tensorflow as tf 3 | import numpy as np 4 | 5 | data_x = np.random.normal(size=[1000, 1]) 6 | noise = np.random.normal(size=[1000, 1]) * 0.2 7 | data_y = data_x * 3. + 2. + noise 8 | 9 | train_x, train_y = data_x[:900], data_y[:900] 10 | test_x, test_y = data_x[900:], data_y[900:] 11 | 12 | l1 = keras.layers.Dense(10, activation=keras.activations.relu, name="l1") 13 | l2 = keras.layers.Dense(1, name="l2") 14 | 15 | 16 | class Model(keras.Model): 17 | def __init__(self): 18 | super(Model, self).__init__(name="m") 19 | self.l1 = l1 20 | self.l2 = l2 21 | 22 | def call(self, x, training=None, mask=None): 23 | x = self.l1(x) 24 | x = self.l2(x) 25 | return x 26 | 27 | 28 | model = Model() 29 | model2 = Model() 30 | 31 | model.build((None, 1)) 32 | model2.build((None, 1)) 33 | 34 | model.compile( 35 | optimizer=keras.optimizers.SGD(0.01), 36 | loss=keras.losses.MeanSquaredError(), 37 | metrics=[keras.metrics.MeanSquaredError()], 38 | ) 39 | model.fit(train_x, train_y, batch_size=32, epochs=3, validation_split=0.2, shuffle=True) 40 | print(np.all(model.get_weights()[0] == model2.get_weights()[0])) 41 | 42 | --------------------------------------------------------------------------------