├── README.md ├── data ├── DTD_Texture │ └── proc_images.py ├── cub_bird │ └── proc_images.py ├── fgvc_aircraft │ └── proc_images.py ├── fgvcx_fungi │ └── proc_images.py ├── miniImagenet │ └── proc_images.py └── omniglot_resized │ └── resize_images.py ├── data_generator.py ├── image_embedding.py ├── lstm_tree.py ├── main.py ├── maml.py ├── multidataset_bash ├── HSML_multidataset_1shot.sh └── HSML_multidataset_5shot.sh ├── special_grads.py ├── task_embedding.py ├── toygroup_bash ├── HSML_toygroup_10shot.sh └── HSML_toygroup_5shot.sh └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # HSML (Hierarchically Structured Meta-learning) 2 | 3 | ## About 4 | Source code1 of the paper [Hierarchically Structured Meta-learning](https://arxiv.org/abs/1905.05301) 5 | 6 | For continual version of this algorithm, please refer to this [repo](https://github.com/huaxiuyao/HSML_Dynamic). 7 | 8 | If you find this repository useful in your research, please cite the following paper: 9 | ``` 10 | @inproceedings{yao2019hierarchically, 11 | title={Hierarchically Structured Meta-learning}, 12 | author={Yao, Huaxiu and Wei, Ying and Huang, Junzhou and Li, Zhenhui}, 13 | booktitle={Proceedings of the 36th International Conference on Machine Learning}, 14 | year={2019} 15 | } 16 | ``` 17 | 18 | ## Data 19 | We release our Multi-Datasets including bird, texture, aircraft and fungi in this [link](https://drive.google.com/file/d/1IJk93N48X0rSL69nQ1Wr-49o8u0e75HM/view?usp=sharing). 20 | 21 | ## Usage 22 | 23 | ### Dependence 24 | * python 3.* 25 | * TensorFlow 1.0+ 26 | * Numpy 1.15+ 27 | 28 | ### Toy Group Data 29 | Please see the bash file in /toygroup_bash for parameter settings 30 | 31 | ### Multi-datasets Data 32 | Please see the bash file in /multidataset_bash for parameter settings 33 | 34 | 35 | 1This code is built based on the [MAML](https://github.com/cbfinn/maml). 36 | 37 | -------------------------------------------------------------------------------- /data/DTD_Texture/proc_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage instructions: 3 | First download the omniglot dataset 4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder) 5 | 6 | Then, run the following: 7 | cd data/ 8 | cp -r omniglot/* omniglot_resized/ 9 | cd omniglot_resized/ 10 | python resize_images.py 11 | """ 12 | from PIL import Image 13 | import glob 14 | import os 15 | import numpy as np 16 | import random 17 | import shutil 18 | 19 | np.random.seed(0) 20 | random.seed(1) 21 | 22 | def Process(): 23 | image_path = '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/dtd/images/*/' 24 | 25 | all_images = glob.glob(image_path + '*') 26 | 27 | i = 0 28 | 29 | for image_file in all_images: 30 | im = Image.open(image_file) 31 | im = im.resize((84,84), resample=Image.LANCZOS) 32 | im.save(image_file) 33 | i += 1 34 | 35 | if i % 200 == 0: 36 | print(i) 37 | 38 | def select_image(): 39 | path = '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/images/' 40 | dirlist = os.listdir(path) 41 | num_images = [] 42 | for eachdir in dirlist: 43 | num_images.append([eachdir, len(os.listdir(path + eachdir))]) 44 | all_folder_id = random.sample(range(len(num_images)), 47) 45 | all_folder = [num_images[id] for id in all_folder_id] 46 | random.shuffle(all_folder) 47 | for i in range(30): 48 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/train/') 49 | for i in range(30, 37): 50 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/val/') 51 | for i in range(37, 47): 52 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/test/') 53 | # num_images = sorted(num_images, key=lambda x: x[1], reverse=True) 54 | 55 | if __name__=='__main__': 56 | select_image() -------------------------------------------------------------------------------- /data/cub_bird/proc_images.py: -------------------------------------------------------------------------------- 1 | """The Caltech-UCSD bird dataset 2 | """ 3 | 4 | import numpy as np 5 | import os 6 | from scipy import misc 7 | from skimage import io 8 | import ipdb 9 | import shutil 10 | import random 11 | 12 | np.random.seed(0) 13 | random.seed(1) 14 | 15 | class CUBDataLayer(): 16 | """ The Caltech-UCSD bird dataset 17 | """ 18 | def __init__(self, **kwargs): 19 | """Load the dataset. 20 | kwargs: 21 | root: the root folder of the CUB_200_2011 dataset. 22 | is_training: if true, load the training data. Otherwise, load the 23 | testing data. 24 | crop: if None, does not crop the bounding box. If a real value, 25 | crop is the ratio of the bounding box that gets cropped. 26 | e.g., if crop = 1.5, the resulting image will be 1.5 * the 27 | bounding box area. 28 | target_size: all images are resized to the size specified. Should 29 | be a tuple of two integers, like [256, 256]. 30 | version: either '2011' or '2010'. 31 | Note that we will use the python indexing (labels start from 0). 32 | """ 33 | root = '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/CUB_200_2011/' 34 | 35 | crop = True 36 | target_size = [84,84] 37 | images = [line.split()[1] for line in 38 | open(os.path.join(root, 'images.txt'), 'r')] 39 | boxes = [line.split()[1:] for line in 40 | open(os.path.join(root, 'bounding_boxes.txt'),'r')] 41 | 42 | 43 | # for the boxes, we store them as a numpy array 44 | boxes = np.array(boxes, dtype=np.float32) 45 | boxes -= 1 46 | # load the data 47 | self._load_data(root, images, boxes, crop, target_size) 48 | 49 | def _load_data(self, root, images, boxes, crop, target_size): 50 | num_imgs = len(images) 51 | 52 | for i in range(num_imgs): 53 | image = io.imread(os.path.join(root, 'images', images[i])) 54 | if image.ndim == 2: 55 | image = np.tile(image[:,:,np.newaxis], (1, 1, 3)) 56 | if image.shape[2] == 4: 57 | image = image[:, :, :3] 58 | if crop: 59 | image = self._crop_image(image, crop, boxes[i]) 60 | data_img = misc.imresize(image, target_size) 61 | misc.imsave(os.path.join(root, 'images', images[i]), data_img) 62 | 63 | if i%500==0: 64 | print(i) 65 | 66 | return 67 | 68 | def _crop_image(self, image, crop, box): 69 | imheight, imwidth = image.shape[:2] 70 | x, y, width, height = box 71 | centerx = x + width / 2. 72 | centery = y + height / 2. 73 | xoffset = width * crop / 2. 74 | yoffset = height * crop / 2. 75 | xmin = max(int(centerx - xoffset + 0.5), 0) 76 | ymin = max(int(centery - yoffset + 0.5), 0) 77 | xmax = min(int(centerx + xoffset + 0.5), imwidth - 1) 78 | ymax = min(int(centery + yoffset + 0.5), imheight - 1) 79 | if xmax - xmin <= 0 or ymax - ymin <= 0: 80 | raise ValueError("The cropped bounding box has size 0.") 81 | return image[ymin:ymax, xmin:xmax] 82 | 83 | def select_image(): 84 | path='/home/huaxiuyao/Data/meta-dataset/CUB_Bird/images/' 85 | dirlist=os.listdir(path) 86 | num_images=[] 87 | for eachdir in dirlist: 88 | tmp=os.listdir(path+eachdir) 89 | for each in tmp: 90 | if each[0]=='.': 91 | print(eachdir, each) 92 | if len(os.listdir(path+eachdir))==60: 93 | num_images.append([eachdir, len(os.listdir(path+eachdir))]) 94 | all_folder_id=random.sample(range(len(num_images)), 100) 95 | all_folder=[num_images[id] for id in all_folder_id] 96 | random.shuffle(all_folder) 97 | for i in range(64): 98 | shutil.move(path+all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/train/') 99 | for i in range(64,80): 100 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/val/') 101 | for i in range(80,100): 102 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/test/') 103 | # num_images=sorted(num_images, key=lambda x:x[1], reverse=True) 104 | 105 | if __name__=='__main__': 106 | select_image() -------------------------------------------------------------------------------- /data/fgvc_aircraft/proc_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage instructions: 3 | First download the omniglot dataset 4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder) 5 | 6 | Then, run the following: 7 | cd data/ 8 | cp -r omniglot/* omniglot_resized/ 9 | cd omniglot_resized/ 10 | python resize_images.py 11 | """ 12 | from PIL import Image 13 | import glob 14 | import os 15 | import numpy as np 16 | import scipy.io as scio 17 | import os 18 | from scipy import misc 19 | from skimage import io 20 | import ipdb 21 | import shutil 22 | import random 23 | 24 | np.random.seed(1) 25 | random.seed(2) 26 | 27 | class FGVC_Aircraft(): 28 | """ The Caltech-UCSD bird dataset 29 | """ 30 | def __init__(self, **kwargs): 31 | """Load the dataset. 32 | kwargs: 33 | root: the root folder of the CUB_200_2011 dataset. 34 | is_training: if true, load the training data. Otherwise, load the 35 | testing data. 36 | crop: if None, does not crop the bounding box. If a real value, 37 | crop is the ratio of the bounding box that gets cropped. 38 | e.g., if crop = 1.5, the resulting image will be 1.5 * the 39 | bounding box area. 40 | target_size: all images are resized to the size specified. Should 41 | be a tuple of two integers, like [256, 256]. 42 | version: either '2011' or '2010'. 43 | Note that we will use the python indexing (labels start from 0). 44 | """ 45 | root = '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/' 46 | 47 | crop = True 48 | target_size = [84,84] 49 | images = [imageid.split('.')[0] for imageid in os.listdir(root+'images')] 50 | boxes = {line.split()[0]:line.split()[1:] for line in 51 | open(os.path.join(root, 'images_box.txt'),'r')} 52 | 53 | 54 | # for the boxes, we store them as a numpy array 55 | for eachkey in boxes: 56 | boxes[eachkey] = np.array(boxes[eachkey], dtype=np.float32) - 1 57 | # load the data 58 | self._load_data(root, images, boxes, crop, target_size) 59 | 60 | def _load_data(self, root, images, boxes, crop, target_size): 61 | num_imgs = len(images) 62 | 63 | for i in range(num_imgs): 64 | image = io.imread(os.path.join(root, 'images', '{}.jpg'.format(images[i]))) 65 | if image.ndim == 2: 66 | image = np.tile(image[:,:,np.newaxis], (1, 1, 3)) 67 | if image.shape[2] == 4: 68 | image = image[:, :, :3] 69 | if crop: 70 | image = self._crop_image(image, crop, boxes[images[i]]) 71 | data_img = misc.imresize(image, target_size) 72 | misc.imsave(os.path.join(root, 'images', '{}.jpg'.format(images[i])), data_img) 73 | 74 | if i%500==0: 75 | print(i) 76 | 77 | return 78 | 79 | def _crop_image(self, image, crop, box): 80 | imheight, imwidth = image.shape[:2] 81 | x, y, width, height = box 82 | centerx = x + width / 2. 83 | centery = y + height / 2. 84 | xoffset = width * crop / 2. 85 | yoffset = height * crop / 2. 86 | xmin = max(int(centerx - xoffset + 0.5), 0) 87 | ymin = max(int(centery - yoffset + 0.5), 0) 88 | xmax = min(int(centerx + xoffset + 0.5), imwidth - 1) 89 | ymax = min(int(centery + yoffset + 0.5), imheight - 1) 90 | if xmax - xmin <= 0 or ymax - ymin <= 0: 91 | raise ValueError("The cropped bounding box has size 0.") 92 | return image[ymin:ymax, xmin:xmax] 93 | 94 | def reorganize(): 95 | root='/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/' 96 | label=[line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_train.txt'),'r')] 97 | label.extend([line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_trainval.txt'),'r')]) 98 | label.extend([line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_val.txt'), 'r')]) 99 | label.extend([line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_test.txt'), 'r')]) 100 | labelall={} 101 | for eachitem in label: 102 | if eachitem[0] in labelall: 103 | continue 104 | labelall[eachitem[0]]='-'.join(eachitem[1:]) 105 | newpath = '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/organized_images/' 106 | for eachfile in os.listdir('/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/images/'): 107 | tmp_id = eachfile.split('.')[0] 108 | folder_id = labelall[tmp_id] 109 | print(folder_id) 110 | if folder_id == 'F-16A/B': 111 | folder_id='F-16A-B' 112 | if folder_id == 'F/A-18': 113 | folder_id='F-A-18' 114 | if not os.path.isdir(newpath + '{}'.format(folder_id)): 115 | os.mkdir(newpath + '{}'.format(folder_id)) 116 | 117 | image_file = '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/images/' + eachfile 118 | im = Image.open(image_file) 119 | im.save(newpath + '{}'.format(folder_id) + '/' + eachfile) 120 | 121 | def select_image(): 122 | path='/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/images/' 123 | dirlist=os.listdir(path) 124 | num_images=[] 125 | for eachdir in dirlist: 126 | num_images.append([eachdir, len(os.listdir(path+eachdir))]) 127 | all_folder_id=random.sample(range(len(num_images)), 100) 128 | all_folder=[num_images[id] for id in all_folder_id] 129 | random.shuffle(all_folder) 130 | for i in range(64): 131 | shutil.move(path+all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/train/') 132 | for i in range(64,80): 133 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/val/') 134 | for i in range(80,100): 135 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/test/') 136 | # num_images=sorted(num_images, key=lambda x:x[1], reverse=True) 137 | 138 | 139 | if __name__=='__main__': 140 | select_image() -------------------------------------------------------------------------------- /data/fgvcx_fungi/proc_images.py: -------------------------------------------------------------------------------- 1 | from PIL import Image 2 | import glob 3 | import os 4 | import shutil 5 | import random 6 | import numpy as np 7 | import ipdb 8 | 9 | np.random.seed(1) 10 | random.seed(2) 11 | 12 | image_path = '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/images/*/' 13 | 14 | 15 | def process(): 16 | all_images = glob.glob(image_path + '*') 17 | 18 | i = 0 19 | 20 | for image_file in all_images: 21 | im = Image.open(image_file) 22 | im = im.resize((84, 84), resample=Image.LANCZOS) 23 | im.save(image_file) 24 | i += 1 25 | 26 | if i % 200 == 0: 27 | print(i) 28 | 29 | 30 | def select_folder(): 31 | path = '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/images/' 32 | dirlist = os.listdir(path) 33 | num_images = [] 34 | for eachdir in dirlist: 35 | if len(os.listdir(path + eachdir)) >= 150: 36 | num_images.append([eachdir, len(os.listdir(path + eachdir))]) 37 | all_folder_id = random.sample(range(len(num_images)), 100) 38 | all_folder = [num_images[id] for id in all_folder_id] 39 | random.shuffle(all_folder) 40 | for i in range(64): 41 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/train/') 42 | for i in range(64, 80): 43 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/val/') 44 | for i in range(80, 100): 45 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/test/') 46 | # num_images = sorted(num_images, key=lambda x: x[1], reverse=True) 47 | # print(len(num_images)) 48 | 49 | 50 | def select_image(): 51 | folder = ['train', 'test', 'val'] 52 | for eachfolder in folder: 53 | all_files = os.listdir('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}/'.format(eachfolder)) 54 | for eachtype in all_files: 55 | images = os.listdir('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}/{}/'.format(eachfolder, eachtype)) 56 | random.shuffle(images) 57 | images_id = random.sample(range(len(images)), 150) 58 | new_images = [images[idx] for idx in images_id] 59 | os.mkdir('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}_new/{}/'.format(eachfolder, eachtype)) 60 | for idx_y in range(len(new_images)): 61 | shutil.move('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}/{}/{}'.format(eachfolder, eachtype, 62 | new_images[idx_y]), 63 | '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}_new/{}/'.format(eachfolder, eachtype)) 64 | # path = '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/images/' 65 | # dirlist = os.listdir(path) 66 | # num_images = [] 67 | # for eachdir in dirlist: 68 | # if len(os.listdir(path + eachdir)) >= 150: 69 | # num_images.append([eachdir, len(os.listdir(path + eachdir))]) 70 | # all_folder_id = random.sample(range(len(num_images)), 100) 71 | # all_folder = [num_images[id] for id in all_folder_id] 72 | # random.shuffle(all_folder) 73 | # for i in range(64): 74 | # shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/train/') 75 | # for i in range(64, 80): 76 | # shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/val/') 77 | # for i in range(80, 100): 78 | # shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/test/') 79 | # num_images = sorted(num_images, key=lambda x: x[1], reverse=True) 80 | # print(len(num_images)) 81 | 82 | 83 | if __name__ == '__main__': 84 | select_folder() 85 | -------------------------------------------------------------------------------- /data/miniImagenet/proc_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code) 3 | 4 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the 5 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'. 6 | Then run this script from the miniImagenet directory: 7 | cd data/miniImagenet/ 8 | python proc_images.py 9 | """ 10 | 11 | from __future__ import print_function 12 | import csv 13 | import glob 14 | import os 15 | 16 | from PIL import Image 17 | 18 | path_to_images = '/home/huaxiuyao/Data/miniimagenet/images/' 19 | 20 | all_images = glob.glob(path_to_images + '*') 21 | 22 | # Resize images 23 | for i, image_file in enumerate(all_images): 24 | im = Image.open(image_file) 25 | im = im.resize((84, 84), resample=Image.LANCZOS) 26 | im.save(image_file) 27 | if i % 500 == 0: 28 | print(i) 29 | 30 | # Put in correct directory 31 | for datatype in ['train', 'val', 'test']: 32 | os.system('mkdir ' +'/home/huaxiuyao/Data/miniimagenet/'+datatype) 33 | 34 | with open('/home/huaxiuyao/Data/miniimagenet/'+datatype + '.csv', 'r') as f: 35 | reader = csv.reader(f, delimiter=',') 36 | last_label = '' 37 | for i, row in enumerate(reader): 38 | if i == 0: # skip the headers 39 | continue 40 | label = row[1] 41 | image_name = row[0] 42 | if label != last_label: 43 | cur_dir = '/home/huaxiuyao/Data/miniimagenet/' + datatype + '/' + label + '/' 44 | os.system('mkdir ' + cur_dir) 45 | last_label = label 46 | os.system('mv /home/huaxiuyao/Data/miniimagenet/images/' + image_name + ' ' + cur_dir) 47 | -------------------------------------------------------------------------------- /data/omniglot_resized/resize_images.py: -------------------------------------------------------------------------------- 1 | """ 2 | Usage instructions: 3 | First download the omniglot dataset 4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder) 5 | 6 | Then, run the following: 7 | cd data/ 8 | cp -r omniglot/* omniglot_resized/ 9 | cd omniglot_resized/ 10 | python resize_images.py 11 | """ 12 | from PIL import Image 13 | import glob 14 | 15 | image_path = '*/*/' 16 | 17 | all_images = glob.glob(image_path + '*') 18 | 19 | i = 0 20 | 21 | for image_file in all_images: 22 | im = Image.open(image_file) 23 | im = im.resize((28,28), resample=Image.LANCZOS) 24 | im.save(image_file) 25 | i += 1 26 | 27 | if i % 200 == 0: 28 | print(i) 29 | 30 | -------------------------------------------------------------------------------- /data_generator.py: -------------------------------------------------------------------------------- 1 | """ Code for loading data. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | import ipdb 7 | 8 | from tensorflow.python.platform import flags 9 | from utils import get_images 10 | import matplotlib.pyplot as plt 11 | 12 | FLAGS = flags.FLAGS 13 | 14 | 15 | class DataGenerator(object): 16 | def __init__(self, num_samples_per_class, batch_size, config={}): 17 | """ 18 | Args: 19 | num_samples_per_class: num samples to generate per class in one batch 20 | batch_size: size of meta batch size (e.g. number of functions) 21 | """ 22 | self.batch_size = batch_size 23 | self.num_samples_per_class = num_samples_per_class 24 | self.num_classes = 1 # by default 1 (only relevant for classification problems) 25 | 26 | if FLAGS.datasource == 'sinusoid': 27 | self.generate = self.generate_sinusoid_batch 28 | # self.amp_range = config.get('amp_range', [0.1, 5.0]) 29 | # self.phase_range = config.get('phase_range', [0, np.pi]) 30 | self.amp_range = config.get('amp_range', [0.1, 5.0]) 31 | self.freq_range = config.get('freq_range', [0.8, 1.2]) 32 | self.phase_range = config.get('phase_range', [0, np.pi]) 33 | self.input_range = config.get('input_range', [-5.0, 5.0]) 34 | self.dim_input = 1 35 | self.dim_output = 1 36 | 37 | elif FLAGS.datasource == 'mixture': 38 | self.generate = self.generate_mixture_batch 39 | self.dim_input = 1 40 | self.dim_output = 1 41 | self.input_range = config.get('input_range', [-5.0, 5.0]) 42 | 43 | elif 'omniglot' in FLAGS.datasource: 44 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 45 | self.img_size = config.get('img_size', (28, 28)) 46 | self.dim_input = np.prod(self.img_size) 47 | self.dim_output = self.num_classes 48 | # data that is pre-resized using PIL with lanczos filter 49 | data_folder = config.get('data_folder', '{}/omniglot_resized'.format(FLAGS.datadir)) 50 | 51 | character_folders = [os.path.join(data_folder, family, character) \ 52 | for family in os.listdir(data_folder) \ 53 | if os.path.isdir(os.path.join(data_folder, family)) \ 54 | for character in os.listdir(os.path.join(data_folder, family))] 55 | random.seed(1) 56 | random.shuffle(character_folders) 57 | if FLAGS.no_val: 58 | num_val = 0 59 | else: 60 | num_val = 100 61 | num_train = config.get('num_train', 1200) - num_val 62 | self.metatrain_character_folders = character_folders[:num_train] 63 | if FLAGS.test_set: 64 | self.metaval_character_folders = character_folders[num_train + num_val:] 65 | else: 66 | self.metaval_character_folders = character_folders[num_train:num_train + num_val] 67 | self.rotations = config.get('rotations', [0, 90, 180, 270]) 68 | elif FLAGS.datasource == 'miniimagenet': 69 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 70 | self.img_size = config.get('img_size', (84, 84)) 71 | self.dim_input = np.prod(self.img_size) * 3 72 | self.dim_output = self.num_classes 73 | metatrain_folder = config.get('metatrain_folder', '{}/miniImagenet/train'.format(FLAGS.datadir)) 74 | if FLAGS.test_set: 75 | metaval_folder = config.get('metaval_folder', '{}/miniImagenet/test'.format(FLAGS.datadir)) 76 | else: 77 | metaval_folder = config.get('metaval_folder', '{}/miniImagenet/val'.format(FLAGS.datadir)) 78 | 79 | metatrain_folders = [os.path.join(metatrain_folder, label) \ 80 | for label in os.listdir(metatrain_folder) \ 81 | if os.path.isdir(os.path.join(metatrain_folder, label)) \ 82 | ] 83 | metaval_folders = [os.path.join(metaval_folder, label) \ 84 | for label in os.listdir(metaval_folder) \ 85 | if os.path.isdir(os.path.join(metaval_folder, label)) \ 86 | ] 87 | self.metatrain_character_folders = metatrain_folders 88 | self.metaval_character_folders = metaval_folders 89 | self.rotations = config.get('rotations', [0]) 90 | 91 | elif FLAGS.datasource == 'multidataset': 92 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 93 | self.img_size = config.get('img_size', (84, 84)) 94 | self.dim_input = np.prod(self.img_size) * 3 95 | self.dim_output = self.num_classes 96 | self.multidataset = ['CUB_Bird', 'DTD_Texture', 'FGVC_Aircraft', 'FGVCx_Fungi'] 97 | metatrain_folders, metaval_folders = [], [] 98 | for eachdataset in self.multidataset: 99 | metatrain_folders.append( 100 | [os.path.join('{0}/meta-dataset/{1}/train'.format(FLAGS.datadir, eachdataset), label) \ 101 | for label in os.listdir('{0}/meta-dataset/{1}/train'.format(FLAGS.datadir, eachdataset)) \ 102 | if 103 | os.path.isdir(os.path.join('{0}/meta-dataset/{1}/train'.format(FLAGS.datadir, eachdataset), label)) \ 104 | ]) 105 | if FLAGS.test_set: 106 | metaval_folders.append( 107 | [os.path.join('{0}/meta-dataset/{1}/test'.format(FLAGS.datadir, eachdataset), label) \ 108 | for label in os.listdir('{0}/meta-dataset/{1}/test'.format(FLAGS.datadir, eachdataset)) \ 109 | if os.path.isdir( 110 | os.path.join('{0}/meta-dataset/{1}/test'.format(FLAGS.datadir, eachdataset), label)) \ 111 | ]) 112 | else: 113 | metaval_folders.append( 114 | [os.path.join('{0}/meta-dataset/{1}/val'.format(FLAGS.datadir, eachdataset), label) \ 115 | for label in os.listdir('{0}/meta-dataset/{1}/val'.format(FLAGS.datadir, eachdataset)) \ 116 | if os.path.isdir( 117 | os.path.join('{0}/meta-dataset/{1}/val'.format(FLAGS.datadir, eachdataset), label)) \ 118 | ]) 119 | self.metatrain_character_folders = metatrain_folders 120 | self.metaval_character_folders = metaval_folders 121 | self.rotations = config.get('rotations', [0]) 122 | 123 | elif FLAGS.datasource == 'multidataset_leave_one_out': 124 | self.num_classes = config.get('num_classes', FLAGS.num_classes) 125 | self.img_size = config.get('img_size', (84, 84)) 126 | self.dim_input = np.prod(self.img_size) * 3 127 | self.dim_output = self.num_classes 128 | self.multidataset = ['CUB_Bird', 'DTD_Texture', 'FGVC_Aircraft', 'FGVCx_Fungi'] 129 | metatrain_folders, metaval_folders = [], [] 130 | for idx_data, eachdataset in enumerate(self.multidataset): 131 | if idx_data == FLAGS.leave_one_out_id: 132 | continue 133 | metatrain_folders.append( 134 | [os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, eachdataset), label) \ 135 | for label in 136 | os.listdir('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, eachdataset)) \ 137 | if 138 | os.path.isdir( 139 | os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, eachdataset), 140 | label)) \ 141 | ]) 142 | if FLAGS.test_set: 143 | metaval_folders = [os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, 144 | self.multidataset[ 145 | FLAGS.leave_one_out_id]), 146 | label) \ 147 | for label in os.listdir( 148 | '{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, 149 | self.multidataset[FLAGS.leave_one_out_id])) \ 150 | if os.path.isdir( 151 | os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, self.multidataset[ 152 | FLAGS.leave_one_out_id]), label)) \ 153 | ] 154 | else: 155 | metaval_folders = [os.path.join('{0}/meta-dataset-leave-one-out/{1}/val'.format(FLAGS.datadir, 156 | self.multidataset[ 157 | FLAGS.leave_one_out_id]), 158 | label) \ 159 | for label in os.listdir( 160 | '{0}/meta-dataset-leave-one-out/{1}/val'.format(FLAGS.datadir, 161 | self.multidataset[FLAGS.leave_one_out_id])) \ 162 | if os.path.isdir( 163 | os.path.join('{0}/meta-dataset-leave-one-out/{1}/val'.format(FLAGS.datadir, 164 | self.multidataset[ 165 | FLAGS.leave_one_out_id]), 166 | label)) \ 167 | ] 168 | self.metatrain_character_folders = metatrain_folders 169 | self.metaval_character_folders = metaval_folders 170 | self.rotations = config.get('rotations', [0]) 171 | 172 | else: 173 | raise ValueError('Unrecognized data source') 174 | 175 | def make_data_tensor(self, train=True): 176 | if train: 177 | folders = self.metatrain_character_folders 178 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) 179 | num_total_batches = 200000 180 | else: 181 | folders = self.metaval_character_folders 182 | num_total_batches = 600 183 | 184 | # make list of files 185 | print('Generating filenames') 186 | all_filenames = [] 187 | for _ in range(num_total_batches): 188 | sampled_character_folders = random.sample(folders, self.num_classes) 189 | random.shuffle(sampled_character_folders) 190 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), 191 | nb_samples=self.num_samples_per_class, shuffle=False) 192 | # make sure the above isn't randomized order 193 | labels = [li[0] for li in labels_and_images] 194 | filenames = [li[1] for li in labels_and_images] 195 | all_filenames.extend(filenames) 196 | 197 | # make queue for tensorflow to read from 198 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 199 | print('Generating image processing ops') 200 | image_reader = tf.WholeFileReader() 201 | _, image_file = image_reader.read(filename_queue) 202 | if FLAGS.datasource == 'miniimagenet': 203 | image = tf.image.decode_jpeg(image_file, channels=3) 204 | image.set_shape((self.img_size[0], self.img_size[1], 3)) 205 | image = tf.reshape(image, [self.dim_input]) 206 | image = tf.cast(image, tf.float32) / 255.0 207 | else: 208 | image = tf.image.decode_png(image_file) 209 | image.set_shape((self.img_size[0], self.img_size[1], 1)) 210 | image = tf.reshape(image, [self.dim_input]) 211 | image = tf.cast(image, tf.float32) / 255.0 212 | image = 1.0 - image 213 | num_preprocess_threads = 1 214 | min_queue_examples = 256 215 | examples_per_batch = self.num_classes * self.num_samples_per_class 216 | batch_image_size = self.batch_size * examples_per_batch 217 | print('Batching images') 218 | images = tf.train.batch( 219 | [image], 220 | batch_size=batch_image_size, 221 | num_threads=num_preprocess_threads, 222 | capacity=min_queue_examples + 3 * batch_image_size, 223 | ) 224 | all_image_batches, all_label_batches = [], [] 225 | print('Manipulating image data to be right shape') 226 | for i in range(self.batch_size): 227 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch] 228 | 229 | if FLAGS.datasource == 'omniglot': 230 | # omniglot augments the dataset by rotating digits to create new classes 231 | # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes) 232 | rotations = tf.multinomial(tf.log([[1., 1., 1., 1.]]), self.num_classes) 233 | label_batch = tf.convert_to_tensor(labels) 234 | new_list, new_label_list = [], [] 235 | for k in range(self.num_samples_per_class): 236 | class_idxs = tf.range(0, self.num_classes) 237 | class_idxs = tf.random_shuffle(class_idxs) 238 | 239 | true_idxs = class_idxs * self.num_samples_per_class + k 240 | 241 | new_list.append(tf.gather(image_batch, true_idxs)) 242 | if FLAGS.datasource == 'omniglot': # and FLAGS.train: 243 | new_list[-1] = tf.stack([tf.reshape(tf.image.rot90( 244 | tf.reshape(new_list[-1][ind], [self.img_size[0], self.img_size[1], 1]), 245 | k=tf.cast(rotations[0, class_idxs[ind]], tf.int32)), (self.dim_input,)) 246 | for ind in range(self.num_classes)]) 247 | new_label_list.append(tf.gather(label_batch, true_idxs)) 248 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 249 | new_label_list = tf.concat(new_label_list, 0) 250 | all_image_batches.append(new_list) 251 | all_label_batches.append(new_label_list) 252 | all_image_batches = tf.stack(all_image_batches) 253 | all_label_batches = tf.stack(all_label_batches) 254 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes) 255 | return all_image_batches, all_label_batches 256 | 257 | def make_data_tensor_multidataset(self, train=True): 258 | if train: 259 | folders = self.metatrain_character_folders 260 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) 261 | if FLAGS.update_batch_size == 10: 262 | num_total_batches = 140000 263 | else: 264 | num_total_batches = 200000 265 | else: 266 | folders = self.metaval_character_folders 267 | num_total_batches = FLAGS.num_test_task 268 | # make list of files 269 | print('Generating filenames') 270 | all_filenames = [] 271 | # if FLAGS.train == False: 272 | # np.random.seed(4) 273 | for image_itr in range(num_total_batches): 274 | sel = np.random.randint(4) 275 | if FLAGS.train == False and FLAGS.test_dataset != -1: 276 | sel = FLAGS.test_dataset 277 | sampled_character_folders = random.sample(folders[sel], self.num_classes) 278 | random.shuffle(sampled_character_folders) 279 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), 280 | nb_samples=self.num_samples_per_class, shuffle=False) 281 | # make sure the above isn't randomized order 282 | labels = [li[0] for li in labels_and_images] 283 | filenames = [li[1] for li in labels_and_images] 284 | all_filenames.extend(filenames) 285 | 286 | # make queue for tensorflow to read from 287 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 288 | print('Generating image processing ops') 289 | image_reader = tf.WholeFileReader() 290 | _, image_file = image_reader.read(filename_queue) 291 | if FLAGS.datasource in ['miniimagenet', 'multidataset']: 292 | image = tf.image.decode_jpeg(image_file, channels=3) 293 | image.set_shape((self.img_size[0], self.img_size[1], 3)) 294 | image = tf.reshape(image, [self.dim_input]) 295 | image = tf.cast(image, tf.float32) / 255.0 296 | else: 297 | image = tf.image.decode_png(image_file) 298 | image.set_shape((self.img_size[0], self.img_size[1], 1)) 299 | image = tf.reshape(image, [self.dim_input]) 300 | image = tf.cast(image, tf.float32) / 255.0 301 | image = 1.0 - image # invert 302 | num_preprocess_threads = 1 # TODO - enable this to be set to >1 303 | min_queue_examples = 256 304 | examples_per_batch = self.num_classes * self.num_samples_per_class 305 | batch_image_size = self.batch_size * examples_per_batch 306 | print('Batching images') 307 | images = tf.train.batch( 308 | [image], 309 | batch_size=batch_image_size, 310 | num_threads=num_preprocess_threads, 311 | capacity=min_queue_examples + 3 * batch_image_size, 312 | ) 313 | all_image_batches, all_label_batches = [], [] 314 | print('Manipulating image data to be right shape') 315 | for i in range(self.batch_size): 316 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch] 317 | label_batch = tf.convert_to_tensor(labels) 318 | new_list, new_label_list = [], [] 319 | for k in range(self.num_samples_per_class): 320 | class_idxs = tf.range(0, self.num_classes) 321 | class_idxs = tf.random_shuffle(class_idxs) 322 | true_idxs = class_idxs * self.num_samples_per_class + k 323 | new_list.append(tf.gather(image_batch, true_idxs)) 324 | new_label_list.append(tf.gather(label_batch, true_idxs)) 325 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 326 | new_label_list = tf.concat(new_label_list, 0) 327 | all_image_batches.append(new_list) 328 | all_label_batches.append(new_label_list) 329 | all_image_batches = tf.stack(all_image_batches) 330 | all_label_batches = tf.stack(all_label_batches) 331 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes) 332 | return all_image_batches, all_label_batches 333 | 334 | def make_data_tensor_multidataset_leave_one_out(self, train=True): 335 | if train: 336 | folders = self.metatrain_character_folders 337 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure) 338 | num_total_batches = 200000 339 | else: 340 | folders = self.metaval_character_folders 341 | num_total_batches = FLAGS.num_test_task 342 | # make list of files 343 | print('Generating filenames') 344 | all_filenames = [] 345 | for image_itr in range(num_total_batches): 346 | if train: 347 | sel = np.random.randint(3) 348 | sampled_character_folders = random.sample(folders[sel], self.num_classes) 349 | else: 350 | sampled_character_folders = random.sample(folders, self.num_classes) 351 | random.shuffle(sampled_character_folders) 352 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes), 353 | nb_samples=self.num_samples_per_class, shuffle=False) 354 | # make sure the above isn't randomized order 355 | labels = [li[0] for li in labels_and_images] 356 | filenames = [li[1] for li in labels_and_images] 357 | all_filenames.extend(filenames) 358 | 359 | # make queue for tensorflow to read from 360 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False) 361 | print('Generating image processing ops') 362 | image_reader = tf.WholeFileReader() 363 | _, image_file = image_reader.read(filename_queue) 364 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 365 | image = tf.image.decode_jpeg(image_file, channels=3) 366 | image.set_shape((self.img_size[0], self.img_size[1], 3)) 367 | image = tf.reshape(image, [self.dim_input]) 368 | image = tf.cast(image, tf.float32) / 255.0 369 | else: 370 | image = tf.image.decode_png(image_file) 371 | image.set_shape((self.img_size[0], self.img_size[1], 1)) 372 | image = tf.reshape(image, [self.dim_input]) 373 | image = tf.cast(image, tf.float32) / 255.0 374 | image = 1.0 - image # invert 375 | 376 | num_preprocess_threads = 1 377 | min_queue_examples = 256 378 | examples_per_batch = self.num_classes * self.num_samples_per_class 379 | batch_image_size = self.batch_size * examples_per_batch 380 | print('Batching images') 381 | images = tf.train.batch( 382 | [image], 383 | batch_size=batch_image_size, 384 | num_threads=num_preprocess_threads, 385 | capacity=min_queue_examples + 3 * batch_image_size, 386 | ) 387 | all_image_batches, all_label_batches = [], [] 388 | print('Manipulating image data to be right shape') 389 | for i in range(self.batch_size): 390 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch] 391 | label_batch = tf.convert_to_tensor(labels) 392 | new_list, new_label_list = [], [] 393 | for k in range(self.num_samples_per_class): 394 | class_idxs = tf.range(0, self.num_classes) 395 | class_idxs = tf.random_shuffle(class_idxs) 396 | true_idxs = class_idxs * self.num_samples_per_class + k 397 | new_list.append(tf.gather(image_batch, true_idxs)) 398 | new_label_list.append(tf.gather(label_batch, true_idxs)) 399 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input] 400 | new_label_list = tf.concat(new_label_list, 0) 401 | all_image_batches.append(new_list) 402 | all_label_batches.append(new_label_list) 403 | all_image_batches = tf.stack(all_image_batches) 404 | all_label_batches = tf.stack(all_label_batches) 405 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes) 406 | return all_image_batches, all_label_batches 407 | 408 | def generate_sinusoid_batch(self, train=True, input_idx=None): 409 | # Note train arg is not used (but it is used for omniglot method. 410 | # input_idx is used during qualitative testing --the number of examples used for the grad update 411 | amp = np.random.uniform(self.amp_range[0], self.amp_range[1], [self.batch_size]) 412 | freq = np.random.uniform(self.freq_range[0], self.freq_range[1], [self.batch_size]) 413 | phase = np.random.uniform(self.phase_range[0], self.phase_range[1], [self.batch_size]) 414 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output]) 415 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input]) 416 | for func in range(self.batch_size): 417 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1], 418 | [self.num_samples_per_class, 1]) 419 | if input_idx is not None: 420 | init_inputs[:, input_idx:, 0] = np.linspace(self.input_range[0], self.input_range[1], 421 | num=self.num_samples_per_class - input_idx, retstep=False) 422 | outputs[func] = amp[func] * np.sin(freq[func] * init_inputs[func] - phase[func]) 423 | return init_inputs, outputs, amp, phase 424 | 425 | def generate_mixture_batch(self, train=True, input_idx=None, DRAW_PLOTS=True): 426 | dim_input = self.dim_input 427 | dim_output = self.dim_output 428 | batch_size = self.batch_size 429 | num_samples_per_class = self.num_samples_per_class 430 | 431 | # sin 432 | amp = np.random.uniform(0.1, 5.0, size=self.batch_size) 433 | phase = np.random.uniform(0., 2 * np.pi, size=batch_size) 434 | freq = np.random.uniform(0.8, 1.2, size=batch_size) 435 | 436 | # linear 437 | A = np.random.uniform(-3.0, 3.0, size=batch_size) 438 | b = np.random.uniform(-3.0, 3.0, size=batch_size) 439 | 440 | # quadratic 441 | A_q = np.random.uniform(-0.2, 0.2, size=batch_size) 442 | b_q = np.random.uniform(-2.0, 2.0, size=batch_size) 443 | c_q = np.random.uniform(-3.0, 3.0, size=batch_size) 444 | 445 | # cubic 446 | A_c = np.random.uniform(-0.1, 0.1, size=batch_size) 447 | b_c = np.random.uniform(-0.2, 0.2, size=batch_size) 448 | c_c = np.random.uniform(-2.0, 2.0, size=batch_size) 449 | d_c = np.random.uniform(-3.0, 3.0, size=batch_size) 450 | 451 | sel_set = np.zeros(batch_size) 452 | 453 | init_inputs = np.zeros([batch_size, num_samples_per_class, dim_input]) 454 | outputs = np.zeros([batch_size, num_samples_per_class, dim_output]) 455 | 456 | for func in range(batch_size): 457 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1], 458 | size=(num_samples_per_class, dim_input)) 459 | sel = np.random.randint(4) 460 | if FLAGS.train == False and FLAGS.test_dataset != -1: 461 | sel = FLAGS.test_dataset 462 | if sel == 0: 463 | outputs[func] = amp[func] * np.sin(freq[func] * init_inputs[func]) + phase[func] 464 | elif sel == 1: 465 | outputs[func] = A[func] * init_inputs[func] + b[func] 466 | elif sel == 2: 467 | outputs[func] = A_q[func] * np.square(init_inputs[func]) + b_q[func] * init_inputs[func] + c_q[func] 468 | elif sel == 3: 469 | outputs[func] = A_c[func] * np.power(init_inputs[func], np.tile([3], init_inputs[func].shape)) + b_c[ 470 | func] * np.square(init_inputs[func]) + c_c[func] * init_inputs[func] + d_c[func] 471 | sel_set[func] = sel 472 | funcs_params = {'amp': amp, 'phase': phase, 'freq': freq, 'A': A, 'b': b, 'A_q': A_q, 'c_q': c_q, 'b_q': b_q, 473 | 'A_c': A_c, 'b_c': b_c, 'c_c': c_c, 'd_c': d_c} 474 | return init_inputs, outputs, funcs_params, sel_set 475 | -------------------------------------------------------------------------------- /image_embedding.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import flags 3 | 4 | FLAGS = flags.FLAGS 5 | 6 | 7 | class ImageEmbedding(object): 8 | def __init__(self, hidden_num, channels, conv_initializer, k=5): 9 | self.hidden_num = hidden_num 10 | self.channels = channels 11 | with tf.variable_scope('image_embedding', reuse=tf.AUTO_REUSE): 12 | self.conv1_kernel = tf.get_variable('conv1_kernel', [k, k, self.channels, self.hidden_num], 13 | initializer=conv_initializer) 14 | 15 | self.conv2_kernel = tf.get_variable('conv2_kernel', [k, k, self.hidden_num, self.hidden_num], 16 | initializer=conv_initializer) 17 | self.activation = tf.nn.relu 18 | 19 | def model(self, images): 20 | conv = tf.nn.conv2d(images, self.conv1_kernel, [1, 1, 1, 1], padding='SAME') 21 | conv1 = tf.nn.relu(conv, name='conv1_post_activation') 22 | 23 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1], 24 | padding='SAME', name='pool1') 25 | norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 26 | name='norm1') 27 | 28 | conv2 = tf.nn.conv2d(norm1, self.conv2_kernel, [1, 1, 1, 1], padding='SAME') 29 | conv2_act = tf.nn.relu(conv2, name='conv2_post_activation') 30 | 31 | norm2 = tf.nn.lrn(conv2_act, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75, 32 | name='norm2') 33 | pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1], 34 | strides=[1, 2, 2, 1], padding='SAME', name='pool2') 35 | 36 | with tf.variable_scope('local3', reuse=tf.AUTO_REUSE): 37 | image_reshape = tf.reshape(pool2, [images.get_shape().as_list()[0], -1]) 38 | dim = image_reshape.get_shape()[1].value 39 | local3_weight = tf.get_variable(name='weight', shape=[dim, 384], 40 | initializer=tf.truncated_normal_initializer(stddev=0.04)) 41 | local3_biases = tf.get_variable(name='biases', shape=[384], initializer=tf.constant_initializer(0.1)) 42 | local3=tf.nn.relu(tf.matmul(image_reshape, local3_weight)+local3_biases, name='local3_dense') 43 | 44 | with tf.variable_scope('local4', reuse=tf.AUTO_REUSE) as scope: 45 | local4_weight = tf.get_variable(name='weight', shape=[384, 64], 46 | initializer=tf.truncated_normal_initializer(stddev=0.04)) 47 | local4_biases = tf.get_variable(name='biases', shape=[64], initializer=tf.constant_initializer(0.1)) 48 | local4 = tf.nn.relu(tf.matmul(local3, local4_weight) + local4_biases, name='local4_dense') 49 | return local4 50 | -------------------------------------------------------------------------------- /lstm_tree.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | from tensorflow.python.platform import flags 4 | 5 | FLAGS = flags.FLAGS 6 | 7 | 8 | class TreeLSTM(object): 9 | def __init__(self, tree_hidden_dim, input_dim): 10 | self.input_dim = input_dim 11 | self.tree_hidden_dim = tree_hidden_dim 12 | self.leaf_weight_i, self.leaf_weight_o, self.leaf_weight_u = [], [], [] 13 | self.leaf_bias_i, self.leaf_bias_o, self.leaf_bias_u = [], [], [] 14 | for i in range(FLAGS.cluster_layer_0): 15 | self.leaf_weight_u.append( 16 | tf.get_variable(name='{}_leaf_weight_u'.format(i), shape=(input_dim, tree_hidden_dim))) 17 | self.leaf_bias_u.append(tf.get_variable(name='{}_leaf_bias_u'.format(i), shape=(1, tree_hidden_dim))) 18 | 19 | self.no_leaf_weight_i, self.no_leaf_weight_o, self.no_leaf_weight_u, self.no_leaf_weight_f = [], [], [], [] 20 | self.no_leaf_bias_i, self.no_leaf_bias_o, self.no_leaf_bias_u, self.no_leaf_bias_f = [], [], [], [] 21 | for i in range(FLAGS.cluster_layer_1): 22 | if FLAGS.tree_type == 1: 23 | self.no_leaf_weight_i.append( 24 | tf.get_variable(name='{}_no_leaf_weight_i'.format(i), shape=(tree_hidden_dim, 1))) 25 | elif FLAGS.tree_type == 2: 26 | self.no_leaf_weight_i.append( 27 | tf.get_variable(name='{}_no_leaf_weight_i'.format(i), shape=(1, tree_hidden_dim))) 28 | self.no_leaf_weight_u.append( 29 | tf.get_variable(name='{}_no_leaf_weight_u'.format(i), shape=(tree_hidden_dim, tree_hidden_dim))) 30 | 31 | self.no_leaf_bias_i.append(tf.get_variable(name='{}_no_leaf_bias_i'.format(i), shape=(1, 1))) 32 | self.no_leaf_bias_u.append(tf.get_variable(name='{}_no_leaf_bias_u'.format(i), shape=(1, tree_hidden_dim))) 33 | 34 | if FLAGS.cluster_layer_2 != -1: 35 | self.no_leaf_weight_i_l2, self.no_leaf_weight_u_l2 = [], [] 36 | self.no_leaf_bias_i_l2, self.no_leaf_bias_u_l2 = [], [] 37 | for i in range(FLAGS.cluster_layer_2): 38 | if FLAGS.tree_type == 1: 39 | self.no_leaf_weight_i_l2.append( 40 | tf.get_variable(name='{}_no_leaf_weight_i_l2'.format(i), shape=(tree_hidden_dim, 1))) 41 | elif FLAGS.tree_type == 2: 42 | self.no_leaf_weight_i_l2.append( 43 | tf.get_variable(name='{}_no_leaf_weight_i_l2'.format(i), shape=(1, tree_hidden_dim))) 44 | self.no_leaf_weight_u_l2.append( 45 | tf.get_variable(name='{}_no_leaf_weight_u_l2'.format(i), shape=(tree_hidden_dim, tree_hidden_dim))) 46 | 47 | self.no_leaf_bias_i_l2.append(tf.get_variable(name='{}_no_leaf_bias_i_l2'.format(i), shape=(1, 1))) 48 | self.no_leaf_bias_u_l2.append( 49 | tf.get_variable(name='{}_no_leaf_bias_u_l2'.format(i), shape=(1, tree_hidden_dim))) 50 | 51 | self.root_weight_u = tf.get_variable(name='{}_root_weight_u'.format(i), 52 | shape=(tree_hidden_dim, tree_hidden_dim)) 53 | 54 | self.root_bias_u = tf.get_variable(name='{}_root_bias_u'.format(i), shape=(1, tree_hidden_dim)) 55 | 56 | self.cluster_center = [] 57 | for i in range(FLAGS.cluster_layer_0): 58 | self.cluster_center.append(tf.get_variable(name='{}_cluster_center'.format(i), 59 | shape=(1, input_dim))) 60 | 61 | self.cluster_layer_0 = FLAGS.cluster_layer_0 62 | self.cluster_layer_1 = FLAGS.cluster_layer_1 63 | self.cluster_layer_2 = FLAGS.cluster_layer_2 64 | 65 | def model(self, inputs): 66 | 67 | if FLAGS.datasource == 'multidataset' or FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'multidataset_leave_one_out': 68 | sigma = 10.0 69 | elif FLAGS.datasource in ['sinusoid', 'mixture']: 70 | sigma = 2.0 71 | 72 | for idx in range(self.cluster_layer_0): 73 | if idx == 0: 74 | all_value = tf.exp(-tf.reduce_sum(tf.square(inputs - self.cluster_center[idx])) / (2.0 * sigma)) 75 | else: 76 | all_value += tf.exp(-tf.reduce_sum(tf.square(inputs - self.cluster_center[idx])) / (2.0 * sigma)) 77 | 78 | c_leaf = [] 79 | for idx in range(self.cluster_layer_0): 80 | assignment_idx = tf.exp( 81 | -tf.reduce_sum(tf.square(inputs - self.cluster_center[idx])) / (2.0 * sigma)) / all_value 82 | value_u = tf.tanh(tf.matmul(inputs, self.leaf_weight_u[idx]) + self.leaf_bias_u[idx]) 83 | value_c = assignment_idx * value_u 84 | c_leaf.append(value_c) 85 | 86 | c_no_leaf = [] 87 | for idx in range(self.cluster_layer_0): 88 | input_gate = [] 89 | for idx_layer_1 in range(self.cluster_layer_1): 90 | if FLAGS.tree_type == 1: 91 | input_gate.append( 92 | tf.matmul(c_leaf[idx], self.no_leaf_weight_i[idx_layer_1]) + self.no_leaf_bias_i[idx_layer_1]) 93 | elif FLAGS.tree_type == 2: 94 | input_gate.append( 95 | -(tf.reduce_sum(tf.square(c_leaf[idx] - self.no_leaf_weight_i[idx_layer_1]), keepdims=True) + 96 | self.no_leaf_bias_i[idx_layer_1]) / ( 97 | 2.0)) 98 | input_gate = tf.nn.softmax(tf.concat(input_gate, axis=0), axis=0) 99 | 100 | c_no_leaf_temp = [] 101 | for idx_layer_1 in range(self.cluster_layer_1): 102 | no_leaf_value_u = tf.tanh( 103 | tf.matmul(c_leaf[idx], self.no_leaf_weight_u[idx_layer_1]) + self.no_leaf_bias_u[idx_layer_1]) 104 | c_no_leaf_temp.append(input_gate[idx_layer_1] * no_leaf_value_u) 105 | c_no_leaf.append(tf.concat(c_no_leaf_temp, axis=0)) 106 | 107 | c_no_leaf = tf.stack(c_no_leaf, axis=0) 108 | c_no_leaf = tf.transpose(c_no_leaf, perm=[1, 0, 2]) 109 | c_no_leaf = tf.reduce_sum(c_no_leaf, axis=1, keepdims=True) 110 | 111 | if FLAGS.cluster_layer_2 != -1: 112 | c_no_leaf_l2 = [] 113 | 114 | for idx_l2 in range(self.cluster_layer_1): 115 | input_gate_l2 = [] 116 | for idx_layer_2 in range(self.cluster_layer_2): 117 | if FLAGS.tree_type == 1: 118 | input_gate_l2.append( 119 | tf.matmul(c_no_leaf[idx_l2], self.no_leaf_weight_i_l2[idx_layer_2]) + 120 | self.no_leaf_bias_i_l2[ 121 | idx_layer_2]) 122 | elif FLAGS.tree_type == 2: 123 | input_gate_l2.append( 124 | -(tf.reduce_sum(tf.square(c_no_leaf[idx_l2] - self.no_leaf_weight_i_l2[idx_layer_2]), 125 | keepdims=True) + self.no_leaf_bias_i[idx_layer_1]) / (2.0)) 126 | input_gate_l2 = tf.nn.softmax(tf.concat(input_gate_l2, axis=0), axis=0) 127 | 128 | c_no_leaf_temp_l2 = [] 129 | for idx_layer_2 in range(self.cluster_layer_2): 130 | no_leaf_value_u_l2 = tf.tanh( 131 | tf.matmul(c_no_leaf[idx_l2], self.no_leaf_weight_u_l2[idx_layer_2]) + self.no_leaf_bias_u_l2[ 132 | idx_layer_2]) 133 | c_no_leaf_temp_l2.append(input_gate_l2[idx_layer_2] * no_leaf_value_u_l2) 134 | c_no_leaf_l2.append(tf.concat(c_no_leaf_temp_l2, axis=0)) 135 | 136 | c_no_leaf_l2 = tf.stack(c_no_leaf_l2, axis=0) 137 | c_no_leaf_l2 = tf.transpose(c_no_leaf_l2, perm=[1, 0, 2]) 138 | c_no_leaf_l2 = tf.reduce_sum(c_no_leaf_l2, axis=1, keepdims=True) 139 | 140 | root_c = [] 141 | 142 | if FLAGS.cluster_layer_2 != -1: 143 | for idx in range(self.cluster_layer_2): 144 | root_c.append(tf.tanh(tf.matmul(c_no_leaf_l2[idx], self.root_weight_u) + self.root_bias_u)) 145 | else: 146 | for idx in range(self.cluster_layer_1): 147 | root_c.append(tf.tanh(tf.matmul(c_no_leaf[idx], self.root_weight_u) + self.root_bias_u)) 148 | 149 | root_c = tf.reduce_sum(tf.concat(root_c, axis=0), axis=0, keepdims=True) 150 | 151 | return root_c, root_c 152 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import csv 2 | import numpy as np 3 | import pickle 4 | import random 5 | import matplotlib.pyplot as plt 6 | import tensorflow as tf 7 | 8 | tf.set_random_seed(1234) 9 | from data_generator import DataGenerator 10 | from maml import MAML 11 | from tensorflow.python.platform import flags 12 | 13 | FLAGS = flags.FLAGS 14 | 15 | ## Dataset/method options 16 | flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet or mixture or multidataset or multidataset_leave_one_out') 17 | flags.DEFINE_integer('leave_one_out_id',-1,'id of leave one out') 18 | flags.DEFINE_integer('test_dataset', -1, 19 | 'which dataset to be test: 0: bird, 1: texture, 2: aircraft, 3: fungi, -1 is test all') 20 | flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).') 21 | flags.DEFINE_integer('num_test_task', 1000, 'number of test tasks.') 22 | flags.DEFINE_integer('test_epoch', -1, 'test epoch, only work when test start') 23 | 24 | ## Training options 25 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.') 26 | flags.DEFINE_integer('metatrain_iterations', 15000, 27 | 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid 28 | flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update') 29 | flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator') 30 | flags.DEFINE_integer('update_batch_size', 5, 31 | 'number of examples used for inner gradient update (K for K-shot learning).') 32 | flags.DEFINE_integer('update_batch_size_eval', 10, 33 | 'number of examples used for inner gradient test (K for K-shot learning).') 34 | flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot 35 | flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.') 36 | flags.DEFINE_integer('num_groups', 1, 'number of groups.') 37 | flags.DEFINE_integer('fix_embedding_sample', -1, 38 | 'if the fix_embedding sample is -1, all samples are used for embedding. Otherwise, specific samples are used') 39 | ## Model options 40 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None') 41 | flags.DEFINE_integer('hidden_dim', 40, 'output dimension of task embedding') 42 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.') 43 | flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases') 44 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions') 45 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)') 46 | flags.DEFINE_float('emb_loss_weight', 0.0, 'the weight of autoencoder') 47 | flags.DEFINE_string('emb_type', 'sigmoid', 'sigmoid') 48 | flags.DEFINE_bool('no_val', False, 'if true, there are no validation set of Omniglot dataset') 49 | flags.DEFINE_integer('tree_type', 1, 'select the tree type: 1 or 2') 50 | flags.DEFINE_integer('task_embedding_num_filters', 32, 'number of filters for task embedding') 51 | flags.DEFINE_string('task_embedding_type', 'rnn', 'rnn or mean') 52 | 53 | ## clustering information 54 | flags.DEFINE_integer('cluster_layer_0', 4, 'number of clusters in the first layer') 55 | flags.DEFINE_integer('cluster_layer_1', 2, 'number of clusters in the second layer') 56 | flags.DEFINE_integer('cluster_layer_2', -1, 'number of clusters in the third layer') 57 | 58 | ## Logging, saving, and testing options 59 | flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.') 60 | flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.') 61 | flags.DEFINE_string('datadir', '/home/huaxiuyao/Data/', 'directory for datasets.') 62 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available') 63 | flags.DEFINE_bool('train', True, 'True to train, False to test.') 64 | flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.') 65 | flags.DEFINE_integer('train_update_batch_size', -1, 66 | 'number of examples used for gradient update during training (use if you want to test with a different number).') 67 | flags.DEFINE_float('train_update_lr', -1, 68 | 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot 69 | 70 | 71 | def train(model, saver, sess, exp_string, data_generator, resume_itr=0): 72 | SUMMARY_INTERVAL = 100 73 | SAVE_INTERVAL = 1000 74 | if FLAGS.datasource in ['sinusoid', 'mixture']: 75 | PRINT_INTERVAL = 1000 76 | TEST_PRINT_INTERVAL = PRINT_INTERVAL * 5 77 | else: 78 | PRINT_INTERVAL = 100 79 | TEST_PRINT_INTERVAL = PRINT_INTERVAL * 10 80 | 81 | if FLAGS.log: 82 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph) 83 | print('Done initializing, starting training.') 84 | 85 | prelosses, postlosses, embedlosses = [], [], [] 86 | 87 | num_classes = data_generator.num_classes # for classification, 1 otherwise 88 | 89 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations): 90 | feed_dict = {} 91 | if 'generate' in dir(data_generator): 92 | if FLAGS.datasource == 'sinusoid': 93 | batch_x, batch_y, amp, phase = data_generator.generate() 94 | elif FLAGS.datasource == 'mixture': 95 | batch_x, batch_y, para_func, sel_set = data_generator.generate() 96 | 97 | inputa = batch_x[:, :num_classes * FLAGS.update_batch_size, :] 98 | labela = batch_y[:, :num_classes * FLAGS.update_batch_size, :] 99 | inputb = batch_x[:, num_classes * FLAGS.update_batch_size:, :] 100 | labelb = batch_y[:, num_classes * FLAGS.update_batch_size:, :] 101 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb} 102 | 103 | if itr < FLAGS.pretrain_iterations: 104 | input_tensors = [model.pretrain_op] 105 | else: 106 | input_tensors = [model.metatrain_op] 107 | 108 | input_tensors.extend( 109 | [model.summ_op, model.total_embed_loss, model.total_loss1, model.total_losses2[FLAGS.num_updates - 1]]) 110 | if model.classification: 111 | input_tensors.extend([model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates - 1]]) 112 | 113 | result = sess.run(input_tensors, feed_dict) 114 | 115 | if np.isnan(result[-2]) == False and np.isnan(result[-2]) == False and np.isnan(result[2]) == False: 116 | prelosses.append(result[-2]) 117 | postlosses.append(result[-1]) 118 | embedlosses.append(result[2]) 119 | 120 | if itr % SUMMARY_INTERVAL == 0: 121 | if FLAGS.log: 122 | train_writer.add_summary(result[1], itr) 123 | 124 | if (itr != 0) and itr % PRINT_INTERVAL == 0: 125 | if itr < FLAGS.pretrain_iterations: 126 | print_str = 'Pretrain Iteration ' + str(itr) 127 | else: 128 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations) 129 | std = np.std(postlosses, 0) 130 | ci95 = 1.96 * std / np.sqrt(PRINT_INTERVAL) 131 | print_str += ': preloss: ' + str(np.mean(prelosses)) + ', postloss: ' + str( 132 | np.mean(postlosses)) + ', embedding loss: ' + str(np.mean(embedlosses)) + ', confidence: ' + str(ci95) 133 | print(print_str) 134 | prelosses, postlosses, embedlosses = [], [], [] 135 | 136 | if (itr != 0) and itr % SAVE_INTERVAL == 0: 137 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 138 | 139 | if (itr != 0) and itr % TEST_PRINT_INTERVAL == 0 and ( 140 | FLAGS.datasource not in ['sinusoid', 'mixture']): 141 | if 'generate' not in dir(data_generator): 142 | feed_dict = {} 143 | if model.classification: 144 | input_tensors = [model.metaval_total_accuracy1, 145 | model.metaval_total_accuracies2[FLAGS.num_updates - 1], model.summ_op] 146 | else: 147 | input_tensors = [model.metaval_total_loss1, model.metaval_total_losses2[FLAGS.num_updates - 1], 148 | model.summ_op] 149 | else: 150 | if FLAGS.datasource == 'sinusoid': 151 | batch_x, batch_y, amp, phase = data_generator.generate(train=False) 152 | elif FLAGS.datasource == 'mixture': 153 | batch_x, batch_y, para_func = data_generator.generate(train=False) 154 | inputa = batch_x[:, :num_classes * FLAGS.update_batch_size, :] 155 | inputb = batch_x[:, num_classes * FLAGS.update_batch_size:, :] 156 | labela = batch_y[:, :num_classes * FLAGS.update_batch_size, :] 157 | labelb = batch_y[:, num_classes * FLAGS.update_batch_size:, :] 158 | 159 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb, 160 | model.meta_lr: 0.0} 161 | if model.classification: 162 | input_tensors = [model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates - 1]] 163 | else: 164 | input_tensors = [model.total_loss1, model.total_losses2[FLAGS.num_updates - 1]] 165 | 166 | result = sess.run(input_tensors, feed_dict) 167 | print('Validation results: ' + str(result[0]) + ', ' + str(result[1])) 168 | 169 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr)) 170 | 171 | 172 | if FLAGS.datasource in ['multidataset', 'multidataset_leave_one_out', 'mixture']: 173 | NUM_TEST_POINTS = FLAGS.num_test_task 174 | 175 | 176 | def test(model, saver, sess, exp_string, data_generator, test_num_updates=None): 177 | num_classes = data_generator.num_classes 178 | 179 | np.random.seed(1) 180 | random.seed(1) 181 | 182 | metaval_accuracies = [] 183 | print(NUM_TEST_POINTS) 184 | for test_itr in range(NUM_TEST_POINTS): 185 | 186 | if 'generate' not in dir(data_generator): 187 | feed_dict = {} 188 | feed_dict = {model.meta_lr: 0.0} 189 | else: 190 | if FLAGS.datasource == 'sinusoid': 191 | batch_x, batch_y, amp, phase = data_generator.generate(train=False) 192 | elif FLAGS.datasource == 'mixture': 193 | batch_x, batch_y, para_func, sel_set = data_generator.generate(train=False) 194 | 195 | inputa = batch_x[:, :num_classes * FLAGS.update_batch_size, :] 196 | inputb = batch_x[:, num_classes * FLAGS.update_batch_size:, :] 197 | labela = batch_y[:, :num_classes * FLAGS.update_batch_size, :] 198 | labelb = batch_y[:, num_classes * FLAGS.update_batch_size:, :] 199 | 200 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb, 201 | model.meta_lr: 0.0} 202 | 203 | if model.classification: 204 | result = sess.run([model.metaval_total_accuracy1] + model.metaval_total_accuracies2, feed_dict) 205 | else: # this is for sinusoid 206 | result = sess.run([model.total_loss1] + model.total_losses2, feed_dict) 207 | 208 | metaval_accuracies.append(result) 209 | 210 | metaval_accuracies = np.array(metaval_accuracies) 211 | means = np.mean(metaval_accuracies, 0) 212 | stds = np.std(metaval_accuracies, 0) 213 | ci95 = 1.96 * stds / np.sqrt(NUM_TEST_POINTS) 214 | 215 | print('Mean validation accuracy/loss, stddev, and confidence intervals') 216 | print((means, stds, ci95)) 217 | 218 | 219 | def main(): 220 | 221 | if FLAGS.datasource == 'multidataset_leave_one_out': 222 | assert FLAGS.leave_one_out_id > -1 223 | 224 | sess = tf.InteractiveSession() 225 | if FLAGS.datasource in ['sinusoid', 'mixture']: 226 | if FLAGS.train: 227 | test_num_updates = 1 228 | else: 229 | test_num_updates = 10 230 | else: 231 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 232 | if FLAGS.train == True: 233 | test_num_updates = 1 # eval on at least one update during training 234 | else: 235 | test_num_updates = 10 236 | else: 237 | test_num_updates = 10 238 | 239 | if FLAGS.train == False: 240 | orig_meta_batch_size = FLAGS.meta_batch_size 241 | # always use meta batch size of 1 when testing. 242 | FLAGS.meta_batch_size = 1 243 | 244 | if FLAGS.datasource in ['sinusoid', 'mixture']: 245 | data_generator = DataGenerator(FLAGS.update_batch_size + FLAGS.update_batch_size_eval, FLAGS.meta_batch_size) 246 | else: 247 | if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 248 | assert FLAGS.meta_batch_size == 1 249 | assert FLAGS.update_batch_size == 1 250 | data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint, 251 | else: 252 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 253 | if FLAGS.train: 254 | data_generator = DataGenerator(FLAGS.update_batch_size + 15, 255 | FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 256 | else: 257 | data_generator = DataGenerator(FLAGS.update_batch_size * 2, 258 | FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 259 | else: 260 | data_generator = DataGenerator(FLAGS.update_batch_size * 2, 261 | FLAGS.meta_batch_size) # only use one datapoint for testing to save memory 262 | 263 | dim_output = data_generator.dim_output 264 | dim_input = data_generator.dim_input 265 | 266 | if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']: 267 | tf_data_load = True 268 | num_classes = data_generator.num_classes 269 | 270 | if FLAGS.train: # only construct training model if needed 271 | random.seed(5) 272 | if FLAGS.datasource in ['miniimagenet', 'omniglot']: 273 | image_tensor, label_tensor = data_generator.make_data_tensor() 274 | elif FLAGS.datasource == 'multidataset': 275 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset() 276 | elif FLAGS.datasource == 'multidataset_leave_one_out': 277 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out() 278 | inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) 279 | inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) 280 | labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) 281 | labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) 282 | input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} 283 | 284 | random.seed(6) 285 | if FLAGS.datasource in ['miniimagenet', 'omniglot']: 286 | image_tensor, label_tensor = data_generator.make_data_tensor(train=False) 287 | elif FLAGS.datasource == 'multidataset': 288 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(train=False) 289 | elif FLAGS.datasource == 'multidataset_leave_one_out': 290 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out(train=False) 291 | inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) 292 | inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) 293 | labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1]) 294 | labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1]) 295 | metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb} 296 | else: 297 | tf_data_load = False 298 | input_tensors = None 299 | 300 | model = MAML(sess, dim_input, dim_output, test_num_updates=test_num_updates) 301 | 302 | if FLAGS.train or not tf_data_load: 303 | model.construct_model(input_tensors=input_tensors, prefix='metatrain_') 304 | if tf_data_load: 305 | model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_') 306 | model.summ_op = tf.summary.merge_all() 307 | saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10) 308 | 309 | if FLAGS.train == False: 310 | # change to original meta batch size when loading model. 311 | FLAGS.meta_batch_size = orig_meta_batch_size 312 | 313 | if FLAGS.train_update_batch_size == -1: 314 | FLAGS.train_update_batch_size = FLAGS.update_batch_size 315 | if FLAGS.train_update_lr == -1: 316 | FLAGS.train_update_lr = FLAGS.update_lr 317 | 318 | exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str( 319 | FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str( 320 | FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) + '.emb_loss_weight' + str( 321 | FLAGS.emb_loss_weight) + '.num_groups' + str(FLAGS.num_groups) + '.emb_type' + str( 322 | FLAGS.emb_type) + '.hidden_dim' + str(FLAGS.hidden_dim) 323 | 324 | if FLAGS.num_filters != 64: 325 | exp_string += 'hidden' + str(FLAGS.num_filters) 326 | if FLAGS.max_pool: 327 | exp_string += 'maxpool' 328 | if FLAGS.stop_grad: 329 | exp_string += 'stopgrad' 330 | if FLAGS.norm == 'batch_norm': 331 | exp_string += 'batchnorm' 332 | elif FLAGS.norm == 'layer_norm': 333 | exp_string += 'layernorm' 334 | elif FLAGS.norm == 'None': 335 | exp_string += 'nonorm' 336 | else: 337 | print('Norm setting not recognized.') 338 | 339 | resume_itr = 0 340 | model_file = None 341 | 342 | tf.global_variables_initializer().run() 343 | tf.train.start_queue_runners() 344 | 345 | print(exp_string) 346 | 347 | if FLAGS.resume or not FLAGS.train: 348 | if FLAGS.train == True: 349 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string) 350 | else: 351 | print(FLAGS.test_epoch) 352 | model_file = '{0}/{2}/model{1}'.format(FLAGS.logdir, FLAGS.test_epoch, exp_string) 353 | if model_file: 354 | ind1 = model_file.index('model') 355 | resume_itr = int(model_file[ind1 + 5:]) 356 | print("Restoring model weights from " + model_file) 357 | saver.restore(sess, model_file) 358 | 359 | if FLAGS.train: 360 | train(model, saver, sess, exp_string, data_generator, resume_itr) 361 | else: 362 | test(model, saver, sess, exp_string, data_generator, test_num_updates) 363 | 364 | 365 | if __name__ == "__main__": 366 | main() 367 | -------------------------------------------------------------------------------- /maml.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import sys 4 | 5 | import numpy as np 6 | import tensorflow as tf 7 | 8 | from image_embedding import ImageEmbedding 9 | from lstm_tree import TreeLSTM 10 | 11 | try: 12 | import special_grads 13 | except KeyError as e: 14 | print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e, 15 | file=sys.stderr) 16 | 17 | from tensorflow.python.platform import flags 18 | from utils import mse, xent, conv_block, normalize 19 | from task_embedding import LSTMAutoencoder, MeanAutoencoder 20 | 21 | FLAGS = flags.FLAGS 22 | 23 | 24 | class MAML: 25 | def __init__(self, sess, dim_input=1, dim_output=1, test_num_updates=5): 26 | """ must call construct_model() after initializing MAML! """ 27 | self.dim_input = dim_input 28 | self.dim_output = dim_output 29 | self.update_lr = FLAGS.update_lr 30 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ()) 31 | self.classification = False 32 | self.test_num_updates = test_num_updates 33 | self.sess = sess 34 | if FLAGS.task_embedding_type == 'rnn': 35 | self.lstmae = LSTMAutoencoder(hidden_num=FLAGS.hidden_dim) 36 | elif FLAGS.task_embedding_type == 'mean': 37 | self.lstmae = MeanAutoencoder(hidden_num=FLAGS.hidden_dim) 38 | self.tree = TreeLSTM(input_dim=FLAGS.hidden_dim, tree_hidden_dim=FLAGS.hidden_dim) 39 | if FLAGS.datasource in ['sinusoid', 'mixture']: 40 | self.dim_hidden = [40, 40] 41 | self.loss_func = mse 42 | self.forward = self.forward_fc 43 | self.construct_weights = self.construct_fc_weights 44 | elif FLAGS.datasource in ['omniglot', 'miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 45 | self.loss_func = xent 46 | self.classification = True 47 | if FLAGS.conv: 48 | self.dim_hidden = FLAGS.num_filters 49 | self.forward = self.forward_conv 50 | self.construct_weights = self.construct_conv_weights 51 | else: 52 | self.dim_hidden = [256, 128, 64, 64] 53 | self.forward = self.forward_fc 54 | self.construct_weights = self.construct_fc_weights 55 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 56 | self.channels = 3 57 | else: 58 | self.channels = 1 59 | self.img_size = int(np.sqrt(self.dim_input / self.channels)) 60 | self.image_embed = ImageEmbedding(hidden_num=FLAGS.task_embedding_num_filters, channels=self.channels, 61 | conv_initializer=tf.truncated_normal_initializer(stddev=0.04)) 62 | else: 63 | raise ValueError('Unrecognized data source.') 64 | 65 | def construct_model(self, input_tensors=None, prefix='metatrain_'): 66 | # a: training data for inner gradient, b: test data for meta gradient 67 | if input_tensors is None: 68 | if FLAGS.datasource in ['sinusoid', 'mixture']: 69 | self.inputa = tf.placeholder(tf.float32, shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size, 1)) 70 | self.inputb = tf.placeholder(tf.float32, 71 | shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size_eval, 1)) 72 | self.labela = tf.placeholder(tf.float32, shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size, 1)) 73 | self.labelb = tf.placeholder(tf.float32, 74 | shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size_eval, 1)) 75 | else: 76 | self.inputa = tf.placeholder(tf.float32) 77 | self.inputb = tf.placeholder(tf.float32) 78 | self.labela = tf.placeholder(tf.float32) 79 | self.labelb = tf.placeholder(tf.float32) 80 | else: 81 | self.inputa = input_tensors['inputa'] 82 | self.inputb = input_tensors['inputb'] 83 | self.labela = input_tensors['labela'] 84 | self.labelb = input_tensors['labelb'] 85 | # tf.summary.scalar('lr', self.update_lr) 86 | 87 | with tf.variable_scope('model', reuse=None) as training_scope: 88 | if 'weights' in dir(self): 89 | training_scope.reuse_variables() 90 | weights = self.weights 91 | else: 92 | # Define the weights 93 | self.weights = weights = self.construct_weights() 94 | 95 | # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates 96 | lossesa, outputas, lossesb, outputbs, emb_loss = [], [], [], [], [] 97 | accuraciesa, accuraciesb = [], [] 98 | num_updates = max(self.test_num_updates, FLAGS.num_updates) 99 | outputbs = [[]] * num_updates 100 | lossesb = [[]] * num_updates 101 | accuraciesb = [[]] * num_updates 102 | 103 | def task_metalearn(inp, reuse=True): 104 | """ Perform gradient descent for one task in the meta-batch. """ 105 | inputa, inputb, labela, labelb = inp 106 | if FLAGS.datasource in ['sinusoid', 'mixture']: 107 | input_task_emb = tf.concat((inputa, labela), axis=-1) 108 | elif FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']: 109 | if FLAGS.fix_embedding_sample != -1: 110 | input_task_emb = self.image_embed.model(tf.reshape(inputa[:FLAGS.fix_embedding_sample], 111 | [-1, self.img_size, self.img_size, 112 | self.channels])) 113 | one_hot_labela = tf.squeeze( 114 | tf.one_hot(tf.to_int32(labela[:FLAGS.fix_embedding_sample]), depth=1, axis=-1)) 115 | else: 116 | input_task_emb = self.image_embed.model(tf.reshape(inputa, 117 | [-1, self.img_size, self.img_size, 118 | self.channels])) 119 | one_hot_labela = tf.squeeze( 120 | tf.one_hot(tf.to_int32(labela), depth=1, axis=-1)) 121 | input_task_emb = tf.concat((input_task_emb, one_hot_labela), axis=-1) 122 | 123 | task_embed_vec, task_emb_loss = self.lstmae.model(input_task_emb) 124 | 125 | _, meta_knowledge_h = self.tree.model(task_embed_vec) 126 | 127 | task_enhanced_emb_vec = tf.concat([task_embed_vec, meta_knowledge_h], axis=1) 128 | 129 | with tf.variable_scope('task_specific_mapping', reuse=tf.AUTO_REUSE): 130 | eta = [] 131 | for key in weights.keys(): 132 | weight_size = np.prod(weights[key].get_shape().as_list()) 133 | eta.append(tf.reshape( 134 | tf.layers.dense(task_enhanced_emb_vec, weight_size, activation=tf.nn.sigmoid, 135 | name='eta_{}'.format(key)), tf.shape(weights[key]))) 136 | eta = dict(zip(weights.keys(), eta)) 137 | 138 | task_weights = dict(zip(weights.keys(), [weights[key] * eta[key] for key in weights.keys()])) 139 | 140 | task_outputbs, task_lossesb = [], [] 141 | 142 | if self.classification: 143 | task_accuraciesb = [] 144 | 145 | task_outputa = self.forward(inputa, task_weights, reuse=reuse) 146 | task_lossa = self.loss_func(task_outputa, labela) 147 | 148 | grads = tf.gradients(task_lossa, list(task_weights.values())) 149 | if FLAGS.stop_grad: 150 | grads = [tf.stop_gradient(grad) for grad in grads] 151 | gradients = dict(zip(task_weights.keys(), grads)) 152 | fast_weights = dict( 153 | zip(task_weights.keys(), 154 | [task_weights[key] - self.update_lr * gradients[key] for key in task_weights.keys()])) 155 | output = self.forward(inputb, fast_weights, reuse=True) 156 | task_outputbs.append(output) 157 | task_lossesb.append(self.loss_func(output, labelb)) 158 | for j in range(num_updates - 1): 159 | loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela) 160 | grads = tf.gradients(loss, list(fast_weights.values())) 161 | if FLAGS.stop_grad: 162 | grads = [tf.stop_gradient(grad) for grad in grads] 163 | gradients = dict(zip(fast_weights.keys(), grads)) 164 | fast_weights = dict(zip(fast_weights.keys(), 165 | [fast_weights[key] - self.update_lr * gradients[key] for key in 166 | fast_weights.keys()])) 167 | output = self.forward(inputb, fast_weights, reuse=True) 168 | task_outputbs.append(output) 169 | task_lossesb.append(self.loss_func(output, labelb)) 170 | 171 | task_output = [task_emb_loss, task_outputa, task_outputbs, task_lossa, task_lossesb] 172 | 173 | if self.classification: 174 | task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1), 175 | tf.argmax(labela, 1)) 176 | for j in range(num_updates): 177 | task_accuraciesb.append( 178 | tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1), 179 | tf.argmax(labelb, 1))) 180 | task_output.extend([task_accuracya, task_accuraciesb]) 181 | 182 | return task_output 183 | 184 | if FLAGS.norm != 'None': 185 | # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice. 186 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False) 187 | 188 | out_dtype = [tf.float32, tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates] 189 | if self.classification: 190 | out_dtype.extend([tf.float32, [tf.float32] * num_updates]) 191 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb), 192 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size) 193 | if self.classification: 194 | emb_loss, outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result 195 | else: 196 | emb_loss, outputas, outputbs, lossesa, lossesb = result 197 | 198 | ## Performance & Optimization 199 | if 'train' in prefix: 200 | self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size) 201 | self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j 202 | in range(num_updates)] 203 | self.total_embed_loss = total_embed_loss = tf.reduce_sum(emb_loss) / tf.to_float(FLAGS.meta_batch_size) 204 | # after the map_fn 205 | self.outputas, self.outputbs = outputas, outputbs 206 | if self.classification: 207 | self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size) 208 | self.total_accuracies2 = total_accuracies2 = [ 209 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 210 | self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1) 211 | 212 | if FLAGS.metatrain_iterations > 0: 213 | optimizer = tf.train.AdamOptimizer(self.meta_lr) 214 | self.gvs = gvs = optimizer.compute_gradients( 215 | self.total_losses2[FLAGS.num_updates - 1] + FLAGS.emb_loss_weight * self.total_embed_loss) 216 | if FLAGS.task_embedding_type == 'mean': 217 | gvs = [(tf.clip_by_value(grad, -3, 3), var) for grad, var in gvs] 218 | if FLAGS.datasource == 'miniimagenet': 219 | gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs] 220 | self.metatrain_op = optimizer.apply_gradients(gvs) 221 | else: 222 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size) 223 | self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) 224 | for j in range(num_updates)] 225 | if self.classification: 226 | self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float( 227 | FLAGS.meta_batch_size) 228 | self.metaval_total_accuracies2 = total_accuracies2 = [ 229 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)] 230 | 231 | ## Summaries 232 | tf.summary.scalar(prefix + 'Pre-update loss', total_loss1) 233 | if self.classification: 234 | tf.summary.scalar(prefix + 'Pre-update accuracy', total_accuracy1) 235 | 236 | for j in range(num_updates): 237 | tf.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), total_losses2[j]) 238 | if self.classification: 239 | tf.summary.scalar(prefix + 'Post-update accuracy, step ' + str(j + 1), total_accuracies2[j]) 240 | 241 | ### Network construction functions (fc networks and conv networks) 242 | def construct_fc_weights(self): 243 | weights = {} 244 | weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01)) 245 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]])) 246 | for i in range(1, len(self.dim_hidden)): 247 | weights['w' + str(i + 1)] = tf.Variable( 248 | tf.truncated_normal([self.dim_hidden[i - 1], self.dim_hidden[i]], stddev=0.01)) 249 | weights['b' + str(i + 1)] = tf.Variable(tf.zeros([self.dim_hidden[i]])) 250 | weights['w' + str(len(self.dim_hidden) + 1)] = tf.Variable( 251 | tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01)) 252 | weights['b' + str(len(self.dim_hidden) + 1)] = tf.Variable(tf.zeros([self.dim_output])) 253 | return weights 254 | 255 | def forward_fc(self, inp, weights, reuse=False): 256 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0') 257 | for i in range(1, len(self.dim_hidden)): 258 | hidden = normalize(tf.matmul(hidden, weights['w' + str(i + 1)]) + weights['b' + str(i + 1)], 259 | activation=tf.nn.relu, reuse=reuse, scope=str(i + 1)) 260 | return tf.matmul(hidden, weights['w' + str(len(self.dim_hidden) + 1)]) + weights[ 261 | 'b' + str(len(self.dim_hidden) + 1)] 262 | 263 | def construct_conv_weights(self): 264 | weights = {} 265 | 266 | dtype = tf.float32 267 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 268 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 269 | k = 3 270 | 271 | weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden], 272 | initializer=conv_initializer, dtype=dtype) 273 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden])) 274 | weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden], 275 | initializer=conv_initializer, dtype=dtype) 276 | weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden])) 277 | weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden], 278 | initializer=conv_initializer, dtype=dtype) 279 | weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden])) 280 | weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden], 281 | initializer=conv_initializer, dtype=dtype) 282 | weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden])) 283 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 284 | # assumes max pooling 285 | weights['w5'] = tf.get_variable('w5', [self.dim_hidden * 5 * 5, self.dim_output], 286 | initializer=fc_initializer) 287 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') 288 | else: 289 | weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5') 290 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5') 291 | return weights 292 | 293 | def forward_conv(self, inp, weights, reuse=False, scope=''): 294 | channels = self.channels 295 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels]) 296 | 297 | hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope + '0') 298 | hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope + '1') 299 | hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope + '2') 300 | hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope + '3') 301 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']: 302 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])]) 303 | else: 304 | hidden4 = tf.reduce_mean(hidden4, [1, 2]) 305 | 306 | return tf.matmul(hidden4, weights['w5']) + weights['b5'] 307 | -------------------------------------------------------------------------------- /multidataset_bash/HSML_multidataset_1shot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --datasource=multidataset --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=1 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_1shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 3 | 4 | python main.py --datasource=multidataset --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=1 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_1shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 --test_set=True --test_epoch=59000 --train=False --test_dataset=1 5 | -------------------------------------------------------------------------------- /multidataset_bash/HSML_multidataset_5shot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --datasource=multidataset --metatrain_iterations=50000 --meta_batch_size=4 --update_batch_size=5 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_5shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 --fix_embedding_sample=10 3 | 4 | python main.py --datasource=multidataset --metatrain_iterations=50000 --meta_batch_size=4 --update_batch_size=5 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_5shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 --test_set=True --test_epoch=49000 --train=False --test_dataset=1 5 | -------------------------------------------------------------------------------- /special_grads.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.framework import ops 2 | from tensorflow.python.ops import array_ops 3 | from tensorflow.python.ops import gen_nn_ops 4 | 5 | @ops.RegisterGradient("MaxPoolGrad") 6 | def _MaxPoolGradGrad(op, grad): 7 | gradient = gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0], 8 | grad, op.get_attr("ksize"), op.get_attr("strides"), 9 | padding=op.get_attr("padding"), data_format=op.get_attr("data_format")) 10 | gradgrad1 = array_ops.zeros(shape = array_ops.shape(op.inputs[1]), dtype=gradient.dtype) 11 | gradgrad2 = array_ops.zeros(shape = array_ops.shape(op.inputs[2]), dtype=gradient.dtype) 12 | return (gradient, gradgrad1, gradgrad2) 13 | -------------------------------------------------------------------------------- /task_embedding.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.ops.rnn_cell import GRUCell 3 | 4 | from tensorflow.python.platform import flags 5 | 6 | 7 | FLAGS = flags.FLAGS 8 | import ipdb 9 | 10 | 11 | class LSTMAutoencoder(object): 12 | def __init__(self, hidden_num, cell=None, reverse=True, decode_without_input=False): 13 | if cell is None: 14 | self._enc_cell = GRUCell(hidden_num, name='encoder_cell') 15 | self._dec_cell = GRUCell(hidden_num, name='decoder_cell') 16 | else: 17 | self._enc_cell = cell 18 | self._dec_cell = cell 19 | self.reverse = reverse 20 | self.decode_without_input = decode_without_input 21 | self.hidden_num = hidden_num 22 | 23 | if FLAGS.datasource in ['sinusoid', 'mixture']: 24 | self.elem_num_init = 2 25 | self.elem_num=20 26 | 27 | elif FLAGS.datasource in ['miniimagenet', 'omniglot','multidataset', 'multidataset_leave_one_out']: 28 | self.elem_num = FLAGS.num_classes + 64 29 | 30 | self.dec_weight = tf.Variable(tf.truncated_normal([self.hidden_num, 31 | self.elem_num], dtype=tf.float32), name='dec_weight') 32 | self.dec_bias = tf.Variable(tf.constant(0.1, shape=[self.elem_num], 33 | dtype=tf.float32), name='dec_bias') 34 | 35 | def model(self, inputs): 36 | 37 | if FLAGS.datasource in ['sinusoid', 'mixture']: 38 | with tf.variable_scope('first_embedding_sync', reuse=tf.AUTO_REUSE): 39 | inputs = tf.layers.dense(inputs, units=self.elem_num, name='first_embedding_sync_dense') 40 | 41 | inputs = tf.expand_dims(inputs, 0) 42 | 43 | inputs = tf.unstack(inputs, axis=1) 44 | 45 | self.batch_num = FLAGS.meta_batch_size 46 | 47 | with tf.variable_scope('encoder'): 48 | (self.z_codes, self.enc_state) = tf.contrib.rnn.static_rnn(self._enc_cell, inputs, dtype=tf.float32) 49 | 50 | with tf.variable_scope('decoder') as vs: 51 | 52 | if self.decode_without_input: 53 | dec_inputs = [tf.zeros(tf.shape(inputs[0]), dtype=tf.float32) for _ in range(len(inputs))] 54 | (dec_outputs, dec_state) = tf.contrib.rnn.static_rnn(self._dec_cell, dec_inputs, 55 | initial_state=self.enc_state, 56 | dtype=tf.float32) 57 | if self.reverse: 58 | dec_outputs = dec_outputs[::-1] 59 | dec_output_ = tf.transpose(tf.stack(dec_outputs), [1, 0, 2]) 60 | dec_weight_ = tf.tile(tf.expand_dims(self.dec_weight, 0), [self.batch_num, 1, 1]) 61 | self.output_ = tf.matmul(dec_weight_, dec_output_) + self.dec_bias 62 | else: 63 | dec_state = self.enc_state 64 | dec_input_ = tf.zeros(tf.shape(inputs[0]), 65 | dtype=tf.float32) 66 | 67 | dec_outputs = [] 68 | for step in range(len(inputs)): 69 | if step > 0: 70 | vs.reuse_variables() 71 | (dec_input_, dec_state) = \ 72 | self._dec_cell(dec_input_, dec_state) 73 | dec_input_ = tf.matmul(dec_input_, self.dec_weight) + self.dec_bias 74 | dec_outputs.append(dec_input_) 75 | if self.reverse: 76 | dec_outputs = dec_outputs[::-1] 77 | self.output_ = tf.transpose(tf.stack(dec_outputs), [1, 0, 2]) 78 | 79 | self.input_ = tf.transpose(tf.stack(inputs), [1, 0, 2]) 80 | self.loss = tf.reduce_mean(tf.square(self.input_ - self.output_)) 81 | self.emb_all = tf.reduce_mean(self.z_codes, axis=0) 82 | 83 | return self.emb_all, self.loss 84 | 85 | class MeanAutoencoder(object): 86 | def __init__(self, hidden_num): 87 | self.hidden_num = hidden_num 88 | 89 | if FLAGS.datasource in ['sinusoid', 'mixture']: 90 | self.elem_num = 2 91 | self.hidden_num_mid = 20 92 | elif FLAGS.datasource in ['miniimagenet', 'omniglot','multidataset', 'multidataset_leave_one_out']: 93 | self.elem_num = FLAGS.num_classes + 64 94 | self.hidden_num_mid = 96 95 | 96 | def model(self, inputs): 97 | with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE): 98 | enc_dense1 = tf.layers.dense(inputs, units=self.hidden_num_mid, activation=tf.nn.relu, name='encoder_dense1') 99 | enc_dense2 = tf.layers.dense(enc_dense1, units=self.hidden_num, activation=tf.nn.relu, name='encoder_dense2') 100 | 101 | with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE): 102 | dec_dense1= tf.layers.dense(enc_dense2, units=self.hidden_num_mid, activation=tf.nn.relu, name='decoder_dense1') 103 | dec_dense2 = tf.layers.dense(dec_dense1, units=self.elem_num, activation=None, 104 | name='decoder_dense2') 105 | emb_pool = tf.reduce_mean(enc_dense2, axis=0, keepdims=True) 106 | with tf.variable_scope('last_fc', reuse=tf.AUTO_REUSE): 107 | self.emb_all = tf.layers.dense(emb_pool, units=self.hidden_num, activation=tf.nn.relu, name='mean_pool') 108 | self.loss = 0.5*tf.reduce_mean(tf.square(inputs-dec_dense2)) 109 | 110 | return self.emb_all, self.loss 111 | 112 | 113 | 114 | 115 | 116 | 117 | -------------------------------------------------------------------------------- /toygroup_bash/HSML_toygroup_10shot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=10 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_10shot --emb_loss_weight=0.01 --hidden_dim=40 3 | 4 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=10 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_10shot --emb_loss_weight=0.01 --hidden_dim=40 --test_set=True --test_epoch=69000 --train=False --num_test_task=4000 5 | -------------------------------------------------------------------------------- /toygroup_bash/HSML_toygroup_5shot.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=5 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_5shot --emb_loss_weight=0.01 --hidden_dim=40 3 | 4 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=5 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_5shot --emb_loss_weight=0.01 --hidden_dim=40 --test_set=True --test_epoch=69000 --train=False --num_test_task=4000 -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | """ Utility functions. """ 2 | import numpy as np 3 | import os 4 | import random 5 | import tensorflow as tf 6 | 7 | from tensorflow.contrib.layers.python import layers as tf_layers 8 | from tensorflow.python.platform import flags 9 | 10 | FLAGS = flags.FLAGS 11 | 12 | ## Image helper 13 | def get_images(paths, labels, nb_samples=None, shuffle=True): 14 | if nb_samples is not None: 15 | sampler = lambda x: random.sample(x, nb_samples) 16 | else: 17 | sampler = lambda x: x 18 | images = [(i, os.path.join(path, image)) \ 19 | for i, path in zip(labels, paths) \ 20 | for image in sampler(os.listdir(path))] 21 | if shuffle: 22 | random.shuffle(images) 23 | return images 24 | 25 | ## Network helpers 26 | def conv_block(inp, cweight, bweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', residual=False): 27 | """ Perform, conv, batch norm, nonlinearity, and max pool """ 28 | stride, no_stride = [1,2,2,1], [1,1,1,1] 29 | 30 | if FLAGS.max_pool: 31 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight 32 | else: 33 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight 34 | normed = normalize(conv_output, activation, reuse, scope) 35 | if FLAGS.max_pool: 36 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad) 37 | return normed 38 | 39 | def normalize(inp, activation, reuse, scope): 40 | if FLAGS.norm == 'batch_norm': 41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 42 | elif FLAGS.norm == 'layer_norm': 43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope) 44 | elif FLAGS.norm == 'None': 45 | if activation is not None: 46 | return activation(inp) 47 | else: 48 | return inp 49 | 50 | ## Loss functions 51 | def mse(pred, label): 52 | pred = tf.reshape(pred, [-1]) 53 | label = tf.reshape(label, [-1]) 54 | return tf.reduce_mean(tf.square(pred-label)) 55 | 56 | def xent(pred, label): 57 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size 58 | --------------------------------------------------------------------------------