├── README.md
├── data
├── DTD_Texture
│ └── proc_images.py
├── cub_bird
│ └── proc_images.py
├── fgvc_aircraft
│ └── proc_images.py
├── fgvcx_fungi
│ └── proc_images.py
├── miniImagenet
│ └── proc_images.py
└── omniglot_resized
│ └── resize_images.py
├── data_generator.py
├── image_embedding.py
├── lstm_tree.py
├── main.py
├── maml.py
├── multidataset_bash
├── HSML_multidataset_1shot.sh
└── HSML_multidataset_5shot.sh
├── special_grads.py
├── task_embedding.py
├── toygroup_bash
├── HSML_toygroup_10shot.sh
└── HSML_toygroup_5shot.sh
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # HSML (Hierarchically Structured Meta-learning)
2 |
3 | ## About
4 | Source code1 of the paper [Hierarchically Structured Meta-learning](https://arxiv.org/abs/1905.05301)
5 |
6 | For continual version of this algorithm, please refer to this [repo](https://github.com/huaxiuyao/HSML_Dynamic).
7 |
8 | If you find this repository useful in your research, please cite the following paper:
9 | ```
10 | @inproceedings{yao2019hierarchically,
11 | title={Hierarchically Structured Meta-learning},
12 | author={Yao, Huaxiu and Wei, Ying and Huang, Junzhou and Li, Zhenhui},
13 | booktitle={Proceedings of the 36th International Conference on Machine Learning},
14 | year={2019}
15 | }
16 | ```
17 |
18 | ## Data
19 | We release our Multi-Datasets including bird, texture, aircraft and fungi in this [link](https://drive.google.com/file/d/1IJk93N48X0rSL69nQ1Wr-49o8u0e75HM/view?usp=sharing).
20 |
21 | ## Usage
22 |
23 | ### Dependence
24 | * python 3.*
25 | * TensorFlow 1.0+
26 | * Numpy 1.15+
27 |
28 | ### Toy Group Data
29 | Please see the bash file in /toygroup_bash for parameter settings
30 |
31 | ### Multi-datasets Data
32 | Please see the bash file in /multidataset_bash for parameter settings
33 |
34 |
35 | 1This code is built based on the [MAML](https://github.com/cbfinn/maml).
36 |
37 |
--------------------------------------------------------------------------------
/data/DTD_Texture/proc_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage instructions:
3 | First download the omniglot dataset
4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder)
5 |
6 | Then, run the following:
7 | cd data/
8 | cp -r omniglot/* omniglot_resized/
9 | cd omniglot_resized/
10 | python resize_images.py
11 | """
12 | from PIL import Image
13 | import glob
14 | import os
15 | import numpy as np
16 | import random
17 | import shutil
18 |
19 | np.random.seed(0)
20 | random.seed(1)
21 |
22 | def Process():
23 | image_path = '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/dtd/images/*/'
24 |
25 | all_images = glob.glob(image_path + '*')
26 |
27 | i = 0
28 |
29 | for image_file in all_images:
30 | im = Image.open(image_file)
31 | im = im.resize((84,84), resample=Image.LANCZOS)
32 | im.save(image_file)
33 | i += 1
34 |
35 | if i % 200 == 0:
36 | print(i)
37 |
38 | def select_image():
39 | path = '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/images/'
40 | dirlist = os.listdir(path)
41 | num_images = []
42 | for eachdir in dirlist:
43 | num_images.append([eachdir, len(os.listdir(path + eachdir))])
44 | all_folder_id = random.sample(range(len(num_images)), 47)
45 | all_folder = [num_images[id] for id in all_folder_id]
46 | random.shuffle(all_folder)
47 | for i in range(30):
48 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/train/')
49 | for i in range(30, 37):
50 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/val/')
51 | for i in range(37, 47):
52 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/DTD_Texture/test/')
53 | # num_images = sorted(num_images, key=lambda x: x[1], reverse=True)
54 |
55 | if __name__=='__main__':
56 | select_image()
--------------------------------------------------------------------------------
/data/cub_bird/proc_images.py:
--------------------------------------------------------------------------------
1 | """The Caltech-UCSD bird dataset
2 | """
3 |
4 | import numpy as np
5 | import os
6 | from scipy import misc
7 | from skimage import io
8 | import ipdb
9 | import shutil
10 | import random
11 |
12 | np.random.seed(0)
13 | random.seed(1)
14 |
15 | class CUBDataLayer():
16 | """ The Caltech-UCSD bird dataset
17 | """
18 | def __init__(self, **kwargs):
19 | """Load the dataset.
20 | kwargs:
21 | root: the root folder of the CUB_200_2011 dataset.
22 | is_training: if true, load the training data. Otherwise, load the
23 | testing data.
24 | crop: if None, does not crop the bounding box. If a real value,
25 | crop is the ratio of the bounding box that gets cropped.
26 | e.g., if crop = 1.5, the resulting image will be 1.5 * the
27 | bounding box area.
28 | target_size: all images are resized to the size specified. Should
29 | be a tuple of two integers, like [256, 256].
30 | version: either '2011' or '2010'.
31 | Note that we will use the python indexing (labels start from 0).
32 | """
33 | root = '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/CUB_200_2011/'
34 |
35 | crop = True
36 | target_size = [84,84]
37 | images = [line.split()[1] for line in
38 | open(os.path.join(root, 'images.txt'), 'r')]
39 | boxes = [line.split()[1:] for line in
40 | open(os.path.join(root, 'bounding_boxes.txt'),'r')]
41 |
42 |
43 | # for the boxes, we store them as a numpy array
44 | boxes = np.array(boxes, dtype=np.float32)
45 | boxes -= 1
46 | # load the data
47 | self._load_data(root, images, boxes, crop, target_size)
48 |
49 | def _load_data(self, root, images, boxes, crop, target_size):
50 | num_imgs = len(images)
51 |
52 | for i in range(num_imgs):
53 | image = io.imread(os.path.join(root, 'images', images[i]))
54 | if image.ndim == 2:
55 | image = np.tile(image[:,:,np.newaxis], (1, 1, 3))
56 | if image.shape[2] == 4:
57 | image = image[:, :, :3]
58 | if crop:
59 | image = self._crop_image(image, crop, boxes[i])
60 | data_img = misc.imresize(image, target_size)
61 | misc.imsave(os.path.join(root, 'images', images[i]), data_img)
62 |
63 | if i%500==0:
64 | print(i)
65 |
66 | return
67 |
68 | def _crop_image(self, image, crop, box):
69 | imheight, imwidth = image.shape[:2]
70 | x, y, width, height = box
71 | centerx = x + width / 2.
72 | centery = y + height / 2.
73 | xoffset = width * crop / 2.
74 | yoffset = height * crop / 2.
75 | xmin = max(int(centerx - xoffset + 0.5), 0)
76 | ymin = max(int(centery - yoffset + 0.5), 0)
77 | xmax = min(int(centerx + xoffset + 0.5), imwidth - 1)
78 | ymax = min(int(centery + yoffset + 0.5), imheight - 1)
79 | if xmax - xmin <= 0 or ymax - ymin <= 0:
80 | raise ValueError("The cropped bounding box has size 0.")
81 | return image[ymin:ymax, xmin:xmax]
82 |
83 | def select_image():
84 | path='/home/huaxiuyao/Data/meta-dataset/CUB_Bird/images/'
85 | dirlist=os.listdir(path)
86 | num_images=[]
87 | for eachdir in dirlist:
88 | tmp=os.listdir(path+eachdir)
89 | for each in tmp:
90 | if each[0]=='.':
91 | print(eachdir, each)
92 | if len(os.listdir(path+eachdir))==60:
93 | num_images.append([eachdir, len(os.listdir(path+eachdir))])
94 | all_folder_id=random.sample(range(len(num_images)), 100)
95 | all_folder=[num_images[id] for id in all_folder_id]
96 | random.shuffle(all_folder)
97 | for i in range(64):
98 | shutil.move(path+all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/train/')
99 | for i in range(64,80):
100 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/val/')
101 | for i in range(80,100):
102 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/CUB_Bird/test/')
103 | # num_images=sorted(num_images, key=lambda x:x[1], reverse=True)
104 |
105 | if __name__=='__main__':
106 | select_image()
--------------------------------------------------------------------------------
/data/fgvc_aircraft/proc_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage instructions:
3 | First download the omniglot dataset
4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder)
5 |
6 | Then, run the following:
7 | cd data/
8 | cp -r omniglot/* omniglot_resized/
9 | cd omniglot_resized/
10 | python resize_images.py
11 | """
12 | from PIL import Image
13 | import glob
14 | import os
15 | import numpy as np
16 | import scipy.io as scio
17 | import os
18 | from scipy import misc
19 | from skimage import io
20 | import ipdb
21 | import shutil
22 | import random
23 |
24 | np.random.seed(1)
25 | random.seed(2)
26 |
27 | class FGVC_Aircraft():
28 | """ The Caltech-UCSD bird dataset
29 | """
30 | def __init__(self, **kwargs):
31 | """Load the dataset.
32 | kwargs:
33 | root: the root folder of the CUB_200_2011 dataset.
34 | is_training: if true, load the training data. Otherwise, load the
35 | testing data.
36 | crop: if None, does not crop the bounding box. If a real value,
37 | crop is the ratio of the bounding box that gets cropped.
38 | e.g., if crop = 1.5, the resulting image will be 1.5 * the
39 | bounding box area.
40 | target_size: all images are resized to the size specified. Should
41 | be a tuple of two integers, like [256, 256].
42 | version: either '2011' or '2010'.
43 | Note that we will use the python indexing (labels start from 0).
44 | """
45 | root = '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/'
46 |
47 | crop = True
48 | target_size = [84,84]
49 | images = [imageid.split('.')[0] for imageid in os.listdir(root+'images')]
50 | boxes = {line.split()[0]:line.split()[1:] for line in
51 | open(os.path.join(root, 'images_box.txt'),'r')}
52 |
53 |
54 | # for the boxes, we store them as a numpy array
55 | for eachkey in boxes:
56 | boxes[eachkey] = np.array(boxes[eachkey], dtype=np.float32) - 1
57 | # load the data
58 | self._load_data(root, images, boxes, crop, target_size)
59 |
60 | def _load_data(self, root, images, boxes, crop, target_size):
61 | num_imgs = len(images)
62 |
63 | for i in range(num_imgs):
64 | image = io.imread(os.path.join(root, 'images', '{}.jpg'.format(images[i])))
65 | if image.ndim == 2:
66 | image = np.tile(image[:,:,np.newaxis], (1, 1, 3))
67 | if image.shape[2] == 4:
68 | image = image[:, :, :3]
69 | if crop:
70 | image = self._crop_image(image, crop, boxes[images[i]])
71 | data_img = misc.imresize(image, target_size)
72 | misc.imsave(os.path.join(root, 'images', '{}.jpg'.format(images[i])), data_img)
73 |
74 | if i%500==0:
75 | print(i)
76 |
77 | return
78 |
79 | def _crop_image(self, image, crop, box):
80 | imheight, imwidth = image.shape[:2]
81 | x, y, width, height = box
82 | centerx = x + width / 2.
83 | centery = y + height / 2.
84 | xoffset = width * crop / 2.
85 | yoffset = height * crop / 2.
86 | xmin = max(int(centerx - xoffset + 0.5), 0)
87 | ymin = max(int(centery - yoffset + 0.5), 0)
88 | xmax = min(int(centerx + xoffset + 0.5), imwidth - 1)
89 | ymax = min(int(centery + yoffset + 0.5), imheight - 1)
90 | if xmax - xmin <= 0 or ymax - ymin <= 0:
91 | raise ValueError("The cropped bounding box has size 0.")
92 | return image[ymin:ymax, xmin:xmax]
93 |
94 | def reorganize():
95 | root='/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/'
96 | label=[line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_train.txt'),'r')]
97 | label.extend([line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_trainval.txt'),'r')])
98 | label.extend([line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_val.txt'), 'r')])
99 | label.extend([line.strip().split(' ') for line in open(os.path.join(root, 'images_variant_test.txt'), 'r')])
100 | labelall={}
101 | for eachitem in label:
102 | if eachitem[0] in labelall:
103 | continue
104 | labelall[eachitem[0]]='-'.join(eachitem[1:])
105 | newpath = '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/organized_images/'
106 | for eachfile in os.listdir('/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/images/'):
107 | tmp_id = eachfile.split('.')[0]
108 | folder_id = labelall[tmp_id]
109 | print(folder_id)
110 | if folder_id == 'F-16A/B':
111 | folder_id='F-16A-B'
112 | if folder_id == 'F/A-18':
113 | folder_id='F-A-18'
114 | if not os.path.isdir(newpath + '{}'.format(folder_id)):
115 | os.mkdir(newpath + '{}'.format(folder_id))
116 |
117 | image_file = '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/data/images/' + eachfile
118 | im = Image.open(image_file)
119 | im.save(newpath + '{}'.format(folder_id) + '/' + eachfile)
120 |
121 | def select_image():
122 | path='/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/images/'
123 | dirlist=os.listdir(path)
124 | num_images=[]
125 | for eachdir in dirlist:
126 | num_images.append([eachdir, len(os.listdir(path+eachdir))])
127 | all_folder_id=random.sample(range(len(num_images)), 100)
128 | all_folder=[num_images[id] for id in all_folder_id]
129 | random.shuffle(all_folder)
130 | for i in range(64):
131 | shutil.move(path+all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/train/')
132 | for i in range(64,80):
133 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/val/')
134 | for i in range(80,100):
135 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVC_Aircraft/test/')
136 | # num_images=sorted(num_images, key=lambda x:x[1], reverse=True)
137 |
138 |
139 | if __name__=='__main__':
140 | select_image()
--------------------------------------------------------------------------------
/data/fgvcx_fungi/proc_images.py:
--------------------------------------------------------------------------------
1 | from PIL import Image
2 | import glob
3 | import os
4 | import shutil
5 | import random
6 | import numpy as np
7 | import ipdb
8 |
9 | np.random.seed(1)
10 | random.seed(2)
11 |
12 | image_path = '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/images/*/'
13 |
14 |
15 | def process():
16 | all_images = glob.glob(image_path + '*')
17 |
18 | i = 0
19 |
20 | for image_file in all_images:
21 | im = Image.open(image_file)
22 | im = im.resize((84, 84), resample=Image.LANCZOS)
23 | im.save(image_file)
24 | i += 1
25 |
26 | if i % 200 == 0:
27 | print(i)
28 |
29 |
30 | def select_folder():
31 | path = '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/images/'
32 | dirlist = os.listdir(path)
33 | num_images = []
34 | for eachdir in dirlist:
35 | if len(os.listdir(path + eachdir)) >= 150:
36 | num_images.append([eachdir, len(os.listdir(path + eachdir))])
37 | all_folder_id = random.sample(range(len(num_images)), 100)
38 | all_folder = [num_images[id] for id in all_folder_id]
39 | random.shuffle(all_folder)
40 | for i in range(64):
41 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/train/')
42 | for i in range(64, 80):
43 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/val/')
44 | for i in range(80, 100):
45 | shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/test/')
46 | # num_images = sorted(num_images, key=lambda x: x[1], reverse=True)
47 | # print(len(num_images))
48 |
49 |
50 | def select_image():
51 | folder = ['train', 'test', 'val']
52 | for eachfolder in folder:
53 | all_files = os.listdir('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}/'.format(eachfolder))
54 | for eachtype in all_files:
55 | images = os.listdir('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}/{}/'.format(eachfolder, eachtype))
56 | random.shuffle(images)
57 | images_id = random.sample(range(len(images)), 150)
58 | new_images = [images[idx] for idx in images_id]
59 | os.mkdir('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}_new/{}/'.format(eachfolder, eachtype))
60 | for idx_y in range(len(new_images)):
61 | shutil.move('/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}/{}/{}'.format(eachfolder, eachtype,
62 | new_images[idx_y]),
63 | '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/{}_new/{}/'.format(eachfolder, eachtype))
64 | # path = '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/images/'
65 | # dirlist = os.listdir(path)
66 | # num_images = []
67 | # for eachdir in dirlist:
68 | # if len(os.listdir(path + eachdir)) >= 150:
69 | # num_images.append([eachdir, len(os.listdir(path + eachdir))])
70 | # all_folder_id = random.sample(range(len(num_images)), 100)
71 | # all_folder = [num_images[id] for id in all_folder_id]
72 | # random.shuffle(all_folder)
73 | # for i in range(64):
74 | # shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/train/')
75 | # for i in range(64, 80):
76 | # shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/val/')
77 | # for i in range(80, 100):
78 | # shutil.move(path + all_folder[i][0], '/home/huaxiuyao/Data/meta-dataset/FGVCx_Fungi/test/')
79 | # num_images = sorted(num_images, key=lambda x: x[1], reverse=True)
80 | # print(len(num_images))
81 |
82 |
83 | if __name__ == '__main__':
84 | select_folder()
85 |
--------------------------------------------------------------------------------
/data/miniImagenet/proc_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Script for converting from csv file datafiles to a directory for each image (which is how it is loaded by MAML code)
3 |
4 | Acquire miniImagenet from Ravi & Larochelle '17, along with the train, val, and test csv files. Put the
5 | csv files in the miniImagenet directory and put the images in the directory 'miniImagenet/images/'.
6 | Then run this script from the miniImagenet directory:
7 | cd data/miniImagenet/
8 | python proc_images.py
9 | """
10 |
11 | from __future__ import print_function
12 | import csv
13 | import glob
14 | import os
15 |
16 | from PIL import Image
17 |
18 | path_to_images = '/home/huaxiuyao/Data/miniimagenet/images/'
19 |
20 | all_images = glob.glob(path_to_images + '*')
21 |
22 | # Resize images
23 | for i, image_file in enumerate(all_images):
24 | im = Image.open(image_file)
25 | im = im.resize((84, 84), resample=Image.LANCZOS)
26 | im.save(image_file)
27 | if i % 500 == 0:
28 | print(i)
29 |
30 | # Put in correct directory
31 | for datatype in ['train', 'val', 'test']:
32 | os.system('mkdir ' +'/home/huaxiuyao/Data/miniimagenet/'+datatype)
33 |
34 | with open('/home/huaxiuyao/Data/miniimagenet/'+datatype + '.csv', 'r') as f:
35 | reader = csv.reader(f, delimiter=',')
36 | last_label = ''
37 | for i, row in enumerate(reader):
38 | if i == 0: # skip the headers
39 | continue
40 | label = row[1]
41 | image_name = row[0]
42 | if label != last_label:
43 | cur_dir = '/home/huaxiuyao/Data/miniimagenet/' + datatype + '/' + label + '/'
44 | os.system('mkdir ' + cur_dir)
45 | last_label = label
46 | os.system('mv /home/huaxiuyao/Data/miniimagenet/images/' + image_name + ' ' + cur_dir)
47 |
--------------------------------------------------------------------------------
/data/omniglot_resized/resize_images.py:
--------------------------------------------------------------------------------
1 | """
2 | Usage instructions:
3 | First download the omniglot dataset
4 | and put the contents of both images_background and images_evaluation in data/omniglot/ (without the root folder)
5 |
6 | Then, run the following:
7 | cd data/
8 | cp -r omniglot/* omniglot_resized/
9 | cd omniglot_resized/
10 | python resize_images.py
11 | """
12 | from PIL import Image
13 | import glob
14 |
15 | image_path = '*/*/'
16 |
17 | all_images = glob.glob(image_path + '*')
18 |
19 | i = 0
20 |
21 | for image_file in all_images:
22 | im = Image.open(image_file)
23 | im = im.resize((28,28), resample=Image.LANCZOS)
24 | im.save(image_file)
25 | i += 1
26 |
27 | if i % 200 == 0:
28 | print(i)
29 |
30 |
--------------------------------------------------------------------------------
/data_generator.py:
--------------------------------------------------------------------------------
1 | """ Code for loading data. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 | import ipdb
7 |
8 | from tensorflow.python.platform import flags
9 | from utils import get_images
10 | import matplotlib.pyplot as plt
11 |
12 | FLAGS = flags.FLAGS
13 |
14 |
15 | class DataGenerator(object):
16 | def __init__(self, num_samples_per_class, batch_size, config={}):
17 | """
18 | Args:
19 | num_samples_per_class: num samples to generate per class in one batch
20 | batch_size: size of meta batch size (e.g. number of functions)
21 | """
22 | self.batch_size = batch_size
23 | self.num_samples_per_class = num_samples_per_class
24 | self.num_classes = 1 # by default 1 (only relevant for classification problems)
25 |
26 | if FLAGS.datasource == 'sinusoid':
27 | self.generate = self.generate_sinusoid_batch
28 | # self.amp_range = config.get('amp_range', [0.1, 5.0])
29 | # self.phase_range = config.get('phase_range', [0, np.pi])
30 | self.amp_range = config.get('amp_range', [0.1, 5.0])
31 | self.freq_range = config.get('freq_range', [0.8, 1.2])
32 | self.phase_range = config.get('phase_range', [0, np.pi])
33 | self.input_range = config.get('input_range', [-5.0, 5.0])
34 | self.dim_input = 1
35 | self.dim_output = 1
36 |
37 | elif FLAGS.datasource == 'mixture':
38 | self.generate = self.generate_mixture_batch
39 | self.dim_input = 1
40 | self.dim_output = 1
41 | self.input_range = config.get('input_range', [-5.0, 5.0])
42 |
43 | elif 'omniglot' in FLAGS.datasource:
44 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
45 | self.img_size = config.get('img_size', (28, 28))
46 | self.dim_input = np.prod(self.img_size)
47 | self.dim_output = self.num_classes
48 | # data that is pre-resized using PIL with lanczos filter
49 | data_folder = config.get('data_folder', '{}/omniglot_resized'.format(FLAGS.datadir))
50 |
51 | character_folders = [os.path.join(data_folder, family, character) \
52 | for family in os.listdir(data_folder) \
53 | if os.path.isdir(os.path.join(data_folder, family)) \
54 | for character in os.listdir(os.path.join(data_folder, family))]
55 | random.seed(1)
56 | random.shuffle(character_folders)
57 | if FLAGS.no_val:
58 | num_val = 0
59 | else:
60 | num_val = 100
61 | num_train = config.get('num_train', 1200) - num_val
62 | self.metatrain_character_folders = character_folders[:num_train]
63 | if FLAGS.test_set:
64 | self.metaval_character_folders = character_folders[num_train + num_val:]
65 | else:
66 | self.metaval_character_folders = character_folders[num_train:num_train + num_val]
67 | self.rotations = config.get('rotations', [0, 90, 180, 270])
68 | elif FLAGS.datasource == 'miniimagenet':
69 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
70 | self.img_size = config.get('img_size', (84, 84))
71 | self.dim_input = np.prod(self.img_size) * 3
72 | self.dim_output = self.num_classes
73 | metatrain_folder = config.get('metatrain_folder', '{}/miniImagenet/train'.format(FLAGS.datadir))
74 | if FLAGS.test_set:
75 | metaval_folder = config.get('metaval_folder', '{}/miniImagenet/test'.format(FLAGS.datadir))
76 | else:
77 | metaval_folder = config.get('metaval_folder', '{}/miniImagenet/val'.format(FLAGS.datadir))
78 |
79 | metatrain_folders = [os.path.join(metatrain_folder, label) \
80 | for label in os.listdir(metatrain_folder) \
81 | if os.path.isdir(os.path.join(metatrain_folder, label)) \
82 | ]
83 | metaval_folders = [os.path.join(metaval_folder, label) \
84 | for label in os.listdir(metaval_folder) \
85 | if os.path.isdir(os.path.join(metaval_folder, label)) \
86 | ]
87 | self.metatrain_character_folders = metatrain_folders
88 | self.metaval_character_folders = metaval_folders
89 | self.rotations = config.get('rotations', [0])
90 |
91 | elif FLAGS.datasource == 'multidataset':
92 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
93 | self.img_size = config.get('img_size', (84, 84))
94 | self.dim_input = np.prod(self.img_size) * 3
95 | self.dim_output = self.num_classes
96 | self.multidataset = ['CUB_Bird', 'DTD_Texture', 'FGVC_Aircraft', 'FGVCx_Fungi']
97 | metatrain_folders, metaval_folders = [], []
98 | for eachdataset in self.multidataset:
99 | metatrain_folders.append(
100 | [os.path.join('{0}/meta-dataset/{1}/train'.format(FLAGS.datadir, eachdataset), label) \
101 | for label in os.listdir('{0}/meta-dataset/{1}/train'.format(FLAGS.datadir, eachdataset)) \
102 | if
103 | os.path.isdir(os.path.join('{0}/meta-dataset/{1}/train'.format(FLAGS.datadir, eachdataset), label)) \
104 | ])
105 | if FLAGS.test_set:
106 | metaval_folders.append(
107 | [os.path.join('{0}/meta-dataset/{1}/test'.format(FLAGS.datadir, eachdataset), label) \
108 | for label in os.listdir('{0}/meta-dataset/{1}/test'.format(FLAGS.datadir, eachdataset)) \
109 | if os.path.isdir(
110 | os.path.join('{0}/meta-dataset/{1}/test'.format(FLAGS.datadir, eachdataset), label)) \
111 | ])
112 | else:
113 | metaval_folders.append(
114 | [os.path.join('{0}/meta-dataset/{1}/val'.format(FLAGS.datadir, eachdataset), label) \
115 | for label in os.listdir('{0}/meta-dataset/{1}/val'.format(FLAGS.datadir, eachdataset)) \
116 | if os.path.isdir(
117 | os.path.join('{0}/meta-dataset/{1}/val'.format(FLAGS.datadir, eachdataset), label)) \
118 | ])
119 | self.metatrain_character_folders = metatrain_folders
120 | self.metaval_character_folders = metaval_folders
121 | self.rotations = config.get('rotations', [0])
122 |
123 | elif FLAGS.datasource == 'multidataset_leave_one_out':
124 | self.num_classes = config.get('num_classes', FLAGS.num_classes)
125 | self.img_size = config.get('img_size', (84, 84))
126 | self.dim_input = np.prod(self.img_size) * 3
127 | self.dim_output = self.num_classes
128 | self.multidataset = ['CUB_Bird', 'DTD_Texture', 'FGVC_Aircraft', 'FGVCx_Fungi']
129 | metatrain_folders, metaval_folders = [], []
130 | for idx_data, eachdataset in enumerate(self.multidataset):
131 | if idx_data == FLAGS.leave_one_out_id:
132 | continue
133 | metatrain_folders.append(
134 | [os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, eachdataset), label) \
135 | for label in
136 | os.listdir('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, eachdataset)) \
137 | if
138 | os.path.isdir(
139 | os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, eachdataset),
140 | label)) \
141 | ])
142 | if FLAGS.test_set:
143 | metaval_folders = [os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir,
144 | self.multidataset[
145 | FLAGS.leave_one_out_id]),
146 | label) \
147 | for label in os.listdir(
148 | '{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir,
149 | self.multidataset[FLAGS.leave_one_out_id])) \
150 | if os.path.isdir(
151 | os.path.join('{0}/meta-dataset-leave-one-out/{1}/train'.format(FLAGS.datadir, self.multidataset[
152 | FLAGS.leave_one_out_id]), label)) \
153 | ]
154 | else:
155 | metaval_folders = [os.path.join('{0}/meta-dataset-leave-one-out/{1}/val'.format(FLAGS.datadir,
156 | self.multidataset[
157 | FLAGS.leave_one_out_id]),
158 | label) \
159 | for label in os.listdir(
160 | '{0}/meta-dataset-leave-one-out/{1}/val'.format(FLAGS.datadir,
161 | self.multidataset[FLAGS.leave_one_out_id])) \
162 | if os.path.isdir(
163 | os.path.join('{0}/meta-dataset-leave-one-out/{1}/val'.format(FLAGS.datadir,
164 | self.multidataset[
165 | FLAGS.leave_one_out_id]),
166 | label)) \
167 | ]
168 | self.metatrain_character_folders = metatrain_folders
169 | self.metaval_character_folders = metaval_folders
170 | self.rotations = config.get('rotations', [0])
171 |
172 | else:
173 | raise ValueError('Unrecognized data source')
174 |
175 | def make_data_tensor(self, train=True):
176 | if train:
177 | folders = self.metatrain_character_folders
178 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure)
179 | num_total_batches = 200000
180 | else:
181 | folders = self.metaval_character_folders
182 | num_total_batches = 600
183 |
184 | # make list of files
185 | print('Generating filenames')
186 | all_filenames = []
187 | for _ in range(num_total_batches):
188 | sampled_character_folders = random.sample(folders, self.num_classes)
189 | random.shuffle(sampled_character_folders)
190 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes),
191 | nb_samples=self.num_samples_per_class, shuffle=False)
192 | # make sure the above isn't randomized order
193 | labels = [li[0] for li in labels_and_images]
194 | filenames = [li[1] for li in labels_and_images]
195 | all_filenames.extend(filenames)
196 |
197 | # make queue for tensorflow to read from
198 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False)
199 | print('Generating image processing ops')
200 | image_reader = tf.WholeFileReader()
201 | _, image_file = image_reader.read(filename_queue)
202 | if FLAGS.datasource == 'miniimagenet':
203 | image = tf.image.decode_jpeg(image_file, channels=3)
204 | image.set_shape((self.img_size[0], self.img_size[1], 3))
205 | image = tf.reshape(image, [self.dim_input])
206 | image = tf.cast(image, tf.float32) / 255.0
207 | else:
208 | image = tf.image.decode_png(image_file)
209 | image.set_shape((self.img_size[0], self.img_size[1], 1))
210 | image = tf.reshape(image, [self.dim_input])
211 | image = tf.cast(image, tf.float32) / 255.0
212 | image = 1.0 - image
213 | num_preprocess_threads = 1
214 | min_queue_examples = 256
215 | examples_per_batch = self.num_classes * self.num_samples_per_class
216 | batch_image_size = self.batch_size * examples_per_batch
217 | print('Batching images')
218 | images = tf.train.batch(
219 | [image],
220 | batch_size=batch_image_size,
221 | num_threads=num_preprocess_threads,
222 | capacity=min_queue_examples + 3 * batch_image_size,
223 | )
224 | all_image_batches, all_label_batches = [], []
225 | print('Manipulating image data to be right shape')
226 | for i in range(self.batch_size):
227 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch]
228 |
229 | if FLAGS.datasource == 'omniglot':
230 | # omniglot augments the dataset by rotating digits to create new classes
231 | # get rotation per class (e.g. 0,1,2,0,0 if there are 5 classes)
232 | rotations = tf.multinomial(tf.log([[1., 1., 1., 1.]]), self.num_classes)
233 | label_batch = tf.convert_to_tensor(labels)
234 | new_list, new_label_list = [], []
235 | for k in range(self.num_samples_per_class):
236 | class_idxs = tf.range(0, self.num_classes)
237 | class_idxs = tf.random_shuffle(class_idxs)
238 |
239 | true_idxs = class_idxs * self.num_samples_per_class + k
240 |
241 | new_list.append(tf.gather(image_batch, true_idxs))
242 | if FLAGS.datasource == 'omniglot': # and FLAGS.train:
243 | new_list[-1] = tf.stack([tf.reshape(tf.image.rot90(
244 | tf.reshape(new_list[-1][ind], [self.img_size[0], self.img_size[1], 1]),
245 | k=tf.cast(rotations[0, class_idxs[ind]], tf.int32)), (self.dim_input,))
246 | for ind in range(self.num_classes)])
247 | new_label_list.append(tf.gather(label_batch, true_idxs))
248 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input]
249 | new_label_list = tf.concat(new_label_list, 0)
250 | all_image_batches.append(new_list)
251 | all_label_batches.append(new_label_list)
252 | all_image_batches = tf.stack(all_image_batches)
253 | all_label_batches = tf.stack(all_label_batches)
254 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes)
255 | return all_image_batches, all_label_batches
256 |
257 | def make_data_tensor_multidataset(self, train=True):
258 | if train:
259 | folders = self.metatrain_character_folders
260 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure)
261 | if FLAGS.update_batch_size == 10:
262 | num_total_batches = 140000
263 | else:
264 | num_total_batches = 200000
265 | else:
266 | folders = self.metaval_character_folders
267 | num_total_batches = FLAGS.num_test_task
268 | # make list of files
269 | print('Generating filenames')
270 | all_filenames = []
271 | # if FLAGS.train == False:
272 | # np.random.seed(4)
273 | for image_itr in range(num_total_batches):
274 | sel = np.random.randint(4)
275 | if FLAGS.train == False and FLAGS.test_dataset != -1:
276 | sel = FLAGS.test_dataset
277 | sampled_character_folders = random.sample(folders[sel], self.num_classes)
278 | random.shuffle(sampled_character_folders)
279 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes),
280 | nb_samples=self.num_samples_per_class, shuffle=False)
281 | # make sure the above isn't randomized order
282 | labels = [li[0] for li in labels_and_images]
283 | filenames = [li[1] for li in labels_and_images]
284 | all_filenames.extend(filenames)
285 |
286 | # make queue for tensorflow to read from
287 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False)
288 | print('Generating image processing ops')
289 | image_reader = tf.WholeFileReader()
290 | _, image_file = image_reader.read(filename_queue)
291 | if FLAGS.datasource in ['miniimagenet', 'multidataset']:
292 | image = tf.image.decode_jpeg(image_file, channels=3)
293 | image.set_shape((self.img_size[0], self.img_size[1], 3))
294 | image = tf.reshape(image, [self.dim_input])
295 | image = tf.cast(image, tf.float32) / 255.0
296 | else:
297 | image = tf.image.decode_png(image_file)
298 | image.set_shape((self.img_size[0], self.img_size[1], 1))
299 | image = tf.reshape(image, [self.dim_input])
300 | image = tf.cast(image, tf.float32) / 255.0
301 | image = 1.0 - image # invert
302 | num_preprocess_threads = 1 # TODO - enable this to be set to >1
303 | min_queue_examples = 256
304 | examples_per_batch = self.num_classes * self.num_samples_per_class
305 | batch_image_size = self.batch_size * examples_per_batch
306 | print('Batching images')
307 | images = tf.train.batch(
308 | [image],
309 | batch_size=batch_image_size,
310 | num_threads=num_preprocess_threads,
311 | capacity=min_queue_examples + 3 * batch_image_size,
312 | )
313 | all_image_batches, all_label_batches = [], []
314 | print('Manipulating image data to be right shape')
315 | for i in range(self.batch_size):
316 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch]
317 | label_batch = tf.convert_to_tensor(labels)
318 | new_list, new_label_list = [], []
319 | for k in range(self.num_samples_per_class):
320 | class_idxs = tf.range(0, self.num_classes)
321 | class_idxs = tf.random_shuffle(class_idxs)
322 | true_idxs = class_idxs * self.num_samples_per_class + k
323 | new_list.append(tf.gather(image_batch, true_idxs))
324 | new_label_list.append(tf.gather(label_batch, true_idxs))
325 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input]
326 | new_label_list = tf.concat(new_label_list, 0)
327 | all_image_batches.append(new_list)
328 | all_label_batches.append(new_label_list)
329 | all_image_batches = tf.stack(all_image_batches)
330 | all_label_batches = tf.stack(all_label_batches)
331 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes)
332 | return all_image_batches, all_label_batches
333 |
334 | def make_data_tensor_multidataset_leave_one_out(self, train=True):
335 | if train:
336 | folders = self.metatrain_character_folders
337 | # number of tasks, not number of meta-iterations. (divide by metabatch size to measure)
338 | num_total_batches = 200000
339 | else:
340 | folders = self.metaval_character_folders
341 | num_total_batches = FLAGS.num_test_task
342 | # make list of files
343 | print('Generating filenames')
344 | all_filenames = []
345 | for image_itr in range(num_total_batches):
346 | if train:
347 | sel = np.random.randint(3)
348 | sampled_character_folders = random.sample(folders[sel], self.num_classes)
349 | else:
350 | sampled_character_folders = random.sample(folders, self.num_classes)
351 | random.shuffle(sampled_character_folders)
352 | labels_and_images = get_images(sampled_character_folders, range(self.num_classes),
353 | nb_samples=self.num_samples_per_class, shuffle=False)
354 | # make sure the above isn't randomized order
355 | labels = [li[0] for li in labels_and_images]
356 | filenames = [li[1] for li in labels_and_images]
357 | all_filenames.extend(filenames)
358 |
359 | # make queue for tensorflow to read from
360 | filename_queue = tf.train.string_input_producer(tf.convert_to_tensor(all_filenames), shuffle=False)
361 | print('Generating image processing ops')
362 | image_reader = tf.WholeFileReader()
363 | _, image_file = image_reader.read(filename_queue)
364 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
365 | image = tf.image.decode_jpeg(image_file, channels=3)
366 | image.set_shape((self.img_size[0], self.img_size[1], 3))
367 | image = tf.reshape(image, [self.dim_input])
368 | image = tf.cast(image, tf.float32) / 255.0
369 | else:
370 | image = tf.image.decode_png(image_file)
371 | image.set_shape((self.img_size[0], self.img_size[1], 1))
372 | image = tf.reshape(image, [self.dim_input])
373 | image = tf.cast(image, tf.float32) / 255.0
374 | image = 1.0 - image # invert
375 |
376 | num_preprocess_threads = 1
377 | min_queue_examples = 256
378 | examples_per_batch = self.num_classes * self.num_samples_per_class
379 | batch_image_size = self.batch_size * examples_per_batch
380 | print('Batching images')
381 | images = tf.train.batch(
382 | [image],
383 | batch_size=batch_image_size,
384 | num_threads=num_preprocess_threads,
385 | capacity=min_queue_examples + 3 * batch_image_size,
386 | )
387 | all_image_batches, all_label_batches = [], []
388 | print('Manipulating image data to be right shape')
389 | for i in range(self.batch_size):
390 | image_batch = images[i * examples_per_batch:(i + 1) * examples_per_batch]
391 | label_batch = tf.convert_to_tensor(labels)
392 | new_list, new_label_list = [], []
393 | for k in range(self.num_samples_per_class):
394 | class_idxs = tf.range(0, self.num_classes)
395 | class_idxs = tf.random_shuffle(class_idxs)
396 | true_idxs = class_idxs * self.num_samples_per_class + k
397 | new_list.append(tf.gather(image_batch, true_idxs))
398 | new_label_list.append(tf.gather(label_batch, true_idxs))
399 | new_list = tf.concat(new_list, 0) # has shape [self.num_classes*self.num_samples_per_class, self.dim_input]
400 | new_label_list = tf.concat(new_label_list, 0)
401 | all_image_batches.append(new_list)
402 | all_label_batches.append(new_label_list)
403 | all_image_batches = tf.stack(all_image_batches)
404 | all_label_batches = tf.stack(all_label_batches)
405 | all_label_batches = tf.one_hot(all_label_batches, self.num_classes)
406 | return all_image_batches, all_label_batches
407 |
408 | def generate_sinusoid_batch(self, train=True, input_idx=None):
409 | # Note train arg is not used (but it is used for omniglot method.
410 | # input_idx is used during qualitative testing --the number of examples used for the grad update
411 | amp = np.random.uniform(self.amp_range[0], self.amp_range[1], [self.batch_size])
412 | freq = np.random.uniform(self.freq_range[0], self.freq_range[1], [self.batch_size])
413 | phase = np.random.uniform(self.phase_range[0], self.phase_range[1], [self.batch_size])
414 | outputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_output])
415 | init_inputs = np.zeros([self.batch_size, self.num_samples_per_class, self.dim_input])
416 | for func in range(self.batch_size):
417 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1],
418 | [self.num_samples_per_class, 1])
419 | if input_idx is not None:
420 | init_inputs[:, input_idx:, 0] = np.linspace(self.input_range[0], self.input_range[1],
421 | num=self.num_samples_per_class - input_idx, retstep=False)
422 | outputs[func] = amp[func] * np.sin(freq[func] * init_inputs[func] - phase[func])
423 | return init_inputs, outputs, amp, phase
424 |
425 | def generate_mixture_batch(self, train=True, input_idx=None, DRAW_PLOTS=True):
426 | dim_input = self.dim_input
427 | dim_output = self.dim_output
428 | batch_size = self.batch_size
429 | num_samples_per_class = self.num_samples_per_class
430 |
431 | # sin
432 | amp = np.random.uniform(0.1, 5.0, size=self.batch_size)
433 | phase = np.random.uniform(0., 2 * np.pi, size=batch_size)
434 | freq = np.random.uniform(0.8, 1.2, size=batch_size)
435 |
436 | # linear
437 | A = np.random.uniform(-3.0, 3.0, size=batch_size)
438 | b = np.random.uniform(-3.0, 3.0, size=batch_size)
439 |
440 | # quadratic
441 | A_q = np.random.uniform(-0.2, 0.2, size=batch_size)
442 | b_q = np.random.uniform(-2.0, 2.0, size=batch_size)
443 | c_q = np.random.uniform(-3.0, 3.0, size=batch_size)
444 |
445 | # cubic
446 | A_c = np.random.uniform(-0.1, 0.1, size=batch_size)
447 | b_c = np.random.uniform(-0.2, 0.2, size=batch_size)
448 | c_c = np.random.uniform(-2.0, 2.0, size=batch_size)
449 | d_c = np.random.uniform(-3.0, 3.0, size=batch_size)
450 |
451 | sel_set = np.zeros(batch_size)
452 |
453 | init_inputs = np.zeros([batch_size, num_samples_per_class, dim_input])
454 | outputs = np.zeros([batch_size, num_samples_per_class, dim_output])
455 |
456 | for func in range(batch_size):
457 | init_inputs[func] = np.random.uniform(self.input_range[0], self.input_range[1],
458 | size=(num_samples_per_class, dim_input))
459 | sel = np.random.randint(4)
460 | if FLAGS.train == False and FLAGS.test_dataset != -1:
461 | sel = FLAGS.test_dataset
462 | if sel == 0:
463 | outputs[func] = amp[func] * np.sin(freq[func] * init_inputs[func]) + phase[func]
464 | elif sel == 1:
465 | outputs[func] = A[func] * init_inputs[func] + b[func]
466 | elif sel == 2:
467 | outputs[func] = A_q[func] * np.square(init_inputs[func]) + b_q[func] * init_inputs[func] + c_q[func]
468 | elif sel == 3:
469 | outputs[func] = A_c[func] * np.power(init_inputs[func], np.tile([3], init_inputs[func].shape)) + b_c[
470 | func] * np.square(init_inputs[func]) + c_c[func] * init_inputs[func] + d_c[func]
471 | sel_set[func] = sel
472 | funcs_params = {'amp': amp, 'phase': phase, 'freq': freq, 'A': A, 'b': b, 'A_q': A_q, 'c_q': c_q, 'b_q': b_q,
473 | 'A_c': A_c, 'b_c': b_c, 'c_c': c_c, 'd_c': d_c}
474 | return init_inputs, outputs, funcs_params, sel_set
475 |
--------------------------------------------------------------------------------
/image_embedding.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.platform import flags
3 |
4 | FLAGS = flags.FLAGS
5 |
6 |
7 | class ImageEmbedding(object):
8 | def __init__(self, hidden_num, channels, conv_initializer, k=5):
9 | self.hidden_num = hidden_num
10 | self.channels = channels
11 | with tf.variable_scope('image_embedding', reuse=tf.AUTO_REUSE):
12 | self.conv1_kernel = tf.get_variable('conv1_kernel', [k, k, self.channels, self.hidden_num],
13 | initializer=conv_initializer)
14 |
15 | self.conv2_kernel = tf.get_variable('conv2_kernel', [k, k, self.hidden_num, self.hidden_num],
16 | initializer=conv_initializer)
17 | self.activation = tf.nn.relu
18 |
19 | def model(self, images):
20 | conv = tf.nn.conv2d(images, self.conv1_kernel, [1, 1, 1, 1], padding='SAME')
21 | conv1 = tf.nn.relu(conv, name='conv1_post_activation')
22 |
23 | pool1 = tf.nn.max_pool(conv1, ksize=[1, 3, 3, 1], strides=[1, 2, 2, 1],
24 | padding='SAME', name='pool1')
25 | norm1 = tf.nn.lrn(pool1, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
26 | name='norm1')
27 |
28 | conv2 = tf.nn.conv2d(norm1, self.conv2_kernel, [1, 1, 1, 1], padding='SAME')
29 | conv2_act = tf.nn.relu(conv2, name='conv2_post_activation')
30 |
31 | norm2 = tf.nn.lrn(conv2_act, 4, bias=1.0, alpha=0.001 / 9.0, beta=0.75,
32 | name='norm2')
33 | pool2 = tf.nn.max_pool(norm2, ksize=[1, 3, 3, 1],
34 | strides=[1, 2, 2, 1], padding='SAME', name='pool2')
35 |
36 | with tf.variable_scope('local3', reuse=tf.AUTO_REUSE):
37 | image_reshape = tf.reshape(pool2, [images.get_shape().as_list()[0], -1])
38 | dim = image_reshape.get_shape()[1].value
39 | local3_weight = tf.get_variable(name='weight', shape=[dim, 384],
40 | initializer=tf.truncated_normal_initializer(stddev=0.04))
41 | local3_biases = tf.get_variable(name='biases', shape=[384], initializer=tf.constant_initializer(0.1))
42 | local3=tf.nn.relu(tf.matmul(image_reshape, local3_weight)+local3_biases, name='local3_dense')
43 |
44 | with tf.variable_scope('local4', reuse=tf.AUTO_REUSE) as scope:
45 | local4_weight = tf.get_variable(name='weight', shape=[384, 64],
46 | initializer=tf.truncated_normal_initializer(stddev=0.04))
47 | local4_biases = tf.get_variable(name='biases', shape=[64], initializer=tf.constant_initializer(0.1))
48 | local4 = tf.nn.relu(tf.matmul(local3, local4_weight) + local4_biases, name='local4_dense')
49 | return local4
50 |
--------------------------------------------------------------------------------
/lstm_tree.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 |
3 | from tensorflow.python.platform import flags
4 |
5 | FLAGS = flags.FLAGS
6 |
7 |
8 | class TreeLSTM(object):
9 | def __init__(self, tree_hidden_dim, input_dim):
10 | self.input_dim = input_dim
11 | self.tree_hidden_dim = tree_hidden_dim
12 | self.leaf_weight_i, self.leaf_weight_o, self.leaf_weight_u = [], [], []
13 | self.leaf_bias_i, self.leaf_bias_o, self.leaf_bias_u = [], [], []
14 | for i in range(FLAGS.cluster_layer_0):
15 | self.leaf_weight_u.append(
16 | tf.get_variable(name='{}_leaf_weight_u'.format(i), shape=(input_dim, tree_hidden_dim)))
17 | self.leaf_bias_u.append(tf.get_variable(name='{}_leaf_bias_u'.format(i), shape=(1, tree_hidden_dim)))
18 |
19 | self.no_leaf_weight_i, self.no_leaf_weight_o, self.no_leaf_weight_u, self.no_leaf_weight_f = [], [], [], []
20 | self.no_leaf_bias_i, self.no_leaf_bias_o, self.no_leaf_bias_u, self.no_leaf_bias_f = [], [], [], []
21 | for i in range(FLAGS.cluster_layer_1):
22 | if FLAGS.tree_type == 1:
23 | self.no_leaf_weight_i.append(
24 | tf.get_variable(name='{}_no_leaf_weight_i'.format(i), shape=(tree_hidden_dim, 1)))
25 | elif FLAGS.tree_type == 2:
26 | self.no_leaf_weight_i.append(
27 | tf.get_variable(name='{}_no_leaf_weight_i'.format(i), shape=(1, tree_hidden_dim)))
28 | self.no_leaf_weight_u.append(
29 | tf.get_variable(name='{}_no_leaf_weight_u'.format(i), shape=(tree_hidden_dim, tree_hidden_dim)))
30 |
31 | self.no_leaf_bias_i.append(tf.get_variable(name='{}_no_leaf_bias_i'.format(i), shape=(1, 1)))
32 | self.no_leaf_bias_u.append(tf.get_variable(name='{}_no_leaf_bias_u'.format(i), shape=(1, tree_hidden_dim)))
33 |
34 | if FLAGS.cluster_layer_2 != -1:
35 | self.no_leaf_weight_i_l2, self.no_leaf_weight_u_l2 = [], []
36 | self.no_leaf_bias_i_l2, self.no_leaf_bias_u_l2 = [], []
37 | for i in range(FLAGS.cluster_layer_2):
38 | if FLAGS.tree_type == 1:
39 | self.no_leaf_weight_i_l2.append(
40 | tf.get_variable(name='{}_no_leaf_weight_i_l2'.format(i), shape=(tree_hidden_dim, 1)))
41 | elif FLAGS.tree_type == 2:
42 | self.no_leaf_weight_i_l2.append(
43 | tf.get_variable(name='{}_no_leaf_weight_i_l2'.format(i), shape=(1, tree_hidden_dim)))
44 | self.no_leaf_weight_u_l2.append(
45 | tf.get_variable(name='{}_no_leaf_weight_u_l2'.format(i), shape=(tree_hidden_dim, tree_hidden_dim)))
46 |
47 | self.no_leaf_bias_i_l2.append(tf.get_variable(name='{}_no_leaf_bias_i_l2'.format(i), shape=(1, 1)))
48 | self.no_leaf_bias_u_l2.append(
49 | tf.get_variable(name='{}_no_leaf_bias_u_l2'.format(i), shape=(1, tree_hidden_dim)))
50 |
51 | self.root_weight_u = tf.get_variable(name='{}_root_weight_u'.format(i),
52 | shape=(tree_hidden_dim, tree_hidden_dim))
53 |
54 | self.root_bias_u = tf.get_variable(name='{}_root_bias_u'.format(i), shape=(1, tree_hidden_dim))
55 |
56 | self.cluster_center = []
57 | for i in range(FLAGS.cluster_layer_0):
58 | self.cluster_center.append(tf.get_variable(name='{}_cluster_center'.format(i),
59 | shape=(1, input_dim)))
60 |
61 | self.cluster_layer_0 = FLAGS.cluster_layer_0
62 | self.cluster_layer_1 = FLAGS.cluster_layer_1
63 | self.cluster_layer_2 = FLAGS.cluster_layer_2
64 |
65 | def model(self, inputs):
66 |
67 | if FLAGS.datasource == 'multidataset' or FLAGS.datasource == 'miniimagenet' or FLAGS.datasource == 'multidataset_leave_one_out':
68 | sigma = 10.0
69 | elif FLAGS.datasource in ['sinusoid', 'mixture']:
70 | sigma = 2.0
71 |
72 | for idx in range(self.cluster_layer_0):
73 | if idx == 0:
74 | all_value = tf.exp(-tf.reduce_sum(tf.square(inputs - self.cluster_center[idx])) / (2.0 * sigma))
75 | else:
76 | all_value += tf.exp(-tf.reduce_sum(tf.square(inputs - self.cluster_center[idx])) / (2.0 * sigma))
77 |
78 | c_leaf = []
79 | for idx in range(self.cluster_layer_0):
80 | assignment_idx = tf.exp(
81 | -tf.reduce_sum(tf.square(inputs - self.cluster_center[idx])) / (2.0 * sigma)) / all_value
82 | value_u = tf.tanh(tf.matmul(inputs, self.leaf_weight_u[idx]) + self.leaf_bias_u[idx])
83 | value_c = assignment_idx * value_u
84 | c_leaf.append(value_c)
85 |
86 | c_no_leaf = []
87 | for idx in range(self.cluster_layer_0):
88 | input_gate = []
89 | for idx_layer_1 in range(self.cluster_layer_1):
90 | if FLAGS.tree_type == 1:
91 | input_gate.append(
92 | tf.matmul(c_leaf[idx], self.no_leaf_weight_i[idx_layer_1]) + self.no_leaf_bias_i[idx_layer_1])
93 | elif FLAGS.tree_type == 2:
94 | input_gate.append(
95 | -(tf.reduce_sum(tf.square(c_leaf[idx] - self.no_leaf_weight_i[idx_layer_1]), keepdims=True) +
96 | self.no_leaf_bias_i[idx_layer_1]) / (
97 | 2.0))
98 | input_gate = tf.nn.softmax(tf.concat(input_gate, axis=0), axis=0)
99 |
100 | c_no_leaf_temp = []
101 | for idx_layer_1 in range(self.cluster_layer_1):
102 | no_leaf_value_u = tf.tanh(
103 | tf.matmul(c_leaf[idx], self.no_leaf_weight_u[idx_layer_1]) + self.no_leaf_bias_u[idx_layer_1])
104 | c_no_leaf_temp.append(input_gate[idx_layer_1] * no_leaf_value_u)
105 | c_no_leaf.append(tf.concat(c_no_leaf_temp, axis=0))
106 |
107 | c_no_leaf = tf.stack(c_no_leaf, axis=0)
108 | c_no_leaf = tf.transpose(c_no_leaf, perm=[1, 0, 2])
109 | c_no_leaf = tf.reduce_sum(c_no_leaf, axis=1, keepdims=True)
110 |
111 | if FLAGS.cluster_layer_2 != -1:
112 | c_no_leaf_l2 = []
113 |
114 | for idx_l2 in range(self.cluster_layer_1):
115 | input_gate_l2 = []
116 | for idx_layer_2 in range(self.cluster_layer_2):
117 | if FLAGS.tree_type == 1:
118 | input_gate_l2.append(
119 | tf.matmul(c_no_leaf[idx_l2], self.no_leaf_weight_i_l2[idx_layer_2]) +
120 | self.no_leaf_bias_i_l2[
121 | idx_layer_2])
122 | elif FLAGS.tree_type == 2:
123 | input_gate_l2.append(
124 | -(tf.reduce_sum(tf.square(c_no_leaf[idx_l2] - self.no_leaf_weight_i_l2[idx_layer_2]),
125 | keepdims=True) + self.no_leaf_bias_i[idx_layer_1]) / (2.0))
126 | input_gate_l2 = tf.nn.softmax(tf.concat(input_gate_l2, axis=0), axis=0)
127 |
128 | c_no_leaf_temp_l2 = []
129 | for idx_layer_2 in range(self.cluster_layer_2):
130 | no_leaf_value_u_l2 = tf.tanh(
131 | tf.matmul(c_no_leaf[idx_l2], self.no_leaf_weight_u_l2[idx_layer_2]) + self.no_leaf_bias_u_l2[
132 | idx_layer_2])
133 | c_no_leaf_temp_l2.append(input_gate_l2[idx_layer_2] * no_leaf_value_u_l2)
134 | c_no_leaf_l2.append(tf.concat(c_no_leaf_temp_l2, axis=0))
135 |
136 | c_no_leaf_l2 = tf.stack(c_no_leaf_l2, axis=0)
137 | c_no_leaf_l2 = tf.transpose(c_no_leaf_l2, perm=[1, 0, 2])
138 | c_no_leaf_l2 = tf.reduce_sum(c_no_leaf_l2, axis=1, keepdims=True)
139 |
140 | root_c = []
141 |
142 | if FLAGS.cluster_layer_2 != -1:
143 | for idx in range(self.cluster_layer_2):
144 | root_c.append(tf.tanh(tf.matmul(c_no_leaf_l2[idx], self.root_weight_u) + self.root_bias_u))
145 | else:
146 | for idx in range(self.cluster_layer_1):
147 | root_c.append(tf.tanh(tf.matmul(c_no_leaf[idx], self.root_weight_u) + self.root_bias_u))
148 |
149 | root_c = tf.reduce_sum(tf.concat(root_c, axis=0), axis=0, keepdims=True)
150 |
151 | return root_c, root_c
152 |
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import csv
2 | import numpy as np
3 | import pickle
4 | import random
5 | import matplotlib.pyplot as plt
6 | import tensorflow as tf
7 |
8 | tf.set_random_seed(1234)
9 | from data_generator import DataGenerator
10 | from maml import MAML
11 | from tensorflow.python.platform import flags
12 |
13 | FLAGS = flags.FLAGS
14 |
15 | ## Dataset/method options
16 | flags.DEFINE_string('datasource', 'sinusoid', 'sinusoid or omniglot or miniimagenet or mixture or multidataset or multidataset_leave_one_out')
17 | flags.DEFINE_integer('leave_one_out_id',-1,'id of leave one out')
18 | flags.DEFINE_integer('test_dataset', -1,
19 | 'which dataset to be test: 0: bird, 1: texture, 2: aircraft, 3: fungi, -1 is test all')
20 | flags.DEFINE_integer('num_classes', 5, 'number of classes used in classification (e.g. 5-way classification).')
21 | flags.DEFINE_integer('num_test_task', 1000, 'number of test tasks.')
22 | flags.DEFINE_integer('test_epoch', -1, 'test epoch, only work when test start')
23 |
24 | ## Training options
25 | flags.DEFINE_integer('pretrain_iterations', 0, 'number of pre-training iterations.')
26 | flags.DEFINE_integer('metatrain_iterations', 15000,
27 | 'number of metatraining iterations.') # 15k for omniglot, 50k for sinusoid
28 | flags.DEFINE_integer('meta_batch_size', 25, 'number of tasks sampled per meta-update')
29 | flags.DEFINE_float('meta_lr', 0.001, 'the base learning rate of the generator')
30 | flags.DEFINE_integer('update_batch_size', 5,
31 | 'number of examples used for inner gradient update (K for K-shot learning).')
32 | flags.DEFINE_integer('update_batch_size_eval', 10,
33 | 'number of examples used for inner gradient test (K for K-shot learning).')
34 | flags.DEFINE_float('update_lr', 1e-3, 'step size alpha for inner gradient update.') # 0.1 for omniglot
35 | flags.DEFINE_integer('num_updates', 1, 'number of inner gradient updates during training.')
36 | flags.DEFINE_integer('num_groups', 1, 'number of groups.')
37 | flags.DEFINE_integer('fix_embedding_sample', -1,
38 | 'if the fix_embedding sample is -1, all samples are used for embedding. Otherwise, specific samples are used')
39 | ## Model options
40 | flags.DEFINE_string('norm', 'batch_norm', 'batch_norm, layer_norm, or None')
41 | flags.DEFINE_integer('hidden_dim', 40, 'output dimension of task embedding')
42 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omiglot.')
43 | flags.DEFINE_bool('conv', True, 'whether or not to use a convolutional network, only applicable in some cases')
44 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
45 | flags.DEFINE_bool('stop_grad', False, 'if True, do not use second derivatives in meta-optimization (for speed)')
46 | flags.DEFINE_float('emb_loss_weight', 0.0, 'the weight of autoencoder')
47 | flags.DEFINE_string('emb_type', 'sigmoid', 'sigmoid')
48 | flags.DEFINE_bool('no_val', False, 'if true, there are no validation set of Omniglot dataset')
49 | flags.DEFINE_integer('tree_type', 1, 'select the tree type: 1 or 2')
50 | flags.DEFINE_integer('task_embedding_num_filters', 32, 'number of filters for task embedding')
51 | flags.DEFINE_string('task_embedding_type', 'rnn', 'rnn or mean')
52 |
53 | ## clustering information
54 | flags.DEFINE_integer('cluster_layer_0', 4, 'number of clusters in the first layer')
55 | flags.DEFINE_integer('cluster_layer_1', 2, 'number of clusters in the second layer')
56 | flags.DEFINE_integer('cluster_layer_2', -1, 'number of clusters in the third layer')
57 |
58 | ## Logging, saving, and testing options
59 | flags.DEFINE_bool('log', True, 'if false, do not log summaries, for debugging code.')
60 | flags.DEFINE_string('logdir', '/tmp/data', 'directory for summaries and checkpoints.')
61 | flags.DEFINE_string('datadir', '/home/huaxiuyao/Data/', 'directory for datasets.')
62 | flags.DEFINE_bool('resume', True, 'resume training if there is a model available')
63 | flags.DEFINE_bool('train', True, 'True to train, False to test.')
64 | flags.DEFINE_bool('test_set', False, 'Set to true to test on the the test set, False for the validation set.')
65 | flags.DEFINE_integer('train_update_batch_size', -1,
66 | 'number of examples used for gradient update during training (use if you want to test with a different number).')
67 | flags.DEFINE_float('train_update_lr', -1,
68 | 'value of inner gradient step step during training. (use if you want to test with a different value)') # 0.1 for omniglot
69 |
70 |
71 | def train(model, saver, sess, exp_string, data_generator, resume_itr=0):
72 | SUMMARY_INTERVAL = 100
73 | SAVE_INTERVAL = 1000
74 | if FLAGS.datasource in ['sinusoid', 'mixture']:
75 | PRINT_INTERVAL = 1000
76 | TEST_PRINT_INTERVAL = PRINT_INTERVAL * 5
77 | else:
78 | PRINT_INTERVAL = 100
79 | TEST_PRINT_INTERVAL = PRINT_INTERVAL * 10
80 |
81 | if FLAGS.log:
82 | train_writer = tf.summary.FileWriter(FLAGS.logdir + '/' + exp_string, sess.graph)
83 | print('Done initializing, starting training.')
84 |
85 | prelosses, postlosses, embedlosses = [], [], []
86 |
87 | num_classes = data_generator.num_classes # for classification, 1 otherwise
88 |
89 | for itr in range(resume_itr, FLAGS.pretrain_iterations + FLAGS.metatrain_iterations):
90 | feed_dict = {}
91 | if 'generate' in dir(data_generator):
92 | if FLAGS.datasource == 'sinusoid':
93 | batch_x, batch_y, amp, phase = data_generator.generate()
94 | elif FLAGS.datasource == 'mixture':
95 | batch_x, batch_y, para_func, sel_set = data_generator.generate()
96 |
97 | inputa = batch_x[:, :num_classes * FLAGS.update_batch_size, :]
98 | labela = batch_y[:, :num_classes * FLAGS.update_batch_size, :]
99 | inputb = batch_x[:, num_classes * FLAGS.update_batch_size:, :]
100 | labelb = batch_y[:, num_classes * FLAGS.update_batch_size:, :]
101 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb}
102 |
103 | if itr < FLAGS.pretrain_iterations:
104 | input_tensors = [model.pretrain_op]
105 | else:
106 | input_tensors = [model.metatrain_op]
107 |
108 | input_tensors.extend(
109 | [model.summ_op, model.total_embed_loss, model.total_loss1, model.total_losses2[FLAGS.num_updates - 1]])
110 | if model.classification:
111 | input_tensors.extend([model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates - 1]])
112 |
113 | result = sess.run(input_tensors, feed_dict)
114 |
115 | if np.isnan(result[-2]) == False and np.isnan(result[-2]) == False and np.isnan(result[2]) == False:
116 | prelosses.append(result[-2])
117 | postlosses.append(result[-1])
118 | embedlosses.append(result[2])
119 |
120 | if itr % SUMMARY_INTERVAL == 0:
121 | if FLAGS.log:
122 | train_writer.add_summary(result[1], itr)
123 |
124 | if (itr != 0) and itr % PRINT_INTERVAL == 0:
125 | if itr < FLAGS.pretrain_iterations:
126 | print_str = 'Pretrain Iteration ' + str(itr)
127 | else:
128 | print_str = 'Iteration ' + str(itr - FLAGS.pretrain_iterations)
129 | std = np.std(postlosses, 0)
130 | ci95 = 1.96 * std / np.sqrt(PRINT_INTERVAL)
131 | print_str += ': preloss: ' + str(np.mean(prelosses)) + ', postloss: ' + str(
132 | np.mean(postlosses)) + ', embedding loss: ' + str(np.mean(embedlosses)) + ', confidence: ' + str(ci95)
133 | print(print_str)
134 | prelosses, postlosses, embedlosses = [], [], []
135 |
136 | if (itr != 0) and itr % SAVE_INTERVAL == 0:
137 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
138 |
139 | if (itr != 0) and itr % TEST_PRINT_INTERVAL == 0 and (
140 | FLAGS.datasource not in ['sinusoid', 'mixture']):
141 | if 'generate' not in dir(data_generator):
142 | feed_dict = {}
143 | if model.classification:
144 | input_tensors = [model.metaval_total_accuracy1,
145 | model.metaval_total_accuracies2[FLAGS.num_updates - 1], model.summ_op]
146 | else:
147 | input_tensors = [model.metaval_total_loss1, model.metaval_total_losses2[FLAGS.num_updates - 1],
148 | model.summ_op]
149 | else:
150 | if FLAGS.datasource == 'sinusoid':
151 | batch_x, batch_y, amp, phase = data_generator.generate(train=False)
152 | elif FLAGS.datasource == 'mixture':
153 | batch_x, batch_y, para_func = data_generator.generate(train=False)
154 | inputa = batch_x[:, :num_classes * FLAGS.update_batch_size, :]
155 | inputb = batch_x[:, num_classes * FLAGS.update_batch_size:, :]
156 | labela = batch_y[:, :num_classes * FLAGS.update_batch_size, :]
157 | labelb = batch_y[:, num_classes * FLAGS.update_batch_size:, :]
158 |
159 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb,
160 | model.meta_lr: 0.0}
161 | if model.classification:
162 | input_tensors = [model.total_accuracy1, model.total_accuracies2[FLAGS.num_updates - 1]]
163 | else:
164 | input_tensors = [model.total_loss1, model.total_losses2[FLAGS.num_updates - 1]]
165 |
166 | result = sess.run(input_tensors, feed_dict)
167 | print('Validation results: ' + str(result[0]) + ', ' + str(result[1]))
168 |
169 | saver.save(sess, FLAGS.logdir + '/' + exp_string + '/model' + str(itr))
170 |
171 |
172 | if FLAGS.datasource in ['multidataset', 'multidataset_leave_one_out', 'mixture']:
173 | NUM_TEST_POINTS = FLAGS.num_test_task
174 |
175 |
176 | def test(model, saver, sess, exp_string, data_generator, test_num_updates=None):
177 | num_classes = data_generator.num_classes
178 |
179 | np.random.seed(1)
180 | random.seed(1)
181 |
182 | metaval_accuracies = []
183 | print(NUM_TEST_POINTS)
184 | for test_itr in range(NUM_TEST_POINTS):
185 |
186 | if 'generate' not in dir(data_generator):
187 | feed_dict = {}
188 | feed_dict = {model.meta_lr: 0.0}
189 | else:
190 | if FLAGS.datasource == 'sinusoid':
191 | batch_x, batch_y, amp, phase = data_generator.generate(train=False)
192 | elif FLAGS.datasource == 'mixture':
193 | batch_x, batch_y, para_func, sel_set = data_generator.generate(train=False)
194 |
195 | inputa = batch_x[:, :num_classes * FLAGS.update_batch_size, :]
196 | inputb = batch_x[:, num_classes * FLAGS.update_batch_size:, :]
197 | labela = batch_y[:, :num_classes * FLAGS.update_batch_size, :]
198 | labelb = batch_y[:, num_classes * FLAGS.update_batch_size:, :]
199 |
200 | feed_dict = {model.inputa: inputa, model.inputb: inputb, model.labela: labela, model.labelb: labelb,
201 | model.meta_lr: 0.0}
202 |
203 | if model.classification:
204 | result = sess.run([model.metaval_total_accuracy1] + model.metaval_total_accuracies2, feed_dict)
205 | else: # this is for sinusoid
206 | result = sess.run([model.total_loss1] + model.total_losses2, feed_dict)
207 |
208 | metaval_accuracies.append(result)
209 |
210 | metaval_accuracies = np.array(metaval_accuracies)
211 | means = np.mean(metaval_accuracies, 0)
212 | stds = np.std(metaval_accuracies, 0)
213 | ci95 = 1.96 * stds / np.sqrt(NUM_TEST_POINTS)
214 |
215 | print('Mean validation accuracy/loss, stddev, and confidence intervals')
216 | print((means, stds, ci95))
217 |
218 |
219 | def main():
220 |
221 | if FLAGS.datasource == 'multidataset_leave_one_out':
222 | assert FLAGS.leave_one_out_id > -1
223 |
224 | sess = tf.InteractiveSession()
225 | if FLAGS.datasource in ['sinusoid', 'mixture']:
226 | if FLAGS.train:
227 | test_num_updates = 1
228 | else:
229 | test_num_updates = 10
230 | else:
231 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
232 | if FLAGS.train == True:
233 | test_num_updates = 1 # eval on at least one update during training
234 | else:
235 | test_num_updates = 10
236 | else:
237 | test_num_updates = 10
238 |
239 | if FLAGS.train == False:
240 | orig_meta_batch_size = FLAGS.meta_batch_size
241 | # always use meta batch size of 1 when testing.
242 | FLAGS.meta_batch_size = 1
243 |
244 | if FLAGS.datasource in ['sinusoid', 'mixture']:
245 | data_generator = DataGenerator(FLAGS.update_batch_size + FLAGS.update_batch_size_eval, FLAGS.meta_batch_size)
246 | else:
247 | if FLAGS.metatrain_iterations == 0 and FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
248 | assert FLAGS.meta_batch_size == 1
249 | assert FLAGS.update_batch_size == 1
250 | data_generator = DataGenerator(1, FLAGS.meta_batch_size) # only use one datapoint,
251 | else:
252 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
253 | if FLAGS.train:
254 | data_generator = DataGenerator(FLAGS.update_batch_size + 15,
255 | FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
256 | else:
257 | data_generator = DataGenerator(FLAGS.update_batch_size * 2,
258 | FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
259 | else:
260 | data_generator = DataGenerator(FLAGS.update_batch_size * 2,
261 | FLAGS.meta_batch_size) # only use one datapoint for testing to save memory
262 |
263 | dim_output = data_generator.dim_output
264 | dim_input = data_generator.dim_input
265 |
266 | if FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']:
267 | tf_data_load = True
268 | num_classes = data_generator.num_classes
269 |
270 | if FLAGS.train: # only construct training model if needed
271 | random.seed(5)
272 | if FLAGS.datasource in ['miniimagenet', 'omniglot']:
273 | image_tensor, label_tensor = data_generator.make_data_tensor()
274 | elif FLAGS.datasource == 'multidataset':
275 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset()
276 | elif FLAGS.datasource == 'multidataset_leave_one_out':
277 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out()
278 | inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
279 | inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
280 | labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
281 | labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
282 | input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
283 |
284 | random.seed(6)
285 | if FLAGS.datasource in ['miniimagenet', 'omniglot']:
286 | image_tensor, label_tensor = data_generator.make_data_tensor(train=False)
287 | elif FLAGS.datasource == 'multidataset':
288 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset(train=False)
289 | elif FLAGS.datasource == 'multidataset_leave_one_out':
290 | image_tensor, label_tensor = data_generator.make_data_tensor_multidataset_leave_one_out(train=False)
291 | inputa = tf.slice(image_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
292 | inputb = tf.slice(image_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
293 | labela = tf.slice(label_tensor, [0, 0, 0], [-1, num_classes * FLAGS.update_batch_size, -1])
294 | labelb = tf.slice(label_tensor, [0, num_classes * FLAGS.update_batch_size, 0], [-1, -1, -1])
295 | metaval_input_tensors = {'inputa': inputa, 'inputb': inputb, 'labela': labela, 'labelb': labelb}
296 | else:
297 | tf_data_load = False
298 | input_tensors = None
299 |
300 | model = MAML(sess, dim_input, dim_output, test_num_updates=test_num_updates)
301 |
302 | if FLAGS.train or not tf_data_load:
303 | model.construct_model(input_tensors=input_tensors, prefix='metatrain_')
304 | if tf_data_load:
305 | model.construct_model(input_tensors=metaval_input_tensors, prefix='metaval_')
306 | model.summ_op = tf.summary.merge_all()
307 | saver = loader = tf.train.Saver(tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES), max_to_keep=10)
308 |
309 | if FLAGS.train == False:
310 | # change to original meta batch size when loading model.
311 | FLAGS.meta_batch_size = orig_meta_batch_size
312 |
313 | if FLAGS.train_update_batch_size == -1:
314 | FLAGS.train_update_batch_size = FLAGS.update_batch_size
315 | if FLAGS.train_update_lr == -1:
316 | FLAGS.train_update_lr = FLAGS.update_lr
317 |
318 | exp_string = 'cls_' + str(FLAGS.num_classes) + '.mbs_' + str(FLAGS.meta_batch_size) + '.ubs_' + str(
319 | FLAGS.train_update_batch_size) + '.numstep' + str(FLAGS.num_updates) + '.updatelr' + str(
320 | FLAGS.train_update_lr) + '.metalr' + str(FLAGS.meta_lr) + '.emb_loss_weight' + str(
321 | FLAGS.emb_loss_weight) + '.num_groups' + str(FLAGS.num_groups) + '.emb_type' + str(
322 | FLAGS.emb_type) + '.hidden_dim' + str(FLAGS.hidden_dim)
323 |
324 | if FLAGS.num_filters != 64:
325 | exp_string += 'hidden' + str(FLAGS.num_filters)
326 | if FLAGS.max_pool:
327 | exp_string += 'maxpool'
328 | if FLAGS.stop_grad:
329 | exp_string += 'stopgrad'
330 | if FLAGS.norm == 'batch_norm':
331 | exp_string += 'batchnorm'
332 | elif FLAGS.norm == 'layer_norm':
333 | exp_string += 'layernorm'
334 | elif FLAGS.norm == 'None':
335 | exp_string += 'nonorm'
336 | else:
337 | print('Norm setting not recognized.')
338 |
339 | resume_itr = 0
340 | model_file = None
341 |
342 | tf.global_variables_initializer().run()
343 | tf.train.start_queue_runners()
344 |
345 | print(exp_string)
346 |
347 | if FLAGS.resume or not FLAGS.train:
348 | if FLAGS.train == True:
349 | model_file = tf.train.latest_checkpoint(FLAGS.logdir + '/' + exp_string)
350 | else:
351 | print(FLAGS.test_epoch)
352 | model_file = '{0}/{2}/model{1}'.format(FLAGS.logdir, FLAGS.test_epoch, exp_string)
353 | if model_file:
354 | ind1 = model_file.index('model')
355 | resume_itr = int(model_file[ind1 + 5:])
356 | print("Restoring model weights from " + model_file)
357 | saver.restore(sess, model_file)
358 |
359 | if FLAGS.train:
360 | train(model, saver, sess, exp_string, data_generator, resume_itr)
361 | else:
362 | test(model, saver, sess, exp_string, data_generator, test_num_updates)
363 |
364 |
365 | if __name__ == "__main__":
366 | main()
367 |
--------------------------------------------------------------------------------
/maml.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import sys
4 |
5 | import numpy as np
6 | import tensorflow as tf
7 |
8 | from image_embedding import ImageEmbedding
9 | from lstm_tree import TreeLSTM
10 |
11 | try:
12 | import special_grads
13 | except KeyError as e:
14 | print('WARN: Cannot define MaxPoolGrad, likely already defined for this version of tensorflow: %s' % e,
15 | file=sys.stderr)
16 |
17 | from tensorflow.python.platform import flags
18 | from utils import mse, xent, conv_block, normalize
19 | from task_embedding import LSTMAutoencoder, MeanAutoencoder
20 |
21 | FLAGS = flags.FLAGS
22 |
23 |
24 | class MAML:
25 | def __init__(self, sess, dim_input=1, dim_output=1, test_num_updates=5):
26 | """ must call construct_model() after initializing MAML! """
27 | self.dim_input = dim_input
28 | self.dim_output = dim_output
29 | self.update_lr = FLAGS.update_lr
30 | self.meta_lr = tf.placeholder_with_default(FLAGS.meta_lr, ())
31 | self.classification = False
32 | self.test_num_updates = test_num_updates
33 | self.sess = sess
34 | if FLAGS.task_embedding_type == 'rnn':
35 | self.lstmae = LSTMAutoencoder(hidden_num=FLAGS.hidden_dim)
36 | elif FLAGS.task_embedding_type == 'mean':
37 | self.lstmae = MeanAutoencoder(hidden_num=FLAGS.hidden_dim)
38 | self.tree = TreeLSTM(input_dim=FLAGS.hidden_dim, tree_hidden_dim=FLAGS.hidden_dim)
39 | if FLAGS.datasource in ['sinusoid', 'mixture']:
40 | self.dim_hidden = [40, 40]
41 | self.loss_func = mse
42 | self.forward = self.forward_fc
43 | self.construct_weights = self.construct_fc_weights
44 | elif FLAGS.datasource in ['omniglot', 'miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
45 | self.loss_func = xent
46 | self.classification = True
47 | if FLAGS.conv:
48 | self.dim_hidden = FLAGS.num_filters
49 | self.forward = self.forward_conv
50 | self.construct_weights = self.construct_conv_weights
51 | else:
52 | self.dim_hidden = [256, 128, 64, 64]
53 | self.forward = self.forward_fc
54 | self.construct_weights = self.construct_fc_weights
55 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
56 | self.channels = 3
57 | else:
58 | self.channels = 1
59 | self.img_size = int(np.sqrt(self.dim_input / self.channels))
60 | self.image_embed = ImageEmbedding(hidden_num=FLAGS.task_embedding_num_filters, channels=self.channels,
61 | conv_initializer=tf.truncated_normal_initializer(stddev=0.04))
62 | else:
63 | raise ValueError('Unrecognized data source.')
64 |
65 | def construct_model(self, input_tensors=None, prefix='metatrain_'):
66 | # a: training data for inner gradient, b: test data for meta gradient
67 | if input_tensors is None:
68 | if FLAGS.datasource in ['sinusoid', 'mixture']:
69 | self.inputa = tf.placeholder(tf.float32, shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size, 1))
70 | self.inputb = tf.placeholder(tf.float32,
71 | shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size_eval, 1))
72 | self.labela = tf.placeholder(tf.float32, shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size, 1))
73 | self.labelb = tf.placeholder(tf.float32,
74 | shape=(FLAGS.meta_batch_size, FLAGS.update_batch_size_eval, 1))
75 | else:
76 | self.inputa = tf.placeholder(tf.float32)
77 | self.inputb = tf.placeholder(tf.float32)
78 | self.labela = tf.placeholder(tf.float32)
79 | self.labelb = tf.placeholder(tf.float32)
80 | else:
81 | self.inputa = input_tensors['inputa']
82 | self.inputb = input_tensors['inputb']
83 | self.labela = input_tensors['labela']
84 | self.labelb = input_tensors['labelb']
85 | # tf.summary.scalar('lr', self.update_lr)
86 |
87 | with tf.variable_scope('model', reuse=None) as training_scope:
88 | if 'weights' in dir(self):
89 | training_scope.reuse_variables()
90 | weights = self.weights
91 | else:
92 | # Define the weights
93 | self.weights = weights = self.construct_weights()
94 |
95 | # outputbs[i] and lossesb[i] is the output and loss after i+1 gradient updates
96 | lossesa, outputas, lossesb, outputbs, emb_loss = [], [], [], [], []
97 | accuraciesa, accuraciesb = [], []
98 | num_updates = max(self.test_num_updates, FLAGS.num_updates)
99 | outputbs = [[]] * num_updates
100 | lossesb = [[]] * num_updates
101 | accuraciesb = [[]] * num_updates
102 |
103 | def task_metalearn(inp, reuse=True):
104 | """ Perform gradient descent for one task in the meta-batch. """
105 | inputa, inputb, labela, labelb = inp
106 | if FLAGS.datasource in ['sinusoid', 'mixture']:
107 | input_task_emb = tf.concat((inputa, labela), axis=-1)
108 | elif FLAGS.datasource in ['miniimagenet', 'omniglot', 'multidataset', 'multidataset_leave_one_out']:
109 | if FLAGS.fix_embedding_sample != -1:
110 | input_task_emb = self.image_embed.model(tf.reshape(inputa[:FLAGS.fix_embedding_sample],
111 | [-1, self.img_size, self.img_size,
112 | self.channels]))
113 | one_hot_labela = tf.squeeze(
114 | tf.one_hot(tf.to_int32(labela[:FLAGS.fix_embedding_sample]), depth=1, axis=-1))
115 | else:
116 | input_task_emb = self.image_embed.model(tf.reshape(inputa,
117 | [-1, self.img_size, self.img_size,
118 | self.channels]))
119 | one_hot_labela = tf.squeeze(
120 | tf.one_hot(tf.to_int32(labela), depth=1, axis=-1))
121 | input_task_emb = tf.concat((input_task_emb, one_hot_labela), axis=-1)
122 |
123 | task_embed_vec, task_emb_loss = self.lstmae.model(input_task_emb)
124 |
125 | _, meta_knowledge_h = self.tree.model(task_embed_vec)
126 |
127 | task_enhanced_emb_vec = tf.concat([task_embed_vec, meta_knowledge_h], axis=1)
128 |
129 | with tf.variable_scope('task_specific_mapping', reuse=tf.AUTO_REUSE):
130 | eta = []
131 | for key in weights.keys():
132 | weight_size = np.prod(weights[key].get_shape().as_list())
133 | eta.append(tf.reshape(
134 | tf.layers.dense(task_enhanced_emb_vec, weight_size, activation=tf.nn.sigmoid,
135 | name='eta_{}'.format(key)), tf.shape(weights[key])))
136 | eta = dict(zip(weights.keys(), eta))
137 |
138 | task_weights = dict(zip(weights.keys(), [weights[key] * eta[key] for key in weights.keys()]))
139 |
140 | task_outputbs, task_lossesb = [], []
141 |
142 | if self.classification:
143 | task_accuraciesb = []
144 |
145 | task_outputa = self.forward(inputa, task_weights, reuse=reuse)
146 | task_lossa = self.loss_func(task_outputa, labela)
147 |
148 | grads = tf.gradients(task_lossa, list(task_weights.values()))
149 | if FLAGS.stop_grad:
150 | grads = [tf.stop_gradient(grad) for grad in grads]
151 | gradients = dict(zip(task_weights.keys(), grads))
152 | fast_weights = dict(
153 | zip(task_weights.keys(),
154 | [task_weights[key] - self.update_lr * gradients[key] for key in task_weights.keys()]))
155 | output = self.forward(inputb, fast_weights, reuse=True)
156 | task_outputbs.append(output)
157 | task_lossesb.append(self.loss_func(output, labelb))
158 | for j in range(num_updates - 1):
159 | loss = self.loss_func(self.forward(inputa, fast_weights, reuse=True), labela)
160 | grads = tf.gradients(loss, list(fast_weights.values()))
161 | if FLAGS.stop_grad:
162 | grads = [tf.stop_gradient(grad) for grad in grads]
163 | gradients = dict(zip(fast_weights.keys(), grads))
164 | fast_weights = dict(zip(fast_weights.keys(),
165 | [fast_weights[key] - self.update_lr * gradients[key] for key in
166 | fast_weights.keys()]))
167 | output = self.forward(inputb, fast_weights, reuse=True)
168 | task_outputbs.append(output)
169 | task_lossesb.append(self.loss_func(output, labelb))
170 |
171 | task_output = [task_emb_loss, task_outputa, task_outputbs, task_lossa, task_lossesb]
172 |
173 | if self.classification:
174 | task_accuracya = tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputa), 1),
175 | tf.argmax(labela, 1))
176 | for j in range(num_updates):
177 | task_accuraciesb.append(
178 | tf.contrib.metrics.accuracy(tf.argmax(tf.nn.softmax(task_outputbs[j]), 1),
179 | tf.argmax(labelb, 1)))
180 | task_output.extend([task_accuracya, task_accuraciesb])
181 |
182 | return task_output
183 |
184 | if FLAGS.norm != 'None':
185 | # to initialize the batch norm vars, might want to combine this, and not run idx 0 twice.
186 | unused = task_metalearn((self.inputa[0], self.inputb[0], self.labela[0], self.labelb[0]), False)
187 |
188 | out_dtype = [tf.float32, tf.float32, [tf.float32] * num_updates, tf.float32, [tf.float32] * num_updates]
189 | if self.classification:
190 | out_dtype.extend([tf.float32, [tf.float32] * num_updates])
191 | result = tf.map_fn(task_metalearn, elems=(self.inputa, self.inputb, self.labela, self.labelb),
192 | dtype=out_dtype, parallel_iterations=FLAGS.meta_batch_size)
193 | if self.classification:
194 | emb_loss, outputas, outputbs, lossesa, lossesb, accuraciesa, accuraciesb = result
195 | else:
196 | emb_loss, outputas, outputbs, lossesa, lossesb = result
197 |
198 | ## Performance & Optimization
199 | if 'train' in prefix:
200 | self.total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
201 | self.total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j
202 | in range(num_updates)]
203 | self.total_embed_loss = total_embed_loss = tf.reduce_sum(emb_loss) / tf.to_float(FLAGS.meta_batch_size)
204 | # after the map_fn
205 | self.outputas, self.outputbs = outputas, outputbs
206 | if self.classification:
207 | self.total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(FLAGS.meta_batch_size)
208 | self.total_accuracies2 = total_accuracies2 = [
209 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
210 | self.pretrain_op = tf.train.AdamOptimizer(self.meta_lr).minimize(total_loss1)
211 |
212 | if FLAGS.metatrain_iterations > 0:
213 | optimizer = tf.train.AdamOptimizer(self.meta_lr)
214 | self.gvs = gvs = optimizer.compute_gradients(
215 | self.total_losses2[FLAGS.num_updates - 1] + FLAGS.emb_loss_weight * self.total_embed_loss)
216 | if FLAGS.task_embedding_type == 'mean':
217 | gvs = [(tf.clip_by_value(grad, -3, 3), var) for grad, var in gvs]
218 | if FLAGS.datasource == 'miniimagenet':
219 | gvs = [(tf.clip_by_value(grad, -10, 10), var) for grad, var in gvs]
220 | self.metatrain_op = optimizer.apply_gradients(gvs)
221 | else:
222 | self.metaval_total_loss1 = total_loss1 = tf.reduce_sum(lossesa) / tf.to_float(FLAGS.meta_batch_size)
223 | self.metaval_total_losses2 = total_losses2 = [tf.reduce_sum(lossesb[j]) / tf.to_float(FLAGS.meta_batch_size)
224 | for j in range(num_updates)]
225 | if self.classification:
226 | self.metaval_total_accuracy1 = total_accuracy1 = tf.reduce_sum(accuraciesa) / tf.to_float(
227 | FLAGS.meta_batch_size)
228 | self.metaval_total_accuracies2 = total_accuracies2 = [
229 | tf.reduce_sum(accuraciesb[j]) / tf.to_float(FLAGS.meta_batch_size) for j in range(num_updates)]
230 |
231 | ## Summaries
232 | tf.summary.scalar(prefix + 'Pre-update loss', total_loss1)
233 | if self.classification:
234 | tf.summary.scalar(prefix + 'Pre-update accuracy', total_accuracy1)
235 |
236 | for j in range(num_updates):
237 | tf.summary.scalar(prefix + 'Post-update loss, step ' + str(j + 1), total_losses2[j])
238 | if self.classification:
239 | tf.summary.scalar(prefix + 'Post-update accuracy, step ' + str(j + 1), total_accuracies2[j])
240 |
241 | ### Network construction functions (fc networks and conv networks)
242 | def construct_fc_weights(self):
243 | weights = {}
244 | weights['w1'] = tf.Variable(tf.truncated_normal([self.dim_input, self.dim_hidden[0]], stddev=0.01))
245 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden[0]]))
246 | for i in range(1, len(self.dim_hidden)):
247 | weights['w' + str(i + 1)] = tf.Variable(
248 | tf.truncated_normal([self.dim_hidden[i - 1], self.dim_hidden[i]], stddev=0.01))
249 | weights['b' + str(i + 1)] = tf.Variable(tf.zeros([self.dim_hidden[i]]))
250 | weights['w' + str(len(self.dim_hidden) + 1)] = tf.Variable(
251 | tf.truncated_normal([self.dim_hidden[-1], self.dim_output], stddev=0.01))
252 | weights['b' + str(len(self.dim_hidden) + 1)] = tf.Variable(tf.zeros([self.dim_output]))
253 | return weights
254 |
255 | def forward_fc(self, inp, weights, reuse=False):
256 | hidden = normalize(tf.matmul(inp, weights['w1']) + weights['b1'], activation=tf.nn.relu, reuse=reuse, scope='0')
257 | for i in range(1, len(self.dim_hidden)):
258 | hidden = normalize(tf.matmul(hidden, weights['w' + str(i + 1)]) + weights['b' + str(i + 1)],
259 | activation=tf.nn.relu, reuse=reuse, scope=str(i + 1))
260 | return tf.matmul(hidden, weights['w' + str(len(self.dim_hidden) + 1)]) + weights[
261 | 'b' + str(len(self.dim_hidden) + 1)]
262 |
263 | def construct_conv_weights(self):
264 | weights = {}
265 |
266 | dtype = tf.float32
267 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
268 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
269 | k = 3
270 |
271 | weights['conv1'] = tf.get_variable('conv1', [k, k, self.channels, self.dim_hidden],
272 | initializer=conv_initializer, dtype=dtype)
273 | weights['b1'] = tf.Variable(tf.zeros([self.dim_hidden]))
274 | weights['conv2'] = tf.get_variable('conv2', [k, k, self.dim_hidden, self.dim_hidden],
275 | initializer=conv_initializer, dtype=dtype)
276 | weights['b2'] = tf.Variable(tf.zeros([self.dim_hidden]))
277 | weights['conv3'] = tf.get_variable('conv3', [k, k, self.dim_hidden, self.dim_hidden],
278 | initializer=conv_initializer, dtype=dtype)
279 | weights['b3'] = tf.Variable(tf.zeros([self.dim_hidden]))
280 | weights['conv4'] = tf.get_variable('conv4', [k, k, self.dim_hidden, self.dim_hidden],
281 | initializer=conv_initializer, dtype=dtype)
282 | weights['b4'] = tf.Variable(tf.zeros([self.dim_hidden]))
283 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
284 | # assumes max pooling
285 | weights['w5'] = tf.get_variable('w5', [self.dim_hidden * 5 * 5, self.dim_output],
286 | initializer=fc_initializer)
287 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
288 | else:
289 | weights['w5'] = tf.Variable(tf.random_normal([self.dim_hidden, self.dim_output]), name='w5')
290 | weights['b5'] = tf.Variable(tf.zeros([self.dim_output]), name='b5')
291 | return weights
292 |
293 | def forward_conv(self, inp, weights, reuse=False, scope=''):
294 | channels = self.channels
295 | inp = tf.reshape(inp, [-1, self.img_size, self.img_size, channels])
296 |
297 | hidden1 = conv_block(inp, weights['conv1'], weights['b1'], reuse, scope + '0')
298 | hidden2 = conv_block(hidden1, weights['conv2'], weights['b2'], reuse, scope + '1')
299 | hidden3 = conv_block(hidden2, weights['conv3'], weights['b3'], reuse, scope + '2')
300 | hidden4 = conv_block(hidden3, weights['conv4'], weights['b4'], reuse, scope + '3')
301 | if FLAGS.datasource in ['miniimagenet', 'multidataset', 'multidataset_leave_one_out']:
302 | hidden4 = tf.reshape(hidden4, [-1, np.prod([int(dim) for dim in hidden4.get_shape()[1:]])])
303 | else:
304 | hidden4 = tf.reduce_mean(hidden4, [1, 2])
305 |
306 | return tf.matmul(hidden4, weights['w5']) + weights['b5']
307 |
--------------------------------------------------------------------------------
/multidataset_bash/HSML_multidataset_1shot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --datasource=multidataset --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=1 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_1shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01
3 |
4 | python main.py --datasource=multidataset --metatrain_iterations=60000 --meta_batch_size=4 --update_batch_size=1 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_1shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 --test_set=True --test_epoch=59000 --train=False --test_dataset=1
5 |
--------------------------------------------------------------------------------
/multidataset_bash/HSML_multidataset_5shot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --datasource=multidataset --metatrain_iterations=50000 --meta_batch_size=4 --update_batch_size=5 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_5shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 --fix_embedding_sample=10
3 |
4 | python main.py --datasource=multidataset --metatrain_iterations=50000 --meta_batch_size=4 --update_batch_size=5 --update_lr=0.01 --num_updates=5 --num_classes=5 --logdir=../Check_point/multidataset_5shot/ --num_filters=32 --max_pool=True --hidden_dim=128 --emb_loss_weight=0.01 --test_set=True --test_epoch=49000 --train=False --test_dataset=1
5 |
--------------------------------------------------------------------------------
/special_grads.py:
--------------------------------------------------------------------------------
1 | from tensorflow.python.framework import ops
2 | from tensorflow.python.ops import array_ops
3 | from tensorflow.python.ops import gen_nn_ops
4 |
5 | @ops.RegisterGradient("MaxPoolGrad")
6 | def _MaxPoolGradGrad(op, grad):
7 | gradient = gen_nn_ops._max_pool_grad(op.inputs[0], op.outputs[0],
8 | grad, op.get_attr("ksize"), op.get_attr("strides"),
9 | padding=op.get_attr("padding"), data_format=op.get_attr("data_format"))
10 | gradgrad1 = array_ops.zeros(shape = array_ops.shape(op.inputs[1]), dtype=gradient.dtype)
11 | gradgrad2 = array_ops.zeros(shape = array_ops.shape(op.inputs[2]), dtype=gradient.dtype)
12 | return (gradient, gradgrad1, gradgrad2)
13 |
--------------------------------------------------------------------------------
/task_embedding.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.ops.rnn_cell import GRUCell
3 |
4 | from tensorflow.python.platform import flags
5 |
6 |
7 | FLAGS = flags.FLAGS
8 | import ipdb
9 |
10 |
11 | class LSTMAutoencoder(object):
12 | def __init__(self, hidden_num, cell=None, reverse=True, decode_without_input=False):
13 | if cell is None:
14 | self._enc_cell = GRUCell(hidden_num, name='encoder_cell')
15 | self._dec_cell = GRUCell(hidden_num, name='decoder_cell')
16 | else:
17 | self._enc_cell = cell
18 | self._dec_cell = cell
19 | self.reverse = reverse
20 | self.decode_without_input = decode_without_input
21 | self.hidden_num = hidden_num
22 |
23 | if FLAGS.datasource in ['sinusoid', 'mixture']:
24 | self.elem_num_init = 2
25 | self.elem_num=20
26 |
27 | elif FLAGS.datasource in ['miniimagenet', 'omniglot','multidataset', 'multidataset_leave_one_out']:
28 | self.elem_num = FLAGS.num_classes + 64
29 |
30 | self.dec_weight = tf.Variable(tf.truncated_normal([self.hidden_num,
31 | self.elem_num], dtype=tf.float32), name='dec_weight')
32 | self.dec_bias = tf.Variable(tf.constant(0.1, shape=[self.elem_num],
33 | dtype=tf.float32), name='dec_bias')
34 |
35 | def model(self, inputs):
36 |
37 | if FLAGS.datasource in ['sinusoid', 'mixture']:
38 | with tf.variable_scope('first_embedding_sync', reuse=tf.AUTO_REUSE):
39 | inputs = tf.layers.dense(inputs, units=self.elem_num, name='first_embedding_sync_dense')
40 |
41 | inputs = tf.expand_dims(inputs, 0)
42 |
43 | inputs = tf.unstack(inputs, axis=1)
44 |
45 | self.batch_num = FLAGS.meta_batch_size
46 |
47 | with tf.variable_scope('encoder'):
48 | (self.z_codes, self.enc_state) = tf.contrib.rnn.static_rnn(self._enc_cell, inputs, dtype=tf.float32)
49 |
50 | with tf.variable_scope('decoder') as vs:
51 |
52 | if self.decode_without_input:
53 | dec_inputs = [tf.zeros(tf.shape(inputs[0]), dtype=tf.float32) for _ in range(len(inputs))]
54 | (dec_outputs, dec_state) = tf.contrib.rnn.static_rnn(self._dec_cell, dec_inputs,
55 | initial_state=self.enc_state,
56 | dtype=tf.float32)
57 | if self.reverse:
58 | dec_outputs = dec_outputs[::-1]
59 | dec_output_ = tf.transpose(tf.stack(dec_outputs), [1, 0, 2])
60 | dec_weight_ = tf.tile(tf.expand_dims(self.dec_weight, 0), [self.batch_num, 1, 1])
61 | self.output_ = tf.matmul(dec_weight_, dec_output_) + self.dec_bias
62 | else:
63 | dec_state = self.enc_state
64 | dec_input_ = tf.zeros(tf.shape(inputs[0]),
65 | dtype=tf.float32)
66 |
67 | dec_outputs = []
68 | for step in range(len(inputs)):
69 | if step > 0:
70 | vs.reuse_variables()
71 | (dec_input_, dec_state) = \
72 | self._dec_cell(dec_input_, dec_state)
73 | dec_input_ = tf.matmul(dec_input_, self.dec_weight) + self.dec_bias
74 | dec_outputs.append(dec_input_)
75 | if self.reverse:
76 | dec_outputs = dec_outputs[::-1]
77 | self.output_ = tf.transpose(tf.stack(dec_outputs), [1, 0, 2])
78 |
79 | self.input_ = tf.transpose(tf.stack(inputs), [1, 0, 2])
80 | self.loss = tf.reduce_mean(tf.square(self.input_ - self.output_))
81 | self.emb_all = tf.reduce_mean(self.z_codes, axis=0)
82 |
83 | return self.emb_all, self.loss
84 |
85 | class MeanAutoencoder(object):
86 | def __init__(self, hidden_num):
87 | self.hidden_num = hidden_num
88 |
89 | if FLAGS.datasource in ['sinusoid', 'mixture']:
90 | self.elem_num = 2
91 | self.hidden_num_mid = 20
92 | elif FLAGS.datasource in ['miniimagenet', 'omniglot','multidataset', 'multidataset_leave_one_out']:
93 | self.elem_num = FLAGS.num_classes + 64
94 | self.hidden_num_mid = 96
95 |
96 | def model(self, inputs):
97 | with tf.variable_scope('encoder', reuse=tf.AUTO_REUSE):
98 | enc_dense1 = tf.layers.dense(inputs, units=self.hidden_num_mid, activation=tf.nn.relu, name='encoder_dense1')
99 | enc_dense2 = tf.layers.dense(enc_dense1, units=self.hidden_num, activation=tf.nn.relu, name='encoder_dense2')
100 |
101 | with tf.variable_scope('decoder', reuse=tf.AUTO_REUSE):
102 | dec_dense1= tf.layers.dense(enc_dense2, units=self.hidden_num_mid, activation=tf.nn.relu, name='decoder_dense1')
103 | dec_dense2 = tf.layers.dense(dec_dense1, units=self.elem_num, activation=None,
104 | name='decoder_dense2')
105 | emb_pool = tf.reduce_mean(enc_dense2, axis=0, keepdims=True)
106 | with tf.variable_scope('last_fc', reuse=tf.AUTO_REUSE):
107 | self.emb_all = tf.layers.dense(emb_pool, units=self.hidden_num, activation=tf.nn.relu, name='mean_pool')
108 | self.loss = 0.5*tf.reduce_mean(tf.square(inputs-dec_dense2))
109 |
110 | return self.emb_all, self.loss
111 |
112 |
113 |
114 |
115 |
116 |
117 |
--------------------------------------------------------------------------------
/toygroup_bash/HSML_toygroup_10shot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=10 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_10shot --emb_loss_weight=0.01 --hidden_dim=40
3 |
4 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=10 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_10shot --emb_loss_weight=0.01 --hidden_dim=40 --test_set=True --test_epoch=69000 --train=False --num_test_task=4000
5 |
--------------------------------------------------------------------------------
/toygroup_bash/HSML_toygroup_5shot.sh:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env bash
2 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=5 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_5shot --emb_loss_weight=0.01 --hidden_dim=40
3 |
4 | python main.py --datasource=mixture --metatrain_iterations=70000 --norm=None --update_batch_size=5 --update_batch_size_eval=10 --resume=False --num_updates=5 --logdir=../Check_point/syncgroup_5shot --emb_loss_weight=0.01 --hidden_dim=40 --test_set=True --test_epoch=69000 --train=False --num_test_task=4000
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | """ Utility functions. """
2 | import numpy as np
3 | import os
4 | import random
5 | import tensorflow as tf
6 |
7 | from tensorflow.contrib.layers.python import layers as tf_layers
8 | from tensorflow.python.platform import flags
9 |
10 | FLAGS = flags.FLAGS
11 |
12 | ## Image helper
13 | def get_images(paths, labels, nb_samples=None, shuffle=True):
14 | if nb_samples is not None:
15 | sampler = lambda x: random.sample(x, nb_samples)
16 | else:
17 | sampler = lambda x: x
18 | images = [(i, os.path.join(path, image)) \
19 | for i, path in zip(labels, paths) \
20 | for image in sampler(os.listdir(path))]
21 | if shuffle:
22 | random.shuffle(images)
23 | return images
24 |
25 | ## Network helpers
26 | def conv_block(inp, cweight, bweight, reuse, scope, activation=tf.nn.relu, max_pool_pad='VALID', residual=False):
27 | """ Perform, conv, batch norm, nonlinearity, and max pool """
28 | stride, no_stride = [1,2,2,1], [1,1,1,1]
29 |
30 | if FLAGS.max_pool:
31 | conv_output = tf.nn.conv2d(inp, cweight, no_stride, 'SAME') + bweight
32 | else:
33 | conv_output = tf.nn.conv2d(inp, cweight, stride, 'SAME') + bweight
34 | normed = normalize(conv_output, activation, reuse, scope)
35 | if FLAGS.max_pool:
36 | normed = tf.nn.max_pool(normed, stride, stride, max_pool_pad)
37 | return normed
38 |
39 | def normalize(inp, activation, reuse, scope):
40 | if FLAGS.norm == 'batch_norm':
41 | return tf_layers.batch_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
42 | elif FLAGS.norm == 'layer_norm':
43 | return tf_layers.layer_norm(inp, activation_fn=activation, reuse=reuse, scope=scope)
44 | elif FLAGS.norm == 'None':
45 | if activation is not None:
46 | return activation(inp)
47 | else:
48 | return inp
49 |
50 | ## Loss functions
51 | def mse(pred, label):
52 | pred = tf.reshape(pred, [-1])
53 | label = tf.reshape(label, [-1])
54 | return tf.reduce_mean(tf.square(pred-label))
55 |
56 | def xent(pred, label):
57 | return tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=label) / FLAGS.update_batch_size
58 |
--------------------------------------------------------------------------------