├── tfops ├── __init__.py ├── train_op.py ├── init_op.py ├── info_op.py ├── loss.py ├── lr_op.py ├── transform_op.py └── nets.py ├── utils ├── __init__.py ├── tqdm_op.py ├── gpu_op.py ├── datasetmanager.py ├── shutil_op.py ├── logger_op.py ├── visual_op.py ├── format_op.py ├── img_tools.py ├── reader_op.py ├── writer_op.py ├── np_op.py ├── datamanager.py ├── general_class.py ├── eval_op.py └── ortools_op.py ├── config ├── parser.py └── path.py ├── LICENSE ├── Dsprites_exp ├── CascadeVAE-C │ ├── local_config.py │ ├── main.py │ └── model.py └── CascadeVAE │ ├── local_config.py │ ├── main.py │ └── model.py ├── README.md ├── ex.yml └── .gitignore /tfops/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | -------------------------------------------------------------------------------- /utils/tqdm_op.py: -------------------------------------------------------------------------------- 1 | from tqdm import tqdm 2 | 3 | def tqdm_range(*args, **kwargs): 4 | return tqdm(range(*args, **kwargs), ascii=True, desc="batch") 5 | 6 | -------------------------------------------------------------------------------- /utils/gpu_op.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | def selectGpuById(id_): 4 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 5 | os.environ["CUDA_VISIBLE_DEVICES"] = "{}".format(id_) 6 | -------------------------------------------------------------------------------- /tfops/train_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_train_op_v2(optimizer, loss, var_list=tf.trainable_variables()): 4 | return optimizer.minimize(loss=loss, var_list=var_list) 5 | -------------------------------------------------------------------------------- /config/parser.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | def dsprites_parser(): 4 | parser = argparse.ArgumentParser() 5 | 6 | parser.add_argument("--gpu", default = 0, help="Utilize which gpu", type = int) 7 | parser.add_argument("--nbatch", default = 64, help="size of batch", type = int) 8 | parser.add_argument("--dataset", default = 'dsprites', help="dataset to be used", type = str) 9 | 10 | return parser 11 | -------------------------------------------------------------------------------- /utils/datasetmanager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | from config.path import DSPRITESPATH 6 | from utils.reader_op import read_npy 7 | from utils.datamanager import DspritesManager 8 | 9 | import numpy as np 10 | 11 | def dsprites_manager(): 12 | dataset_zip = read_npy(DSPRITESPATH+'dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', encoding='latin1') 13 | dm = DspritesManager(dataset_zip) 14 | return dm 15 | 16 | -------------------------------------------------------------------------------- /utils/shutil_op.py: -------------------------------------------------------------------------------- 1 | import shutil 2 | import os 3 | 4 | def remove_file(path): 5 | print("Remove {}".format(path)) 6 | os.remove(path) 7 | 8 | def remove_dir(path): 9 | print("Remove {}".format(path)) 10 | shutil.rmtree(path) 11 | 12 | def copy_file(src_path, dst_path): 13 | print("Copy from {} to {}".format(src_path, dst_path)) 14 | shutil.copy(src_path, dst_path) 15 | 16 | def copy_dir(src_path, dst_path): 17 | print("Copy from {} to {}".format(src_path, dst_path)) 18 | shutil.copytree(src_path, dst_path) 19 | -------------------------------------------------------------------------------- /utils/logger_op.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logging.basicConfig(level=logging.DEBUG, format="[%(asctime)s] %(message)s", datefmt="%m%d %H:%M:%S" ) 3 | 4 | class LoggerManager(object): 5 | def __init__(self, filepath, name="logger"): 6 | self.logger = logging.getLogger(name) 7 | self.logger.addHandler(logging.FileHandler(filepath)) 8 | 9 | def info(self, string): 10 | self.logger.info(string) 11 | 12 | def remove(self): 13 | for handler in self.logger.handlers: 14 | handler.close() 15 | self.logger.removeHandler(handler) 16 | del self.logger 17 | -------------------------------------------------------------------------------- /tfops/init_op.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | from tfops.info_op import get_uninit_vars 6 | 7 | import tensorflow as tf 8 | 9 | def rest_initializer(sess): 10 | print("Initialize uninitialized variables") 11 | sess.run(tf.variables_initializer(get_uninit_vars(sess))) 12 | 13 | def full_initializer(sess): 14 | print("Initialize all variables") 15 | sess.run(tf.global_initializer()) 16 | 17 | def local_initializer(sess, var_list, print_option=False): 18 | if print_option: print("Initialize specific variables") 19 | sess.run(tf.variables_initializer(var_list)) 20 | 21 | -------------------------------------------------------------------------------- /tfops/info_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def get_shape(t): 4 | return t.get_shape().as_list() 5 | 6 | def vars_info_vl(var_list): 7 | return " "+"\n ".join(["{} : {}".format(v.name, get_shape(v)) for v in var_list]) 8 | 9 | def vars_info(string): 10 | return "Collection name %s\n"%string+vars_info_vl(tf.get_collection(string)) 11 | 12 | def get_init_vars(sess): 13 | init_vars = [] 14 | for var in tf.global_variables(): 15 | try: sess.run(var) 16 | except tf.errors.FailedPreconditionError: continue 17 | init_vars.append(var) 18 | return init_vars 19 | 20 | def get_uninit_vars(sess): 21 | uninit_vars = [] 22 | for var in tf.global_variables(): 23 | try : sess.run(var) 24 | except tf.errors.FailedPreconditionError: uninit_vars.append(var) 25 | return uninit_vars 26 | 27 | -------------------------------------------------------------------------------- /tfops/loss.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | 4 | def sigmoid_cross_entropy(labels, logits): 5 | return tf.reduce_sum(tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=0)) 6 | 7 | def sigmoid_cross_entropy_without_mean(labels, logits): 8 | ndim = len(labels.get_shape().as_list()) 9 | return tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits), axis=[idx for idx in range(1, ndim)]) 10 | 11 | def vae_kl_cost(mean, stddev, epsilon=1e-8): 12 | return tf.reduce_sum(tf.reduce_mean(0.5 * (tf.square(mean) + tf.square(stddev) - 2.0 * tf.log(stddev + epsilon) - 1.0), axis=0)) 13 | 14 | def vae_kl_cost_weight(mean, stddev, weight, epsilon=1e-8): 15 | return tf.reduce_sum(tf.multiply(tf.reduce_mean(0.5 * (tf.square(mean) + tf.square(stddev) - 2.0 * tf.log(stddev + epsilon) - 1.0), axis=0), weight)) 16 | 17 | -------------------------------------------------------------------------------- /config/path.py: -------------------------------------------------------------------------------- 1 | ROOT = '(Type here)' 2 | DSPRITESPATH = '(Type here)' 3 | 4 | import os 5 | import sys 6 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 7 | 8 | from utils.writer_op import create_muldir, create_dir 9 | 10 | def subdirs5resultdir(result_dir, generate_option=False): 11 | save_dir = result_dir+'save/' 12 | log_dir = result_dir+'log/' 13 | asset_dir = result_dir+'asset/' 14 | if generate_option: create_muldir(save_dir, log_dir, asset_dir) 15 | 16 | return save_dir, log_dir, asset_dir 17 | 18 | def dir2subdir(dir_path, file_id, generate_option=False): 19 | subdir_path = dir_path+'%s/'%file_id 20 | if generate_option: create_dir(subdir_path) 21 | return subdir_path 22 | 23 | def muldir2mulsubdir(dir_pathes, file_id, generate_option=False): 24 | subdir_pathes = list() 25 | for dir_path in dir_pathes: subdir_pathes.append(dir2subdir(dir_path=dir_path, file_id=file_id, generate_option=generate_option)) 26 | return subdir_pathes 27 | 28 | -------------------------------------------------------------------------------- /utils/visual_op.py: -------------------------------------------------------------------------------- 1 | import sys 2 | import os 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | import numpy as np 6 | import matplotlib.pyplot as plt 7 | 8 | def matrix_image2big_image(matrix_image, row_margin=5, col_margin=5): 9 | nrow, ncol, height, width, nch = matrix_image.shape 10 | big_row = nrow*height + (nrow+1)*row_margin 11 | big_col = ncol*width + (ncol+1)*col_margin 12 | big_image = np.ones([big_row, big_col, nch]) 13 | 14 | for r_idx in range(nrow): 15 | for c_idx in range(ncol): 16 | for h_idx in range(height): 17 | for w_idx in range(width): 18 | big_image_h_idx = r_idx*(height+row_margin)+h_idx+row_margin 19 | big_image_w_idx = c_idx*(width+col_margin)+w_idx+col_margin 20 | for ch_idx in range(nch): big_image[big_image_h_idx][big_image_w_idx][ch_idx] = matrix_image[r_idx][c_idx][h_idx][w_idx][ch_idx] 21 | return np.squeeze(big_image) 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 snu-mllab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /Dsprites_exp/CascadeVAE-C/local_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | from config.path import ROOT 6 | from config.parser import dsprites_parser 7 | 8 | KEY = 'CascadeVAE-C' 9 | RESULT_DIR = ROOT+'{}/'.format(KEY) 10 | TEST_ASSET_DIR = './asset/' 11 | ID_STRUCTURE_DICT = { 12 | 'CascadeVAE-C' : ('nbatch', 'nconti', 'beta_min', 'beta_max', 'dptype', 'rseed'), 13 | } 14 | ID_STRUCTURE = ID_STRUCTURE_DICT[KEY] 15 | 16 | def local_dsprites_parser(): 17 | parser = dsprites_parser() 18 | parser.add_argument("--rseed", default = 0, help="random seed", type = int) 19 | parser.add_argument("--beta_min", default = 0.1, help="min value for +beta*kl_cost", type = float) 20 | parser.add_argument("--beta_max", default = 10.0, help="max value for +beta*kl_cost", type = float) 21 | parser.add_argument("--dtype", default = 'stair', help="decay type", type = str) 22 | parser.add_argument("--dptype", default = 'a3', help="decay parameter type", type = str) 23 | parser.add_argument("--nconti", default = 10, help="the dimension of continuous representation", type = int) 24 | return parser 25 | 26 | -------------------------------------------------------------------------------- /utils/format_op.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | import numpy as np 6 | 7 | def params2id(*args): 8 | nargs = len(args) 9 | id_ = '{}'+'_{}'*(nargs-1) 10 | return id_.format(*args) 11 | 12 | def print_numpy(array): 13 | print(','.join([str(array[v]) for v in range(len(array))])) 14 | 15 | class FileIdManager: 16 | def __init__(self, *attrs): 17 | self.attrs = attrs[0] 18 | self.nattr = len(self.attrs) 19 | 20 | def get_id_from_args(self, args): 21 | tmp = list() 22 | for attr in self.attrs: 23 | if attr == '*': tmp.append('*') 24 | elif type(attr)!=str: tmp.append(attr) 25 | else: tmp.append(getattr(args, attr)) 26 | return params2id(*tuple(tmp)) 27 | 28 | def update_args_with_id(self, args, id_): 29 | id_split = id_.split('_') 30 | assert len(id_split)==self.nattr, "id_ should be composed of the same number of attributes" 31 | 32 | for idx in range(self.nattr): 33 | attr = self.attrs[idx] 34 | type_attr = type(getattr(args, attr)) 35 | setattr(args, attr, type_attr(id_split[idx])) 36 | 37 | -------------------------------------------------------------------------------- /utils/img_tools.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | 4 | def jpg2png(str_): 5 | ''' 6 | Args: 7 | str_ - .jpg 8 | ''' 9 | return os.path.splitext(str_)[0]+'.png' 10 | 11 | def rgb2gray(rgb_img): 12 | ''' 13 | Args: 14 | rgb_img - Numpy 3D array 15 | [nrow, ncol ,3] 16 | Return: 17 | gray_img - Numpy 3D array 18 | [nrow, ncol ,1] 19 | ''' 20 | gray_img = np.mean(rgb_img, axis=-1, keepdims=True) 21 | assert len(gray_img.shape)==3, 'Wrong operations' 22 | return gray_img 23 | 24 | def multirgb2gray(rgb_img): 25 | ''' 26 | Args: 27 | rgb_img - Numpy 4D array 28 | [ndata, nrow, ncol ,3] 29 | Return: 30 | gray_img - Numpy 4D array 31 | [ndata, nrow, ncol ,1] 32 | ''' 33 | gray_img = np.mean(rgb_img, axis=-1, keepdims=True) 34 | assert len(gray_img.shape)==4, 'Wrong operations' 35 | return gray_img 36 | 37 | 38 | def gray2rgb(gray_img): 39 | ''' 40 | Args: 41 | gray_img - Numpy 2D array 42 | [nrow, ncol] 43 | Return: 44 | rgb_img - Numpy 3D array 45 | [nrow, ncol ,3] 46 | 47 | ''' 48 | w, h = gray_img.shape 49 | rgb_img = np.empty((w, h, 3), dtype=np.uint8) 50 | rgb_img[:, :, 0] = rgb_img[:, :, 1] = rgb_img[:, :, 2] = gray_img 51 | return rgb_img 52 | -------------------------------------------------------------------------------- /Dsprites_exp/CascadeVAE-C/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | from model import Model 6 | 7 | from config.path import subdirs5resultdir, muldir2mulsubdir 8 | 9 | from utils.datasetmanager import dsprites_manager 10 | from utils.format_op import FileIdManager 11 | 12 | from local_config import local_dsprites_parser, RESULT_DIR, ID_STRUCTURE 13 | 14 | import tensorflow as tf 15 | import numpy as np 16 | 17 | if __name__ == '__main__': 18 | NITER = 300000 19 | PITER = 20000 20 | SITER = 10000 21 | 22 | parser = local_dsprites_parser() 23 | args = parser.parse_args() # parameter required for model 24 | 25 | fim = FileIdManager(ID_STRUCTURE) 26 | 27 | np.random.seed(args.rseed) 28 | FILE_ID = fim.get_id_from_args(args) 29 | SAVE_DIR, LOG_DIR, ASSET_DIR = subdirs5resultdir(RESULT_DIR, True) 30 | SAVE_SUBDIR, ASSET_SUBDIR = muldir2mulsubdir([SAVE_DIR, ASSET_DIR], FILE_ID, True) 31 | 32 | dm = dsprites_manager() 33 | dm.print_shape() 34 | 35 | model = Model(dm, LOG_DIR+FILE_ID+'.log', args) 36 | model.set_up_train() 37 | model.initialize() 38 | model.train(niter=NITER, siter=SITER, piter=PITER, save_dir=SAVE_SUBDIR, asset_dir=ASSET_SUBDIR) 39 | model.restore(save_dir=SAVE_SUBDIR) 40 | train_idx = model.start_iter//PITER 41 | accuracy = model.evaluate() 42 | 43 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Learning Discrete and Continuous Factors of Data via Alternating Disentanglement 2 | ## Demo 3 | [![](http://img.youtube.com/vi/pRsD0Ot26gw/0.jpg)](http://www.youtube.com/watch?v=pRsD0Ot26gw "Learning Discrete and Continuous Factors of Data via Alternating Disentanglement") 4 | 5 | ## Dependency 6 | - python=3.5 7 | - tensorflow version = 1.4 8 | - CUDA 8.0 9 | - cuDNN 6.0 10 | - Environment detail is listed in `ex.yml' 11 | 12 | ## Citing this work 13 | ``` 14 | @inproceedings{jeongICML19, 15 | title={ 16 | Learning Discrete and Continuous Factors of Data via Alternating Disentanglement 17 | }, 18 | author= {Yeonwoo Jeong and Hyun Oh Song}, 19 | booktitle={International Conference on Machine Learning (ICML)}, 20 | year={2019} 21 | } 22 | ``` 23 | 24 | ## Dataset(dSprites) 25 | - Download from [https://github.com/deepmind/dsprites-dataset](https://github.com/deepmind/dsprites-dataset) 26 | 27 | ## Edit path 28 | - Edit path in 'config/path.py' 29 | - ROOT - (directory for experiment result) 30 | - DSPRITESPATH - (directory for downloaed dsprites) 31 | 32 | ## Run model 33 | - Dsprites_exp/CascadeVAE/main.py 34 | - Dsprites_exp/CascadeVAE-C/main.py 35 | 36 | ## Trained model 37 | - Download from [here](https://drive.google.com/file/d/1GTP2uUCJVaU3nXG1Tk2G-BFiTUlCM5k2/view?usp=sharing). 38 | - Here are trained models from 10 different random seeds. 39 | 40 | ## License 41 | MIT License 42 | -------------------------------------------------------------------------------- /Dsprites_exp/CascadeVAE/local_config.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | from config.path import ROOT 6 | from config.parser import dsprites_parser 7 | 8 | KEY = 'CascadeVAE' 9 | RESULT_DIR = ROOT+'{}/'.format(KEY) 10 | ID_STRUCTURE_DICT = { 11 | 'CascadeVAE' : ('nbatch', 'nconti', 'ncat', 'ntime', 'plamb', 'beta_min', 'beta_max', 'dptype', 'rseed'), 12 | } 13 | ID_STRUCTURE = ID_STRUCTURE_DICT[KEY] 14 | 15 | def local_dsprites_parser(): 16 | parser = dsprites_parser() 17 | parser.add_argument("--rseed", default = 0, help="random seed", type = int) 18 | parser.add_argument("--plamb", default = 0.001, help="pairwise cost", type = float) 19 | parser.add_argument("--beta_min", default = 0.1, help="min value for +beta*kl_cost", type = float) 20 | parser.add_argument("--beta_max", default = 10.0, help="max value for +beta*kl_cost", type = float) 21 | parser.add_argument("--dtype", default = 'stair', help="decay type", type = str) 22 | parser.add_argument("--dptype", default = 'a3', help="decay parameter type", type = str) 23 | parser.add_argument("--nconti", default = 6, help="the dimension of continuous representation", type = int) 24 | parser.add_argument("--ncat", default = 3, help="size of categorical data", type = int) 25 | parser.add_argument("--ntime", default = 4, help="When does discrete variable to be learned", type = int) 26 | return parser 27 | 28 | -------------------------------------------------------------------------------- /utils/reader_op.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | 4 | def read_jpg(jpg_path, plt): 5 | ''' 6 | Dependency : matplotlib.pyplot as plt 7 | Args: 8 | jpg_path - string 9 | ends with jpg 10 | plt - plt object 11 | Return: 12 | numpy 3D image 13 | ''' 14 | return plt.imread(jpg_path) 15 | 16 | def read_pkl(path, encoding='ASCII'): 17 | '''read path(pkl) and return files 18 | Dependency : pickle 19 | Args: 20 | path - string 21 | ends with pkl 22 | Return: 23 | pickle content 24 | ''' 25 | print("Pickle is read from %s"%path) 26 | with open(path, 'rb') as f: return pickle.load(f, encoding=encoding) 27 | 28 | def read_txt(path): 29 | '''read txt files 30 | Args: 31 | path - string 32 | ends with txt 33 | Return: 34 | txt_content - list 35 | line by line 36 | ''' 37 | print("Txt is read from %s"%path) 38 | 39 | txt_content = list() 40 | with open(path, 'r') as lines: 41 | for line in lines: txt_content.append(line) 42 | return txt_content 43 | 44 | def read_npy(path, encoding='ASCII'): 45 | '''read npy files 46 | Args: 47 | path - string 48 | ends with npy 49 | encoding - encoding 50 | Return: 51 | npy_content in path 52 | ''' 53 | print("Npy is read from %s"%path) 54 | npy_content = np.load(path, encoding=encoding) 55 | return npy_content 56 | -------------------------------------------------------------------------------- /Dsprites_exp/CascadeVAE/main.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | from model import Model 6 | 7 | from config.path import subdirs5resultdir, muldir2mulsubdir 8 | 9 | from utils.datasetmanager import dsprites_manager 10 | from utils.format_op import FileIdManager 11 | 12 | from local_config import local_dsprites_parser, RESULT_DIR, ID_STRUCTURE 13 | 14 | import tensorflow as tf 15 | import numpy as np 16 | 17 | if __name__ == '__main__': 18 | NITER = 300000 19 | PITER = 20000 20 | SITER = 10000 21 | 22 | parser = local_dsprites_parser() 23 | args = parser.parse_args() # parameter required for model 24 | 25 | fim = FileIdManager(ID_STRUCTURE) 26 | 27 | np.random.seed(args.rseed) 28 | FILE_ID = fim.get_id_from_args(args) 29 | SAVE_DIR, LOG_DIR, ASSET_DIR = subdirs5resultdir(RESULT_DIR, True) 30 | SAVE_SUBDIR, ASSET_SUBDIR = muldir2mulsubdir([SAVE_DIR, ASSET_DIR], FILE_ID, True) 31 | 32 | dm = dsprites_manager() 33 | dm.print_shape() 34 | 35 | model = Model(dm, LOG_DIR+FILE_ID+'.log', args) 36 | model.set_up_train() 37 | model.initialize() 38 | model.train(niter=NITER, siter=SITER, piter=PITER, save_dir=SAVE_SUBDIR, asset_dir=ASSET_SUBDIR) 39 | model.restore(save_dir=SAVE_SUBDIR) 40 | train_idx = model.start_iter//PITER 41 | include_discrete = False if train_idx < args.ntime else True 42 | accuracy = model.evaluate(include_discrete=include_discrete) 43 | 44 | -------------------------------------------------------------------------------- /utils/writer_op.py: -------------------------------------------------------------------------------- 1 | import pickle 2 | import numpy as np 3 | import os 4 | import imageio 5 | from matplotlib.backends.backend_pdf import PdfPages 6 | 7 | def create_dir(dirname): 8 | if not os.path.exists(dirname): 9 | print("Creating %s"%dirname) 10 | os.makedirs(dirname) 11 | else: 12 | print("Already %s exists"%dirname) 13 | 14 | def create_muldir(*args): 15 | for dirname in args: create_dir(dirname) 16 | 17 | def write_pkl(content, path): 18 | with open(path, 'wb') as f: 19 | print("Pickle is written on %s"%path) 20 | try: pickle.dump(content, f) 21 | except OverflowError: pickle.dump(content, f, protocol=4) 22 | 23 | def write_npy(content, path): 24 | print("Numpy is written on %s"%path) 25 | np.save(path, content) 26 | 27 | class MatplotlibPdfManager: 28 | def __init__(self, path, plt, pad_inches=None): 29 | self.path = path 30 | print("Creating {}".format(self.path)) 31 | self.pdf = PdfPages(self.path) 32 | self.plt = plt 33 | self.ncount = 0 34 | self.pad_inches=pad_inches 35 | 36 | def generate_from(self): 37 | self.ncount+=1 38 | print("From here generating {} pages in {}".format(self.ncount, self.path)) 39 | self.plt.close() 40 | 41 | def generate_to(self): 42 | print("To here generating {} pages in {}".format(self.ncount, self.path)) 43 | if self.pad_inches is None: 44 | self.pdf.savefig(bbox_inches="tight") 45 | else: 46 | self.pdf.savefig(pad_inches = self.pad_inches) 47 | 48 | def write_gif(content, path): 49 | print("Create gif on {}".format(path)) 50 | imageio.mimsave(path, content) 51 | 52 | -------------------------------------------------------------------------------- /tfops/lr_op.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def stair_decay(initial_lr, decay_steps, decay_rate, initial_step=0): 4 | #with tf.variable_scope(scope): global_step = tf.Variable(initial_step, name='global_step', trainable=False) 5 | global_step = tf.Variable(initial_step, trainable=False) 6 | update_step_op = tf.assign_add(global_step, 1) 7 | return tf.train.exponential_decay( 8 | learning_rate=initial_lr,\ 9 | global_step=global_step,\ 10 | decay_steps=decay_steps,\ 11 | decay_rate=decay_rate,\ 12 | staircase=True), update_step_op 13 | 14 | def piecewise_decay(boundaries, values, initial_step = 0): 15 | #with tf.variable_scope(scope): global_step = tf.Variable(initial_step, name='global_step', trainable=False) 16 | global_step = tf.Variable(initial_step, name='global_step', trainable=False) 17 | update_step_op = tf.assign_add(global_step, 1) 18 | return tf.train.piecewise_constant(global_step, boundaries, values), update_step_op 19 | 20 | DECAY_DICT = { 21 | 'stair' : stair_decay, 22 | 'piecewise' : piecewise_decay 23 | } 24 | 25 | DECAY_PARAMS_DICT =\ 26 | { 27 | 'stair' : 28 | { 29 | 64 :{ 30 | 'a1' : {'initial_lr' : 1e-5, 'decay_steps' : 50000, 'decay_rate' : 0.3}, 31 | 'a2' : {'initial_lr' : 3e-5, 'decay_steps' : 50000, 'decay_rate' : 0.3}, 32 | 'a3' : {'initial_lr' : 1e-4, 'decay_steps' : 50000, 'decay_rate' : 0.3}, 33 | 'a4' : {'initial_lr' : 3e-4, 'decay_steps' : 50000, 'decay_rate' : 0.3}, 34 | 'a5' : {'initial_lr' : 1e-3, 'decay_steps' : 50000, 'decay_rate' : 0.3} 35 | } 36 | } 37 | } 38 | 39 | -------------------------------------------------------------------------------- /utils/np_op.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | from utils.tqdm_op import tqdm_range 6 | 7 | import numpy as np 8 | import copy 9 | 10 | def np_softmax(x): 11 | ''' 12 | Args: 13 | x - Numpy 2D array 14 | ''' 15 | x_softmax = np.zeros_like(x) 16 | 17 | ndata, nfeature = x.shape 18 | 19 | for idx in range(ndata): 20 | tmp_max = np.max(x[idx]) 21 | tmp_exp = np.exp(x[idx] - tmp_max) 22 | x_softmax[idx] = tmp_exp/np.sum(tmp_exp) 23 | return x_softmax 24 | 25 | def get_ginni_variance_conti(array): 26 | ''' FactorVAE https://arxiv.org/pdf/1802.05983.pdf 27 | Args: 28 | array - Numpy 1D array 29 | ''' 30 | ndata = array.shape[0] 31 | return ndata/(ndata-1)*np.var(array) 32 | 33 | def get_ginni_variance_discrete(array): 34 | ''' FactorVAE https://arxiv.org/pdf/1802.05983.pdf 35 | Args: array - Numpy 1D array, argmax index 36 | ''' 37 | array = array.astype(int) 38 | ndata = array.shape[0] 39 | count = np.zeros([np.max(array)+1]) 40 | for idx in range(ndata): count[array[idx]]+=1 41 | count = count.astype(float) 42 | return (ndata*ndata - np.sum(np.square(count)))/(2*ndata*(ndata-1)) 43 | 44 | def zero_padding2nmul(inputs, mul): 45 | '''Add zero padding to inputs to be multiple of mul 46 | Args: 47 | inputs - np array 48 | mul - int 49 | 50 | Return: 51 | np array (inputs + zero_padding) 52 | int original input size 53 | ''' 54 | input_shape = list(inputs.shape) 55 | ndata = input_shape[0] 56 | if ndata%mul==0: return inputs, ndata 57 | input_shape[0] = mul-ndata%mul 58 | return np.concatenate([inputs, np.zeros(input_shape)], axis=0), ndata 59 | 60 | def np_random_crop_4d(imgs, size): 61 | ''' 62 | Args: 63 | imgs - 4d image NHWC 64 | size - list (rh, rw) 65 | ''' 66 | rh, rw = size 67 | on, oh, ow, oc = imgs.shape 68 | 69 | cropped_imgs = np.zeros([on, rh, rw, oc]) 70 | ch = np.random.randint(low=0, high=oh-rh, size=on) 71 | cw = np.random.randint(low=0, high=ow-rw, size=on) 72 | for idx in range(on): cropped_imgs[idx] = imgs[idx,ch[idx]:ch[idx]+rh,cw[idx]:cw[idx]+rw] 73 | return cropped_imgs 74 | -------------------------------------------------------------------------------- /ex.yml: -------------------------------------------------------------------------------- 1 | name: tf1.4_gpu 2 | channels: 3 | - defaults 4 | dependencies: 5 | - ca-certificates=2017.08.26=h1d4fec5_0 6 | - certifi=2017.7.27.1=py35h19f42a1_0 7 | - libedit=3.1=heed3624_0 8 | - libffi=3.2.1=hd88cf55_4 9 | - libgcc-ng=7.2.0=h7cc24e2_2 10 | - libstdcxx-ng=7.2.0=h7a57d05_2 11 | - ncurses=6.0=h9df7e31_2 12 | - openssl=1.0.2m=h26d622b_1 13 | - pip=9.0.1=py35h7e7da9d_4 14 | - python=3.5.4=h56e0582_23 15 | - readline=7.0=ha6073c6_4 16 | - setuptools=36.5.0=py35ha8c1747_0 17 | - sqlite=3.20.1=hb898158_2 18 | - tk=8.6.7=hc745277_3 19 | - wheel=0.29.0=py35h601ca99_1 20 | - xz=5.2.3=h55aa19d_2 21 | - zlib=1.2.11=ha838bed_2 22 | - pip: 23 | - bleach==1.5.0 24 | - chardet==3.0.4 25 | - cycler==0.10.0 26 | - cython==0.27.3 27 | - decorator==4.1.2 28 | - entrypoints==0.2.3 29 | - enum34==1.1.6 30 | - h5py==2.7.1 31 | - html5lib==0.9999999 32 | - idna==2.6 33 | - ipykernel==4.6.1 34 | - ipython==6.2.1 35 | - ipython-genutils==0.2.0 36 | - ipywidgets==7.0.5 37 | - jedi==0.11.0 38 | - jinja2==2.10 39 | - jsonschema==2.6.0 40 | - jupyter==1.0.0 41 | - jupyter-client==5.1.0 42 | - jupyter-console==5.2.0 43 | - jupyter-core==4.4.0 44 | - markdown==2.6.9 45 | - markupsafe==1.0 46 | - matplotlib==2.1.0 47 | - mistune==0.8.1 48 | - munkres==1.0.12 49 | - nbconvert==5.3.1 50 | - nbformat==4.4.0 51 | - networkx==2.0 52 | - notebook==5.2.1 53 | - numpy==1.13.3 54 | - opencv-python==3.4.0.12 55 | - ortools==6.6.4656 56 | - pandas==0.21.0 57 | - pandocfilters==1.4.2 58 | - parso==0.1.0 59 | - path.py==11.1.0 60 | - pexpect==4.3.0 61 | - pickleshare==0.7.4 62 | - pillow==5.0.0 63 | - prompt-toolkit==1.0.15 64 | - protobuf==3.5.1 65 | - psutil==5.4.3 66 | - ptyprocess==0.5.2 67 | - pygco==0.0.1 68 | - pygments==2.2.0 69 | - pyparsing==2.2.0 70 | - python-dateutil==2.6.1 71 | - pytz==2017.3 72 | - pywavelets==0.5.2 73 | - pyzmq==16.0.3 74 | - qtconsole==4.3.1 75 | - requests==2.18.4 76 | - scikit-image==0.13.1 77 | - scikit-learn==0.19.1 78 | - scipy==1.0.0 79 | - simplegeneric==0.8.1 80 | - six==1.11.0 81 | - tensorflow-gpu==1.4.0 82 | - tensorflow-tensorboard==0.4.0rc3 83 | - terminado==0.7 84 | - testpath==0.3.1 85 | - tornado==4.5.2 86 | - tqdm==4.19.4 87 | - traitlets==4.3.2 88 | - urllib3==1.22 89 | - wcwidth==0.1.7 90 | - werkzeug==0.12.2 91 | - widgetsnbextension==3.0.8 92 | prefix: /home/maestro/anaconda3/envs/tf1.4_gpu 93 | 94 | -------------------------------------------------------------------------------- /utils/datamanager.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | from utils.general_class import DatamanagerPlugin 6 | 7 | import numpy as np 8 | import random 9 | 10 | class DspritesManager(DatamanagerPlugin): 11 | ''' https://github.com/deepmind/dsprites-dataset/blob/master/dsprites_reloading_example.ipynb''' 12 | def __init__(self, dataset_zip): 13 | self.image = np.expand_dims(dataset_zip['imgs'], axis=-1).astype(float) 14 | self.latents_values = dataset_zip['latents_values'] 15 | self.latents_classes = dataset_zip['latents_classes'] 16 | self.metadata = dataset_zip['metadata'][()] 17 | self.latents_sizes = self.metadata['latents_sizes'] 18 | self.nlatent = len(self.latents_sizes) #6 19 | self.latents_bases = np.concatenate((self.latents_sizes[::-1].cumprod()[::-1][1:], np.array([1,]))) 20 | super().__init__(ndata=self.image.shape[0]) 21 | 22 | def print_shape(self): 23 | print("Image shape : {}({}, max = {}, min = {})".format(self.image.shape, self.image.dtype, np.amax(self.image), np.amin(self.image))) 24 | print("Latent size : {}".format(self.latents_sizes)) 25 | 26 | def normalize(self, nmin, nmax): 27 | cmin = np.amin(self.image) 28 | cmax = np.amax(self.image) 29 | slope = (nmax-nmin)/(cmax-cmin) 30 | 31 | self.image = slope*(self.image-cmin) + nmin 32 | self.print_shape() 33 | 34 | def latent2idx(self, latents): 35 | return np.dot(latents, self.latents_bases).astype(int) 36 | 37 | def next_batch_latent_random(self, batch_size): 38 | samples = np.zeros([batch_size, self.nlatent]) 39 | for lat_i, lat_size in enumerate(self.latents_sizes): 40 | samples[:, lat_i] = np.random.randint(lat_size, size=batch_size) 41 | return samples 42 | 43 | def next_batch_latent_fix(self, batch_size, latent_idx, latent_value): 44 | samples = self.next_batch_latent_random(batch_size) 45 | samples[:, latent_idx] = latent_value 46 | return self.image[self.latent2idx(samples)] 47 | 48 | def next_batch_latent_fix_idx(self, batch_size, latent_idx, latent_value): 49 | samples = self.next_batch_latent_random(batch_size) 50 | samples[:, latent_idx] = latent_value 51 | return self.latent2idx(samples) 52 | 53 | def next_batch(self, batch_size): 54 | subidx = self.sample_idx(batch_size) 55 | return self.image[subidx], self.latents_classes[subidx] 56 | -------------------------------------------------------------------------------- /tfops/transform_op.py: -------------------------------------------------------------------------------- 1 | from utils.np_op import zero_padding2nmul 2 | from utils.tqdm_op import tqdm_range 3 | from tqdm import tqdm 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | 8 | def apply_tf_op(inputs, session, input_gate, output_gate, batch_size, train_gate=None, print_option=True): 9 | inputs, ndata = zero_padding2nmul(inputs=inputs, mul=batch_size) 10 | nbatch = len(inputs)//batch_size 11 | 12 | outputs = list() 13 | 14 | feed_dict = dict() 15 | if train_gate is not None: feed_dict[train_gate] = False 16 | 17 | if print_option: 18 | for b in tqdm_range(nbatch): 19 | feed_dict[input_gate]=inputs[b*batch_size:(b+1)*batch_size] 20 | outputs.append(session.run(output_gate, feed_dict=feed_dict)) 21 | else: 22 | for b in range(nbatch): 23 | feed_dict[input_gate]=inputs[b*batch_size:(b+1)*batch_size] 24 | outputs.append(session.run(output_gate, feed_dict=feed_dict)) 25 | 26 | outputs = np.concatenate(outputs, axis=0) 27 | outputs = outputs[:ndata] 28 | 29 | return outputs 30 | 31 | def apply_tf_op_multi_output(inputs, session, input_gate, output_gate_list, batch_size, train_gate=None, print_option=True): 32 | inputs, ndata = zero_padding2nmul(inputs=inputs, mul=batch_size) 33 | nbatch = len(inputs)//batch_size 34 | 35 | noutput = len(output_gate_list) 36 | outputs_list = [list() for o_idx in range(noutput)] 37 | 38 | feed_dict = dict() 39 | if train_gate is not None: feed_dict[train_gate] = False 40 | 41 | if print_option: 42 | for b in tqdm_range(nbatch): 43 | feed_dict[input_gate]=inputs[b*batch_size:(b+1)*batch_size] 44 | tmp = session.run(output_gate_list, feed_dict=feed_dict) 45 | for o_idx in range(noutput): 46 | outputs_list[o_idx].append(tmp[o_idx]) 47 | else: 48 | for b in range(nbatch): 49 | feed_dict[input_gate]=inputs[b*batch_size:(b+1)*batch_size] 50 | tmp = session.run(output_gate_list, feed_dict=feed_dict) 51 | for o_idx in range(noutput): 52 | outputs_list[o_idx].append(tmp[o_idx]) 53 | 54 | for o_idx in range(noutput): 55 | outputs_list[o_idx] = np.concatenate(outputs_list[o_idx], axis=0) 56 | outputs_list[o_idx] = outputs_list[o_idx][:ndata] 57 | 58 | return outputs_list 59 | 60 | def apply_tf_op_multi_input(inputs_list, session, input_gate_list, output_gate, batch_size, train_gate=None, print_option=True): 61 | assert len(inputs_list)==len(input_gate_list), "Length of list should be same" 62 | ninput = len(inputs_list) 63 | inputs_pad_list = list() 64 | ndata = len(inputs_list[0]) 65 | 66 | for i_idx in range(ninput): 67 | inputs_pad_list.append(zero_padding2nmul(inputs=inputs_list[i_idx], mul=batch_size)[0]) 68 | 69 | nbatch = len(inputs_pad_list[0])//batch_size 70 | 71 | outputs = list() 72 | feed_dict = dict() 73 | if train_gate is not None: feed_dict[train_gate] = False 74 | 75 | if print_option: 76 | for b in tqdm_range(nbatch): 77 | for i_idx in range(ninput): feed_dict[input_gate_list[i_idx]]=inputs_pad_list[i_idx][b*batch_size:(b+1)*batch_size] 78 | outputs.append(session.run(output_gate, feed_dict=feed_dict)) 79 | else: 80 | for b in range(nbatch): 81 | for i_idx in range(ninput): feed_dict[input_gate_list[i_idx]]=inputs_pad_list[i_idx][b*batch_size:(b+1)*batch_size] 82 | outputs.append(session.run(output_gate, feed_dict=feed_dict)) 83 | 84 | outputs = np.concatenate(outputs, axis=0) 85 | outputs = outputs[:ndata] 86 | return outputs 87 | 88 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | ## Core latex/pdflatex auxiliary files: 2 | *.aux 3 | *.lof 4 | *.log 5 | *.lot 6 | *.fls 7 | *.out 8 | *.toc 9 | *.fmt 10 | *.fot 11 | *.cb 12 | *.cb2 13 | .*.lb 14 | 15 | ## Intermediate documents: 16 | *.dvi 17 | *.xdv 18 | *-converted-to.* 19 | # these rules might exclude image files for figures etc. 20 | # *.ps 21 | # *.eps 22 | # *.pdf 23 | 24 | ## Generated if empty string is given at "Please type another file name for output:" 25 | .pdf 26 | 27 | ## Bibliography auxiliary files (bibtex/biblatex/biber): 28 | *.bbl 29 | *.bcf 30 | *.blg 31 | *-blx.aux 32 | *-blx.bib 33 | *.run.xml 34 | 35 | ## Build tool auxiliary files: 36 | *.fdb_latexmk 37 | *.synctex 38 | *.synctex(busy) 39 | *.synctex.gz 40 | *.synctex.gz(busy) 41 | *.pdfsync 42 | 43 | ## Auxiliary and intermediate files from other packages: 44 | # algorithms 45 | *.alg 46 | *.loa 47 | 48 | # achemso 49 | acs-*.bib 50 | 51 | # amsthm 52 | *.thm 53 | 54 | # beamer 55 | *.nav 56 | *.pre 57 | *.snm 58 | *.vrb 59 | 60 | # changes 61 | *.soc 62 | 63 | # cprotect 64 | *.cpt 65 | 66 | # elsarticle (documentclass of Elsevier journals) 67 | *.spl 68 | 69 | # endnotes 70 | *.ent 71 | 72 | # fixme 73 | *.lox 74 | 75 | # feynmf/feynmp 76 | *.mf 77 | *.mp 78 | *.t[1-9] 79 | *.t[1-9][0-9] 80 | *.tfm 81 | 82 | #(r)(e)ledmac/(r)(e)ledpar 83 | *.end 84 | *.?end 85 | *.[1-9] 86 | *.[1-9][0-9] 87 | *.[1-9][0-9][0-9] 88 | *.[1-9]R 89 | *.[1-9][0-9]R 90 | *.[1-9][0-9][0-9]R 91 | *.eledsec[1-9] 92 | *.eledsec[1-9]R 93 | *.eledsec[1-9][0-9] 94 | *.eledsec[1-9][0-9]R 95 | *.eledsec[1-9][0-9][0-9] 96 | *.eledsec[1-9][0-9][0-9]R 97 | 98 | # glossaries 99 | *.acn 100 | *.acr 101 | *.glg 102 | *.glo 103 | *.gls 104 | *.glsdefs 105 | 106 | # gnuplottex 107 | *-gnuplottex-* 108 | 109 | # gregoriotex 110 | *.gaux 111 | *.gtex 112 | 113 | # htlatex 114 | *.4ct 115 | *.4tc 116 | *.idv 117 | *.lg 118 | *.trc 119 | *.xref 120 | 121 | # hyperref 122 | *.brf 123 | 124 | # knitr 125 | *-concordance.tex 126 | # TODO Comment the next line if you want to keep your tikz graphics files 127 | *.tikz 128 | *-tikzDictionary 129 | 130 | # listings 131 | *.lol 132 | 133 | # makeidx 134 | *.idx 135 | *.ilg 136 | *.ind 137 | *.ist 138 | 139 | # minitoc 140 | *.maf 141 | *.mlf 142 | *.mlt 143 | *.mtc[0-9]* 144 | *.slf[0-9]* 145 | *.slt[0-9]* 146 | *.stc[0-9]* 147 | 148 | # minted 149 | _minted* 150 | *.pyg 151 | 152 | # morewrites 153 | *.mw 154 | 155 | # nomencl 156 | *.nlg 157 | *.nlo 158 | *.nls 159 | 160 | # pax 161 | *.pax 162 | 163 | # pdfpcnotes 164 | *.pdfpc 165 | 166 | # sagetex 167 | *.sagetex.sage 168 | *.sagetex.py 169 | *.sagetex.scmd 170 | 171 | # scrwfile 172 | *.wrt 173 | 174 | # sympy 175 | *.sout 176 | *.sympy 177 | sympy-plots-for-*.tex/ 178 | 179 | # pdfcomment 180 | *.upa 181 | *.upb 182 | 183 | # pythontex 184 | *.pytxcode 185 | pythontex-files-*/ 186 | 187 | # thmtools 188 | *.loe 189 | 190 | # TikZ & PGF 191 | *.dpth 192 | *.md5 193 | *.auxlock 194 | 195 | # todonotes 196 | *.tdo 197 | 198 | # easy-todo 199 | *.lod 200 | 201 | # xmpincl 202 | *.xmpi 203 | 204 | # xindy 205 | *.xdy 206 | 207 | # xypic precompiled matrices 208 | *.xyc 209 | 210 | # endfloat 211 | *.ttt 212 | *.fff 213 | 214 | # Latexian 215 | TSWLatexianTemp* 216 | 217 | ## Editors: 218 | # WinEdt 219 | *.bak 220 | *.sav 221 | 222 | # Texpad 223 | .texpadtmp 224 | 225 | # Kile 226 | *.backup 227 | 228 | # KBibTeX 229 | *~[0-9]* 230 | 231 | # auto folder when using emacs and auctex 232 | ./auto/* 233 | *.el 234 | 235 | # expex forward references with \gathertags 236 | *-tags.tex 237 | 238 | # standalone packages 239 | *.sta 240 | 241 | # generated if using elsarticle.cls 242 | *.spl 243 | 244 | main.pdf 245 | supp.pdf 246 | -------------------------------------------------------------------------------- /utils/general_class.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | from utils.logger_op import LoggerManager 6 | from utils.gpu_op import selectGpuById 7 | from tfops.init_op import rest_initializer 8 | 9 | import tensorflow as tf 10 | import numpy as np 11 | import glob 12 | 13 | class DatamanagerPlugin: 14 | def __init__(self, ndata): 15 | self.ndata = ndata 16 | self.start, self.end = 0, 0 17 | self.fullidx = np.arange(self.ndata) 18 | 19 | def sample_idx(self, batch_size): 20 | if self.start == 0 and self.end ==0: 21 | np.random.shuffle(self.fullidx) # shuffle first 22 | 23 | if self.end + batch_size > self.ndata: 24 | self.start = self.end 25 | self.end = (self.end + batch_size)%self.ndata 26 | subidx = np.append(self.fullidx[self.start:self.ndata], self.fullidx[0:self.end]) 27 | self.start = 0 28 | self.end = 0 29 | else: 30 | self.start = self.end 31 | self.end += batch_size 32 | subidx = self.fullidx[self.start:self.end] 33 | return subidx 34 | 35 | class ModelPlugin: 36 | def __init__(self, dataset, logfilepath, args): 37 | self.args = args 38 | 39 | selectGpuById(self.args.gpu) 40 | self.logfilepath = logfilepath 41 | self.logger = LoggerManager(self.logfilepath, __name__) 42 | self.set_dataset(dataset) 43 | 44 | def set_dataset(self, dataset): 45 | self.logger.info("Setting dataset starts") 46 | self.dataset = dataset 47 | self.image = self.dataset.image 48 | self.ndata, self.height, self.width, self.nchannel = self.image.shape 49 | self.logger.info("Setting dataset ends") 50 | 51 | def build(self, *args, **kwargs): 52 | """Builds the neural networks""" 53 | raise NotImplementedError('`build` is not implemented for model class {}'.format(self.__class__.__name__)) 54 | 55 | def set_up_train(self, *args, **kwargs): 56 | """Builds the neural networks""" 57 | raise NotImplementedError('`set_up_train` is not implemented for model class {}'.format(self.__class__.__name__)) 58 | 59 | def train(self, *args, **kwargs): 60 | """train the neural networks""" 61 | raise NotImplementedError('`train` is not implemented for model class {}'.format(self.__class__.__name__)) 62 | 63 | def generate_sess(self): 64 | try: self.sess 65 | except AttributeError: 66 | config = tf.ConfigProto() 67 | config.gpu_options.allow_growth = True 68 | self.sess=tf.Session(config=config) 69 | 70 | def initialize(self): 71 | '''Initialize uninitialized variables''' 72 | self.logger.info("Model initialization starts") 73 | rest_initializer(self.sess) 74 | self.start_iter = 0 75 | self.logger.info("Model initialization ends") 76 | 77 | def save(self, global_step, save_dir, reset_option=True): 78 | self.logger.info("Model save starts") 79 | if reset_option: 80 | for f in glob.glob(save_dir+'*'): os.remove(f) 81 | saver=tf.train.Saver(max_to_keep = 5) 82 | saver.save(self.sess, os.path.join(save_dir, 'model'), global_step = global_step) 83 | self.logger.info("Model save in %s"%save_dir) 84 | self.logger.info("Model save ends") 85 | 86 | def restore(self, save_dir, restore_iter=-1): 87 | """Restore all variables in graph with the latest version""" 88 | self.logger.info("Restoring model starts...") 89 | saver = tf.train.Saver() 90 | checkpoint = tf.train.latest_checkpoint(save_dir) 91 | 92 | if restore_iter==-1: 93 | self.start_iter = int(os.path.basename(checkpoint)[len('model')+1:]) 94 | else: 95 | self.start_iter = restore_iter 96 | checkpoint = save_dir+'model-%d'%restore_iter 97 | self.logger.info("Restoring from {}".format(checkpoint)) 98 | self.generate_sess() 99 | saver.restore(self.sess, checkpoint) 100 | self.logger.info("Restoring model done.") 101 | 102 | def regen_session(self): 103 | tf.reset_default_graph() 104 | self.sess.close() 105 | config = tf.ConfigProto() 106 | config.gpu_options.allow_growth = True 107 | self.sess=tf.Session(config=config) 108 | 109 | def delete(self): 110 | tf.reset_default_graph() 111 | self.logger.remove() 112 | del self.logger 113 | 114 | 115 | -------------------------------------------------------------------------------- /utils/eval_op.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | from utils.np_op import get_ginni_variance_discrete, get_ginni_variance_conti 6 | 7 | import numpy as np 8 | 9 | def DisentanglemetricFactorMask(mean, std, nclasses, sampler, nvote=800, nex_per_vote=100, eps=1e-10, thr=0.1, print_option=False): 10 | kl_cost = np.mean(0.5*(np.square(mean)+np.square(std)-1-2*np.log(std+eps)),axis=0) 11 | return DisentanglemetricFactorMaskcustom(latent_conti=mean, mi_array=kl_cost, nclasses=nclasses, sampler=sampler, nvote=nvote, nex_per_vote=nex_per_vote, eps=eps, thr=thr, print_option=print_option) 12 | 13 | def DisentanglemetricFactorJointMask(mean, std, latent_cat, nclasses, sampler, nvote=800, nex_per_vote=100, eps=1e-10, thr=0.1, print_option=False, ignore_discrete=True): 14 | return DisentanglemetricFactorMultiJointMask(mean=mean, std=std, latent_cat_list=[latent_cat], nclasses=nclasses, sampler=sampler, nvote=nvote, nex_per_vote=nex_per_vote, eps=eps, thr=thr, print_option=print_option, ignore_discrete=ignore_discrete) 15 | 16 | def DisentanglemetricFactorMultiJointMask(mean, std, latent_cat_list, nclasses, sampler, nvote=800, nex_per_vote=100, eps=1e-10, thr=0.1, print_option=False, ignore_discrete=True): 17 | kl_cost = np.mean(0.5*(np.square(mean)+np.square(std)-1-2*np.log(std+eps)),axis=0) 18 | 19 | tmp = list() 20 | for latent_cat in latent_cat_list: 21 | if ignore_discrete: 22 | kl_cat_cost = np.log(latent_cat.shape[1]) + np.mean(np.sum(latent_cat*np.log(latent_cat+eps), axis=1)) 23 | if kl_cat_cost < thr: 24 | tmp.append(np.argmax(latent_cat, axis=-1)) 25 | else: 26 | tmp.append(np.argmax(latent_cat, axis=-1)) 27 | 28 | if len(tmp)==0: 29 | print("Ignore discrete variable") 30 | return DisentanglemetricFactorMaskcustom(latent_conti=mean, mi_array=kl_cost, nclasses=nclasses, sampler=sampler, nvote=nvote, nex_per_vote=nex_per_vote, eps=eps, thr=thr, print_option=print_option) 31 | else: 32 | return DisentanglemetricFactorMultiJointMaskcustom(latent_conti=mean, latent_cat_list=tmp, mi_array=kl_cost, nclasses=nclasses, sampler=sampler, nvote=nvote, nex_per_vote=nex_per_vote, eps=eps, thr=thr, print_option=print_option) 33 | 34 | def DisentanglemetricFactorMaskcustom(latent_conti, mi_array, nclasses, sampler, nvote=800, nex_per_vote=100, eps=1e-10, thr=0.1, print_option=False): 35 | mask = np.where(mi_array>thr)[0] 36 | latent_conti = latent_conti[:, mask] 37 | if print_option: 38 | print(mask) 39 | 40 | k_set = list() 41 | for idx in range(nclasses.shape[0]): 42 | if nclasses[idx]>1: k_set.append(idx) 43 | nfactor = len(k_set) 44 | 45 | if print_option: 46 | print(k_set) 47 | 48 | if latent_conti.shape[1]==0: return 1/nfactor 49 | 50 | var = np.array([get_ginni_variance_conti(latent_conti[:, v]) for v in range(latent_conti.shape[1])]) 51 | 52 | nlatent = var.shape[0] 53 | nvote_per_factor = int(nvote/nfactor) 54 | 55 | count = np.zeros([nlatent, nfactor]) 56 | for idx in range(nfactor): 57 | for iter_ in range(nvote_per_factor): 58 | k_fixed = k_set[idx] 59 | fixed_value = np.random.randint(nclasses[k_fixed]) 60 | batch_latent_conti = latent_conti[sampler(batch_size=nex_per_vote, latent_idx=k_fixed, latent_value=fixed_value)] 61 | batch_var = np.array([get_ginni_variance_conti(batch_latent_conti[:, v]) for v in range(latent_conti.shape[1])]) 62 | batch_var_norm = np.divide(batch_var, var+eps) 63 | count[np.argmin(batch_var_norm)][idx]+=1 64 | if print_option: 65 | print(count) 66 | #print(count) 67 | return get_majority_vote_accuracy(count) 68 | 69 | def DisentanglemetricFactorMultiJointMaskcustom(latent_conti, latent_cat_list, mi_array, nclasses, sampler, nvote=800, nex_per_vote=100, eps=1e-10, thr=0.1, print_option=False): 70 | mask = np.where(mi_array>thr)[0] 71 | latent_conti = latent_conti[:, mask] 72 | if print_option: 73 | print(mask) 74 | 75 | var = np.array([get_ginni_variance_conti(latent_conti[:, v]) for v in range(latent_conti.shape[1])]+[get_ginni_variance_discrete(array=v) for v in latent_cat_list]) 76 | 77 | k_set = list() 78 | for idx in range(nclasses.shape[0]): 79 | if nclasses[idx]>1: k_set.append(idx) 80 | 81 | if print_option: 82 | print(k_set) 83 | 84 | nfactor = len(k_set) 85 | nlatent = var.shape[0] 86 | nvote_per_factor = int(nvote/nfactor) 87 | 88 | count = np.zeros([nlatent, nfactor]) 89 | for idx in range(nfactor): 90 | for iter_ in range(nvote_per_factor): 91 | k_fixed = k_set[idx] 92 | fixed_value = np.random.randint(nclasses[k_fixed]) 93 | sample_idx = sampler(batch_size=nex_per_vote, latent_idx=k_fixed, latent_value=fixed_value) 94 | batch_latent_conti = latent_conti[sample_idx] 95 | batch_var = np.array([get_ginni_variance_conti(batch_latent_conti[:, v]) for v in range(latent_conti.shape[1])]+[get_ginni_variance_discrete(array=v[sample_idx]) for v in latent_cat_list]) 96 | batch_var_norm = np.divide(batch_var, var+eps) 97 | count[np.argmin(batch_var_norm)][idx]+=1 98 | if print_option: 99 | print(count) 100 | 101 | return get_majority_vote_accuracy(count) 102 | 103 | def get_majority_vote_accuracy(count): 104 | ''' 105 | Args: 106 | count - Numpy 2D array [nvoter, ncandidate] 107 | Return: 108 | accuracy 109 | ''' 110 | nvoter = count.shape[0] 111 | classifier = np.argmax(count, axis=-1) 112 | 113 | nvote = np.sum(count) 114 | vote_correct = np.sum([count[idx][classifier[idx]] for idx in range(nvoter)]) 115 | 116 | accuracy = vote_correct/nvote 117 | return accuracy 118 | -------------------------------------------------------------------------------- /tfops/nets.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '..')) 4 | 5 | import tensorflow as tf 6 | 7 | slim = tf.contrib.slim 8 | #=============================================================================================================================================# 9 | def encoder1_32(x, output_dim, output_nonlinearity=None, scope="ENC", reuse=False): 10 | nets_dict = dict() 11 | nets_dict['input'] = x 12 | with tf.variable_scope(scope, reuse=reuse): 13 | with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(0.00004)): 14 | with slim.arg_scope([slim.conv2d], weights_initializer=tf.contrib.slim.variance_scaling_initializer(), stride=2, padding='SAME', activation_fn=tf.nn.relu) : 15 | with slim.arg_scope([slim.fully_connected], biases_initializer=tf.zeros_initializer()): 16 | nets_dict['conv2d0'] = slim.conv2d(nets_dict['input'], 32, [4, 4], scope='conv2d_0') 17 | nets_dict['conv2d1'] = slim.conv2d(nets_dict['conv2d0'], 32, [4, 4], scope='conv2d_1') 18 | nets_dict['conv2d2'] = slim.conv2d(nets_dict['conv2d1'], 64, [4, 4], scope='conv2d_2') 19 | n = tf.reshape(nets_dict['conv2d2'], [-1, 4*4*64]) 20 | nets_dict['fc0'] = slim.fully_connected(n, 256, activation_fn=tf.nn.relu, scope = "output_fc0") 21 | nets_dict['output'] = slim.fully_connected(nets_dict['fc0'], output_dim, activation_fn=output_nonlinearity, scope = "output_fc1") 22 | return nets_dict 23 | def decoder1_32(z, scope="DEC", reuse=False): 24 | nets_dict = dict() 25 | nets_dict['input'] = z 26 | with tf.variable_scope(scope, reuse=reuse): 27 | with slim.arg_scope([slim.conv2d_transpose, slim.fully_connected], weights_regularizer=slim.l2_regularizer(0.00004)): 28 | with slim.arg_scope([slim.conv2d_transpose], weights_initializer=tf.contrib.slim.variance_scaling_initializer(), 29 | stride=2, padding='SAME', activation_fn=tf.nn.relu): 30 | with slim.arg_scope([slim.fully_connected], biases_initializer=tf.zeros_initializer()): 31 | nets_dict['fc0'] = slim.fully_connected(nets_dict['input'], 256, activation_fn=tf.nn.relu, scope = "fc0") 32 | nets_dict['fc1'] = slim.fully_connected(nets_dict['fc0'], 4*4*64, activation_fn=tf.nn.relu, scope = "fc1") 33 | n = tf.reshape(nets_dict['fc1'], [-1, 4, 4, 64]) 34 | nets_dict['deconv2d0'] = slim.conv2d_transpose(n, 32, [4, 4], scope='deconv2d_0') 35 | nets_dict['deconv2d1'] = slim.conv2d_transpose(nets_dict['deconv2d0'], 32, [4, 4], scope='deconv2d_1') 36 | nets_dict['output'] = slim.conv2d_transpose(nets_dict['deconv2d1'], 1, [4, 4], activation_fn=None, scope='deconv2d_2') 37 | return nets_dict 38 | #=============================================================================================================================================# 39 | def encoder1_64(x, output_dim, output_nonlinearity=None, scope="ENC", reuse=False): 40 | nets_dict = dict() 41 | nets_dict['input'] = x 42 | with tf.variable_scope(scope, reuse=reuse): 43 | with slim.arg_scope([slim.conv2d, slim.fully_connected], weights_regularizer=slim.l2_regularizer(0.00004)): 44 | with slim.arg_scope([slim.conv2d], weights_initializer=tf.contrib.slim.variance_scaling_initializer(), stride=2, padding='SAME', activation_fn=tf.nn.relu) : 45 | with slim.arg_scope([slim.fully_connected], biases_initializer=tf.zeros_initializer()): 46 | nets_dict['conv2d0'] = slim.conv2d(nets_dict['input'], 32, [4, 4], scope='conv2d_0') 47 | nets_dict['conv2d1'] = slim.conv2d(nets_dict['conv2d0'], 32, [4, 4], scope='conv2d_1') 48 | nets_dict['conv2d2'] = slim.conv2d(nets_dict['conv2d1'], 64, [4, 4], scope='conv2d_2') 49 | nets_dict['conv2d3'] = slim.conv2d(nets_dict['conv2d2'], 64, [4, 4], scope='conv2d_3') 50 | n = tf.reshape(nets_dict['conv2d3'], [-1, 4*4*64]) 51 | nets_dict['fc0'] = slim.fully_connected(n, 256, activation_fn=tf.nn.relu, scope = "output_fc0") 52 | nets_dict['output'] = slim.fully_connected(nets_dict['fc0'], output_dim, activation_fn=output_nonlinearity, scope = "output_fc1") 53 | return nets_dict 54 | 55 | def decoder1_64(z, scope="DEC", output_channel=1, reuse=False): 56 | nets_dict = dict() 57 | nets_dict['input'] = z 58 | with tf.variable_scope(scope, reuse=reuse): 59 | with slim.arg_scope([slim.conv2d_transpose, slim.fully_connected], weights_regularizer=slim.l2_regularizer(0.00004)): 60 | with slim.arg_scope([slim.conv2d_transpose], weights_initializer=tf.contrib.slim.variance_scaling_initializer(), 61 | stride=2, padding='SAME', activation_fn=tf.nn.relu): 62 | with slim.arg_scope([slim.fully_connected], biases_initializer=tf.zeros_initializer()): 63 | nets_dict['fc0'] = slim.fully_connected(nets_dict['input'], 256, activation_fn=tf.nn.relu, scope = "fc0") 64 | nets_dict['fc1'] = slim.fully_connected(nets_dict['fc0'], 4*4*64, activation_fn=tf.nn.relu, scope = "fc1") 65 | n = tf.reshape(nets_dict['fc1'], [-1, 4, 4, 64]) 66 | nets_dict['deconv2d0'] = slim.conv2d_transpose(n, 64, [4, 4], scope='deconv2d_0') 67 | nets_dict['deconv2d1'] = slim.conv2d_transpose(nets_dict['deconv2d0'], 32, [4, 4], scope='deconv2d_1') 68 | nets_dict['deconv2d2'] = slim.conv2d_transpose(nets_dict['deconv2d1'], 32, [4, 4], scope='deconv2d_2') 69 | nets_dict['output'] = slim.conv2d_transpose(nets_dict['deconv2d2'], output_channel, [4, 4], activation_fn=None, scope='deconv2d_3') 70 | return nets_dict 71 | 72 | -------------------------------------------------------------------------------- /utils/ortools_op.py: -------------------------------------------------------------------------------- 1 | from ortools.graph import pywrapgraph 2 | 3 | import numpy as np 4 | import copy 5 | 6 | class SolveMaxMatching: 7 | def __init__(self, nworkers, ntasks, k, value=10000, pairwise_lamb=0.1): 8 | self.nworkers = nworkers 9 | self.ntasks = ntasks 10 | self.value = value 11 | self.k = k 12 | 13 | self.source = 0 14 | self.sink = self.nworkers+self.ntasks+1 15 | 16 | self.pairwise_cost = int(pairwise_lamb*value) 17 | 18 | self.supplies = [self.nworkers*self.k]+(self.ntasks+self.nworkers)*[0]+[-self.nworkers*self.k] 19 | self.start_nodes = list() 20 | self.end_nodes = list() 21 | self.capacities = list() 22 | self.common_costs = list() 23 | 24 | for work_idx in range(self.nworkers): 25 | self.start_nodes.append(self.source) 26 | self.end_nodes.append(work_idx+1) 27 | self.capacities.append(self.k) 28 | self.common_costs.append(0) 29 | 30 | for work_idx in range(self.nworkers): 31 | for task_idx in range(self.ntasks): 32 | self.start_nodes.append(self.nworkers+1+task_idx) 33 | self.end_nodes.append(self.sink) 34 | self.capacities.append(1) 35 | self.common_costs.append(work_idx*self.pairwise_cost) 36 | 37 | for work_idx in range(self.nworkers): 38 | for task_idx in range(self.ntasks): 39 | self.start_nodes.append(work_idx+1) 40 | self.end_nodes.append(self.nworkers+1+task_idx) 41 | self.capacities.append(1) 42 | 43 | self.nnodes = len(self.start_nodes) 44 | 45 | def solve(self, array): 46 | assert array.shape == (self.nworkers, self.ntasks), "Wrong array shape, it should be ({}, {})".format(self.nworkers, self.ntasks) 47 | 48 | self.array = self.value*array 49 | self.array = -self.array # potential to cost 50 | self.array = self.array.astype(np.int32) 51 | 52 | costs = copy.copy(self.common_costs) 53 | for work_idx in range(self.nworkers): 54 | for task_idx in range(self.ntasks): 55 | costs.append(self.array[work_idx][task_idx]) 56 | 57 | costs = np.array(costs) 58 | costs = (costs.tolist()) 59 | 60 | assert len(costs)==self.nnodes, "Length of costs should be {} but {}".format(self.nnodes, len(costs)) 61 | 62 | min_cost_flow = pywrapgraph.SimpleMinCostFlow() 63 | for idx in range(self.nnodes): 64 | min_cost_flow.AddArcWithCapacityAndUnitCost(self.start_nodes[idx], self.end_nodes[idx], self.capacities[idx], costs[idx]) 65 | for idx in range(self.ntasks+self.nworkers+2): 66 | min_cost_flow.SetNodeSupply(idx, self.supplies[idx]) 67 | 68 | min_cost_flow.Solve() 69 | results = list() 70 | for arc in range(min_cost_flow.NumArcs()): 71 | if min_cost_flow.Tail(arc)!=self.source and min_cost_flow.Head(arc)!=self.sink: 72 | if min_cost_flow.Flow(arc)>0: 73 | results.append([min_cost_flow.Tail(arc)-1, min_cost_flow.Head(arc)-self.nworkers-1]) 74 | 75 | results_np = np.zeros_like(array) 76 | for i,j in results: results_np[i][j]=1 77 | return results, results_np 78 | 79 | class SimpleHungarianSolver: 80 | def __init__(self, nworkers, ntasks, value=10000): 81 | self.nworkers = nworkers 82 | self.ntasks = ntasks 83 | self.value = value 84 | 85 | self.source = 0 86 | self.sink = self.nworkers+self.ntasks+1 87 | 88 | self.supplies = [self.nworkers]+(self.ntasks+self.nworkers)*[0]+[-self.nworkers] 89 | self.start_nodes = list() 90 | self.end_nodes = list() 91 | self.capacities = list() 92 | self.common_costs = list() 93 | 94 | for work_idx in range(self.nworkers): 95 | self.start_nodes.append(self.source) 96 | self.end_nodes.append(work_idx+1) 97 | self.capacities.append(1) 98 | self.common_costs.append(0) 99 | 100 | for task_idx in range(self.ntasks): 101 | self.start_nodes.append(self.nworkers+1+task_idx) 102 | self.end_nodes.append(self.sink) 103 | self.capacities.append(1) 104 | self.common_costs.append(0) 105 | 106 | for work_idx in range(self.nworkers): 107 | for task_idx in range(self.ntasks): 108 | self.start_nodes.append(work_idx+1) 109 | self.end_nodes.append(self.nworkers+1+task_idx) 110 | self.capacities.append(1) 111 | 112 | self.nnodes = len(self.start_nodes) 113 | 114 | def solve(self, array): 115 | assert array.shape == (self.nworkers, self.ntasks), "Wrong array shape, it should be ({}, {})".format(self.nworkers, self.ntasks) 116 | 117 | self.array = self.value*array 118 | self.array = -self.array # potential to cost 119 | self.array = self.array.astype(np.int32) 120 | 121 | costs = copy.copy(self.common_costs) 122 | for work_idx in range(self.nworkers): 123 | for task_idx in range(self.ntasks): 124 | costs.append(self.array[work_idx][task_idx]) 125 | 126 | costs = np.array(costs) 127 | costs = (costs.tolist()) 128 | 129 | assert len(costs)==self.nnodes, "Length of costs should be {} but {}".format(self.nnodes, len(costs)) 130 | 131 | min_cost_flow = pywrapgraph.SimpleMinCostFlow() 132 | for idx in range(self.nnodes): 133 | min_cost_flow.AddArcWithCapacityAndUnitCost(self.start_nodes[idx], self.end_nodes[idx], self.capacities[idx], costs[idx]) 134 | for idx in range(self.ntasks+self.nworkers+2): 135 | min_cost_flow.SetNodeSupply(idx, self.supplies[idx]) 136 | 137 | min_cost_flow.Solve() 138 | results = list() 139 | for arc in range(min_cost_flow.NumArcs()): 140 | if min_cost_flow.Tail(arc)!=self.source and min_cost_flow.Head(arc)!=self.sink: 141 | if min_cost_flow.Flow(arc)>0: 142 | results.append([min_cost_flow.Tail(arc)-1, min_cost_flow.Head(arc)-self.nworkers-1]) 143 | 144 | results_np = np.zeros_like(array) 145 | for i,j in results: results_np[i][j]=1 146 | return results, results_np 147 | -------------------------------------------------------------------------------- /Dsprites_exp/CascadeVAE-C/model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | sys.path.append(os.path.join(os.path.abspath(os.path.dirname(__file__)), '../..')) 4 | 5 | from utils.general_class import ModelPlugin 6 | from utils.ortools_op import SolveMaxMatching 7 | from utils.visual_op import matrix_image2big_image 8 | from utils.writer_op import write_pkl, write_gif 9 | from utils.tqdm_op import tqdm_range 10 | from utils.eval_op import DisentanglemetricFactorMask, DisentanglemetricFactorJointMask 11 | from utils.np_op import np_softmax 12 | 13 | from tfops.transform_op import apply_tf_op, apply_tf_op_multi_output, apply_tf_op_multi_input 14 | from tfops.train_op import get_train_op_v2 15 | from tfops.lr_op import DECAY_DICT, DECAY_PARAMS_DICT 16 | from tfops.nets import encoder1_64, decoder1_64 17 | from tfops.loss import sigmoid_cross_entropy_without_mean, vae_kl_cost_weight 18 | 19 | import tensorflow as tf 20 | import numpy as np 21 | 22 | class Model(ModelPlugin): 23 | def __init__(self, dataset, logfilepath, args): 24 | super().__init__(dataset, logfilepath, args) 25 | self.build() 26 | 27 | def build(self): 28 | self.logger.info("Model building starts") 29 | tf.reset_default_graph() 30 | tf.set_random_seed(self.args.rseed) 31 | 32 | self.input1 = tf.placeholder(tf.float32, shape = [self.args.nbatch, self.height, self.width, self.nchannel]) 33 | self.istrain = tf.placeholder(tf.bool, shape= []) 34 | 35 | self.generate_sess() 36 | # Encoding 37 | self.encoder_net = encoder1_64 38 | self.decoder_net = decoder1_64 39 | 40 | # Encoder 41 | self.mean_total, self.stddev_total = tf.split(self.encoder_net(self.input1, output_dim=2*self.args.nconti, scope='encoder', reuse=False)['output'], num_or_size_splits=2, axis=1) 42 | self.stddev_total = tf.nn.softplus(self.stddev_total) 43 | 44 | self.z_sample = tf.add(self.mean_total, tf.multiply(self.stddev_total, tf.random_normal([self.args.nbatch, self.args.nconti]))) 45 | 46 | self.dec_output = self.decoder_net(z=self.z_sample, output_channel=self.nchannel, scope="decoder", reuse=False)['output'] 47 | # Unary vector 48 | self.rec_cost_vector = sigmoid_cross_entropy_without_mean(labels=self.input1, logits=self.dec_output) 49 | self.rec_cost = tf.reduce_mean(self.rec_cost_vector) 50 | 51 | self.loss_list = list() 52 | for idx in range(self.args.nconti): 53 | weight = tf.constant(np.array((idx+1)*[self.args.beta_min] + (self.args.nconti-idx-1)*[self.args.beta_max]), dtype=tf.float32) 54 | kl_cost = vae_kl_cost_weight(mean=self.mean_total, stddev=self.stddev_total, weight=weight) 55 | self.loss_list.append(self.rec_cost+kl_cost+tf.losses.get_regularization_loss()) 56 | 57 | # Decode 58 | self.latent_ph = tf.placeholder(tf.float32, shape = [self.args.nbatch, self.args.nconti]) 59 | self.dec_output_ph = tf.nn.sigmoid(self.decoder_net(z=self.latent_ph, output_channel=self.nchannel, scope="decoder", reuse=True)['output']) 60 | 61 | self.logger.info("Model building ends") 62 | 63 | def decode(self, latent_input): 64 | return apply_tf_op(inputs=latent_input, session=self.sess, input_gate=self.latent_ph, output_gate=self.dec_output_ph, batch_size=self.args.nbatch) 65 | 66 | def set_up_train(self): 67 | self.logger.info("Model setting up train starts") 68 | 69 | if not hasattr(self, 'start_iter'): self.start_iter = 0 70 | self.logger.info("Start iter: {}".format(self.start_iter)) 71 | 72 | decay_func = DECAY_DICT[self.args.dtype] 73 | decay_params = DECAY_PARAMS_DICT[self.args.dtype][self.args.nbatch][self.args.dptype].copy() 74 | decay_params['initial_step'] = self.start_iter 75 | 76 | self.lr, update_step_op = decay_func(**decay_params) 77 | self.update_step_op = [update_step_op] 78 | 79 | var_list = [v for v in tf.trainable_variables() if 'encoder' in v.name] + [v for v in tf.trainable_variables() if 'decoder' in v.name] 80 | 81 | with tf.control_dependencies(tf.get_collection("update_ops")): 82 | self.train_op_list = [get_train_op_v2(tf.train.AdamOptimizer(learning_rate=self.lr, beta1=0.9, beta2=0.999), loss=self.loss_list[v], var_list=var_list) for v in range(self.args.nconti)] 83 | self.logger.info("Model setting up train ends") 84 | 85 | def run_batch(self, train_idx): 86 | feed_dict = dict() 87 | feed_dict[self.input1] = self.dataset.next_batch(batch_size=self.args.nbatch)[0] 88 | feed_dict[self.istrain] = True 89 | idx = min(train_idx, self.args.nconti-1) 90 | self.sess.run([self.train_op_list[idx]], feed_dict=feed_dict) 91 | 92 | def train(self, niter, piter, siter, save_dir=None, asset_dir=None): 93 | self.logger.info("Model training starts") 94 | 95 | final_iter = self.start_iter+niter 96 | max_accuracy = -1 97 | 98 | for iter_ in tqdm_range(self.start_iter, final_iter): 99 | train_idx = (iter_ - self.start_iter)//piter 100 | self.run_batch(train_idx) 101 | 102 | if (iter_+1)%siter==0 or iter_+1==final_iter: 103 | accuracy = self.evaluate() 104 | 105 | self.latent_traversal_gif(path=asset_dir+'{}.gif'.format(iter_+1)) 106 | if max_accuracy==-1 or max_accuracy=self.args.ntime: 109 | idx = min(train_idx, self.args.nconti) 110 | else: 111 | idx = min(train_idx+1, self.args.nconti) 112 | self.sess.run(self.train_op_dict[idx], feed_dict=feed_dict) 113 | 114 | def train(self, niter, piter, siter, save_dir=None, asset_dir=None): 115 | self.logger.info("Model training starts") 116 | 117 | final_iter = self.start_iter+niter 118 | max_accuracy = -1 119 | 120 | for iter_ in tqdm_range(self.start_iter, final_iter): 121 | train_idx = (iter_ - self.start_iter)//piter 122 | self.run_batch(train_idx) 123 | 124 | if (iter_+1)%siter==0 or iter_+1==final_iter: 125 | include_discrete = False if train_idx < self.args.ntime else True 126 | accuracy = self.evaluate(include_discrete=include_discrete) 127 | 128 | self.latent_traversal_gif(path=asset_dir+'{}.gif'.format(iter_+1), include_discrete=include_discrete) 129 | if max_accuracy==-1 or max_accuracy