├── EBMs ├── README.md ├── ais.py ├── custom_adam.py ├── data.py ├── ebm_combine.py ├── ebm_sandbox.py ├── fid.py ├── hmc.py ├── imagenet_demo.py ├── imagenet_preprocessing.py ├── inception.py ├── models.py ├── requirements.txt ├── test_inception.py ├── train.py └── utils.py ├── README.md ├── books └── advances_in_financial_machine_learning.pdf ├── crypto_agents ├── README.md ├── strategy_workflow │ ├── README.md │ ├── backtesting.md │ ├── data_analysis.md │ ├── defi_glossary.md │ ├── live_trading.md │ ├── optimization.md │ ├── paper_trading.md │ ├── policy.md │ ├── strategy_metrics.md │ └── supervised_learning.md └── trading_on_gmx.md ├── deep_learning ├── README.md ├── deep_learning.md └── reinforcement_learning.md └── llms ├── README.md ├── claude └── README.md ├── deepseek └── README.md ├── eliza └── README.md └── gpt └── README.md /EBMs/README.md: -------------------------------------------------------------------------------- 1 | ## quantum ai: training energy-based-models using openai 2 | 3 |
4 | 5 | 6 | #### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf) 7 | 8 |
9 | 10 | ### installing 11 | 12 |
13 | 14 | ```bash 15 | brew install gcc@6 16 | brew install open-mpi 17 | brew install pkg-config 18 | ``` 19 | 20 |
21 | 22 | * there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with: 23 | 24 |
25 | 26 | ``` 27 | export PMIX_MCA_gds=^ds12 28 | ``` 29 | 30 |
31 | 32 | * then install python's requirements: 33 | 34 |
35 | 36 | ```bash 37 | virtualenv venv 38 | source venv/bin/activate 39 | pip install -r requirements.txt 40 | ``` 41 |
42 | 43 | * note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct 44 | * finally, download and install **[mujoco](https://www.roboti.us/index.html)** 45 | * you will also need to register for a license, which asks for a machine ID 46 | * the documentation on the website is incomplete, so just download the suggested script and run: 47 | 48 |
49 | 50 | ```bash 51 | mv getid_osx getid_osx.dms 52 | ./getid_osx.dms 53 | ``` 54 | 55 |
56 | 57 | --- 58 | 59 | ### running 60 | 61 |
62 | 63 | #### download pre-trained models (examples) 64 | 65 |
66 | 67 | * download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`: 68 | 69 |
70 | 71 | ```bash 72 | mkdir cachedir 73 | ``` 74 | 75 |
76 | 77 | #### setting results directory 78 | 79 |
80 | 81 | * openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)** 82 | * i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py` 83 | 84 |
85 | 86 | #### running (parallelization with `mpiexec`) 87 | 88 |
89 | 90 | * all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command: 91 | 92 |
93 | 94 | ``` 95 | mpiexec -n 96 | ``` 97 | 98 |
99 | 100 | ##### cifar-10 unconditional 101 | 102 |
103 | 104 | ``` 105 | python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model 106 | ``` 107 | 108 | * this should generate the following output: 109 | 110 |
111 | 112 | ```bash 113 | Instructions for updating: 114 | Use tf.gfile.GFile. 115 | 2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization(). 116 | 64 batch size 117 | Local rank: 0 1 118 | Loading data... 119 | Files already downloaded and verified 120 | Files already downloaded and verified 121 | Files already downloaded and verified 122 | Files already downloaded and verified 123 | Done loading... 124 | WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version. 125 | Instructions for updating: 126 | Colocations handled automatically by placer. 127 | Building graph... 128 | WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version. 129 | Instructions for updating: 130 | Use tf.cast instead. 131 | Finished processing loop construction ... 132 | Started gradient computation... 133 | Applying gradients... 134 | Finished applying gradients. 135 | Model has a total of 7567880 parameters 136 | Initializing variables... 137 | Start broadcast 138 | End broadcast 139 | Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l 140 | oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/ 141 | cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268 142 | 3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0 143 | : 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2 144 | _res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908 145 | 2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463 146 | 1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g: 147 | 0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5 148 | _res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/ 149 | b:0: 0.4506262540817261, 150 | 151 | ................................................................................................................................ 152 | Inception score of 1.2397289276123047 with std of 0.0 153 | ``` 154 | 155 |
156 | 157 | ##### cifar-10 conditional 158 | 159 |
160 | 161 | ``` 162 | python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass 163 | ``` 164 | 165 |
166 | 167 | ##### imagenet 32x32 conditional 168 | 169 |
170 | 171 | ``` 172 | python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path= 173 | ``` 174 | 175 |
176 | 177 | ##### imagenet 128x128 conditional 178 | 179 |
180 | 181 | ``` 182 | python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir= 183 | ``` 184 | 185 |
186 | 187 | ##### imagenet demo 188 | 189 |
190 | 191 | * the imagenet_demo.py file contains code for experiments with ebms on conditional imagenet 128x128 192 | * to generate a gif on sampling, you can run the command: 193 | 194 |
195 | 196 | ``` 197 | python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act 198 | ``` 199 | 200 | * the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file 201 | * for example, to visualize cross class mappings in cifar-10, you can run: 202 | 203 |
204 | 205 | ``` 206 | python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resume_iter=74700 207 | ``` 208 | 209 |
210 | 211 | ##### generalization 212 | 213 |
214 | 215 | * to test generalization to out of distribution classification for SVHN (with similar commands for other datasets): 216 | 217 |
218 | 219 | ``` 220 | python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False 221 | ``` 222 | 223 |
224 | 225 | * to test classification on cifar-10 using a conditional model under either L2 or Li perturbations 226 | 227 |
228 | 229 | ``` 230 | python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd= --num_steps=10 --lival=
  • --wider_model 231 | ``` 232 | 233 |
    234 | 235 | ##### concept combination 236 | 237 |
    238 | 239 | * to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below: 240 | 241 |
    242 | 243 | ``` 244 | python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass 245 | ``` 246 | 247 |
    248 | 249 | * once models are trained, they can be sampled from jointly by running: 250 | 251 | ``` 252 | python ebm_combine.py --task=conceptcombine --exp_size= --exp_shape= --exp_pos= --exp_rot= --resume_size= --resume_shape= --resume_rot= --resume_pos= 253 | ``` 254 | -------------------------------------------------------------------------------- /EBMs/ais.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | from hmc import hmc 4 | from tensorflow.python.platform import flags 5 | from torch.utils.data import DataLoader 6 | from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet 7 | from data import Cifar10, Mnist, DSprites 8 | from scipy.misc import logsumexp 9 | from scipy.misc import imsave 10 | from utils import optimistic_restore 11 | import os.path as osp 12 | import numpy as np 13 | from tqdm import tqdm 14 | 15 | flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') 16 | flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss') 17 | flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') 18 | flags.DEFINE_string('exp', 'default', 'name of experiments') 19 | flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel') 20 | flags.DEFINE_integer('batch_size', 16, 'Size of inputs') 21 | flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from') 22 | 23 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions') 24 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') 25 | flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais') 26 | flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian') 27 | flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box') 28 | flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model') 29 | flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') 30 | flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') 31 | flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') 32 | flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') 33 | flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not') 34 | flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not') 35 | flags.DEFINE_bool('large_model', False, 'Use large model to evaluate') 36 | flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate') 37 | flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps') 38 | 39 | FLAGS = flags.FLAGS 40 | 41 | label_default = np.eye(10)[0:1, :] 42 | label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32)) 43 | 44 | 45 | def unscale_im(im): 46 | return (255 * np.clip(im, 0, 1)).astype(np.uint8) 47 | 48 | def gauss_prob_log(x, prec=1.0): 49 | 50 | nh = float(np.prod([s.value for s in x.get_shape()[1:]])) 51 | norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec)) 52 | prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec 53 | 54 | return norm_constant_log + prob_density_log 55 | 56 | 57 | def uniform_prob_log(x): 58 | 59 | return tf.zeros(1) 60 | 61 | 62 | def model_prob_log(x, e_func, weights, temp): 63 | if FLAGS.cclass: 64 | batch_size = tf.shape(x)[0] 65 | label_tiled = tf.tile(label_default, (batch_size, 1)) 66 | e_raw = e_func.forward(x, weights, label=label_tiled) 67 | else: 68 | e_raw = e_func.forward(x, weights) 69 | energy = tf.reduce_sum(e_raw, axis=[1]) 70 | return -temp * energy 71 | 72 | 73 | def bridge_prob_neg_log(alpha, x, e_func, weights, temp): 74 | 75 | if FLAGS.dataset == "gauss": 76 | norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature) 77 | else: 78 | norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp) 79 | # Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely 80 | 81 | 82 | if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': 83 | oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1]) 84 | elif FLAGS.dataset == 'mnist': 85 | oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2]) 86 | else: 87 | oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3]) 88 | 89 | return -norm_prob + oob_prob 90 | 91 | 92 | def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10): 93 | if FLAGS.dataset == "2d": 94 | x = tf.placeholder(tf.float32, shape=(None, 2)) 95 | elif FLAGS.dataset == "gauss": 96 | x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim)) 97 | elif FLAGS.dataset == "mnist": 98 | x = tf.placeholder(tf.float32, shape=(None, 28, 28)) 99 | else: 100 | x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3)) 101 | 102 | x_init = x 103 | 104 | alpha_prev = tf.placeholder(tf.float32, shape=()) 105 | alpha_new = tf.placeholder(tf.float32, shape=()) 106 | approx_lr = tf.placeholder(tf.float32, shape=()) 107 | 108 | chain_weights = tf.zeros(batch_size) 109 | # for i in range(1, prop_dist+1): 110 | # print("processing loop {}".format(i)) 111 | # alpha_prev = (i-1) / prop_dist 112 | # alpha_new = i / prop_dist 113 | 114 | prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp) 115 | prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp) 116 | 117 | chain_weights = -prob_log_new_neg + prob_log_old_neg 118 | # chain_weights = tf.Print(chain_weights, [chain_weights]) 119 | 120 | # Sample new x using HMC 121 | def unorm_prob(x): 122 | return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp) 123 | 124 | for j in range(1): 125 | x = hmc(x, approx_lr, hmc_step, unorm_prob) 126 | 127 | return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr 128 | 129 | 130 | def main(): 131 | 132 | # Initialize dataset 133 | if FLAGS.dataset == 'cifar10': 134 | dataset = Cifar10(train=False, rescale=FLAGS.rescale) 135 | channel_num = 3 136 | dim_input = 32 * 32 * 3 137 | elif FLAGS.dataset == 'imagenet': 138 | dataset = ImagenetClass() 139 | channel_num = 3 140 | dim_input = 64 * 64 * 3 141 | elif FLAGS.dataset == 'mnist': 142 | dataset = Mnist(train=False, rescale=FLAGS.rescale) 143 | channel_num = 1 144 | dim_input = 28 * 28 * 1 145 | elif FLAGS.dataset == 'dsprites': 146 | dataset = DSprites() 147 | channel_num = 1 148 | dim_input = 64 * 64 * 1 149 | elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss': 150 | dataset = Box2D() 151 | 152 | dim_output = 1 153 | data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True) 154 | 155 | if FLAGS.dataset == 'mnist': 156 | model = MnistNet(num_channels=channel_num) 157 | elif FLAGS.dataset == 'cifar10': 158 | if FLAGS.large_model: 159 | model = ResNet32Large(num_filters=128) 160 | elif FLAGS.wider_model: 161 | model = ResNet32Wider(num_filters=192) 162 | else: 163 | model = ResNet32(num_channels=channel_num, num_filters=128) 164 | elif FLAGS.dataset == 'dsprites': 165 | model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters) 166 | 167 | weights = model.construct_weights('context_{}'.format(0)) 168 | 169 | config = tf.ConfigProto() 170 | sess = tf.Session(config=config) 171 | saver = loader = tf.train.Saver(max_to_keep=10) 172 | 173 | sess.run(tf.global_variables_initializer()) 174 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 175 | 176 | model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) 177 | resume_itr = FLAGS.resume_iter 178 | 179 | if FLAGS.resume_iter != "-1": 180 | optimistic_restore(sess, model_file) 181 | else: 182 | print("WARNING, YOU ARE NOT LOADING A SAVE FILE") 183 | # saver.restore(sess, model_file) 184 | 185 | chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature) 186 | print("Finished constructing ancestral sample ...................") 187 | 188 | if FLAGS.dataset != "gauss": 189 | comb_weights_cum = [] 190 | batch_size = tf.shape(x_init)[0] 191 | label_tiled = tf.tile(label_default, (batch_size, 1)) 192 | e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled) 193 | e_pos_list = [] 194 | 195 | for data_corrupt, data, label_gt in tqdm(data_loader): 196 | e_pos = sess.run([e_compute], {x_init: data})[0] 197 | e_pos_list.extend(list(e_pos)) 198 | 199 | print(len(e_pos_list)) 200 | print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list)) 201 | 202 | if FLAGS.dataset == "2d": 203 | alr = 0.0045 204 | elif FLAGS.dataset == "gauss": 205 | alr = 0.0085 206 | elif FLAGS.dataset == "mnist": 207 | alr = 0.0065 208 | #90 alr = 0.0035 209 | else: 210 | # alr = 0.0125 211 | if FLAGS.rescale == 8: 212 | alr = 0.0085 213 | else: 214 | alr = 0.0045 215 | # 216 | for i in range(1): 217 | tot_weight = 0 218 | for j in tqdm(range(1, FLAGS.pdist+1)): 219 | if j == 1: 220 | if FLAGS.dataset == "cifar10": 221 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3)) 222 | elif FLAGS.dataset == "gauss": 223 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim)) 224 | elif FLAGS.dataset == "mnist": 225 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28)) 226 | else: 227 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2)) 228 | 229 | alpha_prev = (j-1) / FLAGS.pdist 230 | alpha_new = j / FLAGS.pdist 231 | cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))}) 232 | tot_weight = tot_weight + cweight 233 | 234 | print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight)) 235 | 236 | tot_weight = 0 237 | 238 | for j in tqdm(range(FLAGS.pdist, 0, -1)): 239 | alpha_new = (j-1) / FLAGS.pdist 240 | alpha_prev = j / FLAGS.pdist 241 | cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))}) 242 | tot_weight = tot_weight - cweight 243 | 244 | print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight)) 245 | 246 | 247 | 248 | if __name__ == "__main__": 249 | main() 250 | -------------------------------------------------------------------------------- /EBMs/custom_adam.py: -------------------------------------------------------------------------------- 1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Adam for TensorFlow.""" 17 | from __future__ import absolute_import 18 | from __future__ import division 19 | from __future__ import print_function 20 | 21 | from tensorflow.python.eager import context 22 | from tensorflow.python.framework import ops 23 | from tensorflow.python.ops import control_flow_ops 24 | from tensorflow.python.ops import math_ops 25 | from tensorflow.python.ops import resource_variable_ops 26 | from tensorflow.python.ops import state_ops 27 | from tensorflow.python.training import optimizer 28 | from tensorflow.python.training import training_ops 29 | from tensorflow.python.util.tf_export import tf_export 30 | import tensorflow as tf 31 | 32 | 33 | @tf_export("train.AdamOptimizer") 34 | class AdamOptimizer(optimizer.Optimizer): 35 | """Optimizer that implements the Adam algorithm. 36 | 37 | See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980) 38 | ([pdf](http://arxiv.org/pdf/1412.6980.pdf)). 39 | """ 40 | 41 | def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8, 42 | use_locking=False, name="Adam"): 43 | """Construct a new Adam optimizer. 44 | 45 | Initialization: 46 | 47 | $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$ 48 | $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$ 49 | $$t := 0 \text{(Initialize timestep)}$$ 50 | 51 | The update rule for `variable` with gradient `g` uses an optimization 52 | described at the end of section2 of the paper: 53 | 54 | $$t := t + 1$$ 55 | $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$ 56 | 57 | $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$ 58 | $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$ 59 | $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$ 60 | 61 | The default value of 1e-8 for epsilon might not be a good default in 62 | general. For example, when training an Inception network on ImageNet a 63 | current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the 64 | formulation just before Section 2.1 of the Kingma and Ba paper rather than 65 | the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon 66 | hat" in the paper. 67 | 68 | The sparse implementation of this algorithm (used when the gradient is an 69 | IndexedSlices object, typically because of `tf.gather` or an embedding 70 | lookup in the forward pass) does apply momentum to variable slices even if 71 | they were not used in the forward pass (meaning they have a gradient equal 72 | to zero). Momentum decay (beta1) is also applied to the entire momentum 73 | accumulator. This means that the sparse behavior is equivalent to the dense 74 | behavior (in contrast to some momentum implementations which ignore momentum 75 | unless a variable slice was actually used). 76 | 77 | Args: 78 | learning_rate: A Tensor or a floating point value. The learning rate. 79 | beta1: A float value or a constant float tensor. 80 | The exponential decay rate for the 1st moment estimates. 81 | beta2: A float value or a constant float tensor. 82 | The exponential decay rate for the 2nd moment estimates. 83 | epsilon: A small constant for numerical stability. This epsilon is 84 | "epsilon hat" in the Kingma and Ba paper (in the formula just before 85 | Section 2.1), not the epsilon in Algorithm 1 of the paper. 86 | use_locking: If True use locks for update operations. 87 | name: Optional name for the operations created when applying gradients. 88 | Defaults to "Adam". 89 | 90 | @compatibility(eager) 91 | When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and 92 | `epsilon` can each be a callable that takes no arguments and returns the 93 | actual value to use. This can be useful for changing these values across 94 | different invocations of optimizer functions. 95 | @end_compatibility 96 | """ 97 | super(AdamOptimizer, self).__init__(use_locking, name) 98 | self._lr = learning_rate 99 | self._beta1 = beta1 100 | self._beta2 = beta2 101 | self._epsilon = epsilon 102 | 103 | # Tensor versions of the constructor arguments, created in _prepare(). 104 | self._lr_t = None 105 | self._beta1_t = None 106 | self._beta2_t = None 107 | self._epsilon_t = None 108 | 109 | # Created in SparseApply if needed. 110 | self._updated_lr = None 111 | 112 | def _get_beta_accumulators(self): 113 | with ops.init_scope(): 114 | if context.executing_eagerly(): 115 | graph = None 116 | else: 117 | graph = ops.get_default_graph() 118 | return (self._get_non_slot_variable("beta1_power", graph=graph), 119 | self._get_non_slot_variable("beta2_power", graph=graph)) 120 | 121 | def _create_slots(self, var_list): 122 | # Create the beta1 and beta2 accumulators on the same device as the first 123 | # variable. Sort the var_list to make sure this device is consistent across 124 | # workers (these need to go on the same PS, otherwise some updates are 125 | # silently ignored). 126 | first_var = min(var_list, key=lambda x: x.name) 127 | self._create_non_slot_variable(initial_value=self._beta1, 128 | name="beta1_power", 129 | colocate_with=first_var) 130 | self._create_non_slot_variable(initial_value=self._beta2, 131 | name="beta2_power", 132 | colocate_with=first_var) 133 | 134 | # Create slots for the first and second moments. 135 | for v in var_list: 136 | self._zeros_slot(v, "m", self._name) 137 | self._zeros_slot(v, "v", self._name) 138 | 139 | def _prepare(self): 140 | lr = self._call_if_callable(self._lr) 141 | beta1 = self._call_if_callable(self._beta1) 142 | beta2 = self._call_if_callable(self._beta2) 143 | epsilon = self._call_if_callable(self._epsilon) 144 | 145 | self._lr_t = ops.convert_to_tensor(lr, name="learning_rate") 146 | self._beta1_t = ops.convert_to_tensor(beta1, name="beta1") 147 | self._beta2_t = ops.convert_to_tensor(beta2, name="beta2") 148 | self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon") 149 | 150 | def _apply_dense(self, grad, var): 151 | m = self.get_slot(var, "m") 152 | v = self.get_slot(var, "v") 153 | beta1_power, beta2_power = self._get_beta_accumulators() 154 | 155 | clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1 156 | grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds) 157 | # Clip gradients by 3 std 158 | return training_ops.apply_adam( 159 | var, m, v, 160 | math_ops.cast(beta1_power, var.dtype.base_dtype), 161 | math_ops.cast(beta2_power, var.dtype.base_dtype), 162 | math_ops.cast(self._lr_t, var.dtype.base_dtype), 163 | math_ops.cast(self._beta1_t, var.dtype.base_dtype), 164 | math_ops.cast(self._beta2_t, var.dtype.base_dtype), 165 | math_ops.cast(self._epsilon_t, var.dtype.base_dtype), 166 | grad, use_locking=self._use_locking).op 167 | 168 | def _resource_apply_dense(self, grad, var): 169 | m = self.get_slot(var, "m") 170 | v = self.get_slot(var, "v") 171 | beta1_power, beta2_power = self._get_beta_accumulators() 172 | return training_ops.resource_apply_adam( 173 | var.handle, m.handle, v.handle, 174 | math_ops.cast(beta1_power, grad.dtype.base_dtype), 175 | math_ops.cast(beta2_power, grad.dtype.base_dtype), 176 | math_ops.cast(self._lr_t, grad.dtype.base_dtype), 177 | math_ops.cast(self._beta1_t, grad.dtype.base_dtype), 178 | math_ops.cast(self._beta2_t, grad.dtype.base_dtype), 179 | math_ops.cast(self._epsilon_t, grad.dtype.base_dtype), 180 | grad, use_locking=self._use_locking) 181 | 182 | def _apply_sparse_shared(self, grad, var, indices, scatter_add): 183 | beta1_power, beta2_power = self._get_beta_accumulators() 184 | beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype) 185 | beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype) 186 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype) 187 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype) 188 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype) 189 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype) 190 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power)) 191 | # m_t = beta1 * m + (1 - beta1) * g_t 192 | m = self.get_slot(var, "m") 193 | m_scaled_g_values = grad * (1 - beta1_t) 194 | m_t = state_ops.assign(m, m * beta1_t, 195 | use_locking=self._use_locking) 196 | with ops.control_dependencies([m_t]): 197 | m_t = scatter_add(m, indices, m_scaled_g_values) 198 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t) 199 | v = self.get_slot(var, "v") 200 | v_scaled_g_values = (grad * grad) * (1 - beta2_t) 201 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking) 202 | with ops.control_dependencies([v_t]): 203 | v_t = scatter_add(v, indices, v_scaled_g_values) 204 | v_sqrt = math_ops.sqrt(v_t) 205 | var_update = state_ops.assign_sub(var, 206 | lr * m_t / (v_sqrt + epsilon_t), 207 | use_locking=self._use_locking) 208 | return control_flow_ops.group(*[var_update, m_t, v_t]) 209 | 210 | def _apply_sparse(self, grad, var): 211 | return self._apply_sparse_shared( 212 | grad.values, var, grad.indices, 213 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda 214 | x, i, v, use_locking=self._use_locking)) 215 | 216 | def _resource_scatter_add(self, x, i, v): 217 | with ops.control_dependencies( 218 | [resource_variable_ops.resource_scatter_add( 219 | x.handle, i, v)]): 220 | return x.value() 221 | 222 | def _resource_apply_sparse(self, grad, var, indices): 223 | return self._apply_sparse_shared( 224 | grad, var, indices, self._resource_scatter_add) 225 | 226 | def _finish(self, update_ops, name_scope): 227 | # Update the power accumulators. 228 | with ops.control_dependencies(update_ops): 229 | beta1_power, beta2_power = self._get_beta_accumulators() 230 | with ops.colocate_with(beta1_power): 231 | update_beta1 = beta1_power.assign( 232 | beta1_power * self._beta1_t, use_locking=self._use_locking) 233 | update_beta2 = beta2_power.assign( 234 | beta2_power * self._beta2_t, use_locking=self._use_locking) 235 | return control_flow_ops.group(*update_ops + [update_beta1, update_beta2], 236 | name=name_scope) 237 | -------------------------------------------------------------------------------- /EBMs/data.py: -------------------------------------------------------------------------------- 1 | from tensorflow.python.platform import flags 2 | from tensorflow.contrib.data.python.ops import batching, threadpool 3 | import tensorflow as tf 4 | import json 5 | from torch.utils.data import Dataset 6 | import pickle 7 | import os.path as osp 8 | import os 9 | import numpy as np 10 | import time 11 | from scipy.misc import imread, imresize 12 | from skimage.color import rgb2grey 13 | from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder 14 | from torchvision import transforms 15 | from imagenet_preprocessing import ImagenetPreprocessor 16 | import torch 17 | import torchvision 18 | 19 | FLAGS = flags.FLAGS 20 | ROOT_DIR = "./results" 21 | 22 | # Dataset Options 23 | flags.DEFINE_string('dsprites_path', 24 | '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 25 | 'path to dsprites characters') 26 | flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image') 27 | flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes') 28 | flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes') 29 | flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects') 30 | flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects') 31 | flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects') 32 | flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images') 33 | 34 | 35 | # Data augmentation options 36 | flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image') 37 | flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout') 38 | flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout') 39 | flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data') 40 | 41 | 42 | def cutout(mask_color=(0, 0, 0)): 43 | mask_size_half = FLAGS.cutout_mask_size // 2 44 | offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0 45 | 46 | def _cutout(image): 47 | image = np.asarray(image).copy() 48 | 49 | if np.random.random() > FLAGS.cutout_prob: 50 | return image 51 | 52 | h, w = image.shape[:2] 53 | 54 | if FLAGS.cutout_inside: 55 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half 56 | cymin, cymax = mask_size_half, h + offset - mask_size_half 57 | else: 58 | cxmin, cxmax = 0, w + offset 59 | cymin, cymax = 0, h + offset 60 | 61 | cx = np.random.randint(cxmin, cxmax) 62 | cy = np.random.randint(cymin, cymax) 63 | xmin = cx - mask_size_half 64 | ymin = cy - mask_size_half 65 | xmax = xmin + FLAGS.cutout_mask_size 66 | ymax = ymin + FLAGS.cutout_mask_size 67 | xmin = max(0, xmin) 68 | ymin = max(0, ymin) 69 | xmax = min(w, xmax) 70 | ymax = min(h, ymax) 71 | image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None] 72 | return image 73 | 74 | return _cutout 75 | 76 | 77 | class TFImagenetLoader(Dataset): 78 | 79 | def __init__(self, split, batchsize, idx, num_workers, rescale=1): 80 | IMAGENET_NUM_TRAIN_IMAGES = 1281167 81 | IMAGENET_NUM_VAL_IMAGES = 50000 82 | 83 | self.rescale = rescale 84 | 85 | if split == "train": 86 | im_length = IMAGENET_NUM_TRAIN_IMAGES 87 | records_to_skip = im_length * idx // num_workers 88 | records_to_read = im_length * (idx + 1) // num_workers - records_to_skip 89 | else: 90 | im_length = IMAGENET_NUM_VAL_IMAGES 91 | 92 | self.curr_sample = 0 93 | 94 | index_path = osp.join(FLAGS.imagenet_datadir, 'index.json') 95 | with open(index_path) as f: 96 | metadata = json.load(f) 97 | counts = metadata['record_counts'] 98 | 99 | if split == 'train': 100 | file_names = list(sorted([x for x in counts.keys() if x.startswith('train')])) 101 | 102 | result_records_to_skip = None 103 | files = [] 104 | for filename in file_names: 105 | records_in_file = counts[filename] 106 | if records_to_skip >= records_in_file: 107 | records_to_skip -= records_in_file 108 | continue 109 | elif records_to_read > 0: 110 | if result_records_to_skip is None: 111 | # Record the number to skip in the first file 112 | result_records_to_skip = records_to_skip 113 | files.append(filename) 114 | records_to_read -= (records_in_file - records_to_skip) 115 | records_to_skip = 0 116 | else: 117 | break 118 | else: 119 | files = list(sorted([x for x in counts.keys() if x.startswith('validation')])) 120 | 121 | files = [osp.join(FLAGS.imagenet_datadir, x) for x in files] 122 | preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess 123 | 124 | ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string) 125 | ds = ds.apply(tf.data.TFRecordDataset) 126 | ds = ds.take(im_length) 127 | ds = ds.prefetch(buffer_size=FLAGS.batch_size) 128 | ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000)) 129 | ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4)) 130 | ds = ds.prefetch(buffer_size=2) 131 | 132 | ds_iterator = ds.make_initializable_iterator() 133 | labels, images = ds_iterator.get_next() 134 | self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0) 135 | self.labels = labels 136 | 137 | config = tf.ConfigProto(device_count = {'GPU': 0}) 138 | sess = tf.Session(config=config) 139 | sess.run(ds_iterator.initializer) 140 | 141 | self.im_length = im_length // batchsize 142 | 143 | self.sess = sess 144 | 145 | def __next__(self): 146 | self.curr_sample += 1 147 | 148 | sess = self.sess 149 | 150 | im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3)) 151 | label, im = sess.run([self.labels, self.images]) 152 | im = im * self.rescale 153 | label = np.eye(1000)[label.squeeze() - 1] 154 | im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label) 155 | return im_corrupt, im, label 156 | 157 | def __iter__(self): 158 | return self 159 | 160 | def __len__(self): 161 | return self.im_length 162 | 163 | class CelebA(Dataset): 164 | 165 | def __init__(self): 166 | self.path = "/root/data/img_align_celeba" 167 | self.ims = os.listdir(self.path) 168 | self.ims = [osp.join(self.path, im) for im in self.ims] 169 | 170 | def __len__(self): 171 | return len(self.ims) 172 | 173 | def __getitem__(self, index): 174 | label = 1 175 | 176 | if FLAGS.single: 177 | index = 0 178 | 179 | path = self.ims[index] 180 | im = imread(path) 181 | im = imresize(im, (32, 32)) 182 | image_size = 32 183 | im = im / 255. 184 | 185 | if FLAGS.datasource == 'default': 186 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) 187 | elif FLAGS.datasource == 'random': 188 | im_corrupt = np.random.uniform( 189 | 0, 1, size=(image_size, image_size, 3)) 190 | 191 | return im_corrupt, im, label 192 | 193 | 194 | class Cifar10(Dataset): 195 | def __init__( 196 | self, 197 | train=True, 198 | full=False, 199 | augment=False, 200 | noise=True, 201 | rescale=1.0): 202 | 203 | if augment: 204 | transform_list = [ 205 | torchvision.transforms.RandomCrop(32, padding=4), 206 | torchvision.transforms.RandomHorizontalFlip(), 207 | torchvision.transforms.ToTensor(), 208 | ] 209 | 210 | if FLAGS.cutout: 211 | transform_list.append(cutout()) 212 | 213 | transform = transforms.Compose(transform_list) 214 | else: 215 | transform = transforms.ToTensor() 216 | 217 | self.full = full 218 | self.data = CIFAR10( 219 | ROOT_DIR, 220 | transform=transform, 221 | train=train, 222 | download=True) 223 | self.test_data = CIFAR10( 224 | ROOT_DIR, 225 | transform=transform, 226 | train=False, 227 | download=True) 228 | self.one_hot_map = np.eye(10) 229 | self.noise = noise 230 | self.rescale = rescale 231 | 232 | def __len__(self): 233 | 234 | if self.full: 235 | return len(self.data) + len(self.test_data) 236 | else: 237 | return len(self.data) 238 | 239 | def __getitem__(self, index): 240 | if not FLAGS.single: 241 | if self.full: 242 | if index >= len(self.data): 243 | im, label = self.test_data[index - len(self.data)] 244 | else: 245 | im, label = self.data[index] 246 | else: 247 | im, label = self.data[index] 248 | else: 249 | im, label = self.data[0] 250 | 251 | im = np.transpose(im, (1, 2, 0)).numpy() 252 | image_size = 32 253 | label = self.one_hot_map[label] 254 | 255 | im = im * 255 / 256 256 | 257 | if self.noise: 258 | im = im * self.rescale + \ 259 | np.random.uniform(0, self.rescale * 1 / 256., im.shape) 260 | 261 | np.random.seed((index + int(time.time() * 1e7)) % 2**32) 262 | 263 | if FLAGS.datasource == 'default': 264 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) 265 | elif FLAGS.datasource == 'random': 266 | im_corrupt = np.random.uniform( 267 | 0.0, self.rescale, (image_size, image_size, 3)) 268 | 269 | return im_corrupt, im, label 270 | 271 | 272 | class Cifar100(Dataset): 273 | def __init__(self, train=True, augment=False): 274 | 275 | if augment: 276 | transform_list = [ 277 | torchvision.transforms.RandomCrop(32, padding=4), 278 | torchvision.transforms.RandomHorizontalFlip(), 279 | torchvision.transforms.ToTensor(), 280 | ] 281 | 282 | if FLAGS.cutout: 283 | transform_list.append(cutout()) 284 | 285 | transform = transforms.Compose(transform_list) 286 | else: 287 | transform = transforms.ToTensor() 288 | 289 | self.data = CIFAR100( 290 | "/root/cifar100", 291 | transform=transform, 292 | train=train, 293 | download=True) 294 | self.one_hot_map = np.eye(100) 295 | 296 | def __len__(self): 297 | return len(self.data) 298 | 299 | def __getitem__(self, index): 300 | if not FLAGS.single: 301 | im, label = self.data[index] 302 | else: 303 | im, label = self.data[0] 304 | 305 | im = np.transpose(im, (1, 2, 0)).numpy() 306 | image_size = 32 307 | label = self.one_hot_map[label] 308 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) 309 | np.random.seed((index + int(time.time() * 1e7)) % 2**32) 310 | 311 | if FLAGS.datasource == 'default': 312 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) 313 | elif FLAGS.datasource == 'random': 314 | im_corrupt = np.random.uniform( 315 | 0.0, 1.0, (image_size, image_size, 3)) 316 | 317 | return im_corrupt, im, label 318 | 319 | 320 | class Svhn(Dataset): 321 | def __init__(self, train=True, augment=False): 322 | 323 | transform = transforms.ToTensor() 324 | 325 | self.data = SVHN("/root/svhn", transform=transform, download=True) 326 | self.one_hot_map = np.eye(10) 327 | 328 | def __len__(self): 329 | return len(self.data) 330 | 331 | def __getitem__(self, index): 332 | if not FLAGS.single: 333 | im, label = self.data[index] 334 | else: 335 | em, label = self.data[0] 336 | 337 | im = np.transpose(im, (1, 2, 0)).numpy() 338 | image_size = 32 339 | label = self.one_hot_map[label] 340 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) 341 | np.random.seed((index + int(time.time() * 1e7)) % 2**32) 342 | 343 | if FLAGS.datasource == 'default': 344 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) 345 | elif FLAGS.datasource == 'random': 346 | im_corrupt = np.random.uniform( 347 | 0.0, 1.0, (image_size, image_size, 3)) 348 | 349 | return im_corrupt, im, label 350 | 351 | 352 | class Mnist(Dataset): 353 | def __init__(self, train=True, rescale=1.0): 354 | self.data = MNIST( 355 | "/root/mnist", 356 | transform=transforms.ToTensor(), 357 | download=True, train=train) 358 | self.labels = np.eye(10) 359 | self.rescale = rescale 360 | 361 | def __len__(self): 362 | return len(self.data) 363 | 364 | def __getitem__(self, index): 365 | im, label = self.data[index] 366 | label = self.labels[label] 367 | im = im.squeeze() 368 | # im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28)) 369 | # im = im.numpy() / 2 + 0.2 370 | im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28)) 371 | im = im * self.rescale 372 | image_size = 28 373 | 374 | if FLAGS.datasource == 'default': 375 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size) 376 | elif FLAGS.datasource == 'random': 377 | im_corrupt = np.random.uniform(0, self.rescale, (28, 28)) 378 | 379 | return im_corrupt, im, label 380 | 381 | 382 | class DSprites(Dataset): 383 | def __init__( 384 | self, 385 | cond_size=False, 386 | cond_shape=False, 387 | cond_pos=False, 388 | cond_rot=False): 389 | dat = np.load(FLAGS.dsprites_path) 390 | 391 | if FLAGS.dshape_only: 392 | l = dat['latents_values'] 393 | mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 / 394 | 31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) 395 | self.data = np.tile(dat['imgs'][mask], (10000, 1, 1)) 396 | self.label = np.tile(dat['latents_values'][mask], (10000, 1)) 397 | self.label = self.label[:, 1:2] 398 | elif FLAGS.dpos_only: 399 | l = dat['latents_values'] 400 | # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) 401 | mask = (l[:, 1] == 1) & ( 402 | l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) 403 | self.data = np.tile(dat['imgs'][mask], (100, 1, 1)) 404 | self.label = np.tile(dat['latents_values'][mask], (100, 1)) 405 | self.label = self.label[:, 4:] + 0.5 406 | elif FLAGS.dsize_only: 407 | l = dat['latents_values'] 408 | # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39) 409 | mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 / 410 | 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) 411 | self.data = np.tile(dat['imgs'][mask], (10000, 1, 1)) 412 | self.label = np.tile(dat['latents_values'][mask], (10000, 1)) 413 | self.label = (self.label[:, 2:3]) 414 | elif FLAGS.drot_only: 415 | l = dat['latents_values'] 416 | mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 / 417 | 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1) 418 | self.data = np.tile(dat['imgs'][mask], (100, 1, 1)) 419 | self.label = np.tile(dat['latents_values'][mask], (100, 1)) 420 | self.label = (self.label[:, 3:4]) 421 | self.label = np.concatenate( 422 | [np.cos(self.label), np.sin(self.label)], axis=1) 423 | elif FLAGS.dsprites_restrict: 424 | l = dat['latents_values'] 425 | mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39) 426 | 427 | self.data = dat['imgs'][mask] 428 | self.label = dat['latents_values'][mask] 429 | else: 430 | self.data = dat['imgs'] 431 | self.label = dat['latents_values'] 432 | 433 | if cond_size: 434 | self.label = self.label[:, 2:3] 435 | elif cond_shape: 436 | self.label = self.label[:, 1:2] 437 | elif cond_pos: 438 | self.label = self.label[:, 4:] 439 | elif cond_rot: 440 | self.label = self.label[:, 3:4] 441 | self.label = np.concatenate( 442 | [np.cos(self.label), np.sin(self.label)], axis=1) 443 | else: 444 | self.label = self.label[:, 1:2] 445 | 446 | self.identity = np.eye(3) 447 | 448 | def __len__(self): 449 | return self.data.shape[0] 450 | 451 | def __getitem__(self, index): 452 | im = self.data[index] 453 | image_size = 64 454 | 455 | if not ( 456 | FLAGS.dpos_only or FLAGS.dsize_only) and ( 457 | not FLAGS.cond_size) and ( 458 | not FLAGS.cond_pos) and ( 459 | not FLAGS.cond_rot) and ( 460 | not FLAGS.drot_only): 461 | label = self.identity[self.label[index].astype( 462 | np.int32) - 1].squeeze() 463 | else: 464 | label = self.label[index] 465 | 466 | if FLAGS.datasource == 'default': 467 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size) 468 | elif FLAGS.datasource == 'random': 469 | im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size) 470 | 471 | return im_corrupt, im, label 472 | 473 | 474 | class Imagenet(Dataset): 475 | def __init__(self, train=True, augment=False): 476 | 477 | if train: 478 | for i in range(1, 11): 479 | f = pickle.load( 480 | open( 481 | osp.join( 482 | FLAGS.imagenet_path, 483 | 'train_data_batch_{}'.format(i)), 484 | 'rb')) 485 | if i == 1: 486 | labels = f['labels'] 487 | data = f['data'] 488 | else: 489 | labels.extend(f['labels']) 490 | data = np.vstack((data, f['data'])) 491 | else: 492 | f = pickle.load( 493 | open( 494 | osp.join( 495 | FLAGS.imagenet_path, 496 | 'val_data'), 497 | 'rb')) 498 | labels = f['labels'] 499 | data = f['data'] 500 | 501 | self.labels = labels 502 | self.data = data 503 | self.one_hot_map = np.eye(1000) 504 | 505 | def __len__(self): 506 | return self.data.shape[0] 507 | 508 | def __getitem__(self, index): 509 | if not FLAGS.single: 510 | im, label = self.data[index], self.labels[index] 511 | else: 512 | im, label = self.data[0], self.labels[0] 513 | 514 | label -= 1 515 | 516 | im = im.reshape((3, 32, 32)) / 255 517 | im = im.transpose((1, 2, 0)) 518 | image_size = 32 519 | label = self.one_hot_map[label] 520 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) 521 | np.random.seed((index + int(time.time() * 1e7)) % 2**32) 522 | 523 | if FLAGS.datasource == 'default': 524 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3) 525 | elif FLAGS.datasource == 'random': 526 | im_corrupt = np.random.uniform( 527 | 0.0, 1.0, (image_size, image_size, 3)) 528 | 529 | return im_corrupt, im, label 530 | 531 | 532 | class Textures(Dataset): 533 | def __init__(self, train=True, augment=False): 534 | self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images") 535 | 536 | def __len__(self): 537 | return 2 * len(self.dataset) 538 | 539 | def __getitem__(self, index): 540 | idx = index % (len(self.dataset)) 541 | im, label = self.dataset[idx] 542 | 543 | im = np.array(im)[:32, :32] / 255 544 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape) 545 | 546 | return im, im, label 547 | -------------------------------------------------------------------------------- /EBMs/ebm_combine.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import math 3 | from tqdm import tqdm 4 | from hmc import hmc 5 | from tensorflow.python.platform import flags 6 | from torch.utils.data import DataLoader, Dataset 7 | from models import DspritesNet 8 | from utils import optimistic_restore, ReplayBuffer 9 | import os.path as osp 10 | import numpy as np 11 | from rl_algs.logger import TensorBoardOutputFormat 12 | from scipy.misc import imsave 13 | import os 14 | from custom_adam import AdamOptimizer 15 | 16 | flags.DEFINE_integer('batch_size', 256, 'Size of inputs') 17 | flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things') 18 | flags.DEFINE_string('logdir', 'cachedir', 'directory for logging') 19 | flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored') 20 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.') 21 | flags.DEFINE_float('step_lr', 500, 'size of gradient descent size') 22 | flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters') 23 | flags.DEFINE_bool('cclass', True, 'not cclass') 24 | flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons') 25 | flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') 26 | flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') 27 | flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') 28 | flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results') 29 | flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') 30 | flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.') 31 | flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape') 32 | flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape') 33 | 34 | # Conditions on which models to use 35 | flags.DEFINE_bool('cond_pos', True, 'whether to condition on position') 36 | flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation') 37 | flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape') 38 | flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale') 39 | 40 | flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments') 41 | flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments') 42 | flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments') 43 | flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments') 44 | flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume') 45 | flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume') 46 | flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume') 47 | flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume') 48 | flags.DEFINE_integer('break_steps', 300, 'steps to break') 49 | 50 | # Whether to train for gentest 51 | flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions') 52 | 53 | FLAGS = flags.FLAGS 54 | 55 | class DSpritesGen(Dataset): 56 | def __init__(self, data, latents, frac=0.0): 57 | 58 | l = latents 59 | 60 | if FLAGS.joint_shape: 61 | mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) 62 | elif FLAGS.joint_rot: 63 | mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5) 64 | else: 65 | mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1) 66 | 67 | mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5) 68 | 69 | data_pos = data[mask_pos] 70 | l_pos = l[mask_pos] 71 | 72 | data_size = data[mask_size] 73 | l_size = l[mask_size] 74 | 75 | n = data_pos.shape[0] // data_size.shape[0] 76 | 77 | data_pos = np.tile(data_pos, (n, 1, 1)) 78 | l_pos = np.tile(l_pos, (n, 1)) 79 | 80 | self.data = np.concatenate((data_pos, data_size), axis=0) 81 | self.label = np.concatenate((l_pos, l_size), axis=0) 82 | 83 | mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39)) 84 | data_add = data[mask_neg] 85 | l_add = l[mask_neg] 86 | 87 | perm_idx = np.random.permutation(data_add.shape[0]) 88 | select_idx = perm_idx[:int(frac*perm_idx.shape[0])] 89 | data_add = data_add[select_idx] 90 | l_add = l_add[select_idx] 91 | 92 | self.data = np.concatenate((self.data, data_add), axis=0) 93 | self.label = np.concatenate((self.label, l_add), axis=0) 94 | 95 | self.identity = np.eye(3) 96 | 97 | def __len__(self): 98 | return self.data.shape[0] 99 | 100 | def __getitem__(self, index): 101 | im = self.data[index] 102 | im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64) 103 | 104 | if FLAGS.joint_shape: 105 | label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1] 106 | elif FLAGS.joint_rot: 107 | label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])]) 108 | else: 109 | label_size = self.label[index, 2:3] 110 | 111 | label_pos = self.label[index, 4:] 112 | 113 | return (im_corrupt, im, label_size, label_pos) 114 | 115 | 116 | def labeldiscover(sess, kvs, data, latents, save_exp_dir): 117 | LABEL_SIZE = kvs['LABEL_SIZE'] 118 | model_size = kvs['model_size'] 119 | weight_size = kvs['weight_size'] 120 | x_mod = kvs['X_NOISE'] 121 | 122 | label_output = LABEL_SIZE 123 | for i in range(FLAGS.num_steps): 124 | label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03) 125 | e_noise = model_size.forward(x_mod, weight_size, label=label_output) 126 | label_grad = tf.gradients(e_noise, [label_output])[0] 127 | # label_grad = tf.Print(label_grad, [label_grad]) 128 | label_output = label_output - 1.0 * label_grad 129 | label_output = tf.clip_by_value(label_output, 0.5, 1.0) 130 | 131 | diffs = [] 132 | for i in range(30): 133 | s = i*FLAGS.batch_size 134 | d = (i+1)*FLAGS.batch_size 135 | data_i = data[s:d] 136 | latent_i = latents[s:d] 137 | latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1)) 138 | 139 | feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init} 140 | size_pred = sess.run([label_output], feed_dict)[0] 141 | size_gt = latent_i[:, 2:3] 142 | 143 | diffs.append(np.abs(size_pred - size_gt).mean()) 144 | 145 | print(np.array(diffs).mean()) 146 | 147 | 148 | def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0): 149 | # tf.reset_default_graph() 150 | 151 | if FLAGS.joint_shape: 152 | model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5) 153 | LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32) 154 | else: 155 | model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3) 156 | LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32) 157 | 158 | weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac)) 159 | 160 | X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32) 161 | X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) 162 | 163 | X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL) 164 | loss_sq = tf.reduce_mean(tf.square(X_out - X_label)) 165 | 166 | optimizer = AdamOptimizer(1e-3) 167 | gvs = optimizer.compute_gradients(loss_sq) 168 | gvs = [(k, v) for (k, v) in gvs if k is not None] 169 | train_op = optimizer.apply_gradients(gvs) 170 | 171 | dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) 172 | 173 | datafull = data 174 | 175 | itr = 0 176 | saver = tf.train.Saver() 177 | 178 | vs = optimizer.variables() 179 | sess.run(tf.global_variables_initializer()) 180 | 181 | if FLAGS.train: 182 | for _ in range(5): 183 | for data_corrupt, data, label_size, label_pos in tqdm(dataloader): 184 | 185 | data_corrupt = data_corrupt.numpy() 186 | label_size, label_pos = label_size.numpy(), label_pos.numpy() 187 | 188 | data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters) 189 | label_comb = np.concatenate([label_size, label_pos], axis=1) 190 | 191 | feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb} 192 | 193 | output = [loss_sq, train_op] 194 | 195 | loss, _ = sess.run(output, feed_dict=feed_dict) 196 | 197 | itr += 1 198 | 199 | saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline')) 200 | 201 | saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline')) 202 | 203 | l = latents 204 | 205 | if FLAGS.joint_shape: 206 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) 207 | else: 208 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) 209 | 210 | data_gen = datafull[mask_gen] 211 | latents_gen = latents[mask_gen] 212 | losses = [] 213 | 214 | for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)): 215 | data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters) 216 | 217 | if FLAGS.joint_shape: 218 | latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1] 219 | latent_pos = latent[:, 4:6] 220 | latent = np.concatenate([latent_size, latent_pos], axis=1) 221 | feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat} 222 | else: 223 | feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat} 224 | loss = sess.run([loss_sq], feed_dict=feed_dict)[0] 225 | # print(loss) 226 | losses.append(loss) 227 | 228 | print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac)) 229 | 230 | 231 | data_try = data_gen[:10] 232 | data_init = np.random.randn(10, 2*FLAGS.num_filters) 233 | 234 | if FLAGS.joint_shape: 235 | latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1] 236 | latent_pos = latents_gen[:10, 4:] 237 | else: 238 | latent_scale = latents_gen[:10, 2:3] 239 | latent_pos = latents_gen[:10, 4:] 240 | 241 | latent_tot = np.concatenate([latent_scale, latent_pos], axis=1) 242 | 243 | feed_dict = {X_feed: data_init, LABEL: latent_tot} 244 | x_output = sess.run([X_out], feed_dict=feed_dict)[0] 245 | x_output = np.clip(x_output, 0, 1) 246 | 247 | im_name = "size_scale_combine_genbaseline.png" 248 | 249 | x_output_wrap = np.ones((10, 66, 66)) 250 | data_try_wrap = np.ones((10, 66, 66)) 251 | 252 | x_output_wrap[:, 1:-1, 1:-1] = x_output 253 | data_try_wrap[:, 1:-1, 1:-1] = data_try 254 | 255 | im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) 256 | impath = osp.join(save_exp_dir, im_name) 257 | imsave(impath, im_output) 258 | print("Successfully saved images at {}".format(impath)) 259 | 260 | return np.mean(losses) 261 | 262 | 263 | def gentest(sess, kvs, data, latents, save_exp_dir): 264 | X_NOISE = kvs['X_NOISE'] 265 | LABEL_SIZE = kvs['LABEL_SIZE'] 266 | LABEL_SHAPE = kvs['LABEL_SHAPE'] 267 | LABEL_POS = kvs['LABEL_POS'] 268 | LABEL_ROT = kvs['LABEL_ROT'] 269 | model_size = kvs['model_size'] 270 | model_shape = kvs['model_shape'] 271 | model_pos = kvs['model_pos'] 272 | model_rot = kvs['model_rot'] 273 | weight_size = kvs['weight_size'] 274 | weight_shape = kvs['weight_shape'] 275 | weight_pos = kvs['weight_pos'] 276 | weight_rot = kvs['weight_rot'] 277 | X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) 278 | 279 | datafull = data 280 | # Test combination of generalization where we use slices of both training 281 | x_final = X_NOISE 282 | x_mod_size = X_NOISE 283 | x_mod_pos = X_NOISE 284 | 285 | for i in range(FLAGS.num_steps): 286 | 287 | # use cond_pos 288 | 289 | energies = [] 290 | x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) 291 | e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS) 292 | 293 | # energies.append(e_noise) 294 | x_grad = tf.gradients(e_noise, [x_final])[0] 295 | x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005) 296 | x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad 297 | x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) 298 | 299 | if FLAGS.joint_shape: 300 | # use cond_shape 301 | e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE) 302 | elif FLAGS.joint_rot: 303 | e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT) 304 | else: 305 | # use cond_size 306 | e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE) 307 | 308 | # energies.append(e_noise) 309 | # energy_stack = tf.concat(energies, axis=1) 310 | # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1) 311 | # energy_stack = tf.reduce_sum(energy_stack, axis=1) 312 | 313 | x_grad = tf.gradients(e_noise, [x_mod_pos])[0] 314 | x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad 315 | x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1) 316 | 317 | # for x_mod_size 318 | # use cond_size 319 | # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE) 320 | # x_grad = tf.gradients(e_noise, [x_mod_size])[0] 321 | # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) 322 | # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad 323 | # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) 324 | 325 | # # use cond_pos 326 | # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS) 327 | # x_grad = tf.gradients(e_noise, [x_mod_size])[0] 328 | # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005) 329 | # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad) 330 | # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1) 331 | 332 | x_mod = x_mod_pos 333 | x_final = x_mod 334 | 335 | 336 | if FLAGS.joint_shape: 337 | loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \ 338 | model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) 339 | 340 | energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \ 341 | model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) 342 | 343 | energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \ 344 | model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) 345 | elif FLAGS.joint_rot: 346 | loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \ 347 | model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) 348 | 349 | energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \ 350 | model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) 351 | 352 | energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \ 353 | model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) 354 | else: 355 | loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \ 356 | model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True) 357 | 358 | energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \ 359 | model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS) 360 | 361 | energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \ 362 | model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS) 363 | 364 | energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg)) 365 | coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced)) 366 | norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4 367 | neg_loss = coeff * (-1*energy_neg) / norm_constant 368 | 369 | loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg) 370 | loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg))) 371 | 372 | optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999) 373 | gvs = optimizer.compute_gradients(loss_total) 374 | gvs = [(k, v) for (k, v) in gvs if k is not None] 375 | train_op = optimizer.apply_gradients(gvs) 376 | 377 | vs = optimizer.variables() 378 | sess.run(tf.variables_initializer(vs)) 379 | 380 | dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True) 381 | 382 | x_off = tf.reduce_mean(tf.square(x_mod - X)) 383 | 384 | itr = 0 385 | saver = tf.train.Saver() 386 | x_mod = None 387 | 388 | 389 | if FLAGS.train: 390 | replay_buffer = ReplayBuffer(10000) 391 | for _ in range(1): 392 | 393 | 394 | for data_corrupt, data, label_size, label_pos in tqdm(dataloader): 395 | data_corrupt = data_corrupt.numpy()[:, :, :] 396 | data = data.numpy()[:, :, :] 397 | 398 | if x_mod is not None: 399 | replay_buffer.add(x_mod) 400 | replay_batch = replay_buffer.sample(FLAGS.batch_size) 401 | replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95) 402 | data_corrupt[replay_mask] = replay_batch[replay_mask] 403 | 404 | if FLAGS.joint_shape: 405 | feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos} 406 | elif FLAGS.joint_rot: 407 | feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos} 408 | else: 409 | feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos} 410 | 411 | _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict) 412 | itr += 1 413 | 414 | if itr % 10 == 0: 415 | print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr)) 416 | 417 | if itr == FLAGS.break_steps: 418 | break 419 | 420 | 421 | saver.save(sess, osp.join(save_exp_dir, 'model_gentest')) 422 | 423 | saver.restore(sess, osp.join(save_exp_dir, 'model_gentest')) 424 | 425 | l = latents 426 | 427 | if FLAGS.joint_shape: 428 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5) 429 | elif FLAGS.joint_rot: 430 | mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5) 431 | else: 432 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31)))) 433 | 434 | data_gen = datafull[mask_gen] 435 | latents_gen = latents[mask_gen] 436 | 437 | losses = [] 438 | 439 | for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)): 440 | x = 0.5 + np.random.randn(*dat.shape) 441 | 442 | if FLAGS.joint_shape: 443 | feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} 444 | elif FLAGS.joint_rot: 445 | feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} 446 | else: 447 | feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat} 448 | 449 | for i in range(2): 450 | x = sess.run([x_final], feed_dict=feed_dict)[0] 451 | feed_dict[X_NOISE] = x 452 | 453 | loss = sess.run([x_off], feed_dict=feed_dict)[0] 454 | losses.append(loss) 455 | 456 | print("Mean MSE loss of {} ".format(np.mean(losses))) 457 | 458 | data_try = data_gen[:10] 459 | data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) 460 | latent_scale = latents_gen[:10, 2:3] 461 | latent_pos = latents_gen[:10, 4:] 462 | 463 | if FLAGS.joint_shape: 464 | feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos} 465 | elif FLAGS.joint_rot: 466 | feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init} 467 | else: 468 | feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos} 469 | 470 | x_output = sess.run([x_final], feed_dict=feed_dict)[0] 471 | 472 | if FLAGS.joint_shape: 473 | im_name = "size_shape_combine_gentest.png" 474 | else: 475 | im_name = "size_scale_combine_gentest.png" 476 | 477 | x_output_wrap = np.ones((10, 66, 66)) 478 | data_try_wrap = np.ones((10, 66, 66)) 479 | 480 | x_output_wrap[:, 1:-1, 1:-1] = x_output 481 | data_try_wrap[:, 1:-1, 1:-1] = data_try 482 | 483 | im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2) 484 | impath = osp.join(save_exp_dir, im_name) 485 | imsave(impath, im_output) 486 | print("Successfully saved images at {}".format(impath)) 487 | 488 | 489 | 490 | def conceptcombine(sess, kvs, data, latents, save_exp_dir): 491 | X_NOISE = kvs['X_NOISE'] 492 | LABEL_SIZE = kvs['LABEL_SIZE'] 493 | LABEL_SHAPE = kvs['LABEL_SHAPE'] 494 | LABEL_POS = kvs['LABEL_POS'] 495 | LABEL_ROT = kvs['LABEL_ROT'] 496 | model_size = kvs['model_size'] 497 | model_shape = kvs['model_shape'] 498 | model_pos = kvs['model_pos'] 499 | model_rot = kvs['model_rot'] 500 | weight_size = kvs['weight_size'] 501 | weight_shape = kvs['weight_shape'] 502 | weight_pos = kvs['weight_pos'] 503 | weight_rot = kvs['weight_rot'] 504 | 505 | x_mod = X_NOISE 506 | for i in range(FLAGS.num_steps): 507 | 508 | if FLAGS.cond_scale: 509 | e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE) 510 | x_grad = tf.gradients(e_noise, [x_mod])[0] 511 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) 512 | x_mod = x_mod - FLAGS.step_lr * x_grad 513 | x_mod = tf.clip_by_value(x_mod, 0, 1) 514 | 515 | if FLAGS.cond_shape: 516 | e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE) 517 | x_grad = tf.gradients(e_noise, [x_mod])[0] 518 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) 519 | x_mod = x_mod - FLAGS.step_lr * x_grad 520 | x_mod = tf.clip_by_value(x_mod, 0, 1) 521 | 522 | if FLAGS.cond_pos: 523 | e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS) 524 | x_grad = tf.gradients(e_noise, [x_mod])[0] 525 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) 526 | x_mod = x_mod - FLAGS.step_lr * x_grad 527 | x_mod = tf.clip_by_value(x_mod, 0, 1) 528 | 529 | if FLAGS.cond_rot: 530 | e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT) 531 | x_grad = tf.gradients(e_noise, [x_mod])[0] 532 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005) 533 | x_mod = x_mod - FLAGS.step_lr * x_grad 534 | x_mod = tf.clip_by_value(x_mod, 0, 1) 535 | 536 | print("Finished constructing loop {}".format(i)) 537 | 538 | x_final = x_mod 539 | 540 | data_try = data[:10] 541 | data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64) 542 | label_scale = latents[:10, 2:3] 543 | label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)] 544 | label_rot = latents[:10, 3:4] 545 | label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1) 546 | label_pos = latents[:10, 4:] 547 | 548 | feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos, 549 | LABEL_ROT: label_rot} 550 | x_out = sess.run([x_final], feed_dict)[0] 551 | 552 | im_name = "im" 553 | 554 | if FLAGS.cond_scale: 555 | im_name += "_condscale" 556 | 557 | if FLAGS.cond_shape: 558 | im_name += "_condshape" 559 | 560 | if FLAGS.cond_pos: 561 | im_name += "_condpos" 562 | 563 | if FLAGS.cond_rot: 564 | im_name += "_condrot" 565 | 566 | im_name += ".png" 567 | 568 | x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66)) 569 | x_out_pad[:, 1:-1, 1:-1] = x_out 570 | data_try_pad[:, 1:-1, 1:-1] = data_try 571 | 572 | im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2) 573 | impath = osp.join(save_exp_dir, im_name) 574 | imsave(impath, im_output) 575 | print("Successfully saved images at {}".format(impath)) 576 | 577 | def main(): 578 | data = np.load(FLAGS.dsprites_path)['imgs'] 579 | l = latents = np.load(FLAGS.dsprites_path)['latents_values'] 580 | 581 | np.random.seed(1) 582 | idx = np.random.permutation(data.shape[0]) 583 | 584 | data = data[idx] 585 | latents = latents[idx] 586 | 587 | config = tf.ConfigProto() 588 | sess = tf.Session(config=config) 589 | 590 | # Model 1 will be conditioned on size 591 | model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True) 592 | weight_size = model_size.construct_weights('context_0') 593 | 594 | # Model 2 will be conditioned on shape 595 | model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True) 596 | weight_shape = model_shape.construct_weights('context_1') 597 | 598 | # Model 3 will be conditioned on position 599 | model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True) 600 | weight_pos = model_pos.construct_weights('context_2') 601 | 602 | # Model 4 will be conditioned on rotation 603 | model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True) 604 | weight_rot = model_rot.construct_weights('context_3') 605 | 606 | sess.run(tf.global_variables_initializer()) 607 | save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size)) 608 | 609 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0)) 610 | v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list} 611 | 612 | if FLAGS.cond_scale: 613 | saver = tf.train.Saver(v_map) 614 | saver.restore(sess, save_path_size) 615 | 616 | save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape)) 617 | 618 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1)) 619 | v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list} 620 | 621 | if FLAGS.cond_shape: 622 | saver = tf.train.Saver(v_map) 623 | saver.restore(sess, save_path_shape) 624 | 625 | 626 | save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos)) 627 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2)) 628 | v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list} 629 | saver = tf.train.Saver(v_map) 630 | 631 | if FLAGS.cond_pos: 632 | saver.restore(sess, save_path_pos) 633 | 634 | 635 | save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot)) 636 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3)) 637 | v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list} 638 | saver = tf.train.Saver(v_map) 639 | 640 | if FLAGS.cond_rot: 641 | saver.restore(sess, save_path_rot) 642 | 643 | X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32) 644 | LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32) 645 | LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32) 646 | LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32) 647 | LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32) 648 | 649 | x_mod = X_NOISE 650 | 651 | kvs = {} 652 | kvs['X_NOISE'] = X_NOISE 653 | kvs['LABEL_SIZE'] = LABEL_SIZE 654 | kvs['LABEL_SHAPE'] = LABEL_SHAPE 655 | kvs['LABEL_POS'] = LABEL_POS 656 | kvs['LABEL_ROT'] = LABEL_ROT 657 | kvs['model_size'] = model_size 658 | kvs['model_shape'] = model_shape 659 | kvs['model_pos'] = model_pos 660 | kvs['model_rot'] = model_rot 661 | kvs['weight_size'] = weight_size 662 | kvs['weight_shape'] = weight_shape 663 | kvs['weight_pos'] = weight_pos 664 | kvs['weight_rot'] = weight_rot 665 | 666 | save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape)) 667 | if not osp.exists(save_exp_dir): 668 | os.makedirs(save_exp_dir) 669 | 670 | 671 | if FLAGS.task == 'conceptcombine': 672 | conceptcombine(sess, kvs, data, latents, save_exp_dir) 673 | elif FLAGS.task == 'labeldiscover': 674 | labeldiscover(sess, kvs, data, latents, save_exp_dir) 675 | elif FLAGS.task == 'gentest': 676 | save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos)) 677 | if not osp.exists(save_exp_dir): 678 | os.makedirs(save_exp_dir) 679 | 680 | gentest(sess, kvs, data, latents, save_exp_dir) 681 | elif FLAGS.task == 'genbaseline': 682 | save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos)) 683 | if not osp.exists(save_exp_dir): 684 | os.makedirs(save_exp_dir) 685 | 686 | if FLAGS.plot_curve: 687 | mse_losses = [] 688 | for frac in [i/10 for i in range(11)]: 689 | mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac) 690 | mse_losses.append(mse_loss) 691 | np.save("mse_baseline_comb.npy", mse_losses) 692 | else: 693 | genbaseline(sess, kvs, data, latents, save_exp_dir) 694 | 695 | 696 | 697 | if __name__ == "__main__": 698 | main() 699 | -------------------------------------------------------------------------------- /EBMs/fid.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python3 2 | ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs. 3 | 4 | The FID metric calculates the distance between two distributions of images. 5 | Typically, we have summary statistics (mean & covariance matrix) of one 6 | of these distributions, while the 2nd distribution is given by a GAN. 7 | 8 | When run as a stand-alone program, it compares the distribution of 9 | images that are stored as PNG/JPEG at a specified location with a 10 | distribution given by summary statistics (in pickle format). 11 | 12 | The FID is calculated by assuming that X_1 and X_2 are the activations of 13 | the pool_3 layer of the inception net for generated samples and real world 14 | samples respectivly. 15 | 16 | See --help to see further details. 17 | ''' 18 | 19 | from __future__ import absolute_import, division, print_function 20 | import numpy as np 21 | import os 22 | import gzip, pickle 23 | import tensorflow as tf 24 | from scipy.misc import imread 25 | from scipy import linalg 26 | import pathlib 27 | import urllib 28 | import tarfile 29 | import warnings 30 | 31 | MODEL_DIR = '/tmp/imagenet' 32 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 33 | pool3 = None 34 | 35 | class InvalidFIDException(Exception): 36 | pass 37 | 38 | #------------------------------------------------------------------------------- 39 | def get_fid_score(images, images_gt): 40 | images = np.stack(images, 0) 41 | images_gt = np.stack(images_gt, 0) 42 | 43 | with tf.Session() as sess: 44 | m1, s1 = calculate_activation_statistics(images, sess) 45 | m2, s2 = calculate_activation_statistics(images_gt, sess) 46 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 47 | 48 | print("Obtained fid value of {}".format(fid_value)) 49 | return fid_value 50 | 51 | 52 | def create_inception_graph(pth): 53 | """Creates a graph from saved GraphDef file.""" 54 | # Creates graph from saved graph_def.pb. 55 | with tf.gfile.FastGFile( pth, 'rb') as f: 56 | graph_def = tf.GraphDef() 57 | graph_def.ParseFromString( f.read()) 58 | _ = tf.import_graph_def( graph_def, name='FID_Inception_Net') 59 | #------------------------------------------------------------------------------- 60 | 61 | 62 | # code for handling inception net derived from 63 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py 64 | def _get_inception_layer(sess): 65 | """Prepares inception net for batched usage and returns pool_3 layer. """ 66 | layername = 'FID_Inception_Net/pool_3:0' 67 | pool3 = sess.graph.get_tensor_by_name(layername) 68 | ops = pool3.graph.get_operations() 69 | for op_idx, op in enumerate(ops): 70 | for o in op.outputs: 71 | shape = o.get_shape() 72 | if shape._dims != []: 73 | shape = [s.value for s in shape] 74 | new_shape = [] 75 | for j, s in enumerate(shape): 76 | if s == 1 and j == 0: 77 | new_shape.append(None) 78 | else: 79 | new_shape.append(s) 80 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 81 | return pool3 82 | #------------------------------------------------------------------------------- 83 | 84 | 85 | def get_activations(images, sess, batch_size=50, verbose=False): 86 | """Calculates the activations of the pool_3 layer for all images. 87 | 88 | Params: 89 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 90 | must lie between 0 and 256. 91 | -- sess : current session 92 | -- batch_size : the images numpy array is split into batches with batch size 93 | batch_size. A reasonable batch size depends on the disposable hardware. 94 | -- verbose : If set to True and parameter out_step is given, the number of calculated 95 | batches is reported. 96 | Returns: 97 | -- A numpy array of dimension (num images, 2048) that contains the 98 | activations of the given tensor when feeding inception with the query tensor. 99 | """ 100 | # inception_layer = _get_inception_layer(sess) 101 | d0 = images.shape[0] 102 | if batch_size > d0: 103 | print("warning: batch size is bigger than the data size. setting batch size to data size") 104 | batch_size = d0 105 | n_batches = d0//batch_size 106 | n_used_imgs = n_batches*batch_size 107 | pred_arr = np.empty((n_used_imgs,2048)) 108 | for i in range(n_batches): 109 | if verbose: 110 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True) 111 | start = i*batch_size 112 | end = start + batch_size 113 | batch = images[start:end] 114 | pred = sess.run(pool3, {'ExpandDims:0': batch}) 115 | pred_arr[start:end] = pred.reshape(batch_size,-1) 116 | if verbose: 117 | print(" done") 118 | return pred_arr 119 | #------------------------------------------------------------------------------- 120 | 121 | 122 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): 123 | """Numpy implementation of the Frechet Distance. 124 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) 125 | and X_2 ~ N(mu_2, C_2) is 126 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). 127 | Stable version by Dougal J. Sutherland. 128 | 129 | Params: 130 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the 131 | inception net ( like returned by the function 'get_predictions') 132 | for generated samples. 133 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted 134 | on an representive data set. 135 | -- sigma1: The covariance matrix over activations of the pool_3 layer for 136 | generated samples. 137 | -- sigma2: The covariance matrix over activations of the pool_3 layer, 138 | precalcualted on an representive data set. 139 | 140 | Returns: 141 | -- : The Frechet Distance. 142 | """ 143 | 144 | mu1 = np.atleast_1d(mu1) 145 | mu2 = np.atleast_1d(mu2) 146 | 147 | sigma1 = np.atleast_2d(sigma1) 148 | sigma2 = np.atleast_2d(sigma2) 149 | 150 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths" 151 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions" 152 | 153 | diff = mu1 - mu2 154 | 155 | # product might be almost singular 156 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) 157 | if not np.isfinite(covmean).all(): 158 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps 159 | warnings.warn(msg) 160 | offset = np.eye(sigma1.shape[0]) * eps 161 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) 162 | 163 | # numerical error might give slight imaginary component 164 | if np.iscomplexobj(covmean): 165 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): 166 | m = np.max(np.abs(covmean.imag)) 167 | raise ValueError("Imaginary component {}".format(m)) 168 | covmean = covmean.real 169 | 170 | tr_covmean = np.trace(covmean) 171 | 172 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean 173 | #------------------------------------------------------------------------------- 174 | 175 | 176 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False): 177 | """Calculation of the statistics used by the FID. 178 | Params: 179 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values 180 | must lie between 0 and 255. 181 | -- sess : current session 182 | -- batch_size : the images numpy array is split into batches with batch size 183 | batch_size. A reasonable batch size depends on the available hardware. 184 | -- verbose : If set to True and parameter out_step is given, the number of calculated 185 | batches is reported. 186 | Returns: 187 | -- mu : The mean over samples of the activations of the pool_3 layer of 188 | the incption model. 189 | -- sigma : The covariance matrix of the activations of the pool_3 layer of 190 | the incption model. 191 | """ 192 | act = get_activations(images, sess, batch_size, verbose) 193 | mu = np.mean(act, axis=0) 194 | sigma = np.cov(act, rowvar=False) 195 | return mu, sigma 196 | #------------------------------------------------------------------------------- 197 | 198 | 199 | #------------------------------------------------------------------------------- 200 | # The following functions aren't needed for calculating the FID 201 | # they're just here to make this module work as a stand-alone script 202 | # for calculating FID scores 203 | #------------------------------------------------------------------------------- 204 | def check_or_download_inception(inception_path): 205 | ''' Checks if the path to the inception file is valid, or downloads 206 | the file if it is not present. ''' 207 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 208 | if inception_path is None: 209 | inception_path = '/tmp' 210 | inception_path = pathlib.Path(inception_path) 211 | model_file = inception_path / 'classify_image_graph_def.pb' 212 | if not model_file.exists(): 213 | print("Downloading Inception model") 214 | from urllib import request 215 | import tarfile 216 | fn, _ = request.urlretrieve(INCEPTION_URL) 217 | with tarfile.open(fn, mode='r') as f: 218 | f.extract('classify_image_graph_def.pb', str(model_file.parent)) 219 | return str(model_file) 220 | 221 | 222 | def _handle_path(path, sess): 223 | if path.endswith('.npz'): 224 | f = np.load(path) 225 | m, s = f['mu'][:], f['sigma'][:] 226 | f.close() 227 | else: 228 | path = pathlib.Path(path) 229 | files = list(path.glob('*.jpg')) + list(path.glob('*.png')) 230 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files]) 231 | m, s = calculate_activation_statistics(x, sess) 232 | return m, s 233 | 234 | 235 | def calculate_fid_given_paths(paths, inception_path): 236 | ''' Calculates the FID of two paths. ''' 237 | inception_path = check_or_download_inception(inception_path) 238 | 239 | for p in paths: 240 | if not os.path.exists(p): 241 | raise RuntimeError("Invalid path: %s" % p) 242 | 243 | create_inception_graph(str(inception_path)) 244 | with tf.Session() as sess: 245 | sess.run(tf.global_variables_initializer()) 246 | m1, s1 = _handle_path(paths[0], sess) 247 | m2, s2 = _handle_path(paths[1], sess) 248 | fid_value = calculate_frechet_distance(m1, s1, m2, s2) 249 | return fid_value 250 | 251 | 252 | def _init_inception(): 253 | global pool3 254 | if not os.path.exists(MODEL_DIR): 255 | os.makedirs(MODEL_DIR) 256 | filename = DATA_URL.split('/')[-1] 257 | filepath = os.path.join(MODEL_DIR, filename) 258 | if not os.path.exists(filepath): 259 | def _progress(count, block_size, total_size): 260 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 261 | filename, float(count * block_size) / float(total_size) * 100.0)) 262 | sys.stdout.flush() 263 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 264 | print() 265 | statinfo = os.stat(filepath) 266 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 267 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 268 | with tf.gfile.FastGFile(os.path.join( 269 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 270 | graph_def = tf.GraphDef() 271 | graph_def.ParseFromString(f.read()) 272 | _ = tf.import_graph_def(graph_def, name='') 273 | # Works with an arbitrary minibatch size. 274 | with tf.Session() as sess: 275 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 276 | ops = pool3.graph.get_operations() 277 | for op_idx, op in enumerate(ops): 278 | for o in op.outputs: 279 | shape = o.get_shape() 280 | if shape._dims != []: 281 | shape = [s.value for s in shape] 282 | new_shape = [] 283 | for j, s in enumerate(shape): 284 | if s == 1 and j == 0: 285 | new_shape.append(None) 286 | else: 287 | new_shape.append(s) 288 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape) 289 | 290 | 291 | if pool3 is None: 292 | _init_inception() 293 | -------------------------------------------------------------------------------- /EBMs/hmc.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | 4 | from tensorflow.python.platform import flags 5 | flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes') 6 | 7 | FLAGS = flags.FLAGS 8 | 9 | def kinetic_energy(velocity): 10 | """Kinetic energy of the current velocity (assuming a standard Gaussian) 11 | (x dot x) / 2 12 | 13 | Parameters 14 | ---------- 15 | velocity : tf.Variable 16 | Vector of current velocity 17 | 18 | Returns 19 | ------- 20 | kinetic_energy : float 21 | """ 22 | return 0.5 * tf.square(velocity) 23 | 24 | def hamiltonian(position, velocity, energy_function): 25 | """Computes the Hamiltonian of the current position, velocity pair 26 | 27 | H = U(x) + K(v) 28 | 29 | U is the potential energy and is = -log_posterior(x) 30 | 31 | Parameters 32 | ---------- 33 | position : tf.Variable 34 | Position or state vector x (sample from the target distribution) 35 | velocity : tf.Variable 36 | Auxiliary velocity variable 37 | energy_function 38 | Function from state to position to 'energy' 39 | = -log_posterior 40 | 41 | Returns 42 | ------- 43 | hamitonian : float 44 | """ 45 | batch_size = tf.shape(velocity)[0] 46 | kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1)) 47 | return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1]) 48 | 49 | def leapfrog_step(x0, 50 | v0, 51 | neg_log_posterior, 52 | step_size, 53 | num_steps): 54 | 55 | # Start by updating the velocity a half-step 56 | v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0] 57 | 58 | # Initalize x to be the first step 59 | x = x0 + step_size * v 60 | 61 | for i in range(num_steps): 62 | # Compute gradient of the log-posterior with respect to x 63 | gradient = tf.gradients(neg_log_posterior(x), x)[0] 64 | 65 | # Update velocity 66 | v = v - step_size * gradient 67 | 68 | # x_clip = tf.clip_by_value(x, 0.0, 1.0) 69 | # x = x_clip 70 | # v_mask = 1 - 2 * tf.abs(tf.sign(x - x_clip)) 71 | # v = v * v_mask 72 | 73 | # Update x 74 | x = x + step_size * v 75 | 76 | # x = tf.clip_by_value(x, -0.01, 1.01) 77 | 78 | # x = tf.Print(x, [tf.reduce_min(x), tf.reduce_max(x), tf.reduce_mean(x)]) 79 | 80 | # Do a final update of the velocity for a half step 81 | v = v - 0.5 * step_size * tf.gradients(neg_log_posterior(x), x)[0] 82 | 83 | # return new proposal state 84 | return x, v 85 | 86 | def hmc(initial_x, 87 | step_size, 88 | num_steps, 89 | neg_log_posterior): 90 | """Summary 91 | 92 | Parameters 93 | ---------- 94 | initial_x : tf.Variable 95 | Initial sample x ~ p 96 | step_size : float 97 | Step-size in Hamiltonian simulation 98 | num_steps : int 99 | Number of steps to take in Hamiltonian simulation 100 | neg_log_posterior : str 101 | Negative log posterior (unnormalized) for the target distribution 102 | 103 | Returns 104 | ------- 105 | sample : 106 | Sample ~ target distribution 107 | """ 108 | 109 | v0 = tf.random_normal(tf.shape(initial_x)) 110 | x, v = leapfrog_step(initial_x, 111 | v0, 112 | step_size=step_size, 113 | num_steps=num_steps, 114 | neg_log_posterior=neg_log_posterior) 115 | 116 | orig = hamiltonian(initial_x, v0, neg_log_posterior) 117 | current = hamiltonian(x, v, neg_log_posterior) 118 | 119 | prob_accept = tf.exp(orig - current) 120 | 121 | if FLAGS.proposal_debug: 122 | prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))]) 123 | 124 | uniform = tf.random_uniform(tf.shape(prob_accept)) 125 | keep_mask = (prob_accept > uniform) 126 | # print(keep_mask.get_shape()) 127 | 128 | x_new = tf.where(keep_mask, x, initial_x) 129 | return x_new 130 | -------------------------------------------------------------------------------- /EBMs/imagenet_demo.py: -------------------------------------------------------------------------------- 1 | from models import ResNet128 2 | import numpy as np 3 | import os.path as osp 4 | from tensorflow.python.platform import flags 5 | import tensorflow as tf 6 | import imageio 7 | from utils import optimistic_restore 8 | 9 | 10 | flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') 11 | flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling') 12 | flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics') 13 | flags.DEFINE_integer('batch_size', 16, 'number of steps to run') 14 | flags.DEFINE_string('exp', 'default', 'name of experiments') 15 | flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from') 16 | flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model') 17 | flags.DEFINE_bool('cclass', True, 'conditional models') 18 | flags.DEFINE_bool('use_attention', False, 'using attention') 19 | 20 | FLAGS = flags.FLAGS 21 | 22 | def rescale_im(im): 23 | return np.clip(im * 256, 0, 255).astype(np.uint8) 24 | 25 | 26 | if __name__ == "__main__": 27 | model = ResNet128(num_filters=64) 28 | X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32) 29 | LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32) 30 | 31 | sess = tf.InteractiveSession() 32 | weights = model.construct_weights("context_0") 33 | 34 | x_mod = X_NOISE 35 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), 36 | mean=0.0, 37 | stddev=0.005) 38 | 39 | energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL, 40 | reuse=True, stop_at_grad=False, stop_batch=True) 41 | 42 | x_grad = tf.gradients(energy_noise, [x_mod])[0] 43 | energy_noise_old = energy_noise 44 | 45 | lr = FLAGS.step_lr 46 | 47 | x_last = x_mod - (lr) * x_grad 48 | 49 | x_mod = x_last 50 | x_mod = tf.clip_by_value(x_mod, 0, 1) 51 | x_output = x_mod 52 | 53 | sess.run(tf.global_variables_initializer()) 54 | saver = loader = tf.train.Saver() 55 | 56 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 57 | model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter)) 58 | saver.restore(sess, model_file) 59 | 60 | lx = np.random.permutation(1000)[:16] 61 | ims = [] 62 | 63 | # What to initialize sampling with. 64 | x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3)) 65 | labels = np.eye(1000)[lx] 66 | 67 | for i in range(FLAGS.num_steps): 68 | e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels}) 69 | ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3))) 70 | 71 | imageio.mimwrite('sample.gif', ims) 72 | 73 | 74 | -------------------------------------------------------------------------------- /EBMs/imagenet_preprocessing.py: -------------------------------------------------------------------------------- 1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # ============================================================================== 15 | 16 | """Image pre-processing utilities. 17 | """ 18 | import tensorflow as tf 19 | 20 | 21 | IMAGE_DEPTH = 3 # color images 22 | 23 | import tensorflow as tf 24 | 25 | # _R_MEAN = 123.68 26 | # _G_MEAN = 116.78 27 | # _B_MEAN = 103.94 28 | # _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN] 29 | _CHANNEL_MEANS = [0.0, 0.0, 0.0] 30 | 31 | # The lower bound for the smallest side of the image for aspect-preserving 32 | # resizing. For example, if an image is 500 x 1000, it will be resized to 33 | # _RESIZE_MIN x (_RESIZE_MIN * 2). 34 | _RESIZE_MIN = 128 35 | 36 | 37 | def _decode_crop_and_flip(image_buffer, bbox, num_channels): 38 | """Crops the given image to a random part of the image, and randomly flips. 39 | 40 | We use the fused decode_and_crop op, which performs better than the two ops 41 | used separately in series, but note that this requires that the image be 42 | passed in as an un-decoded string Tensor. 43 | 44 | Args: 45 | image_buffer: scalar string Tensor representing the raw JPEG image buffer. 46 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 47 | where each coordinate is [0, 1) and the coordinates are arranged as 48 | [ymin, xmin, ymax, xmax]. 49 | num_channels: Integer depth of the image buffer for decoding. 50 | 51 | Returns: 52 | 3-D tensor with cropped image. 53 | 54 | """ 55 | # A large fraction of image datasets contain a human-annotated bounding box 56 | # delineating the region of the image containing the object of interest. We 57 | # choose to create a new bounding box for the object which is a randomly 58 | # distorted version of the human-annotated bounding box that obeys an 59 | # allowed range of aspect ratios, sizes and overlap with the human-annotated 60 | # bounding box. If no box is supplied, then we assume the bounding box is 61 | # the entire image. 62 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box( 63 | tf.image.extract_jpeg_shape(image_buffer), 64 | bounding_boxes=bbox, 65 | min_object_covered=0.1, 66 | aspect_ratio_range=[0.75, 1.33], 67 | area_range=[0.05, 1.0], 68 | max_attempts=100, 69 | use_image_if_no_bounding_boxes=True) 70 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box 71 | 72 | # Reassemble the bounding box in the format the crop op requires. 73 | offset_y, offset_x, _ = tf.unstack(bbox_begin) 74 | target_height, target_width, _ = tf.unstack(bbox_size) 75 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) 76 | 77 | # Use the fused decode and crop op here, which is faster than each in series. 78 | cropped = tf.image.decode_and_crop_jpeg( 79 | image_buffer, crop_window, channels=num_channels) 80 | 81 | # Flip to add a little more random distortion in. 82 | cropped = tf.image.random_flip_left_right(cropped) 83 | return cropped 84 | 85 | 86 | def _central_crop(image, crop_height, crop_width): 87 | """Performs central crops of the given image list. 88 | 89 | Args: 90 | image: a 3-D image tensor 91 | crop_height: the height of the image following the crop. 92 | crop_width: the width of the image following the crop. 93 | 94 | Returns: 95 | 3-D tensor with cropped image. 96 | """ 97 | shape = tf.shape(input=image) 98 | height, width = shape[0], shape[1] 99 | 100 | amount_to_be_cropped_h = (height - crop_height) 101 | crop_top = amount_to_be_cropped_h // 2 102 | amount_to_be_cropped_w = (width - crop_width) 103 | crop_left = amount_to_be_cropped_w // 2 104 | return tf.slice( 105 | image, [crop_top, crop_left, 0], [crop_height, crop_width, -1]) 106 | 107 | 108 | def _mean_image_subtraction(image, means, num_channels): 109 | """Subtracts the given means from each image channel. 110 | 111 | For example: 112 | means = [123.68, 116.779, 103.939] 113 | image = _mean_image_subtraction(image, means) 114 | 115 | Note that the rank of `image` must be known. 116 | 117 | Args: 118 | image: a tensor of size [height, width, C]. 119 | means: a C-vector of values to subtract from each channel. 120 | num_channels: number of color channels in the image that will be distorted. 121 | 122 | Returns: 123 | the centered image. 124 | 125 | Raises: 126 | ValueError: If the rank of `image` is unknown, if `image` has a rank other 127 | than three or if the number of channels in `image` doesn't match the 128 | number of values in `means`. 129 | """ 130 | if image.get_shape().ndims != 3: 131 | raise ValueError('Input must be of size [height, width, C>0]') 132 | 133 | if len(means) != num_channels: 134 | raise ValueError('len(means) must match the number of channels') 135 | 136 | # We have a 1-D tensor of means; convert to 3-D. 137 | means = tf.expand_dims(tf.expand_dims(means, 0), 0) 138 | 139 | return image - means 140 | 141 | 142 | def _smallest_size_at_least(height, width, resize_min): 143 | """Computes new shape with the smallest side equal to `smallest_side`. 144 | 145 | Computes new shape with the smallest side equal to `smallest_side` while 146 | preserving the original aspect ratio. 147 | 148 | Args: 149 | height: an int32 scalar tensor indicating the current height. 150 | width: an int32 scalar tensor indicating the current width. 151 | resize_min: A python integer or scalar `Tensor` indicating the size of 152 | the smallest side after resize. 153 | 154 | Returns: 155 | new_height: an int32 scalar tensor indicating the new height. 156 | new_width: an int32 scalar tensor indicating the new width. 157 | """ 158 | resize_min = tf.cast(resize_min, tf.float32) 159 | 160 | # Convert to floats to make subsequent calculations go smoothly. 161 | height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32) 162 | 163 | smaller_dim = tf.minimum(height, width) 164 | scale_ratio = resize_min / smaller_dim 165 | 166 | # Convert back to ints to make heights and widths that TF ops will accept. 167 | new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32) 168 | new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32) 169 | 170 | return new_height, new_width 171 | 172 | 173 | def _aspect_preserving_resize(image, resize_min): 174 | """Resize images preserving the original aspect ratio. 175 | 176 | Args: 177 | image: A 3-D image `Tensor`. 178 | resize_min: A python integer or scalar `Tensor` indicating the size of 179 | the smallest side after resize. 180 | 181 | Returns: 182 | resized_image: A 3-D tensor containing the resized image. 183 | """ 184 | shape = tf.shape(input=image) 185 | height, width = shape[0], shape[1] 186 | 187 | new_height, new_width = _smallest_size_at_least(height, width, resize_min) 188 | 189 | return _resize_image(image, new_height, new_width) 190 | 191 | 192 | def _resize_image(image, height, width): 193 | """Simple wrapper around tf.resize_images. 194 | 195 | This is primarily to make sure we use the same `ResizeMethod` and other 196 | details each time. 197 | 198 | Args: 199 | image: A 3-D image `Tensor`. 200 | height: The target height for the resized image. 201 | width: The target width for the resized image. 202 | 203 | Returns: 204 | resized_image: A 3-D tensor containing the resized image. The first two 205 | dimensions have the shape [height, width]. 206 | """ 207 | return tf.image.resize_images( 208 | image, [height, width], method=tf.image.ResizeMethod.BILINEAR, 209 | align_corners=False) 210 | 211 | 212 | def preprocess_image(image_buffer, bbox, output_height, output_width, 213 | num_channels, is_training=False): 214 | """Preprocesses the given image. 215 | 216 | Preprocessing includes decoding, cropping, and resizing for both training 217 | and eval images. Training preprocessing, however, introduces some random 218 | distortion of the image to improve accuracy. 219 | 220 | Args: 221 | image_buffer: scalar string Tensor representing the raw JPEG image buffer. 222 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 223 | where each coordinate is [0, 1) and the coordinates are arranged as 224 | [ymin, xmin, ymax, xmax]. 225 | output_height: The height of the image after preprocessing. 226 | output_width: The width of the image after preprocessing. 227 | num_channels: Integer depth of the image buffer for decoding. 228 | is_training: `True` if we're preprocessing the image for training and 229 | `False` otherwise. 230 | 231 | Returns: 232 | A preprocessed image. 233 | """ 234 | if is_training: 235 | # For training, we want to randomize some of the distortions. 236 | image = _decode_crop_and_flip(image_buffer, bbox, num_channels) 237 | image = _resize_image(image, output_height, output_width) 238 | else: 239 | # For validation, we want to decode, resize, then just crop the middle. 240 | image = tf.image.decode_jpeg(image_buffer, channels=num_channels) 241 | image = _aspect_preserving_resize(image, _RESIZE_MIN) 242 | print(image) 243 | image = _central_crop(image, output_height, output_width) 244 | 245 | image.set_shape([output_height, output_width, num_channels]) 246 | 247 | return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels) 248 | 249 | 250 | def parse_example_proto(example_serialized): 251 | """Parses an Example proto containing a training example of an image. 252 | 253 | The output of the build_image_data.py image preprocessing script is a dataset 254 | containing serialized Example protocol buffers. Each Example proto contains 255 | the following fields: 256 | 257 | image/height: 462 258 | image/width: 581 259 | image/colorspace: 'RGB' 260 | image/channels: 3 261 | image/class/label: 615 262 | image/class/synset: 'n03623198' 263 | image/class/text: 'knee pad' 264 | image/object/bbox/xmin: 0.1 265 | image/object/bbox/xmax: 0.9 266 | image/object/bbox/ymin: 0.2 267 | image/object/bbox/ymax: 0.6 268 | image/object/bbox/label: 615 269 | image/format: 'JPEG' 270 | image/filename: 'ILSVRC2012_val_00041207.JPEG' 271 | image/encoded: 272 | 273 | Args: 274 | example_serialized: scalar Tensor tf.string containing a serialized 275 | Example protocol buffer. 276 | 277 | Returns: 278 | image_buffer: Tensor tf.string containing the contents of a JPEG file. 279 | label: Tensor tf.int32 containing the label. 280 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords] 281 | where each coordinate is [0, 1) and the coordinates are arranged as 282 | [ymin, xmin, ymax, xmax]. 283 | text: Tensor tf.string containing the human-readable label. 284 | """ 285 | # Dense features in Example proto. 286 | feature_map = { 287 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, 288 | default_value=''), 289 | 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, 290 | default_value=-1), 291 | 'image/class/text': tf.FixedLenFeature([], dtype=tf.string, 292 | default_value=''), 293 | } 294 | sparse_float32 = tf.VarLenFeature(dtype=tf.float32) 295 | # Sparse features in Example proto. 296 | feature_map.update( 297 | {k: sparse_float32 for k in ['image/object/bbox/xmin', 298 | 'image/object/bbox/ymin', 299 | 'image/object/bbox/xmax', 300 | 'image/object/bbox/ymax']}) 301 | 302 | features = tf.parse_single_example(example_serialized, feature_map) 303 | label = tf.cast(features['image/class/label'], dtype=tf.int32) 304 | 305 | xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0) 306 | ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0) 307 | xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0) 308 | ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0) 309 | 310 | # Note that we impose an ordering of (y, x) just to make life difficult. 311 | bbox = tf.concat([ymin, xmin, ymax, xmax], 0) 312 | 313 | # Force the variable number of bounding boxes into the shape 314 | # [1, num_boxes, coords]. 315 | bbox = tf.expand_dims(bbox, 0) 316 | bbox = tf.transpose(bbox, [0, 2, 1]) 317 | 318 | return features['image/encoded'], label, bbox, features['image/class/text'] 319 | 320 | 321 | class ImagenetPreprocessor: 322 | def __init__(self, image_size, dtype, train): 323 | self.image_size = image_size 324 | self.dtype = dtype 325 | self.train = train 326 | 327 | def preprocess(self, image_buffer, bbox): 328 | # pylint: disable=g-import-not-at-top 329 | image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train) 330 | return tf.cast(image, self.dtype) 331 | 332 | def parse_and_preprocess(self, value): 333 | image_buffer, label_index, bbox, _ = parse_example_proto(value) 334 | image = self.preprocess(image_buffer, bbox) 335 | image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH]) 336 | return label_index, image 337 | 338 | -------------------------------------------------------------------------------- /EBMs/inception.py: -------------------------------------------------------------------------------- 1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py 2 | from __future__ import absolute_import 3 | from __future__ import division 4 | from __future__ import print_function 5 | 6 | import os.path 7 | import sys 8 | import tarfile 9 | 10 | import numpy as np 11 | from six.moves import urllib 12 | import tensorflow as tf 13 | import glob 14 | import scipy.misc 15 | import math 16 | import sys 17 | 18 | import horovod.tensorflow as hvd 19 | 20 | MODEL_DIR = '/tmp/imagenet' 21 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz' 22 | softmax = None 23 | 24 | config = tf.ConfigProto() 25 | config.gpu_options.visible_device_list = str(hvd.local_rank()) 26 | sess = tf.Session(config=config) 27 | 28 | # Call this function with list of images. Each of elements should be a 29 | # numpy array with values ranging from 0 to 255. 30 | def get_inception_score(images, splits=10): 31 | # For convenience 32 | if len(images[0].shape) != 3: 33 | return 0, 0 34 | 35 | # Bypassing all the assertions so that we don't end prematuraly' 36 | # assert(type(images) == list) 37 | # assert(type(images[0]) == np.ndarray) 38 | # assert(len(images[0].shape) == 3) 39 | # assert(np.max(images[0]) > 10) 40 | # assert(np.min(images[0]) >= 0.0) 41 | inps = [] 42 | for img in images: 43 | img = img.astype(np.float32) 44 | inps.append(np.expand_dims(img, 0)) 45 | bs = 1 46 | preds = [] 47 | n_batches = int(math.ceil(float(len(inps)) / float(bs))) 48 | for i in range(n_batches): 49 | sys.stdout.write(".") 50 | sys.stdout.flush() 51 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))] 52 | inp = np.concatenate(inp, 0) 53 | pred = sess.run(softmax, {'ExpandDims:0': inp}) 54 | preds.append(pred) 55 | preds = np.concatenate(preds, 0) 56 | scores = [] 57 | for i in range(splits): 58 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :] 59 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0))) 60 | kl = np.mean(np.sum(kl, 1)) 61 | scores.append(np.exp(kl)) 62 | return np.mean(scores), np.std(scores) 63 | 64 | # This function is called automatically. 65 | def _init_inception(): 66 | global softmax 67 | if not os.path.exists(MODEL_DIR): 68 | os.makedirs(MODEL_DIR) 69 | filename = DATA_URL.split('/')[-1] 70 | filepath = os.path.join(MODEL_DIR, filename) 71 | if not os.path.exists(filepath): 72 | def _progress(count, block_size, total_size): 73 | sys.stdout.write('\r>> Downloading %s %.1f%%' % ( 74 | filename, float(count * block_size) / float(total_size) * 100.0)) 75 | sys.stdout.flush() 76 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress) 77 | print() 78 | statinfo = os.stat(filepath) 79 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.') 80 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR) 81 | with tf.gfile.FastGFile(os.path.join( 82 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f: 83 | graph_def = tf.GraphDef() 84 | graph_def.ParseFromString(f.read()) 85 | _ = tf.import_graph_def(graph_def, name='') 86 | # Works with an arbitrary minibatch size. 87 | pool3 = sess.graph.get_tensor_by_name('pool_3:0') 88 | ops = pool3.graph.get_operations() 89 | for op_idx, op in enumerate(ops): 90 | for o in op.outputs: 91 | shape = o.get_shape() 92 | shape = [s.value for s in shape] 93 | new_shape = [] 94 | for j, s in enumerate(shape): 95 | if s == 1 and j == 0: 96 | new_shape.append(None) 97 | else: 98 | new_shape.append(s) 99 | o.set_shape(tf.TensorShape(new_shape)) 100 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1] 101 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w) 102 | softmax = tf.nn.softmax(logits) 103 | 104 | if softmax is None: 105 | _init_inception() 106 | -------------------------------------------------------------------------------- /EBMs/models.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | from tensorflow.python.platform import flags 3 | import numpy as np 4 | from utils import conv_block, get_weight, attention, conv_cond_concat, init_conv_weight, init_attention_weight, init_res_weight, smart_res_block, smart_res_block_optim, init_convt_weight 5 | from utils import init_fc_weight, smart_conv_block, smart_fc_block, smart_atten_block, groupsort, smart_convt_block, swish 6 | 7 | flags.DEFINE_bool('swish_act', False, 'use the swish activation for dsprites') 8 | 9 | FLAGS = flags.FLAGS 10 | 11 | 12 | class MnistNet(object): 13 | def __init__(self, num_channels=1, num_filters=64): 14 | 15 | self.channels = num_channels 16 | self.dim_hidden = num_filters 17 | self.datasource = FLAGS.datasource 18 | 19 | if FLAGS.cclass: 20 | self.label_size = 10 21 | else: 22 | self.label_size = 0 23 | 24 | def construct_weights(self, scope=''): 25 | weights = {} 26 | 27 | dtype = tf.float32 28 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 29 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 30 | 31 | classes = 1 32 | 33 | with tf.variable_scope(scope): 34 | init_conv_weight(weights, 'c1_pre', 3, 1, 64) 35 | init_conv_weight(weights, 'c1', 4, 64, self.dim_hidden, classes=classes) 36 | init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes) 37 | init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) 38 | init_fc_weight(weights, 'fc_dense', 4*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True) 39 | init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False) 40 | 41 | if FLAGS.cclass: 42 | self.label_size = 10 43 | else: 44 | self.label_size = 0 45 | return weights 46 | 47 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, **kwargs): 48 | channels = self.channels 49 | weights = weights.copy() 50 | inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1)) 51 | 52 | if FLAGS.swish_act: 53 | act = swish 54 | else: 55 | act = tf.nn.leaky_relu 56 | 57 | if stop_grad: 58 | for k, v in weights.items(): 59 | if type(v) == dict: 60 | v = v.copy() 61 | weights[k] = v 62 | for k_sub, v_sub in v.items(): 63 | v[k_sub] = tf.stop_gradient(v_sub) 64 | else: 65 | weights[k] = tf.stop_gradient(v) 66 | 67 | if FLAGS.cclass: 68 | label_d = tf.reshape(label, shape=(tf.shape(label)[0], 1, 1, self.label_size)) 69 | inp = conv_cond_concat(inp, label_d) 70 | 71 | h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) 72 | h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act) 73 | h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act) 74 | h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=False, extra_bias=False, activation=act) 75 | 76 | h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])]) 77 | h6 = act(smart_fc_block(h5, weights, reuse, 'fc_dense')) 78 | hidden6 = smart_fc_block(h6, weights, reuse, 'fc5') 79 | 80 | return hidden6 81 | 82 | 83 | class DspritesNet(object): 84 | def __init__(self, num_channels=1, num_filters=64, cond_size=False, cond_shape=False, cond_pos=False, 85 | cond_rot=False, label_size=1): 86 | 87 | self.channels = num_channels 88 | self.dim_hidden = num_filters 89 | self.img_size = 64 90 | self.label_size = label_size 91 | 92 | if FLAGS.cclass: 93 | self.label_size = 3 94 | 95 | try: 96 | if FLAGS.dshape_only: 97 | self.label_size = 3 98 | 99 | if FLAGS.dpos_only: 100 | self.label_size = 2 101 | 102 | if FLAGS.dsize_only: 103 | self.label_size = 1 104 | 105 | if FLAGS.drot_only: 106 | self.label_size = 2 107 | except: 108 | pass 109 | 110 | if cond_size: 111 | self.label_size = 1 112 | 113 | if cond_shape: 114 | self.label_size = 3 115 | 116 | if cond_pos: 117 | self.label_size = 2 118 | 119 | if cond_rot: 120 | self.label_size = 2 121 | 122 | self.cond_size = cond_size 123 | self.cond_shape = cond_shape 124 | self.cond_pos = cond_pos 125 | 126 | def construct_weights(self, scope=''): 127 | weights = {} 128 | 129 | dtype = tf.float32 130 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype) 131 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype) 132 | k = 5 133 | classes = self.label_size 134 | 135 | with tf.variable_scope(scope): 136 | init_conv_weight(weights, 'c1_pre', 3, 1, 32) 137 | init_conv_weight(weights, 'c1', 4, 32, self.dim_hidden, classes=classes) 138 | init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes) 139 | init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 140 | init_conv_weight(weights, 'c4', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 141 | init_fc_weight(weights, 'fc_dense', 2*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True) 142 | init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False) 143 | 144 | return weights 145 | 146 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, return_logit=False): 147 | channels = self.channels 148 | batch_size = tf.shape(inp)[0] 149 | 150 | inp = tf.reshape(inp, (batch_size, 64, 64, 1)) 151 | 152 | if FLAGS.swish_act: 153 | act = swish 154 | else: 155 | act = tf.nn.leaky_relu 156 | 157 | if not FLAGS.cclass: 158 | label = None 159 | 160 | weights = weights.copy() 161 | 162 | if stop_grad: 163 | for k, v in weights.items(): 164 | if type(v) == dict: 165 | v = v.copy() 166 | weights[k] = v 167 | for k_sub, v_sub in v.items(): 168 | v[k_sub] = tf.stop_gradient(v_sub) 169 | else: 170 | weights[k] = tf.stop_gradient(v) 171 | 172 | h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) 173 | h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act) 174 | h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act) 175 | h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=True, extra_bias=True, activation=act) 176 | h5 = smart_conv_block(h4, weights, reuse, 'c4', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act) 177 | 178 | hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1)) 179 | hidden7 = act(smart_fc_block(hidden6, weights, reuse, 'fc_dense')) 180 | energy = smart_fc_block(hidden7, weights, reuse, 'fc5') 181 | 182 | if return_logit: 183 | return hidden7 184 | else: 185 | return energy 186 | 187 | 188 | 189 | class ResNet32(object): 190 | def __init__(self, num_channels=3, num_filters=128): 191 | 192 | self.channels = num_channels 193 | self.dim_hidden = num_filters 194 | self.groupsort = groupsort() 195 | 196 | def construct_weights(self, scope=''): 197 | weights = {} 198 | dtype = tf.float32 199 | 200 | if FLAGS.cclass: 201 | classes = 10 202 | else: 203 | classes = 1 204 | 205 | with tf.variable_scope(scope): 206 | # First block 207 | init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden) 208 | init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes) 209 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) 210 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) 211 | init_res_weight(weights, 'res_3', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 212 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 213 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 214 | init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden) 215 | init_fc_weight(weights, 'fc5', 2*self.dim_hidden , 1, spec_norm=False) 216 | 217 | init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True) 218 | 219 | return weights 220 | 221 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): 222 | weights = weights.copy() 223 | batch = tf.shape(inp)[0] 224 | 225 | act = tf.nn.leaky_relu 226 | 227 | if not FLAGS.cclass: 228 | label = None 229 | 230 | if stop_grad: 231 | for k, v in weights.items(): 232 | if type(v) == dict: 233 | v = v.copy() 234 | weights[k] = v 235 | for k_sub, v_sub in v.items(): 236 | v[k_sub] = tf.stop_gradient(v_sub) 237 | else: 238 | weights[k] = tf.stop_gradient(v) 239 | 240 | # Make sure gradients are modified a bit 241 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False) 242 | 243 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act) 244 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, act=act) 245 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, label=label, act=act) 246 | 247 | if FLAGS.use_attention: 248 | hidden4 = smart_atten_block(hidden3, weights, reuse, 'atten', stop_at_grad=stop_at_grad, label=label) 249 | else: 250 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act) 251 | 252 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=False, label=label, act=act) 253 | compact = hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 254 | hidden6 = tf.nn.relu(hidden6) 255 | hidden5 = tf.reduce_sum(hidden6, [1, 2]) 256 | 257 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') 258 | 259 | energy = hidden6 260 | 261 | return energy 262 | 263 | 264 | class ResNet32Large(object): 265 | def __init__(self, num_channels=3, num_filters=128, train=False): 266 | 267 | self.channels = num_channels 268 | self.dim_hidden = num_filters 269 | self.dropout = train 270 | self.train = train 271 | 272 | def construct_weights(self, scope=''): 273 | weights = {} 274 | dtype = tf.float32 275 | 276 | if FLAGS.cclass: 277 | classes = 10 278 | else: 279 | classes = 1 280 | 281 | with tf.variable_scope(scope): 282 | # First block 283 | init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden) 284 | init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes) 285 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) 286 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes) 287 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) 288 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 289 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 290 | init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) 291 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 292 | init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 293 | init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False) 294 | 295 | init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden, trainable_gamma=True) 296 | 297 | return weights 298 | 299 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): 300 | weights = weights.copy() 301 | batch = tf.shape(inp)[0] 302 | 303 | if not FLAGS.cclass: 304 | label = None 305 | 306 | if stop_grad: 307 | for k, v in weights.items(): 308 | if type(v) == dict: 309 | v = v.copy() 310 | weights[k] = v 311 | for k_sub, v_sub in v.items(): 312 | v[k_sub] = tf.stop_gradient(v_sub) 313 | else: 314 | weights[k] = tf.stop_gradient(v) 315 | 316 | # Make sure gradients are modified a bit 317 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False) 318 | 319 | dropout = self.dropout 320 | train = self.train 321 | 322 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, dropout=dropout, train=train) 323 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train) 324 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train) 325 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train) 326 | 327 | if FLAGS.use_attention: 328 | hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad) 329 | else: 330 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) 331 | 332 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) 333 | 334 | hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train) 335 | hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) 336 | 337 | compact = hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train) 338 | 339 | if FLAGS.cclass: 340 | hidden6 = tf.nn.leaky_relu(hidden9) 341 | else: 342 | hidden6 = tf.nn.relu(hidden9) 343 | hidden5 = tf.reduce_sum(hidden6, [1, 2]) 344 | 345 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') 346 | 347 | energy = hidden6 348 | 349 | return energy 350 | 351 | 352 | class ResNet32Wider(object): 353 | def __init__(self, num_channels=3, num_filters=128, train=False): 354 | 355 | self.channels = num_channels 356 | self.dim_hidden = num_filters 357 | self.dropout = train 358 | self.train = train 359 | 360 | def construct_weights(self, scope=''): 361 | weights = {} 362 | dtype = tf.float32 363 | 364 | if FLAGS.cclass and FLAGS.dataset == "cifar10": 365 | classes = 10 366 | elif FLAGS.cclass and FLAGS.dataset == "imagenet": 367 | classes = 1000 368 | else: 369 | classes = 1 370 | 371 | with tf.variable_scope(scope): 372 | # First block 373 | init_conv_weight(weights, 'c1_pre', 3, self.channels, 128) 374 | init_res_weight(weights, 'res_optim', 3, 128, self.dim_hidden, classes=classes) 375 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) 376 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes) 377 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) 378 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 379 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 380 | init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) 381 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 382 | init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 383 | init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False) 384 | 385 | init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True) 386 | 387 | return weights 388 | 389 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): 390 | weights = weights.copy() 391 | batch = tf.shape(inp)[0] 392 | 393 | if not FLAGS.cclass: 394 | label = None 395 | 396 | if stop_grad: 397 | for k, v in weights.items(): 398 | if type(v) == dict: 399 | v = v.copy() 400 | weights[k] = v 401 | for k_sub, v_sub in v.items(): 402 | v[k_sub] = tf.stop_gradient(v_sub) 403 | else: 404 | weights[k] = tf.stop_gradient(v) 405 | 406 | if FLAGS.swish_act: 407 | act = swish 408 | else: 409 | act = tf.nn.leaky_relu 410 | 411 | # Make sure gradients are modified a bit 412 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) 413 | dropout = self.dropout 414 | train = self.train 415 | 416 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=True, label=label, dropout=dropout, train=train) 417 | 418 | if FLAGS.use_attention: 419 | hidden2 = smart_atten_block(hidden1, weights, reuse, 'atten', train=train, dropout=dropout, stop_at_grad=stop_at_grad) 420 | else: 421 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act) 422 | 423 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act) 424 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) 425 | 426 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) 427 | 428 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) 429 | 430 | hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) 431 | hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) 432 | 433 | hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act) 434 | 435 | if FLAGS.swish_act: 436 | hidden6 = act(hidden9) 437 | else: 438 | hidden6 = tf.nn.relu(hidden9) 439 | 440 | hidden5 = tf.reduce_sum(hidden6, [1, 2]) 441 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') 442 | energy = hidden6 443 | 444 | return energy 445 | 446 | 447 | class ResNet32Larger(object): 448 | def __init__(self, num_channels=3, num_filters=128): 449 | 450 | self.channels = num_channels 451 | self.dim_hidden = num_filters 452 | 453 | def construct_weights(self, scope=''): 454 | weights = {} 455 | dtype = tf.float32 456 | 457 | if FLAGS.cclass: 458 | classes = 10 459 | else: 460 | classes = 1 461 | 462 | with tf.variable_scope(scope): 463 | # First block 464 | init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden) 465 | init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes) 466 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes) 467 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes) 468 | init_res_weight(weights, 'res_2a', 3, self.dim_hidden, self.dim_hidden, classes=classes) 469 | init_res_weight(weights, 'res_2b', 3, self.dim_hidden, self.dim_hidden, classes=classes) 470 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) 471 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 472 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 473 | init_res_weight(weights, 'res_5a', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 474 | init_res_weight(weights, 'res_5b', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes) 475 | init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) 476 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 477 | init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 478 | init_res_weight(weights, 'res_8a', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 479 | init_res_weight(weights, 'res_8b', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes) 480 | init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden) 481 | init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False) 482 | 483 | init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True) 484 | 485 | return weights 486 | 487 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): 488 | weights = weights.copy() 489 | batch = tf.shape(inp)[0] 490 | 491 | if not FLAGS.cclass: 492 | label = None 493 | 494 | if stop_grad: 495 | for k, v in weights.items(): 496 | if type(v) == dict: 497 | v = v.copy() 498 | weights[k] = v 499 | for k_sub, v_sub in v.items(): 500 | v[k_sub] = tf.stop_gradient(v_sub) 501 | else: 502 | weights[k] = tf.stop_gradient(v) 503 | 504 | # Make sure gradients are modified a bit 505 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False) 506 | 507 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label) 508 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) 509 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) 510 | hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2a', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) 511 | hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2b', stop_batch=stop_batch, downsample=False, adaptive=False, label=label) 512 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label) 513 | 514 | if FLAGS.use_attention: 515 | hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad) 516 | else: 517 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 518 | 519 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 520 | 521 | hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 522 | hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 523 | hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label) 524 | hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 525 | hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 526 | hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 527 | compact = hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label) 528 | 529 | if FLAGS.cclass: 530 | hidden6 = tf.nn.leaky_relu(hidden9) 531 | else: 532 | hidden6 = tf.nn.relu(hidden9) 533 | hidden5 = tf.reduce_sum(hidden6, [1, 2]) 534 | 535 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') 536 | 537 | energy = hidden6 538 | 539 | return energy 540 | 541 | 542 | class ResNet128(object): 543 | """Construct the convolutional network specified in MAML""" 544 | 545 | def __init__(self, num_channels=3, num_filters=64, train=False): 546 | 547 | self.channels = num_channels 548 | self.dim_hidden = num_filters 549 | self.dropout = train 550 | self.train = train 551 | 552 | def construct_weights(self, scope=''): 553 | weights = {} 554 | dtype = tf.float32 555 | 556 | classes = 1000 557 | 558 | with tf.variable_scope(scope): 559 | # First block 560 | init_conv_weight(weights, 'c1_pre', 3, self.channels, 64) 561 | init_res_weight(weights, 'res_optim', 3, 64, self.dim_hidden, classes=classes) 562 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes) 563 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes) 564 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 8*self.dim_hidden, classes=classes) 565 | init_res_weight(weights, 'res_9', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes) 566 | init_res_weight(weights, 'res_10', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes) 567 | init_fc_weight(weights, 'fc5', 8*self.dim_hidden , 1, spec_norm=False) 568 | 569 | 570 | init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2., trainable_gamma=True) 571 | 572 | return weights 573 | 574 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False): 575 | weights = weights.copy() 576 | batch = tf.shape(inp)[0] 577 | 578 | if not FLAGS.cclass: 579 | label = None 580 | 581 | 582 | if stop_grad: 583 | for k, v in weights.items(): 584 | if type(v) == dict: 585 | v = v.copy() 586 | weights[k] = v 587 | for k_sub, v_sub in v.items(): 588 | v[k_sub] = tf.stop_gradient(v_sub) 589 | else: 590 | weights[k] = tf.stop_gradient(v) 591 | 592 | if FLAGS.swish_act: 593 | act = swish 594 | else: 595 | act = tf.nn.leaky_relu 596 | 597 | dropout = self.dropout 598 | train = self.train 599 | 600 | # Make sure gradients are modified a bit 601 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act) 602 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False) 603 | 604 | if FLAGS.use_attention: 605 | hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad) 606 | 607 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act) 608 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act) 609 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True) 610 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False) 611 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False) 612 | 613 | if FLAGS.swish_act: 614 | hidden6 = act(hidden6) 615 | else: 616 | hidden6 = tf.nn.relu(hidden6) 617 | 618 | hidden5 = tf.reduce_sum(hidden6, [1, 2]) 619 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5') 620 | energy = hidden6 621 | 622 | return energy 623 | -------------------------------------------------------------------------------- /EBMs/requirements.txt: -------------------------------------------------------------------------------- 1 | scipy==1.10.0 2 | horovod==0.24.0 3 | torch==1.13.1 4 | torchvision==0.6.0 5 | six==1.11.0 6 | imageio==2.8.0 7 | tqdm==4.46.0 8 | matplotlib==3.2.1 9 | mpi4py==3.0.3 10 | numpy==1.22.0 11 | Pillow==10.0.1 12 | baselines==0.1.5 13 | scikit-image==0.14.2 14 | scikit_learn 15 | tensorflow==2.11.1 16 | cloudpickle==1.3.0 17 | Cython==0.29.17 18 | mujoco-py==1.50.1.68 19 | -------------------------------------------------------------------------------- /EBMs/test_inception.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | import numpy as np 3 | from tensorflow.python.platform import flags 4 | from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128 5 | import os.path as osp 6 | import os 7 | from utils import optimistic_restore, remap_restore, optimistic_remap_restore 8 | from tqdm import tqdm 9 | import random 10 | from scipy.misc import imsave 11 | from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader 12 | from torch.utils.data import DataLoader 13 | from baselines.common.tf_util import initialize 14 | 15 | import horovod.tensorflow as hvd 16 | hvd.init() 17 | 18 | from inception import get_inception_score 19 | from fid import get_fid_score 20 | 21 | flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored') 22 | flags.DEFINE_string('exp', 'default', 'name of experiments') 23 | flags.DEFINE_bool('cclass', False, 'whether to condition on class') 24 | 25 | # Architecture settings 26 | flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not') 27 | flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights') 28 | flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution') 29 | flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network') 30 | flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent') 31 | flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label') 32 | flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images') 33 | flags.DEFINE_integer('batch_size', 512, 'batch size') 34 | flags.DEFINE_integer('resume_iter', -1, 'resume iteration') 35 | flags.DEFINE_integer('ensemble', 10, 'number of ensembles') 36 | flags.DEFINE_integer('im_number', 50000, 'number of ensembles') 37 | flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations') 38 | flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output') 39 | flags.DEFINE_integer('idx', 0, 'save index') 40 | flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing') 41 | flags.DEFINE_bool('scaled', True, 'whether to scale noise added') 42 | flags.DEFINE_bool('large_model', False, 'whether to use a small or large model') 43 | flags.DEFINE_bool('larger_model', False, 'Whether to use a large model') 44 | flags.DEFINE_bool('wider_model', False, 'Whether to use a large model') 45 | flags.DEFINE_bool('single', False, 'single ') 46 | flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single') 47 | flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull') 48 | 49 | FLAGS = flags.FLAGS 50 | 51 | class InceptionReplayBuffer(object): 52 | def __init__(self, size): 53 | """Create Replay buffer. 54 | Parameters 55 | ---------- 56 | size: int 57 | Max number of transitions to store in the buffer. When the buffer 58 | overflows the old memories are dropped. 59 | """ 60 | self._storage = [] 61 | self._label_storage = [] 62 | self._maxsize = size 63 | self._next_idx = 0 64 | 65 | def __len__(self): 66 | return len(self._storage) 67 | 68 | def add(self, ims, labels): 69 | batch_size = ims.shape[0] 70 | if self._next_idx >= len(self._storage): 71 | self._storage.extend(list(ims)) 72 | self._label_storage.extend(list(labels)) 73 | else: 74 | if batch_size + self._next_idx < self._maxsize: 75 | self._storage[self._next_idx:self._next_idx+batch_size] = list(ims) 76 | self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels) 77 | else: 78 | split_idx = self._maxsize - self._next_idx 79 | self._storage[self._next_idx:] = list(ims)[:split_idx] 80 | self._storage[:batch_size-split_idx] = list(ims)[split_idx:] 81 | self._label_storage[self._next_idx:] = list(labels)[:split_idx] 82 | self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:] 83 | 84 | self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize 85 | 86 | def _encode_sample(self, idxes): 87 | ims = [] 88 | labels = [] 89 | for i in idxes: 90 | ims.append(self._storage[i]) 91 | labels.append(self._label_storage[i]) 92 | return np.array(ims), np.array(labels) 93 | 94 | def sample(self, batch_size): 95 | """Sample a batch of experiences. 96 | Parameters 97 | ---------- 98 | batch_size: int 99 | How many transitions to sample. 100 | Returns 101 | ------- 102 | obs_batch: np.array 103 | batch of observations 104 | act_batch: np.array 105 | batch of actions executed given obs_batch 106 | rew_batch: np.array 107 | rewards received as results of executing act_batch 108 | next_obs_batch: np.array 109 | next set of observations seen after executing act_batch 110 | done_mask: np.array 111 | done_mask[i] = 1 if executing act_batch[i] resulted in 112 | the end of an episode and 0 otherwise. 113 | """ 114 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)] 115 | return self._encode_sample(idxes), idxes 116 | 117 | def set_elms(self, idxes, data, labels): 118 | for i, ix in enumerate(idxes): 119 | self._storage[ix] = data[i] 120 | self._label_storage[ix] = labels[i] 121 | 122 | 123 | def rescale_im(im): 124 | return np.clip(im * 256, 0, 255).astype(np.uint8) 125 | 126 | def compute_inception(sess, target_vars): 127 | X_START = target_vars['X_START'] 128 | Y_GT = target_vars['Y_GT'] 129 | X_finals = target_vars['X_finals'] 130 | NOISE_SCALE = target_vars['NOISE_SCALE'] 131 | energy_noise = target_vars['energy_noise'] 132 | 133 | size = FLAGS.im_number 134 | num_steps = size // 1000 135 | 136 | images = [] 137 | test_ims = [] 138 | 139 | 140 | if FLAGS.dataset == "cifar10": 141 | test_dataset = Cifar10(full=True, noise=False) 142 | elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull": 143 | test_dataset = Imagenet(train=False) 144 | 145 | if FLAGS.dataset != "imagenetfull": 146 | test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False) 147 | else: 148 | test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1) 149 | 150 | for data_corrupt, data, label_gt in tqdm(test_dataloader): 151 | data = data.numpy() 152 | test_ims.extend(list(rescale_im(data))) 153 | 154 | if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000: 155 | test_ims = test_ims[:60000] 156 | break 157 | 158 | 159 | # n = min(len(images), len(test_ims)) 160 | print(len(test_ims)) 161 | # fid = get_fid_score(test_ims[:30000], test_ims[-30000:]) 162 | # print("Base FID of score {}".format(fid)) 163 | 164 | if FLAGS.dataset == "cifar10": 165 | classes = 10 166 | else: 167 | classes = 1000 168 | 169 | if FLAGS.dataset == "imagenetfull": 170 | n = 128 171 | else: 172 | n = 32 173 | 174 | for j in range(num_steps): 175 | itr = int(1000 / 500 * FLAGS.repeat_scale) 176 | data_buffer = InceptionReplayBuffer(1000) 177 | curr_index = 0 178 | 179 | identity = np.eye(classes) 180 | 181 | for i in tqdm(range(itr)): 182 | model_index = curr_index % len(X_finals) 183 | x_final = X_finals[model_index] 184 | 185 | noise_scale = [1] 186 | if len(data_buffer) < 1000: 187 | x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) 188 | label = np.random.randint(0, classes, (FLAGS.batch_size)) 189 | label = identity[label] 190 | x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0] 191 | data_buffer.add(x_new, label) 192 | else: 193 | (x_init, label), idx = data_buffer.sample(FLAGS.batch_size) 194 | keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99) 195 | label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9) 196 | label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size)) 197 | label_corrupt = identity[label_corrupt] 198 | x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3)) 199 | 200 | if i < itr - FLAGS.nomix: 201 | x_init[keep_mask] = x_init_corrupt[keep_mask] 202 | label[label_keep_mask] = label_corrupt[label_keep_mask] 203 | # else: 204 | # noise_scale = [0.7] 205 | 206 | x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale}) 207 | data_buffer.set_elms(idx, x_new, label) 208 | 209 | if FLAGS.im_number != 50000: 210 | print(np.mean(e_noise), np.std(e_noise)) 211 | 212 | curr_index += 1 213 | 214 | ims = np.array(data_buffer._storage[:1000]) 215 | ims = rescale_im(ims) 216 | 217 | images.extend(list(ims)) 218 | 219 | saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx)) 220 | 221 | ims = ims[:100] 222 | 223 | if FLAGS.dataset != "imagenetfull": 224 | im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3)) 225 | else: 226 | im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3)) 227 | imsave(saveim, im_panel) 228 | 229 | print("Saved image!!!!") 230 | splits = max(1, len(images) // 5000) 231 | score, std = get_inception_score(images, splits=splits) 232 | print("Inception score of {} with std of {}".format(score, std)) 233 | 234 | # FID score 235 | # n = min(len(images), len(test_ims)) 236 | fid = get_fid_score(images, test_ims) 237 | print("FID of score {}".format(fid)) 238 | 239 | 240 | 241 | 242 | def main(model_list): 243 | 244 | if FLAGS.dataset == "imagenetfull": 245 | model = ResNet128(num_filters=64) 246 | elif FLAGS.large_model: 247 | model = ResNet32Large(num_filters=128) 248 | elif FLAGS.larger_model: 249 | model = ResNet32Larger(num_filters=hidden_dim) 250 | elif FLAGS.wider_model: 251 | model = ResNet32Wider(num_filters=256, train=False) 252 | else: 253 | model = ResNet32(num_filters=128) 254 | 255 | # config = tf.ConfigProto() 256 | sess = tf.InteractiveSession() 257 | 258 | logdir = osp.join(FLAGS.logdir, FLAGS.exp) 259 | weights = [] 260 | 261 | for i, model_num in enumerate(model_list): 262 | weight = model.construct_weights('context_{}'.format(i)) 263 | initialize() 264 | save_file = osp.join(logdir, 'model_{}'.format(model_num)) 265 | 266 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i)) 267 | v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list} 268 | saver = tf.train.Saver(v_map) 269 | try: 270 | saver.restore(sess, save_file) 271 | except: 272 | optimistic_remap_restore(sess, save_file, i) 273 | weights.append(weight) 274 | 275 | 276 | if FLAGS.dataset == "imagenetfull": 277 | X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32) 278 | else: 279 | X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32) 280 | 281 | if FLAGS.dataset == "cifar10": 282 | Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32) 283 | else: 284 | Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32) 285 | 286 | NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32) 287 | 288 | X_finals = [] 289 | 290 | 291 | # Seperate loops 292 | for weight in weights: 293 | X = X_START 294 | 295 | steps = tf.constant(0) 296 | c = lambda i, x: tf.less(i, FLAGS.num_steps) 297 | def langevin_step(counter, X): 298 | scale_rate = 1 299 | 300 | X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE) 301 | 302 | energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) 303 | x_grad = tf.gradients(energy_noise, [X])[0] 304 | 305 | if FLAGS.proj_norm != 0.0: 306 | x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm) 307 | 308 | X = X - FLAGS.step_lr * x_grad * scale_rate 309 | X = tf.clip_by_value(X, 0, 1) 310 | 311 | counter = counter + 1 312 | 313 | return counter, X 314 | 315 | steps, X = tf.while_loop(c, langevin_step, (steps, X)) 316 | energy_noise = model.forward(X, weight, label=Y_GT, reuse=True) 317 | X_final = X 318 | X_finals.append(X_final) 319 | 320 | target_vars = {} 321 | target_vars['X_START'] = X_START 322 | target_vars['Y_GT'] = Y_GT 323 | target_vars['X_finals'] = X_finals 324 | target_vars['NOISE_SCALE'] = NOISE_SCALE 325 | target_vars['energy_noise'] = energy_noise 326 | 327 | compute_inception(sess, target_vars) 328 | 329 | 330 | if __name__ == "__main__": 331 | # model_list = [117000, 116700] 332 | model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)] 333 | main(model_list) 334 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## resources and experiments on autonomous agents 2 | 3 |
    4 | 5 | * **[⬛ ai && ml tl; dr](deep_learning)** 6 | * **[⬛ large language models](llms)** 7 | * **[⬛ agents on blockchains](crypto_agents)** 8 | * **[⬛ on quantum computing](EBMs)** (my adaptation of openai's implicit generation and generalization in energy based models) 9 | 10 |
    11 | 12 | ### cool resources 13 | 14 |
    15 | 16 | * **[mr. vp jd vance at the ai action summit in paris (2025)](https://www.youtube.com/watch?v=MnKsxnP2IVk)** 17 | -------------------------------------------------------------------------------- /books/advances_in_financial_machine_learning.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/autistic-symposium/ml-ai-agents-py/fdba2218a052e2a23dd7c0d7d2d029168377f11a/books/advances_in_financial_machine_learning.pdf -------------------------------------------------------------------------------- /crypto_agents/README.md: -------------------------------------------------------------------------------- 1 | ## crypto agents 2 | 3 |
    4 | 5 | * **[basic strategy workflow](strategy_workflow)** 6 | 7 |
    8 | 9 | --- 10 | 11 | ### cool resources 12 | 13 |
    14 | 15 | 16 | ##### projects 17 | 18 | * **[ritual.net](https://ritual.net/)** 19 | * **[sahara labs](https://saharalabs.ai/)** 20 | * **[multi-agent orchestrator](https://github.com/awslabs/multi-agent-orchestrator)** 21 | * **[swarms](https://github.com/kyegomez/swarms)** 22 | * **[hugging faces](https://huggingface.co/)** 23 | * **[io.net](https://io.net/)** 24 | * **[exo](https://github.com/exo-explore/exo)** 25 | 26 |
    27 | 28 | ##### readings 29 | 30 | * **[the internet's notary public: why verifiability matters, by axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)** 31 | * **[cryptos role in the ai revolution, by pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)** 32 | * **[the promise and challenges of crypto + ai applications, by vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)** 33 | * **[on training defi agents with markov chains, by bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)** 34 | 35 |
    36 | 37 | ##### metas 38 | 39 | * **eth denver 2025 ai + agentic metas**: 40 | * **[a new dawn for decentralized and the convergence of ai x blockchain, by n. ameline](https://www.youtube.com/watch?v=HQuGtN9zidQ)** 41 | * **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)** 42 | * **[ai meets blockchain: payment, multi-Agent & distributed inference, by o. jaros](https://www.youtube.com/watch?v=aPTosw4hrY0)** 43 | * **[from ai agents to agentic economies, by r. bodkin](https://www.youtube.com/watch?v=Q7eaYJ9aPpI)** 44 | * **[building ai agents and agent economies, by d. minarsch](https://www.youtube.com/watch?v=tDVK2Q5RY0c)** 45 | * **[building ai agents on top of hedera, by j. hall](https://www.youtube.com/watch?v=h8D6vi2m8LQ)** 46 | * **[identity, privacy, and security in the new age of ai agents, by m. csernai](https://www.youtube.com/watch?v=vMcaot04RQo)** 47 | * **[verifiable ai agents and world superintelligence, by k. wong](https://www.youtube.com/watch?v=ngkp7HTj_4A)** 48 | * **[how web3 can compete in ai, by g. narula](https://www.youtube.com/watch?v=oLLM1I-3fDU)** 49 | * **[shade agents, by m. lockyer](https://www.youtube.com/watch?v=PEfJnCtrbMU)** 50 | * **[how ai will enable the next generatin of intent, by i. yang](https://www.youtube.com/watch?v=fbc3DpI6jiA)** 51 | * **[2025 is the year of agents, by i. polosukhin](https://www.youtube.com/watch?v=jPyzVNcQMKw)** 52 | 53 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/README.md: -------------------------------------------------------------------------------- 1 | ## strategy workflow 2 | 3 |
    4 | 5 |

    6 | 7 |

    8 | 9 |
    10 | 11 | 1. **[data analysis](data_analysis.md)** 12 | 2. **[supervised model training](supervised_learning.md)** 13 | 3. **[policy development](policy.md)** 14 | 4. **[backtesting](backtesting.md)** 15 | 5. **[parameter optimization](optimization.md)** 16 | 6. **[simulation and paper trading](paper_trading.md)** 17 | 7. **[live trading](live_trading.md)** 18 | 8. **[strategy metrics](strategy_metrics.md)** 19 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/backtesting.md: -------------------------------------------------------------------------------- 1 | ## strategy backtesting 2 | 3 |
    4 | 5 | * use a simulator to test an inital version of the strategy against a set of historical data. 6 | * the simulator can take things such as order book liquidity, network latencies, fees, etc. 7 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/data_analysis.md: -------------------------------------------------------------------------------- 1 | ## data analysis 2 | 3 |
    4 | 5 | * perform exploratory data analysis to find trading opportunities, such as look at charts, calculate statistics, etc. 6 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/defi_glossary.md: -------------------------------------------------------------------------------- 1 | ## DeFi and MEV Glossary 2 | 3 |
    4 | 5 | 6 | ### A 7 | 8 | - Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage of their price discrepancies. 9 | - Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of clients. 10 | ​ 11 |
    12 | 13 | 14 | ### B 15 | 16 | - Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target transaction. 17 | - Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the updates to the blockchain state. 18 | - Block time: the time interval between blocks being added to the blockchain. 19 | - Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to the network. This request is public (anyone can listen to it). 20 | - Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to (multiple) relays (setting themselves afeeRecipient to receive the block’s MEV). 21 | - Bundles: one or more transactions that are grouped together and executed in the order they are provided. In addition to the searcher's transaction(s), a bundle can also contain other users' pending transactions from the mempool. Bundles can target specific blocks for inclusion as well. 22 | ​ 23 |
    24 | 25 | ### C 26 | 27 | - Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until the desired size is reached. 28 | - Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a triggering transaction. 29 | - Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another trader's method. 30 | ​ 31 |
    32 | 33 | ### D 34 | 35 | - Derivatives: financial contracts that derive their values from underlying assets. 36 | - Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment frequency. 37 | ​ 38 |
    39 | 40 | ### E 41 | 42 | - Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen to propose the block in that slot. An epoch contains 32 slots. 43 | - Externally owned account (EOA): an account that is a combination of public address and private key, and that can be used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address derived from the last 20 bytes of the public key of the account (with 0x appended in front). 44 | ​ 45 |
    46 | 47 | ### F 48 | 49 | - Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this information to obtain profit. 50 | - Fully diluted valuations (FDV): the total number of tokens multiplied by the current price of a single token. 51 | - Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their price changes. 52 | - Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of orders is placed both above and below a specific reference market price for the asset). 53 | ​ 54 |
    55 | 56 | ### G 57 | 58 | - Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have their transaction processed. 59 | - Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency. A gwei or gigawei is defined as 1,000,000,000 wei, the smallest base unit of Ether. Conversely, 1 ETH represents 1 billion gwei. 60 | - Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching goal of buying low and selling high. 61 | ​ 62 |
    63 | 64 | ### H 65 | 66 | - Hedging: taking short positions. 67 | ​ 68 |
    69 | 70 | ### K 71 | 72 | - Keys: blockchain account keys can be either private keys (for digital signatures), or public keys (for addresses). 73 | ​ 74 |
    75 | 76 | ### L 77 | 78 | - Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will be executed (see market orders and stop-loss orders). 79 | - Liquidity pools: a collection of crypto assets that can be used for decentralized trading. They are essential for automated market makers (AMM), borrow-lend protocols, yield farming, synthetic assets, on-chain insurance, blockchain gaming, etc. 80 | - Liquidation threshold: the percentage at which a collateral value is counted towards the borrowing capacity. 81 | - Liquidation: when the value of a borrowed asset exceeds the collateral. Anyone can liquidate the collateral and collect the liquidation fee for themselves. 82 | - Long: traders maintain long positions, which means that they expect the price of a coin to rise in the future. 83 | ​ 84 |
    85 | 86 | ### M 87 | 88 | - Fully diluted market capitalization: the total token supply, multiplied by the price of a single token. 89 | - Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the price of a single token. 90 | - Margin trading: buying or sell assets with leverage. 91 | - Marginal seller: a type of seller who is willing first to leave the market if the prices are lower. 92 | - Market orders: Market orders are executed immediately at the asset's market price (see limit orders). 93 | - Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and low prices are a temporary effect that reverts to their mean value (average value). 94 | - Mempool: a cryptocurrency node’s mechanism for storing information on unconfirmed transactions. 95 | - Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom, containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a single root node, also formed from the hash of its two children, representing the top of the tree. 96 | - Minting: the process of validating information, creating a new block, and recording that information into the blockchain. 97 | ​ 98 |
    99 | 100 | ### P 101 | 102 | - Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as Time-Weighted-Average-Price (TWAP). 103 | - Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from arbitrage opportunities, driving up user fees. 104 | - Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital signature. 105 | - Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature made with the matching private key. 106 | - Provider: an entity that provides an abstraction for a connection to the blockchain network. 107 | - POFPs: private order flow protocols. 108 | ​ 109 |
    110 | 111 | ### O 112 | 113 | - Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state of the blockchain. 114 | - Open interest: total number of futures contracts held by market participants at the end of the trading day. Used as an indicator to determine market sentiment and the strength behind price trends. 115 | ​ 116 |
    117 | 118 | ### R 119 | 120 | - RPC endpoints: blockchain odes with RPC endpoints. 121 | ​ 122 |
    123 | 124 | ### S 125 | 126 | - Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen validator has time to propose a block. 127 | - Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties. They are reliant upon code (the functions) and data (the state), and they can trigger specific actions, such as transferring tokens from A to B. 128 | - Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to an unfavorable level, executing the trade, and then returning the asset to the original price. 129 | - Slippage: delta in pricing between the time of order and when the order is executed. 130 | - Short: traders maintain short positions, which means they expect the price of a coin to drop in the future. 131 | - Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason. This situation prompts short sellers to scramble to buy the stock to cover their positions and cap their mounting losses. 132 | - Spot trading: buy or selling assets for immediate delivery. 133 | - Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past statistics. 134 | - Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an investor's loss on a crypto asset. 135 | ​ 136 |
    137 | 138 | ### T 139 | 140 | - otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or derivatives protocols. 141 | - Тrading volume: the total amount of traded cryptocurrency (equivalent to US dollars) during a given timeframe. 142 | - Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract interactions. 143 | - Transaction hash: a unique 66-character identifier generated with each new transaction. 144 | - Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block, allowing attacks that benefit from certain ordering. 145 | - Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using evenly divided time slots between a start and end time. 146 | ​ 147 |
    148 | 149 | ### V 150 | 151 | - Validation: a mathematical proof that the state change in the blockchain is consistent. To be included into a block in the blockchain, a list of transactions needs to be validated. 152 | - VTRPs: validator transaction Reordering protocols. 153 | - Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using historical volume profiles. 154 | 155 |
    156 | 157 | ### W 158 | 159 | - Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become powerful enough to manipulate the valuation. 160 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/live_trading.md: -------------------------------------------------------------------------------- 1 | ## live trading 2 | 3 |
    4 | 5 | * the strategy is now running live on an exchange. 6 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/optimization.md: -------------------------------------------------------------------------------- 1 | ## parameter optimization 2 | 3 |
    4 | 5 | * perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients (using the simulator and a set of historical data) 6 | * overfitting to historical data is a big risk (be careful with validation and test sets). 7 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/paper_trading.md: -------------------------------------------------------------------------------- 1 | ## paper trading 2 | 3 |
    4 | 5 | * before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents overfitting 6 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/policy.md: -------------------------------------------------------------------------------- 1 | ## policy development 2 | 3 |
    4 | 5 | * come with a rule-based policy that determines what actions to take based on the current state of the market and the outpus of supervised models. 6 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/strategy_metrics.md: -------------------------------------------------------------------------------- 1 | ## trading strategy metrics 2 | 3 |
    4 | 5 | * **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period, minus trading fees 6 | * **alpha nad beta** 7 | * **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation adjusted for risk; the higher the better). 8 | * **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure of risk. 9 | * **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal market conditions. 10 | -------------------------------------------------------------------------------- /crypto_agents/strategy_workflow/supervised_learning.md: -------------------------------------------------------------------------------- 1 | ## supervised learning 2 | 3 |
    4 | 5 | * train one or more supervised learning models to predict quantities of interest that are necessary for the strategy work, for example, price prediction, quantity prediction, etc. 6 | -------------------------------------------------------------------------------- /crypto_agents/trading_on_gmx.md: -------------------------------------------------------------------------------- 1 | ## basics on trading, illustrated on gmx 2 | 3 |
    4 | 5 | #### price chart 6 | 7 | 8 | * the current price is the price of the most recent trade. 9 | * it varies on whether that trade was a buy or a sell. 10 | * high volume means the price movement is more reliable (consensus of a large number of market participants). 11 | * candlesitck chart showing open/start (O), high (h), low (low), anc close/end (c) prices for a given time window. 12 | 13 |
    14 | 15 | 16 | 17 | 18 | 19 |
    20 | 21 | ---- 22 | 23 | #### order book 24 | 25 |
    26 | 27 | * the order book is made of two sides, asks (sell, offers) and bids (buy). 28 | * the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to buy). 29 | * the difference between the best ask and the best bid is called spread. 30 | * **market order**: best price possible, right now. it takes liquidity from the market and usually has higher fees. 31 | * **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the match. 32 | * **stop orders**: allow you to set a maximum price for your market orders. 33 | 34 |
    35 | -------------------------------------------------------------------------------- /deep_learning/README.md: -------------------------------------------------------------------------------- 1 | ## ai agents 2 | 3 |
    4 | 5 | * **[deep learning](deep_learning.md)** 6 | * **[reinforcement learning](reinforcement_learning.md)** 7 | 8 |
    9 | 10 | ---- 11 | 12 | ### cool resources 13 | 14 |
    15 | 16 | * **[cursor ai editor](https://www.cursor.com/)** 17 | * **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)** 18 | * **[google's jax (composable transformations of numpy programs)](https://github.com/google/jax)** 19 | * **[machine learning engineering open book](https://github.com/stas00/ml-engineering)** 20 | * **[advances in financial machine learning](books/advances_in_financial_machine_learning.pdf)** 21 | -------------------------------------------------------------------------------- /deep_learning/deep_learning.md: -------------------------------------------------------------------------------- 1 | ## deep learning 2 | 3 |
    4 | 5 | ### timeline tl; dr 6 | 7 |
    8 | 9 | * **[2012: imagenet and alexnet](https://github.com/tensorflow/models/blob/master/research/slim/nets/alexnet.py)** 10 | 11 | * **[2013: atari with deep reinforcement learning](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial)** 12 | * **[2014: seq2seq](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt)** 13 | * **[2014: adam optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)** 14 | * **[2015: gans](https://www.tensorflow.org/tutorials/generative/dcgan)** 15 | * **[2015: resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)** 16 | * **[2017: transformers](https://github.com/huggingface/transformers)** 17 | * **[2018: bert](https://arxiv.org/abs/1810.04805)** 18 | 19 |
    20 | 21 | --- 22 | 23 | ### deep reinforcement learning for trading 24 | 25 |
    26 | 27 | * a map consists of a set of states, a set of actions, a transition function that describes the probability of moving rom one state to another after taking an action, and a reward function that assigns a numerical reward to each state-action pair 28 | 29 | * the goal of a map is to maximize its expected cumulative reward over a sequence of actions, called a policy. 30 | 31 | * a policy is a function that maps each state to a probability distribution over actions. The optimal policy is the one that maximizes the expected cumulative rewards. 32 | 33 | * the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as the optimal control of incompletely-known Markov decision processes. 34 | 35 | * as opposed to supervised learning, an agent must be able to learn from its own experience. and as oppose to unsupervised learning because, reinforcement learning is trying to maximize a reward signal instead of trying to find hidden structure. 36 | 37 | * the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain a reliable estimate of its expected reward. 38 | 39 | * beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a policy, a reward signal, a value function, and, optionally, a model of the environment. 40 | 41 | * traditional reinforcement learning problems can be formulated as a markov decision process (MDP): 42 | * we have an agent acting in an environment 43 | * each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward R_{t+1} and the next state S_{t+1} 44 | * the agent choose the action based on some policy pi: A_t = pi(S_t) 45 | * it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon 46 | 47 | 48 |
    49 | 50 | 51 | 52 |
    53 | 54 | #### agent 55 | 56 |
    57 | 58 | * agent is the trading agent (e.g. the human trader who opens the gui of an exchange and makes trading decision based on the current state of the exchange and their account) 59 | 60 |
    61 | 62 | #### environment 63 | 64 |
    65 | 66 | * the exchange and other agents are the environment, and they are not something we can control 67 | * by putting other agents together into some big complex environment, we lose the ability to explicitly model them 68 | * if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent reinforcement learning (MARL) problem setting 69 | 70 |
    71 | 72 | #### state 73 | 74 |
    75 | 76 | * in the case of trading on an exchange, we don't observe the complete state of the environment (e.g. other agents), so we are dealing with a partially observable markov decision process (pomdp). 77 | * what the agents observe is not the actual state S_t of the environment, but some derivation of that. 78 | * we can call the observation X_t, which is calculated using some function of the full state X_t ~ O(S_t) 79 | * the observation at each timestep t is simply the history of all exchange events received up to time t. 80 | * this event history can be used to build up the current exchange state, however, in order for our agent to make decisions, extra info such as account balance and open limit orders need to be included. 81 | 82 |
    83 | 84 | #### time scale 85 | 86 |
    87 | 88 | * hft techniques: decisions are based almost entirely on market microstructure signals. decisions are made on nanoseconds timescales and trading strategies use dedicated connections to exchanges and extremly fast but simple algorithms running fpga hardware. 89 | * neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed of hft algorithms. 90 | * guess: the optimal time scale is between a few milliseconds and a few minutes. 91 | * can deep rl algorithms pick up hidden patterns? 92 | 93 |
    94 | 95 | #### action space 96 | 97 |
    98 | 99 | * the simplest approach has 3 actions: buy, hold, and sell. this works but limits us to placing market orders and to invest a deterministic amount of money at each step. 100 | * in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model, putting us into a continuous action space. 101 | * in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of the order, and be able to cancel orders that have not been yet matched. 102 | 103 |
    104 | 105 | #### reward function 106 | 107 |
    108 | 109 | * there are several possible reward functions, an obvious would realized PnL (profit and loss). the agent receives a reward whenever it closes a position. 110 | * the net profit is either negative or positive, and this is the reward signal. 111 | * as the agent maximize the total cumulative reward, it learns to trade profitably. the reward function leads to the optimal policy in the limit. 112 | * however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent feedback. 113 | * an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions immediately. 114 | * because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals. however the direct feedback may bias the agent towards short-term actions. 115 | * both naively optimize for profit, but a trader may want to minimize risk (lower volatility) 116 | * using the sharpe ration is one simple way to take risk into account. other way is maximum drawdown. 117 | 118 |
    119 | 120 | 121 | 122 |
    123 | 124 | #### learned policies 125 | 126 |
    127 | 128 | * instead of needing to hand-code a rule-based policy, rl directly learns a policy 129 | 130 | 131 |
    132 | 133 | #### trained directly in simulation environments 134 | 135 |
    136 | 137 | * we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into account environmental factors: order book liquidity, fee structures, latencies. 138 | * getting around environmental limitations is part of the opimization process. if we simulate the latency in the reinforcement learning environment, and this results in the agent making a mistake, the agent will get a negative rewards, forcing it to learn to work around the latencies. 139 | * by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts), we could take into account potential reactions of the market (other agents) 140 | * by being smart about the data we collect from the live environment, we can continously improve our model 141 | * do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting information that we can use to improve the model of our environment and other agents? 142 | 143 |
    144 | 145 | #### learning to adapt to market conditions 146 | 147 |
    148 | 149 | * some strategy may work better in a bearish environment but lose money in a bullish environment. 150 | * because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient memory. 151 | 152 |
    153 | 154 | #### trading as research 155 | 156 |
    157 | 158 | * the trading environment is a multiplayer game with thousands of agents acting simultaneously 159 | * understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live environment with the goal of maximizing the information grain with respect to kind policies the other agents may be following 160 | * trading agents receive sparse rewards from the market. naively applying reward-hungry rl algorithms will fail. 161 | * this opens up the possibility for new algorithms and techniques, that can efficiently deal with sparse rewards. 162 | * many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris. 163 | * the trading environment is inherently nonstationary. market conditions change and other agent join, leave, and constantly change their strategies. 164 | * can we train an agent that can transit from bear to bull and then back to bear, without needing to be re-trained? 165 | -------------------------------------------------------------------------------- /deep_learning/reinforcement_learning.md: -------------------------------------------------------------------------------- 1 | ## reinforcement learning 2 | 3 |
    4 | 5 | ### tl; dr 6 | 7 |
    8 | 9 | * reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward signal 10 | * an autonomous agent is a software program or system that can operate independently and make decisions on its own, without direct intervention from a human 11 | 12 |
    13 | 14 | --- 15 | 16 | ### overview 17 | 18 |
    19 | 20 | * we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of incompletely-known Markov decision processes. 21 | * a learning agent must be able to sense the state of its environment to some extent and must be able to take actions that affect the state. 22 | * markov decision processes are intented to include just these three aspects, sensation, action, and goal. 23 | * the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in order to make better action selections in the future. 24 | * on a stochastic tasks, each action must be tried many times to gain a reliable estimate of its expected reward. 25 | 26 |
    27 | 28 | --- 29 | 30 | ### elements of reinforcement learning 31 | 32 |
    33 | 34 | * beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward signal, a value funtion, and a model of the environmnet. 35 | * a policy defines the learning agent's way of behacing at a given time. It's a mapping from perceiv ed states of the environment to actions to be taken when in those states. in general, policies may be stochastics (specifying probabilities for each action). 36 | * a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total reward over the run. 37 | * a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent can expect to accumulate over the future, starting from that state 38 | * a model of the environment. 39 | * the most important feature distinguishing reinforcement learning from other types of learning is that it uses training information that evaluates the actions taken rather than instructs by giving correct actions. 40 | 41 |
    42 | 43 | --- 44 | 45 | ### finite markov decision processes (mdps) 46 | 47 |
    48 | 49 | * the problem involves evaluating feedbacks and choosing different actions in different situations. 50 | * mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards, but also subsequent situations. 51 | * mdps involve delayed reward and the need to trade off immediate and delayed reward. 52 | 53 |
    54 | 55 | ##### the agent-environment interface 56 | 57 | * mdps are meant to be a straightfoward framing of the problem of learning from interaction to achieve a goal. 58 | * the learner and the decision makers is called the agent. 59 | * the thing it interacts with, comprimising everything outside the agent, is called the environment. 60 | * the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice of actions. 61 | 62 |
    63 | 64 | 65 | 66 |
    67 | 68 | * the agent and the environment interact at each of a sequence of discrete steps, t = 0, 1, 2, 3... 69 | * at each time step t, the agent receives some representation of the environments state St 70 | * on that basis, the agent selects an action At 71 | * one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a new state. 72 | * the mdp and the agent together give rise to a sequence (trajectory) 73 | * in a finite mdp, the set of states, actions, and rewards all have a finite number of elements. in this case, the random variables R and S have well defined discrete probability distributions dependent only on the proceding state and action. 74 | * in a markov decision process, the probabilities given by p completely characterize the environment's dynamics. 75 | * the state must include information about all aspects of the past agent-environment interaction that make a differnce for the future. 76 | * anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its environment. 77 | 78 |
    79 | 80 | ##### goals and rewards 81 | 82 | 83 | * each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to a sample from a standard distribution of starting states. 84 | * almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of state–action pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a given action in a given state). 85 | * the Bellman equation averages over all the possibilities, weighting each by its probability of occurring. tt states that the value of the start state must equal the 86 | (discounted) value of the expected next state, plus the reward expected along the way. 87 | * solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run. 88 | 89 |
    90 | 91 | --- 92 | 93 | ### dynamic programming 94 | 95 | * collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a mdp. 96 | * a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state and action spaces and then apply finite-state DP methods. 97 | * the reason for computing the value function for a policy is to help find better policies. 98 | * asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other states happen to be available. the values of some states may be updated several times before the values of others ar 99 | * policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy. 100 | * policy improvement refers to the computation of an improved policy given the value function for that policy. 101 | 102 |
    103 | 104 | ##### generalized policy interaction 105 | 106 | * policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with the current policy (policy evaluation), and the other making the policy greedy with respect to the current value function (policy improvement). 107 | * generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement processes interact, independent of the granularity and other details of the two processes. 108 | * DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the number of states often grows exponentially with the number of state variables 109 | 110 |
    111 | 112 | --- 113 | 114 | ### cool resources 115 | 116 |
    117 | 118 | * **[gymnasium api](https://gymnasium.farama.org/)** 119 | * **[reinforcement learning with unsupervised auxiliary tasks, by jaderberg et al.](https://arxiv.org/abs/1611.05397)** -------------------------------------------------------------------------------- /llms/README.md: -------------------------------------------------------------------------------- 1 | ## large language models 2 | 3 |
    4 | 5 | * **[gpt-x](gpt)** 6 | * **[claude](claude)** 7 | * **[eliza](eliza)** 8 | * **[deepseek](deeepsek)** 9 | 10 |
    11 | 12 | --- 13 | 14 | ### cool resources 15 | 16 |
    17 | 18 | #### articles 19 | 20 | * **[people cannot distinguish gpt-4 from a human in a turing test, by c. jones et al (2024)](https://arxiv.org/pdf/2405.08007)** 21 | 22 |
    23 | 24 | #### code 25 | 26 | * **[awesome chatgpt prompts](https://github.com/f/awesome-chatgpt-prompts)** 27 | -------------------------------------------------------------------------------- /llms/claude/README.md: -------------------------------------------------------------------------------- 1 | ## claude 2 | 3 |
    4 | 5 | ### cool resources 6 | 7 |
    8 | 9 | * **[dario amodei on agi & the future of humanity, by lex fridman](https://www.youtube.com/watch?v=ugvHCXCOmm4)** 10 | -------------------------------------------------------------------------------- /llms/deepseek/README.md: -------------------------------------------------------------------------------- 1 | ## deepseek stuff 2 | 3 |
    4 | 5 |

    6 | 7 | 8 |

    9 | -------------------------------------------------------------------------------- /llms/eliza/README.md: -------------------------------------------------------------------------------- 1 | ## eliza -------------------------------------------------------------------------------- /llms/gpt/README.md: -------------------------------------------------------------------------------- 1 | ## gpt 2 | 3 |
    4 | 5 | ### cool resources 6 | 7 |
    8 | 9 | * **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode) (and [here](https://marketplace.visualstudio.com/items?itemName=timkmecl.chatgpt))** 10 | * **[scispace extension (paper explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)** 11 | * **[fix python bugs](https://platform.openai.com/playground/p/default-fix-python-bugs?model=code-davinci-002)** 12 | * **[explain code](https://platform.openai.com/playground/p/default-explain-code?model=code-davinci-002)** 13 | * **[translate code](https://platform.openai.com/playground/p/default-translate-code?model=code-davinci-002)** 14 | * **[translate sql](https://platform.openai.com/playground/p/default-sql-translate?model=code-davinci-002)** 15 | * **[calculate time complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)** 16 | * **[text to programmatic command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)** 17 | --------------------------------------------------------------------------------