├── EBMs
├── README.md
├── ais.py
├── custom_adam.py
├── data.py
├── ebm_combine.py
├── ebm_sandbox.py
├── fid.py
├── hmc.py
├── imagenet_demo.py
├── imagenet_preprocessing.py
├── inception.py
├── models.py
├── requirements.txt
├── test_inception.py
├── train.py
└── utils.py
├── README.md
├── books
└── advances_in_financial_machine_learning.pdf
├── crypto_agents
├── README.md
├── strategy_workflow
│ ├── README.md
│ ├── backtesting.md
│ ├── data_analysis.md
│ ├── defi_glossary.md
│ ├── live_trading.md
│ ├── optimization.md
│ ├── paper_trading.md
│ ├── policy.md
│ ├── strategy_metrics.md
│ └── supervised_learning.md
└── trading_on_gmx.md
├── deep_learning
├── README.md
├── deep_learning.md
└── reinforcement_learning.md
└── llms
├── README.md
├── claude
└── README.md
├── deepseek
└── README.md
├── eliza
└── README.md
└── gpt
└── README.md
/EBMs/README.md:
--------------------------------------------------------------------------------
1 | ## quantum ai: training energy-based-models using openai
2 |
3 |
4 |
5 |
6 | #### ⚛️ this repository contains my adapted code from [opeani's implicit generation and generalization in energy-based-models](https://arxiv.org/pdf/1903.08689.pdf)
7 |
8 |
9 |
10 | ### installing
11 |
12 |
13 |
14 | ```bash
15 | brew install gcc@6
16 | brew install open-mpi
17 | brew install pkg-config
18 | ```
19 |
20 |
21 |
22 | * there is a **[bug](https://github.com/open-mpi/ompi/issues/7516)** in open-mpi for the specific libraries in this problem (`PMIX ERROR: ERROR`) that can be fixed with:
23 |
24 |
25 |
26 | ```
27 | export PMIX_MCA_gds=^ds12
28 | ```
29 |
30 |
31 |
32 | * then install python's requirements:
33 |
34 |
35 |
36 | ```bash
37 | virtualenv venv
38 | source venv/bin/activate
39 | pip install -r requirements.txt
40 | ```
41 |
42 |
43 | * note that this is an adapted requirement file since the **[openai's original](https://github.com/openai/ebm_code_release/blob/master/requirements.txt)** is not complete/correct
44 | * finally, download and install **[mujoco](https://www.roboti.us/index.html)**
45 | * you will also need to register for a license, which asks for a machine ID
46 | * the documentation on the website is incomplete, so just download the suggested script and run:
47 |
48 |
49 |
50 | ```bash
51 | mv getid_osx getid_osx.dms
52 | ./getid_osx.dms
53 | ```
54 |
55 |
56 |
57 | ---
58 |
59 | ### running
60 |
61 |
62 |
63 | #### download pre-trained models (examples)
64 |
65 |
66 |
67 | * download all **[pre-trained models](https://sites.google.com/view/igebm/home)** and unzip them into a local folder `cachedir`:
68 |
69 |
70 |
71 | ```bash
72 | mkdir cachedir
73 | ```
74 |
75 |
76 |
77 | #### setting results directory
78 |
79 |
80 |
81 | * openai's original code contains **[hardcoded constants that only work on Linux](https://github.com/openai/ebm_code_release/blob/master/data.py#L218)**
82 | * i changed this to a constant (`ROOT_DIR = "./results"`) in the top of `data.py`
83 |
84 |
85 |
86 | #### running (parallelization with `mpiexec`)
87 |
88 |
89 |
90 | * all code supports **[`horovod` execution](https://github.com/horovod/horovod)**, so model training can be increased substantially by using multiple different workers by running each command:
91 |
92 |
93 |
94 | ```
95 | mpiexec -n
96 | ```
97 |
98 |
99 |
100 | ##### cifar-10 unconditional
101 |
102 |
103 |
104 | ```
105 | python train.py --exp=cifar10_uncond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --large_model
106 | ```
107 |
108 | * this should generate the following output:
109 |
110 |
111 |
112 | ```bash
113 | Instructions for updating:
114 | Use tf.gfile.GFile.
115 | 2020-05-10 22:12:32.471415: W tensorflow/core/framework/op_def_util.cc:355] Op BatchNormWithGlobalNormalization is deprecated. It will cease to work in GraphDef version 9. Use tf.nn.batch_normalization().
116 | 64 batch size
117 | Local rank: 0 1
118 | Loading data...
119 | Files already downloaded and verified
120 | Files already downloaded and verified
121 | Files already downloaded and verified
122 | Files already downloaded and verified
123 | Done loading...
124 | WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
125 | Instructions for updating:
126 | Colocations handled automatically by placer.
127 | Building graph...
128 | WARNING:tensorflow:From /Users/mia/dev/ebm_code_release/venv/lib/python3.7/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
129 | Instructions for updating:
130 | Use tf.cast instead.
131 | Finished processing loop construction ...
132 | Started gradient computation...
133 | Applying gradients...
134 | Finished applying gradients.
135 | Model has a total of 7567880 parameters
136 | Initializing variables...
137 | Start broadcast
138 | End broadcast
139 | Obtained a total of e_pos: -0.0025530937127768993, e_pos_std: 0.09564747661352158, e_neg: -0.22276005148887634, e_diff: 0.22020696103572845, e_neg_std: 0.016306934878230095, temp: 1, loss_e: -0.22276005148887634, eps: 0.0, label_ent: 2.272536277770996, l
140 | oss_ml: 0.22020693123340607, loss_total: 0.2792498469352722, x_grad: 0.0009156676824204624, x_grad_first: 0.0009156676824204624, x_off: 0.31731340289115906, iter: 0, gamma: [0.], context_0/c1_pre/cweight:0: 0.0731438547372818, context_0/res_optim_res_c1/
141 | cweight:0: 4.732660444095593e-11, context_0/res_optim_res_c1/gb:0: 3.4007335836250263e-10, context_0/res_optim_res_c2/cweight:0: 0.9494612216949463, context_0/res_optim_res_c2/g:0: 1.8536269741353806e-10, context_0/res_optim_res_c2/gb:0: 6.27235652306268
142 | 3e-10, context_0/res_optim_res_c2/cb:0: 1.1606662297936055e-09, context_0/res_1_res_c1/cweight:0: 6.714453298917178e-11, context_0/res_1_res_c1/gb:0: 3.6198691266697836e-10, context_0/res_1_res_c2/cweight:0: 0.6582950353622437, context_0/res_1_res_c2/g:0
143 | : 1.669797633496728e-10, context_0/res_1_res_c2/gb:0: 5.911696687732615e-10, context_0/res_1_res_c2/cb:0: 1.1932842491901852e-09, context_0/res_2_res_c1/cweight:0: 8.567072745657711e-11, context_0/res_2_res_c1/gb:0: 6.868505764145993e-10, context_0/res_2
144 | _res_c2/cweight:0: 0.46929678320884705, context_0/res_2_res_c2/g:0: 1.655784120924153e-10, context_0/res_2_res_c2/gb:0: 8.058526068666083e-10, context_0/res_2_res_c2/cb:0: 1.0161046448686761e-09, context_0/res_2_res_adaptive/cweight:0: 0.0194275379180908
145 | 2, context_0/res_3_res_c1/cweight:0: 4.011655244107182e-11, context_0/res_3_res_c1/gb:0: 5.064903496609929e-10, context_0/res_3_res_c2/cweight:0: 0.32239994406700134, context_0/res_3_res_c2/g:0: 9.758494012857e-11, context_0/res_3_res_c2/gb:0: 7.75612463
146 | 1441708e-10, context_0/res_3_res_c2/cb:0: 6.362700366580043e-10, context_0/res_4_res_c1/cweight:0: 4.090133440270982e-11, context_0/res_4_res_c1/gb:0: 6.013010089844784e-10, context_0/res_4_res_c2/cweight:0: 0.34806951880455017, context_0/res_4_res_c2/g:
147 | 0: 8.414659247168998e-11, context_0/res_4_res_c2/gb:0: 6.443054978433338e-10, context_0/res_4_res_c2/cb:0: 5.496815780325903e-10, context_0/res_5_res_c1/cweight:0: 3.990113794927197e-11, context_0/res_5_res_c1/gb:0: 3.807749116013781e-10, context_0/res_5
148 | _res_c2/cweight:0: 0.22841960191726685, context_0/res_5_res_c2/g:0: 4.942361797599659e-11, context_0/res_5_res_c2/gb:0: 7.697342763179904e-10, context_0/res_5_res_c2/cb:0: 3.1796060229183354e-10, context_0/fc5/wweight:0: 3.081033706665039, context_0/fc5/
149 | b:0: 0.4506262540817261,
150 |
151 | ................................................................................................................................
152 | Inception score of 1.2397289276123047 with std of 0.0
153 | ```
154 |
155 |
156 |
157 | ##### cifar-10 conditional
158 |
159 |
160 |
161 | ```
162 | python train.py --exp=cifar10_cond --dataset=cifar10 --num_steps=60 --batch_size=128 --step_lr=10.0 --proj_norm=0.01 --zero_kl --replay_batch --cclass
163 | ```
164 |
165 |
166 |
167 | ##### imagenet 32x32 conditional
168 |
169 |
170 |
171 | ```
172 | python train.py --exp=imagenet_cond --num_steps=60 --wider_model --batch_size=32 step_lr=10.0 --proj_norm=0.01 --replay_batch --cclass --zero_kl --dataset=imagenet --imagenet_path=
173 | ```
174 |
175 |
176 |
177 | ##### imagenet 128x128 conditional
178 |
179 |
180 |
181 | ```
182 | python train.py --exp=imagenet_cond --num_steps=50 --batch_size=16 step_lr=100.0 --replay_batch --swish_act --cclass --zero_kl --dataset=imagenetfull --imagenet_datadir=
183 | ```
184 |
185 |
186 |
187 | ##### imagenet demo
188 |
189 |
190 |
191 | * the imagenet_demo.py file contains code for experiments with ebms on conditional imagenet 128x128
192 | * to generate a gif on sampling, you can run the command:
193 |
194 |
195 |
196 | ```
197 | python imagenet_demo.py --exp=imagenet128_cond --resume_iter=2238000 --swish_act
198 | ```
199 |
200 | * the ebm_sandbox.py file contains several different tasks that can be used to evaluate ebms, which are defined by different settings of task flag in the file
201 | * for example, to visualize cross class mappings in cifar-10, you can run:
202 |
203 |
204 |
205 | ```
206 | python ebm_sandbox.py --task=crossclass --num_steps=40 --exp=cifar10_cond --resume_iter=74700
207 | ```
208 |
209 |
210 |
211 | ##### generalization
212 |
213 |
214 |
215 | * to test generalization to out of distribution classification for SVHN (with similar commands for other datasets):
216 |
217 |
218 |
219 | ```
220 | python ebm_sandbox.py --task=mixenergy --num_steps=40 --exp=cifar10_large_model_uncond --resume_iter=121200 --large_model --svhnmix --cclass=False
221 | ```
222 |
223 |
224 |
225 | * to test classification on cifar-10 using a conditional model under either L2 or Li perturbations
226 |
227 |
228 |
229 | ```
230 | python ebm_sandbox.py --task=label --exp=cifar10_wider_model_cond --resume_iter=21600 --lnorm=-1 --pgd= --num_steps=10 --lival= --wider_model
231 | ```
232 |
233 |
234 |
235 | ##### concept combination
236 |
237 |
238 |
239 | * to train ebms on conditional dsprites dataset, you can train each model separately on each conditioned latent in `cond_pos`, `cond_rot`, `cond_shape`, `cond_scale`, with an example command given below:
240 |
241 |
242 |
243 | ```
244 | python train.py --dataset=dsprites --exp=dsprites_cond_pos --zero_kl --num_steps=20 --step_lr=500.0 --swish_act --cond_pos --replay_batch -cclass
245 | ```
246 |
247 |
248 |
249 | * once models are trained, they can be sampled from jointly by running:
250 |
251 | ```
252 | python ebm_combine.py --task=conceptcombine --exp_size= --exp_shape= --exp_pos= --exp_rot= --resume_size= --resume_shape= --resume_rot= --resume_pos=
253 | ```
254 |
--------------------------------------------------------------------------------
/EBMs/ais.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import math
3 | from hmc import hmc
4 | from tensorflow.python.platform import flags
5 | from torch.utils.data import DataLoader
6 | from models import DspritesNet, ResNet32, ResNet32Large, ResNet32Wider, MnistNet
7 | from data import Cifar10, Mnist, DSprites
8 | from scipy.misc import logsumexp
9 | from scipy.misc import imsave
10 | from utils import optimistic_restore
11 | import os.path as osp
12 | import numpy as np
13 | from tqdm import tqdm
14 |
15 | flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
16 | flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or mnist or dsprites or 2d or toy Gauss')
17 | flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
18 | flags.DEFINE_string('exp', 'default', 'name of experiments')
19 | flags.DEFINE_integer('data_workers', 5, 'Number of different data workers to load data in parallel')
20 | flags.DEFINE_integer('batch_size', 16, 'Size of inputs')
21 | flags.DEFINE_string('resume_iter', '-1', 'iteration to resume training from')
22 |
23 | flags.DEFINE_bool('max_pool', False, 'Whether or not to use max pooling rather than strided convolutions')
24 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
25 | flags.DEFINE_integer('pdist', 10, 'number of intermediate distributions for ais')
26 | flags.DEFINE_integer('gauss_dim', 500, 'dimensions for modeling Gaussian')
27 | flags.DEFINE_integer('rescale', 1, 'factor to rescale input outside of normal (0, 1) box')
28 | flags.DEFINE_float('temperature', 1, 'temperature at which to compute likelihood of model')
29 | flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
30 | flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
31 | flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
32 | flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
33 | flags.DEFINE_bool('cclass', False, 'Whether to evaluate the log likelihood of conditional model or not')
34 | flags.DEFINE_bool('single', False, 'Whether to evaluate the log likelihood of conditional model or not')
35 | flags.DEFINE_bool('large_model', False, 'Use large model to evaluate')
36 | flags.DEFINE_bool('wider_model', False, 'Use large model to evaluate')
37 | flags.DEFINE_float('alr', 0.0045, 'Learning rate to use for HMC steps')
38 |
39 | FLAGS = flags.FLAGS
40 |
41 | label_default = np.eye(10)[0:1, :]
42 | label_default = tf.Variable(tf.convert_to_tensor(label_default, np.float32))
43 |
44 |
45 | def unscale_im(im):
46 | return (255 * np.clip(im, 0, 1)).astype(np.uint8)
47 |
48 | def gauss_prob_log(x, prec=1.0):
49 |
50 | nh = float(np.prod([s.value for s in x.get_shape()[1:]]))
51 | norm_constant_log = -0.5 * (tf.log(2 * math.pi) * nh - nh * tf.log(prec))
52 | prob_density_log = -tf.reduce_sum(tf.square(x - 0.5), axis=[1]) / 2. * prec
53 |
54 | return norm_constant_log + prob_density_log
55 |
56 |
57 | def uniform_prob_log(x):
58 |
59 | return tf.zeros(1)
60 |
61 |
62 | def model_prob_log(x, e_func, weights, temp):
63 | if FLAGS.cclass:
64 | batch_size = tf.shape(x)[0]
65 | label_tiled = tf.tile(label_default, (batch_size, 1))
66 | e_raw = e_func.forward(x, weights, label=label_tiled)
67 | else:
68 | e_raw = e_func.forward(x, weights)
69 | energy = tf.reduce_sum(e_raw, axis=[1])
70 | return -temp * energy
71 |
72 |
73 | def bridge_prob_neg_log(alpha, x, e_func, weights, temp):
74 |
75 | if FLAGS.dataset == "gauss":
76 | norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * gauss_prob_log(x, prec=FLAGS.temperature)
77 | else:
78 | norm_prob = (1-alpha) * uniform_prob_log(x) + alpha * model_prob_log(x, e_func, weights, temp)
79 | # Add an additional log likelihood penalty so that points outside of (0, 1) box are *highly* unlikely
80 |
81 |
82 | if FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
83 | oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1])
84 | elif FLAGS.dataset == 'mnist':
85 | oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0, FLAGS.rescale))), axis = [1, 2])
86 | else:
87 | oob_prob = tf.reduce_sum(tf.square(100 * (x - tf.clip_by_value(x, 0., FLAGS.rescale))), axis = [1, 2, 3])
88 |
89 | return -norm_prob + oob_prob
90 |
91 |
92 | def ancestral_sample(e_func, weights, batch_size=128, prop_dist=10, temp=1, hmc_step=10):
93 | if FLAGS.dataset == "2d":
94 | x = tf.placeholder(tf.float32, shape=(None, 2))
95 | elif FLAGS.dataset == "gauss":
96 | x = tf.placeholder(tf.float32, shape=(None, FLAGS.gauss_dim))
97 | elif FLAGS.dataset == "mnist":
98 | x = tf.placeholder(tf.float32, shape=(None, 28, 28))
99 | else:
100 | x = tf.placeholder(tf.float32, shape=(None, 32, 32, 3))
101 |
102 | x_init = x
103 |
104 | alpha_prev = tf.placeholder(tf.float32, shape=())
105 | alpha_new = tf.placeholder(tf.float32, shape=())
106 | approx_lr = tf.placeholder(tf.float32, shape=())
107 |
108 | chain_weights = tf.zeros(batch_size)
109 | # for i in range(1, prop_dist+1):
110 | # print("processing loop {}".format(i))
111 | # alpha_prev = (i-1) / prop_dist
112 | # alpha_new = i / prop_dist
113 |
114 | prob_log_old_neg = bridge_prob_neg_log(alpha_prev, x, e_func, weights, temp)
115 | prob_log_new_neg = bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
116 |
117 | chain_weights = -prob_log_new_neg + prob_log_old_neg
118 | # chain_weights = tf.Print(chain_weights, [chain_weights])
119 |
120 | # Sample new x using HMC
121 | def unorm_prob(x):
122 | return bridge_prob_neg_log(alpha_new, x, e_func, weights, temp)
123 |
124 | for j in range(1):
125 | x = hmc(x, approx_lr, hmc_step, unorm_prob)
126 |
127 | return chain_weights, alpha_prev, alpha_new, x, x_init, approx_lr
128 |
129 |
130 | def main():
131 |
132 | # Initialize dataset
133 | if FLAGS.dataset == 'cifar10':
134 | dataset = Cifar10(train=False, rescale=FLAGS.rescale)
135 | channel_num = 3
136 | dim_input = 32 * 32 * 3
137 | elif FLAGS.dataset == 'imagenet':
138 | dataset = ImagenetClass()
139 | channel_num = 3
140 | dim_input = 64 * 64 * 3
141 | elif FLAGS.dataset == 'mnist':
142 | dataset = Mnist(train=False, rescale=FLAGS.rescale)
143 | channel_num = 1
144 | dim_input = 28 * 28 * 1
145 | elif FLAGS.dataset == 'dsprites':
146 | dataset = DSprites()
147 | channel_num = 1
148 | dim_input = 64 * 64 * 1
149 | elif FLAGS.dataset == '2d' or FLAGS.dataset == 'gauss':
150 | dataset = Box2D()
151 |
152 | dim_output = 1
153 | data_loader = DataLoader(dataset, batch_size=FLAGS.batch_size, num_workers=FLAGS.data_workers, drop_last=False, shuffle=True)
154 |
155 | if FLAGS.dataset == 'mnist':
156 | model = MnistNet(num_channels=channel_num)
157 | elif FLAGS.dataset == 'cifar10':
158 | if FLAGS.large_model:
159 | model = ResNet32Large(num_filters=128)
160 | elif FLAGS.wider_model:
161 | model = ResNet32Wider(num_filters=192)
162 | else:
163 | model = ResNet32(num_channels=channel_num, num_filters=128)
164 | elif FLAGS.dataset == 'dsprites':
165 | model = DspritesNet(num_channels=channel_num, num_filters=FLAGS.num_filters)
166 |
167 | weights = model.construct_weights('context_{}'.format(0))
168 |
169 | config = tf.ConfigProto()
170 | sess = tf.Session(config=config)
171 | saver = loader = tf.train.Saver(max_to_keep=10)
172 |
173 | sess.run(tf.global_variables_initializer())
174 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
175 |
176 | model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
177 | resume_itr = FLAGS.resume_iter
178 |
179 | if FLAGS.resume_iter != "-1":
180 | optimistic_restore(sess, model_file)
181 | else:
182 | print("WARNING, YOU ARE NOT LOADING A SAVE FILE")
183 | # saver.restore(sess, model_file)
184 |
185 | chain_weights, a_prev, a_new, x, x_init, approx_lr = ancestral_sample(model, weights, FLAGS.batch_size, temp=FLAGS.temperature)
186 | print("Finished constructing ancestral sample ...................")
187 |
188 | if FLAGS.dataset != "gauss":
189 | comb_weights_cum = []
190 | batch_size = tf.shape(x_init)[0]
191 | label_tiled = tf.tile(label_default, (batch_size, 1))
192 | e_compute = -FLAGS.temperature * model.forward(x_init, weights, label=label_tiled)
193 | e_pos_list = []
194 |
195 | for data_corrupt, data, label_gt in tqdm(data_loader):
196 | e_pos = sess.run([e_compute], {x_init: data})[0]
197 | e_pos_list.extend(list(e_pos))
198 |
199 | print(len(e_pos_list))
200 | print("Positive sample probability ", np.mean(e_pos_list), np.std(e_pos_list))
201 |
202 | if FLAGS.dataset == "2d":
203 | alr = 0.0045
204 | elif FLAGS.dataset == "gauss":
205 | alr = 0.0085
206 | elif FLAGS.dataset == "mnist":
207 | alr = 0.0065
208 | #90 alr = 0.0035
209 | else:
210 | # alr = 0.0125
211 | if FLAGS.rescale == 8:
212 | alr = 0.0085
213 | else:
214 | alr = 0.0045
215 | #
216 | for i in range(1):
217 | tot_weight = 0
218 | for j in tqdm(range(1, FLAGS.pdist+1)):
219 | if j == 1:
220 | if FLAGS.dataset == "cifar10":
221 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 32, 32, 3))
222 | elif FLAGS.dataset == "gauss":
223 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, FLAGS.gauss_dim))
224 | elif FLAGS.dataset == "mnist":
225 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 28, 28))
226 | else:
227 | x_curr = np.random.uniform(0, FLAGS.rescale, size=(FLAGS.batch_size, 2))
228 |
229 | alpha_prev = (j-1) / FLAGS.pdist
230 | alpha_new = j / FLAGS.pdist
231 | cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
232 | tot_weight = tot_weight + cweight
233 |
234 | print("Total values of lower value based off forward sampling", np.mean(tot_weight), np.std(tot_weight))
235 |
236 | tot_weight = 0
237 |
238 | for j in tqdm(range(FLAGS.pdist, 0, -1)):
239 | alpha_new = (j-1) / FLAGS.pdist
240 | alpha_prev = j / FLAGS.pdist
241 | cweight, x_curr = sess.run([chain_weights, x], {a_prev: alpha_prev, a_new: alpha_new, x_init: x_curr, approx_lr: alr * (5 ** (2.5*-alpha_prev))})
242 | tot_weight = tot_weight - cweight
243 |
244 | print("Total values of upper value based off backward sampling", np.mean(tot_weight), np.std(tot_weight))
245 |
246 |
247 |
248 | if __name__ == "__main__":
249 | main()
250 |
--------------------------------------------------------------------------------
/EBMs/custom_adam.py:
--------------------------------------------------------------------------------
1 | # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Adam for TensorFlow."""
17 | from __future__ import absolute_import
18 | from __future__ import division
19 | from __future__ import print_function
20 |
21 | from tensorflow.python.eager import context
22 | from tensorflow.python.framework import ops
23 | from tensorflow.python.ops import control_flow_ops
24 | from tensorflow.python.ops import math_ops
25 | from tensorflow.python.ops import resource_variable_ops
26 | from tensorflow.python.ops import state_ops
27 | from tensorflow.python.training import optimizer
28 | from tensorflow.python.training import training_ops
29 | from tensorflow.python.util.tf_export import tf_export
30 | import tensorflow as tf
31 |
32 |
33 | @tf_export("train.AdamOptimizer")
34 | class AdamOptimizer(optimizer.Optimizer):
35 | """Optimizer that implements the Adam algorithm.
36 |
37 | See [Kingma et al., 2014](http://arxiv.org/abs/1412.6980)
38 | ([pdf](http://arxiv.org/pdf/1412.6980.pdf)).
39 | """
40 |
41 | def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-8,
42 | use_locking=False, name="Adam"):
43 | """Construct a new Adam optimizer.
44 |
45 | Initialization:
46 |
47 | $$m_0 := 0 \text{(Initialize initial 1st moment vector)}$$
48 | $$v_0 := 0 \text{(Initialize initial 2nd moment vector)}$$
49 | $$t := 0 \text{(Initialize timestep)}$$
50 |
51 | The update rule for `variable` with gradient `g` uses an optimization
52 | described at the end of section2 of the paper:
53 |
54 | $$t := t + 1$$
55 | $$lr_t := \text{learning\_rate} * \sqrt{1 - beta_2^t} / (1 - beta_1^t)$$
56 |
57 | $$m_t := beta_1 * m_{t-1} + (1 - beta_1) * g$$
58 | $$v_t := beta_2 * v_{t-1} + (1 - beta_2) * g * g$$
59 | $$variable := variable - lr_t * m_t / (\sqrt{v_t} + \epsilon)$$
60 |
61 | The default value of 1e-8 for epsilon might not be a good default in
62 | general. For example, when training an Inception network on ImageNet a
63 | current good choice is 1.0 or 0.1. Note that since AdamOptimizer uses the
64 | formulation just before Section 2.1 of the Kingma and Ba paper rather than
65 | the formulation in Algorithm 1, the "epsilon" referred to here is "epsilon
66 | hat" in the paper.
67 |
68 | The sparse implementation of this algorithm (used when the gradient is an
69 | IndexedSlices object, typically because of `tf.gather` or an embedding
70 | lookup in the forward pass) does apply momentum to variable slices even if
71 | they were not used in the forward pass (meaning they have a gradient equal
72 | to zero). Momentum decay (beta1) is also applied to the entire momentum
73 | accumulator. This means that the sparse behavior is equivalent to the dense
74 | behavior (in contrast to some momentum implementations which ignore momentum
75 | unless a variable slice was actually used).
76 |
77 | Args:
78 | learning_rate: A Tensor or a floating point value. The learning rate.
79 | beta1: A float value or a constant float tensor.
80 | The exponential decay rate for the 1st moment estimates.
81 | beta2: A float value or a constant float tensor.
82 | The exponential decay rate for the 2nd moment estimates.
83 | epsilon: A small constant for numerical stability. This epsilon is
84 | "epsilon hat" in the Kingma and Ba paper (in the formula just before
85 | Section 2.1), not the epsilon in Algorithm 1 of the paper.
86 | use_locking: If True use locks for update operations.
87 | name: Optional name for the operations created when applying gradients.
88 | Defaults to "Adam".
89 |
90 | @compatibility(eager)
91 | When eager execution is enabled, `learning_rate`, `beta1`, `beta2`, and
92 | `epsilon` can each be a callable that takes no arguments and returns the
93 | actual value to use. This can be useful for changing these values across
94 | different invocations of optimizer functions.
95 | @end_compatibility
96 | """
97 | super(AdamOptimizer, self).__init__(use_locking, name)
98 | self._lr = learning_rate
99 | self._beta1 = beta1
100 | self._beta2 = beta2
101 | self._epsilon = epsilon
102 |
103 | # Tensor versions of the constructor arguments, created in _prepare().
104 | self._lr_t = None
105 | self._beta1_t = None
106 | self._beta2_t = None
107 | self._epsilon_t = None
108 |
109 | # Created in SparseApply if needed.
110 | self._updated_lr = None
111 |
112 | def _get_beta_accumulators(self):
113 | with ops.init_scope():
114 | if context.executing_eagerly():
115 | graph = None
116 | else:
117 | graph = ops.get_default_graph()
118 | return (self._get_non_slot_variable("beta1_power", graph=graph),
119 | self._get_non_slot_variable("beta2_power", graph=graph))
120 |
121 | def _create_slots(self, var_list):
122 | # Create the beta1 and beta2 accumulators on the same device as the first
123 | # variable. Sort the var_list to make sure this device is consistent across
124 | # workers (these need to go on the same PS, otherwise some updates are
125 | # silently ignored).
126 | first_var = min(var_list, key=lambda x: x.name)
127 | self._create_non_slot_variable(initial_value=self._beta1,
128 | name="beta1_power",
129 | colocate_with=first_var)
130 | self._create_non_slot_variable(initial_value=self._beta2,
131 | name="beta2_power",
132 | colocate_with=first_var)
133 |
134 | # Create slots for the first and second moments.
135 | for v in var_list:
136 | self._zeros_slot(v, "m", self._name)
137 | self._zeros_slot(v, "v", self._name)
138 |
139 | def _prepare(self):
140 | lr = self._call_if_callable(self._lr)
141 | beta1 = self._call_if_callable(self._beta1)
142 | beta2 = self._call_if_callable(self._beta2)
143 | epsilon = self._call_if_callable(self._epsilon)
144 |
145 | self._lr_t = ops.convert_to_tensor(lr, name="learning_rate")
146 | self._beta1_t = ops.convert_to_tensor(beta1, name="beta1")
147 | self._beta2_t = ops.convert_to_tensor(beta2, name="beta2")
148 | self._epsilon_t = ops.convert_to_tensor(epsilon, name="epsilon")
149 |
150 | def _apply_dense(self, grad, var):
151 | m = self.get_slot(var, "m")
152 | v = self.get_slot(var, "v")
153 | beta1_power, beta2_power = self._get_beta_accumulators()
154 |
155 | clip_bounds = 3 * tf.sqrt(v / (1 - beta2_power)) + 0.1
156 | grad = tf.clip_by_value(grad, -clip_bounds, clip_bounds)
157 | # Clip gradients by 3 std
158 | return training_ops.apply_adam(
159 | var, m, v,
160 | math_ops.cast(beta1_power, var.dtype.base_dtype),
161 | math_ops.cast(beta2_power, var.dtype.base_dtype),
162 | math_ops.cast(self._lr_t, var.dtype.base_dtype),
163 | math_ops.cast(self._beta1_t, var.dtype.base_dtype),
164 | math_ops.cast(self._beta2_t, var.dtype.base_dtype),
165 | math_ops.cast(self._epsilon_t, var.dtype.base_dtype),
166 | grad, use_locking=self._use_locking).op
167 |
168 | def _resource_apply_dense(self, grad, var):
169 | m = self.get_slot(var, "m")
170 | v = self.get_slot(var, "v")
171 | beta1_power, beta2_power = self._get_beta_accumulators()
172 | return training_ops.resource_apply_adam(
173 | var.handle, m.handle, v.handle,
174 | math_ops.cast(beta1_power, grad.dtype.base_dtype),
175 | math_ops.cast(beta2_power, grad.dtype.base_dtype),
176 | math_ops.cast(self._lr_t, grad.dtype.base_dtype),
177 | math_ops.cast(self._beta1_t, grad.dtype.base_dtype),
178 | math_ops.cast(self._beta2_t, grad.dtype.base_dtype),
179 | math_ops.cast(self._epsilon_t, grad.dtype.base_dtype),
180 | grad, use_locking=self._use_locking)
181 |
182 | def _apply_sparse_shared(self, grad, var, indices, scatter_add):
183 | beta1_power, beta2_power = self._get_beta_accumulators()
184 | beta1_power = math_ops.cast(beta1_power, var.dtype.base_dtype)
185 | beta2_power = math_ops.cast(beta2_power, var.dtype.base_dtype)
186 | lr_t = math_ops.cast(self._lr_t, var.dtype.base_dtype)
187 | beta1_t = math_ops.cast(self._beta1_t, var.dtype.base_dtype)
188 | beta2_t = math_ops.cast(self._beta2_t, var.dtype.base_dtype)
189 | epsilon_t = math_ops.cast(self._epsilon_t, var.dtype.base_dtype)
190 | lr = (lr_t * math_ops.sqrt(1 - beta2_power) / (1 - beta1_power))
191 | # m_t = beta1 * m + (1 - beta1) * g_t
192 | m = self.get_slot(var, "m")
193 | m_scaled_g_values = grad * (1 - beta1_t)
194 | m_t = state_ops.assign(m, m * beta1_t,
195 | use_locking=self._use_locking)
196 | with ops.control_dependencies([m_t]):
197 | m_t = scatter_add(m, indices, m_scaled_g_values)
198 | # v_t = beta2 * v + (1 - beta2) * (g_t * g_t)
199 | v = self.get_slot(var, "v")
200 | v_scaled_g_values = (grad * grad) * (1 - beta2_t)
201 | v_t = state_ops.assign(v, v * beta2_t, use_locking=self._use_locking)
202 | with ops.control_dependencies([v_t]):
203 | v_t = scatter_add(v, indices, v_scaled_g_values)
204 | v_sqrt = math_ops.sqrt(v_t)
205 | var_update = state_ops.assign_sub(var,
206 | lr * m_t / (v_sqrt + epsilon_t),
207 | use_locking=self._use_locking)
208 | return control_flow_ops.group(*[var_update, m_t, v_t])
209 |
210 | def _apply_sparse(self, grad, var):
211 | return self._apply_sparse_shared(
212 | grad.values, var, grad.indices,
213 | lambda x, i, v: state_ops.scatter_add( # pylint: disable=g-long-lambda
214 | x, i, v, use_locking=self._use_locking))
215 |
216 | def _resource_scatter_add(self, x, i, v):
217 | with ops.control_dependencies(
218 | [resource_variable_ops.resource_scatter_add(
219 | x.handle, i, v)]):
220 | return x.value()
221 |
222 | def _resource_apply_sparse(self, grad, var, indices):
223 | return self._apply_sparse_shared(
224 | grad, var, indices, self._resource_scatter_add)
225 |
226 | def _finish(self, update_ops, name_scope):
227 | # Update the power accumulators.
228 | with ops.control_dependencies(update_ops):
229 | beta1_power, beta2_power = self._get_beta_accumulators()
230 | with ops.colocate_with(beta1_power):
231 | update_beta1 = beta1_power.assign(
232 | beta1_power * self._beta1_t, use_locking=self._use_locking)
233 | update_beta2 = beta2_power.assign(
234 | beta2_power * self._beta2_t, use_locking=self._use_locking)
235 | return control_flow_ops.group(*update_ops + [update_beta1, update_beta2],
236 | name=name_scope)
237 |
--------------------------------------------------------------------------------
/EBMs/data.py:
--------------------------------------------------------------------------------
1 | from tensorflow.python.platform import flags
2 | from tensorflow.contrib.data.python.ops import batching, threadpool
3 | import tensorflow as tf
4 | import json
5 | from torch.utils.data import Dataset
6 | import pickle
7 | import os.path as osp
8 | import os
9 | import numpy as np
10 | import time
11 | from scipy.misc import imread, imresize
12 | from skimage.color import rgb2grey
13 | from torchvision.datasets import CIFAR10, MNIST, SVHN, CIFAR100, ImageFolder
14 | from torchvision import transforms
15 | from imagenet_preprocessing import ImagenetPreprocessor
16 | import torch
17 | import torchvision
18 |
19 | FLAGS = flags.FLAGS
20 | ROOT_DIR = "./results"
21 |
22 | # Dataset Options
23 | flags.DEFINE_string('dsprites_path',
24 | '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz',
25 | 'path to dsprites characters')
26 | flags.DEFINE_string('imagenet_datadir', '/root/imagenet_big', 'whether cutoff should always in image')
27 | flags.DEFINE_bool('dshape_only', False, 'fix all factors except for shapes')
28 | flags.DEFINE_bool('dpos_only', False, 'fix all factors except for positions of shapes')
29 | flags.DEFINE_bool('dsize_only', False,'fix all factors except for size of objects')
30 | flags.DEFINE_bool('drot_only', False, 'fix all factors except for rotation of objects')
31 | flags.DEFINE_bool('dsprites_restrict', False, 'fix all factors except for rotation of objects')
32 | flags.DEFINE_string('imagenet_path', '/root/imagenet', 'path to imagenet images')
33 |
34 |
35 | # Data augmentation options
36 | flags.DEFINE_bool('cutout_inside', False,'whether cutoff should always in image')
37 | flags.DEFINE_float('cutout_prob', 1.0, 'probability of using cutout')
38 | flags.DEFINE_integer('cutout_mask_size', 16, 'size of cutout')
39 | flags.DEFINE_bool('cutout', False,'whether to add cutout regularizer to data')
40 |
41 |
42 | def cutout(mask_color=(0, 0, 0)):
43 | mask_size_half = FLAGS.cutout_mask_size // 2
44 | offset = 1 if FLAGS.cutout_mask_size % 2 == 0 else 0
45 |
46 | def _cutout(image):
47 | image = np.asarray(image).copy()
48 |
49 | if np.random.random() > FLAGS.cutout_prob:
50 | return image
51 |
52 | h, w = image.shape[:2]
53 |
54 | if FLAGS.cutout_inside:
55 | cxmin, cxmax = mask_size_half, w + offset - mask_size_half
56 | cymin, cymax = mask_size_half, h + offset - mask_size_half
57 | else:
58 | cxmin, cxmax = 0, w + offset
59 | cymin, cymax = 0, h + offset
60 |
61 | cx = np.random.randint(cxmin, cxmax)
62 | cy = np.random.randint(cymin, cymax)
63 | xmin = cx - mask_size_half
64 | ymin = cy - mask_size_half
65 | xmax = xmin + FLAGS.cutout_mask_size
66 | ymax = ymin + FLAGS.cutout_mask_size
67 | xmin = max(0, xmin)
68 | ymin = max(0, ymin)
69 | xmax = min(w, xmax)
70 | ymax = min(h, ymax)
71 | image[:, ymin:ymax, xmin:xmax] = np.array(mask_color)[:, None, None]
72 | return image
73 |
74 | return _cutout
75 |
76 |
77 | class TFImagenetLoader(Dataset):
78 |
79 | def __init__(self, split, batchsize, idx, num_workers, rescale=1):
80 | IMAGENET_NUM_TRAIN_IMAGES = 1281167
81 | IMAGENET_NUM_VAL_IMAGES = 50000
82 |
83 | self.rescale = rescale
84 |
85 | if split == "train":
86 | im_length = IMAGENET_NUM_TRAIN_IMAGES
87 | records_to_skip = im_length * idx // num_workers
88 | records_to_read = im_length * (idx + 1) // num_workers - records_to_skip
89 | else:
90 | im_length = IMAGENET_NUM_VAL_IMAGES
91 |
92 | self.curr_sample = 0
93 |
94 | index_path = osp.join(FLAGS.imagenet_datadir, 'index.json')
95 | with open(index_path) as f:
96 | metadata = json.load(f)
97 | counts = metadata['record_counts']
98 |
99 | if split == 'train':
100 | file_names = list(sorted([x for x in counts.keys() if x.startswith('train')]))
101 |
102 | result_records_to_skip = None
103 | files = []
104 | for filename in file_names:
105 | records_in_file = counts[filename]
106 | if records_to_skip >= records_in_file:
107 | records_to_skip -= records_in_file
108 | continue
109 | elif records_to_read > 0:
110 | if result_records_to_skip is None:
111 | # Record the number to skip in the first file
112 | result_records_to_skip = records_to_skip
113 | files.append(filename)
114 | records_to_read -= (records_in_file - records_to_skip)
115 | records_to_skip = 0
116 | else:
117 | break
118 | else:
119 | files = list(sorted([x for x in counts.keys() if x.startswith('validation')]))
120 |
121 | files = [osp.join(FLAGS.imagenet_datadir, x) for x in files]
122 | preprocess_function = ImagenetPreprocessor(128, dtype=tf.float32, train=False).parse_and_preprocess
123 |
124 | ds = tf.data.TFRecordDataset.from_generator(lambda: files, output_types=tf.string)
125 | ds = ds.apply(tf.data.TFRecordDataset)
126 | ds = ds.take(im_length)
127 | ds = ds.prefetch(buffer_size=FLAGS.batch_size)
128 | ds = ds.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=10000))
129 | ds = ds.apply(batching.map_and_batch(map_func=preprocess_function, batch_size=FLAGS.batch_size, num_parallel_batches=4))
130 | ds = ds.prefetch(buffer_size=2)
131 |
132 | ds_iterator = ds.make_initializable_iterator()
133 | labels, images = ds_iterator.get_next()
134 | self.images = tf.clip_by_value(images / 256 + tf.random_uniform(tf.shape(images), 0, 1. / 256), 0.0, 1.0)
135 | self.labels = labels
136 |
137 | config = tf.ConfigProto(device_count = {'GPU': 0})
138 | sess = tf.Session(config=config)
139 | sess.run(ds_iterator.initializer)
140 |
141 | self.im_length = im_length // batchsize
142 |
143 | self.sess = sess
144 |
145 | def __next__(self):
146 | self.curr_sample += 1
147 |
148 | sess = self.sess
149 |
150 | im_corrupt = np.random.uniform(0, self.rescale, size=(FLAGS.batch_size, 128, 128, 3))
151 | label, im = sess.run([self.labels, self.images])
152 | im = im * self.rescale
153 | label = np.eye(1000)[label.squeeze() - 1]
154 | im, im_corrupt, label = torch.from_numpy(im), torch.from_numpy(im_corrupt), torch.from_numpy(label)
155 | return im_corrupt, im, label
156 |
157 | def __iter__(self):
158 | return self
159 |
160 | def __len__(self):
161 | return self.im_length
162 |
163 | class CelebA(Dataset):
164 |
165 | def __init__(self):
166 | self.path = "/root/data/img_align_celeba"
167 | self.ims = os.listdir(self.path)
168 | self.ims = [osp.join(self.path, im) for im in self.ims]
169 |
170 | def __len__(self):
171 | return len(self.ims)
172 |
173 | def __getitem__(self, index):
174 | label = 1
175 |
176 | if FLAGS.single:
177 | index = 0
178 |
179 | path = self.ims[index]
180 | im = imread(path)
181 | im = imresize(im, (32, 32))
182 | image_size = 32
183 | im = im / 255.
184 |
185 | if FLAGS.datasource == 'default':
186 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
187 | elif FLAGS.datasource == 'random':
188 | im_corrupt = np.random.uniform(
189 | 0, 1, size=(image_size, image_size, 3))
190 |
191 | return im_corrupt, im, label
192 |
193 |
194 | class Cifar10(Dataset):
195 | def __init__(
196 | self,
197 | train=True,
198 | full=False,
199 | augment=False,
200 | noise=True,
201 | rescale=1.0):
202 |
203 | if augment:
204 | transform_list = [
205 | torchvision.transforms.RandomCrop(32, padding=4),
206 | torchvision.transforms.RandomHorizontalFlip(),
207 | torchvision.transforms.ToTensor(),
208 | ]
209 |
210 | if FLAGS.cutout:
211 | transform_list.append(cutout())
212 |
213 | transform = transforms.Compose(transform_list)
214 | else:
215 | transform = transforms.ToTensor()
216 |
217 | self.full = full
218 | self.data = CIFAR10(
219 | ROOT_DIR,
220 | transform=transform,
221 | train=train,
222 | download=True)
223 | self.test_data = CIFAR10(
224 | ROOT_DIR,
225 | transform=transform,
226 | train=False,
227 | download=True)
228 | self.one_hot_map = np.eye(10)
229 | self.noise = noise
230 | self.rescale = rescale
231 |
232 | def __len__(self):
233 |
234 | if self.full:
235 | return len(self.data) + len(self.test_data)
236 | else:
237 | return len(self.data)
238 |
239 | def __getitem__(self, index):
240 | if not FLAGS.single:
241 | if self.full:
242 | if index >= len(self.data):
243 | im, label = self.test_data[index - len(self.data)]
244 | else:
245 | im, label = self.data[index]
246 | else:
247 | im, label = self.data[index]
248 | else:
249 | im, label = self.data[0]
250 |
251 | im = np.transpose(im, (1, 2, 0)).numpy()
252 | image_size = 32
253 | label = self.one_hot_map[label]
254 |
255 | im = im * 255 / 256
256 |
257 | if self.noise:
258 | im = im * self.rescale + \
259 | np.random.uniform(0, self.rescale * 1 / 256., im.shape)
260 |
261 | np.random.seed((index + int(time.time() * 1e7)) % 2**32)
262 |
263 | if FLAGS.datasource == 'default':
264 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
265 | elif FLAGS.datasource == 'random':
266 | im_corrupt = np.random.uniform(
267 | 0.0, self.rescale, (image_size, image_size, 3))
268 |
269 | return im_corrupt, im, label
270 |
271 |
272 | class Cifar100(Dataset):
273 | def __init__(self, train=True, augment=False):
274 |
275 | if augment:
276 | transform_list = [
277 | torchvision.transforms.RandomCrop(32, padding=4),
278 | torchvision.transforms.RandomHorizontalFlip(),
279 | torchvision.transforms.ToTensor(),
280 | ]
281 |
282 | if FLAGS.cutout:
283 | transform_list.append(cutout())
284 |
285 | transform = transforms.Compose(transform_list)
286 | else:
287 | transform = transforms.ToTensor()
288 |
289 | self.data = CIFAR100(
290 | "/root/cifar100",
291 | transform=transform,
292 | train=train,
293 | download=True)
294 | self.one_hot_map = np.eye(100)
295 |
296 | def __len__(self):
297 | return len(self.data)
298 |
299 | def __getitem__(self, index):
300 | if not FLAGS.single:
301 | im, label = self.data[index]
302 | else:
303 | im, label = self.data[0]
304 |
305 | im = np.transpose(im, (1, 2, 0)).numpy()
306 | image_size = 32
307 | label = self.one_hot_map[label]
308 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
309 | np.random.seed((index + int(time.time() * 1e7)) % 2**32)
310 |
311 | if FLAGS.datasource == 'default':
312 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
313 | elif FLAGS.datasource == 'random':
314 | im_corrupt = np.random.uniform(
315 | 0.0, 1.0, (image_size, image_size, 3))
316 |
317 | return im_corrupt, im, label
318 |
319 |
320 | class Svhn(Dataset):
321 | def __init__(self, train=True, augment=False):
322 |
323 | transform = transforms.ToTensor()
324 |
325 | self.data = SVHN("/root/svhn", transform=transform, download=True)
326 | self.one_hot_map = np.eye(10)
327 |
328 | def __len__(self):
329 | return len(self.data)
330 |
331 | def __getitem__(self, index):
332 | if not FLAGS.single:
333 | im, label = self.data[index]
334 | else:
335 | em, label = self.data[0]
336 |
337 | im = np.transpose(im, (1, 2, 0)).numpy()
338 | image_size = 32
339 | label = self.one_hot_map[label]
340 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
341 | np.random.seed((index + int(time.time() * 1e7)) % 2**32)
342 |
343 | if FLAGS.datasource == 'default':
344 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
345 | elif FLAGS.datasource == 'random':
346 | im_corrupt = np.random.uniform(
347 | 0.0, 1.0, (image_size, image_size, 3))
348 |
349 | return im_corrupt, im, label
350 |
351 |
352 | class Mnist(Dataset):
353 | def __init__(self, train=True, rescale=1.0):
354 | self.data = MNIST(
355 | "/root/mnist",
356 | transform=transforms.ToTensor(),
357 | download=True, train=train)
358 | self.labels = np.eye(10)
359 | self.rescale = rescale
360 |
361 | def __len__(self):
362 | return len(self.data)
363 |
364 | def __getitem__(self, index):
365 | im, label = self.data[index]
366 | label = self.labels[label]
367 | im = im.squeeze()
368 | # im = im.numpy() / 2 + np.random.uniform(0, 0.5, (28, 28))
369 | # im = im.numpy() / 2 + 0.2
370 | im = im.numpy() / 256 * 255 + np.random.uniform(0, 1. / 256, (28, 28))
371 | im = im * self.rescale
372 | image_size = 28
373 |
374 | if FLAGS.datasource == 'default':
375 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
376 | elif FLAGS.datasource == 'random':
377 | im_corrupt = np.random.uniform(0, self.rescale, (28, 28))
378 |
379 | return im_corrupt, im, label
380 |
381 |
382 | class DSprites(Dataset):
383 | def __init__(
384 | self,
385 | cond_size=False,
386 | cond_shape=False,
387 | cond_pos=False,
388 | cond_rot=False):
389 | dat = np.load(FLAGS.dsprites_path)
390 |
391 | if FLAGS.dshape_only:
392 | l = dat['latents_values']
393 | mask = (l[:, 4] == 16 / 31) & (l[:, 5] == 16 /
394 | 31) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
395 | self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
396 | self.label = np.tile(dat['latents_values'][mask], (10000, 1))
397 | self.label = self.label[:, 1:2]
398 | elif FLAGS.dpos_only:
399 | l = dat['latents_values']
400 | # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
401 | mask = (l[:, 1] == 1) & (
402 | l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
403 | self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
404 | self.label = np.tile(dat['latents_values'][mask], (100, 1))
405 | self.label = self.label[:, 4:] + 0.5
406 | elif FLAGS.dsize_only:
407 | l = dat['latents_values']
408 | # mask = (l[:, 1] == 1) & (l[:, 2] == 0.5) & (l[:, 3] == 30 * np.pi / 39)
409 | mask = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16 /
410 | 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
411 | self.data = np.tile(dat['imgs'][mask], (10000, 1, 1))
412 | self.label = np.tile(dat['latents_values'][mask], (10000, 1))
413 | self.label = (self.label[:, 2:3])
414 | elif FLAGS.drot_only:
415 | l = dat['latents_values']
416 | mask = (l[:, 2] == 0.5) & (l[:, 4] == 16 /
417 | 31) & (l[:, 5] == 16 / 31) & (l[:, 1] == 1)
418 | self.data = np.tile(dat['imgs'][mask], (100, 1, 1))
419 | self.label = np.tile(dat['latents_values'][mask], (100, 1))
420 | self.label = (self.label[:, 3:4])
421 | self.label = np.concatenate(
422 | [np.cos(self.label), np.sin(self.label)], axis=1)
423 | elif FLAGS.dsprites_restrict:
424 | l = dat['latents_values']
425 | mask = (l[:, 1] == 1) & (l[:, 3] == 0 * np.pi / 39)
426 |
427 | self.data = dat['imgs'][mask]
428 | self.label = dat['latents_values'][mask]
429 | else:
430 | self.data = dat['imgs']
431 | self.label = dat['latents_values']
432 |
433 | if cond_size:
434 | self.label = self.label[:, 2:3]
435 | elif cond_shape:
436 | self.label = self.label[:, 1:2]
437 | elif cond_pos:
438 | self.label = self.label[:, 4:]
439 | elif cond_rot:
440 | self.label = self.label[:, 3:4]
441 | self.label = np.concatenate(
442 | [np.cos(self.label), np.sin(self.label)], axis=1)
443 | else:
444 | self.label = self.label[:, 1:2]
445 |
446 | self.identity = np.eye(3)
447 |
448 | def __len__(self):
449 | return self.data.shape[0]
450 |
451 | def __getitem__(self, index):
452 | im = self.data[index]
453 | image_size = 64
454 |
455 | if not (
456 | FLAGS.dpos_only or FLAGS.dsize_only) and (
457 | not FLAGS.cond_size) and (
458 | not FLAGS.cond_pos) and (
459 | not FLAGS.cond_rot) and (
460 | not FLAGS.drot_only):
461 | label = self.identity[self.label[index].astype(
462 | np.int32) - 1].squeeze()
463 | else:
464 | label = self.label[index]
465 |
466 | if FLAGS.datasource == 'default':
467 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size)
468 | elif FLAGS.datasource == 'random':
469 | im_corrupt = 0.5 + 0.5 * np.random.randn(image_size, image_size)
470 |
471 | return im_corrupt, im, label
472 |
473 |
474 | class Imagenet(Dataset):
475 | def __init__(self, train=True, augment=False):
476 |
477 | if train:
478 | for i in range(1, 11):
479 | f = pickle.load(
480 | open(
481 | osp.join(
482 | FLAGS.imagenet_path,
483 | 'train_data_batch_{}'.format(i)),
484 | 'rb'))
485 | if i == 1:
486 | labels = f['labels']
487 | data = f['data']
488 | else:
489 | labels.extend(f['labels'])
490 | data = np.vstack((data, f['data']))
491 | else:
492 | f = pickle.load(
493 | open(
494 | osp.join(
495 | FLAGS.imagenet_path,
496 | 'val_data'),
497 | 'rb'))
498 | labels = f['labels']
499 | data = f['data']
500 |
501 | self.labels = labels
502 | self.data = data
503 | self.one_hot_map = np.eye(1000)
504 |
505 | def __len__(self):
506 | return self.data.shape[0]
507 |
508 | def __getitem__(self, index):
509 | if not FLAGS.single:
510 | im, label = self.data[index], self.labels[index]
511 | else:
512 | im, label = self.data[0], self.labels[0]
513 |
514 | label -= 1
515 |
516 | im = im.reshape((3, 32, 32)) / 255
517 | im = im.transpose((1, 2, 0))
518 | image_size = 32
519 | label = self.one_hot_map[label]
520 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
521 | np.random.seed((index + int(time.time() * 1e7)) % 2**32)
522 |
523 | if FLAGS.datasource == 'default':
524 | im_corrupt = im + 0.3 * np.random.randn(image_size, image_size, 3)
525 | elif FLAGS.datasource == 'random':
526 | im_corrupt = np.random.uniform(
527 | 0.0, 1.0, (image_size, image_size, 3))
528 |
529 | return im_corrupt, im, label
530 |
531 |
532 | class Textures(Dataset):
533 | def __init__(self, train=True, augment=False):
534 | self.dataset = ImageFolder("/mnt/nfs/yilundu/data/dtd/images")
535 |
536 | def __len__(self):
537 | return 2 * len(self.dataset)
538 |
539 | def __getitem__(self, index):
540 | idx = index % (len(self.dataset))
541 | im, label = self.dataset[idx]
542 |
543 | im = np.array(im)[:32, :32] / 255
544 | im = im + np.random.uniform(-1 / 512, 1 / 512, im.shape)
545 |
546 | return im, im, label
547 |
--------------------------------------------------------------------------------
/EBMs/ebm_combine.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import math
3 | from tqdm import tqdm
4 | from hmc import hmc
5 | from tensorflow.python.platform import flags
6 | from torch.utils.data import DataLoader, Dataset
7 | from models import DspritesNet
8 | from utils import optimistic_restore, ReplayBuffer
9 | import os.path as osp
10 | import numpy as np
11 | from rl_algs.logger import TensorBoardOutputFormat
12 | from scipy.misc import imsave
13 | import os
14 | from custom_adam import AdamOptimizer
15 |
16 | flags.DEFINE_integer('batch_size', 256, 'Size of inputs')
17 | flags.DEFINE_integer('data_workers', 4, 'Number of workers to do things')
18 | flags.DEFINE_string('logdir', 'cachedir', 'directory for logging')
19 | flags.DEFINE_string('savedir', 'cachedir', 'location where log of experiments will be stored')
20 | flags.DEFINE_integer('num_filters', 64, 'number of filters for conv nets -- 32 for miniimagenet, 64 for omniglot.')
21 | flags.DEFINE_float('step_lr', 500, 'size of gradient descent size')
22 | flags.DEFINE_string('dsprites_path', '/root/data/dsprites-dataset/dsprites_ndarray_co1sh3sc6or40x32y32_64x64.npz', 'path to dsprites characters')
23 | flags.DEFINE_bool('cclass', True, 'not cclass')
24 | flags.DEFINE_bool('proj_cclass', False, 'use for backwards compatibility reasons')
25 | flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
26 | flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
27 | flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
28 | flags.DEFINE_bool('plot_curve', False, 'Generate a curve of results')
29 | flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
30 | flags.DEFINE_string('task', 'conceptcombine', 'conceptcombine, labeldiscover, gentest, genbaseline, etc.')
31 | flags.DEFINE_bool('joint_shape', False, 'whether to use pos_size or pos_shape')
32 | flags.DEFINE_bool('joint_rot', False, 'whether to use pos_size or pos_shape')
33 |
34 | # Conditions on which models to use
35 | flags.DEFINE_bool('cond_pos', True, 'whether to condition on position')
36 | flags.DEFINE_bool('cond_rot', True, 'whether to condition on rotation')
37 | flags.DEFINE_bool('cond_shape', True, 'whether to condition on shape')
38 | flags.DEFINE_bool('cond_scale', True, 'whether to condition on scale')
39 |
40 | flags.DEFINE_string('exp_size', 'dsprites_2018_cond_size', 'name of experiments')
41 | flags.DEFINE_string('exp_shape', 'dsprites_2018_cond_shape', 'name of experiments')
42 | flags.DEFINE_string('exp_pos', 'dsprites_2018_cond_pos_cert', 'name of experiments')
43 | flags.DEFINE_string('exp_rot', 'dsprites_cond_rot_119_00', 'name of experiments')
44 | flags.DEFINE_integer('resume_size', 169000, 'First iteration to resume')
45 | flags.DEFINE_integer('resume_shape', 477000, 'Second iteration to resume')
46 | flags.DEFINE_integer('resume_pos', 8000, 'Second iteration to resume')
47 | flags.DEFINE_integer('resume_rot', 690000, 'Second iteration to resume')
48 | flags.DEFINE_integer('break_steps', 300, 'steps to break')
49 |
50 | # Whether to train for gentest
51 | flags.DEFINE_bool('train', False, 'whether to train on generalization into multiple different predictions')
52 |
53 | FLAGS = flags.FLAGS
54 |
55 | class DSpritesGen(Dataset):
56 | def __init__(self, data, latents, frac=0.0):
57 |
58 | l = latents
59 |
60 | if FLAGS.joint_shape:
61 | mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
62 | elif FLAGS.joint_rot:
63 | mask_size = (l[:, 1] == 1) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 2] == 0.5)
64 | else:
65 | mask_size = (l[:, 3] == 30 * np.pi / 39) & (l[:, 4] == 16/31) & (l[:, 5] == 16/31) & (l[:, 1] == 1)
66 |
67 | mask_pos = (l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39) & (l[:, 2] == 0.5)
68 |
69 | data_pos = data[mask_pos]
70 | l_pos = l[mask_pos]
71 |
72 | data_size = data[mask_size]
73 | l_size = l[mask_size]
74 |
75 | n = data_pos.shape[0] // data_size.shape[0]
76 |
77 | data_pos = np.tile(data_pos, (n, 1, 1))
78 | l_pos = np.tile(l_pos, (n, 1))
79 |
80 | self.data = np.concatenate((data_pos, data_size), axis=0)
81 | self.label = np.concatenate((l_pos, l_size), axis=0)
82 |
83 | mask_neg = (~(mask_size & mask_pos)) & ((l[:, 1] == 1) & (l[:, 3] == 30 * np.pi / 39))
84 | data_add = data[mask_neg]
85 | l_add = l[mask_neg]
86 |
87 | perm_idx = np.random.permutation(data_add.shape[0])
88 | select_idx = perm_idx[:int(frac*perm_idx.shape[0])]
89 | data_add = data_add[select_idx]
90 | l_add = l_add[select_idx]
91 |
92 | self.data = np.concatenate((self.data, data_add), axis=0)
93 | self.label = np.concatenate((self.label, l_add), axis=0)
94 |
95 | self.identity = np.eye(3)
96 |
97 | def __len__(self):
98 | return self.data.shape[0]
99 |
100 | def __getitem__(self, index):
101 | im = self.data[index]
102 | im_corrupt = 0.5 + 0.5 * np.random.randn(64, 64)
103 |
104 | if FLAGS.joint_shape:
105 | label_size = np.eye(3)[self.label[index, 1].astype(np.int32) - 1]
106 | elif FLAGS.joint_rot:
107 | label_size = np.array([np.cos(self.label[index, 3]), np.sin(self.label[index, 3])])
108 | else:
109 | label_size = self.label[index, 2:3]
110 |
111 | label_pos = self.label[index, 4:]
112 |
113 | return (im_corrupt, im, label_size, label_pos)
114 |
115 |
116 | def labeldiscover(sess, kvs, data, latents, save_exp_dir):
117 | LABEL_SIZE = kvs['LABEL_SIZE']
118 | model_size = kvs['model_size']
119 | weight_size = kvs['weight_size']
120 | x_mod = kvs['X_NOISE']
121 |
122 | label_output = LABEL_SIZE
123 | for i in range(FLAGS.num_steps):
124 | label_output = label_output + tf.random_normal(tf.shape(label_output), mean=0.0, stddev=0.03)
125 | e_noise = model_size.forward(x_mod, weight_size, label=label_output)
126 | label_grad = tf.gradients(e_noise, [label_output])[0]
127 | # label_grad = tf.Print(label_grad, [label_grad])
128 | label_output = label_output - 1.0 * label_grad
129 | label_output = tf.clip_by_value(label_output, 0.5, 1.0)
130 |
131 | diffs = []
132 | for i in range(30):
133 | s = i*FLAGS.batch_size
134 | d = (i+1)*FLAGS.batch_size
135 | data_i = data[s:d]
136 | latent_i = latents[s:d]
137 | latent_init = np.random.uniform(0.5, 1, (FLAGS.batch_size, 1))
138 |
139 | feed_dict = {x_mod: data_i, LABEL_SIZE:latent_init}
140 | size_pred = sess.run([label_output], feed_dict)[0]
141 | size_gt = latent_i[:, 2:3]
142 |
143 | diffs.append(np.abs(size_pred - size_gt).mean())
144 |
145 | print(np.array(diffs).mean())
146 |
147 |
148 | def genbaseline(sess, kvs, data, latents, save_exp_dir, frac=0.0):
149 | # tf.reset_default_graph()
150 |
151 | if FLAGS.joint_shape:
152 | model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=5)
153 | LABEL = tf.placeholder(shape=(None, 5), dtype=tf.float32)
154 | else:
155 | model_baseline = DspritesNetGen(num_filters=FLAGS.num_filters, label_size=3)
156 | LABEL = tf.placeholder(shape=(None, 3), dtype=tf.float32)
157 |
158 | weights_baseline = model_baseline.construct_weights('context_baseline_{}'.format(frac))
159 |
160 | X_feed = tf.placeholder(shape=(None, 2*FLAGS.num_filters), dtype=tf.float32)
161 | X_label = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
162 |
163 | X_out = model_baseline.forward(X_feed, weights_baseline, label=LABEL)
164 | loss_sq = tf.reduce_mean(tf.square(X_out - X_label))
165 |
166 | optimizer = AdamOptimizer(1e-3)
167 | gvs = optimizer.compute_gradients(loss_sq)
168 | gvs = [(k, v) for (k, v) in gvs if k is not None]
169 | train_op = optimizer.apply_gradients(gvs)
170 |
171 | dataloader = DataLoader(DSpritesGen(data, latents, frac=frac), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
172 |
173 | datafull = data
174 |
175 | itr = 0
176 | saver = tf.train.Saver()
177 |
178 | vs = optimizer.variables()
179 | sess.run(tf.global_variables_initializer())
180 |
181 | if FLAGS.train:
182 | for _ in range(5):
183 | for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
184 |
185 | data_corrupt = data_corrupt.numpy()
186 | label_size, label_pos = label_size.numpy(), label_pos.numpy()
187 |
188 | data_corrupt = np.random.randn(data_corrupt.shape[0], 2*FLAGS.num_filters)
189 | label_comb = np.concatenate([label_size, label_pos], axis=1)
190 |
191 | feed_dict = {X_feed: data_corrupt, X_label: data, LABEL: label_comb}
192 |
193 | output = [loss_sq, train_op]
194 |
195 | loss, _ = sess.run(output, feed_dict=feed_dict)
196 |
197 | itr += 1
198 |
199 | saver.save(sess, osp.join(save_exp_dir, 'model_genbaseline'))
200 |
201 | saver.restore(sess, osp.join(save_exp_dir, 'model_genbaseline'))
202 |
203 | l = latents
204 |
205 | if FLAGS.joint_shape:
206 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
207 | else:
208 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
209 |
210 | data_gen = datafull[mask_gen]
211 | latents_gen = latents[mask_gen]
212 | losses = []
213 |
214 | for dat, latent in zip(np.array_split(data_gen, 10), np.array_split(latents_gen, 10)):
215 | data_init = np.random.randn(dat.shape[0], 2*FLAGS.num_filters)
216 |
217 | if FLAGS.joint_shape:
218 | latent_size = np.eye(3)[latent[:, 1].astype(np.int32) - 1]
219 | latent_pos = latent[:, 4:6]
220 | latent = np.concatenate([latent_size, latent_pos], axis=1)
221 | feed_dict = {X_feed: data_init, LABEL: latent, X_label: dat}
222 | else:
223 | feed_dict = {X_feed: data_init, LABEL: latent[:, [2,4,5]], X_label: dat}
224 | loss = sess.run([loss_sq], feed_dict=feed_dict)[0]
225 | # print(loss)
226 | losses.append(loss)
227 |
228 | print("Overall MSE for generalization of {} for fraction of {}".format(np.mean(losses), frac))
229 |
230 |
231 | data_try = data_gen[:10]
232 | data_init = np.random.randn(10, 2*FLAGS.num_filters)
233 |
234 | if FLAGS.joint_shape:
235 | latent_scale = np.eye(3)[latent[:10, 1].astype(np.int32) - 1]
236 | latent_pos = latents_gen[:10, 4:]
237 | else:
238 | latent_scale = latents_gen[:10, 2:3]
239 | latent_pos = latents_gen[:10, 4:]
240 |
241 | latent_tot = np.concatenate([latent_scale, latent_pos], axis=1)
242 |
243 | feed_dict = {X_feed: data_init, LABEL: latent_tot}
244 | x_output = sess.run([X_out], feed_dict=feed_dict)[0]
245 | x_output = np.clip(x_output, 0, 1)
246 |
247 | im_name = "size_scale_combine_genbaseline.png"
248 |
249 | x_output_wrap = np.ones((10, 66, 66))
250 | data_try_wrap = np.ones((10, 66, 66))
251 |
252 | x_output_wrap[:, 1:-1, 1:-1] = x_output
253 | data_try_wrap[:, 1:-1, 1:-1] = data_try
254 |
255 | im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
256 | impath = osp.join(save_exp_dir, im_name)
257 | imsave(impath, im_output)
258 | print("Successfully saved images at {}".format(impath))
259 |
260 | return np.mean(losses)
261 |
262 |
263 | def gentest(sess, kvs, data, latents, save_exp_dir):
264 | X_NOISE = kvs['X_NOISE']
265 | LABEL_SIZE = kvs['LABEL_SIZE']
266 | LABEL_SHAPE = kvs['LABEL_SHAPE']
267 | LABEL_POS = kvs['LABEL_POS']
268 | LABEL_ROT = kvs['LABEL_ROT']
269 | model_size = kvs['model_size']
270 | model_shape = kvs['model_shape']
271 | model_pos = kvs['model_pos']
272 | model_rot = kvs['model_rot']
273 | weight_size = kvs['weight_size']
274 | weight_shape = kvs['weight_shape']
275 | weight_pos = kvs['weight_pos']
276 | weight_rot = kvs['weight_rot']
277 | X = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
278 |
279 | datafull = data
280 | # Test combination of generalization where we use slices of both training
281 | x_final = X_NOISE
282 | x_mod_size = X_NOISE
283 | x_mod_pos = X_NOISE
284 |
285 | for i in range(FLAGS.num_steps):
286 |
287 | # use cond_pos
288 |
289 | energies = []
290 | x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
291 | e_noise = model_pos.forward(x_final, weight_pos, label=LABEL_POS)
292 |
293 | # energies.append(e_noise)
294 | x_grad = tf.gradients(e_noise, [x_final])[0]
295 | x_mod_pos = x_mod_pos + tf.random_normal(tf.shape(x_mod_pos), mean=0.0, stddev=0.005)
296 | x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
297 | x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
298 |
299 | if FLAGS.joint_shape:
300 | # use cond_shape
301 | e_noise = model_shape.forward(x_mod_pos, weight_shape, label=LABEL_SHAPE)
302 | elif FLAGS.joint_rot:
303 | e_noise = model_rot.forward(x_mod_pos, weight_rot, label=LABEL_ROT)
304 | else:
305 | # use cond_size
306 | e_noise = model_size.forward(x_mod_pos, weight_size, label=LABEL_SIZE)
307 |
308 | # energies.append(e_noise)
309 | # energy_stack = tf.concat(energies, axis=1)
310 | # energy_stack = tf.reduce_logsumexp(-1*energy_stack, axis=1)
311 | # energy_stack = tf.reduce_sum(energy_stack, axis=1)
312 |
313 | x_grad = tf.gradients(e_noise, [x_mod_pos])[0]
314 | x_mod_pos = x_mod_pos - FLAGS.step_lr * x_grad
315 | x_mod_pos = tf.clip_by_value(x_mod_pos, 0, 1)
316 |
317 | # for x_mod_size
318 | # use cond_size
319 | # e_noise = model_size.forward(x_mod_size, weight_size, label=LABEL_SIZE)
320 | # x_grad = tf.gradients(e_noise, [x_mod_size])[0]
321 | # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
322 | # x_mod_size = x_mod_size - FLAGS.step_lr * x_grad
323 | # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
324 |
325 | # # use cond_pos
326 | # e_noise = model_pos.forward(x_mod_size, weight_pos, label=LABEL_POS)
327 | # x_grad = tf.gradients(e_noise, [x_mod_size])[0]
328 | # x_mod_size = x_mod_size + tf.random_normal(tf.shape(x_mod_size), mean=0.0, stddev=0.005)
329 | # x_mod_size = x_mod_size - FLAGS.step_lr * tf.stop_gradient(x_grad)
330 | # x_mod_size = tf.clip_by_value(x_mod_size, 0, 1)
331 |
332 | x_mod = x_mod_pos
333 | x_final = x_mod
334 |
335 |
336 | if FLAGS.joint_shape:
337 | loss_kl = model_shape.forward(x_final, weight_shape, reuse=True, label=LABEL_SHAPE, stop_grad=True) + \
338 | model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
339 |
340 | energy_pos = model_shape.forward(X, weight_shape, reuse=True, label=LABEL_SHAPE) + \
341 | model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
342 |
343 | energy_neg = model_shape.forward(tf.stop_gradient(x_mod), weight_shape, reuse=True, label=LABEL_SHAPE) + \
344 | model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
345 | elif FLAGS.joint_rot:
346 | loss_kl = model_rot.forward(x_final, weight_rot, reuse=True, label=LABEL_ROT, stop_grad=True) + \
347 | model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
348 |
349 | energy_pos = model_rot.forward(X, weight_rot, reuse=True, label=LABEL_ROT) + \
350 | model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
351 |
352 | energy_neg = model_rot.forward(tf.stop_gradient(x_mod), weight_rot, reuse=True, label=LABEL_ROT) + \
353 | model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
354 | else:
355 | loss_kl = model_size.forward(x_final, weight_size, reuse=True, label=LABEL_SIZE, stop_grad=True) + \
356 | model_pos.forward(x_final, weight_pos, reuse=True, label=LABEL_POS, stop_grad=True)
357 |
358 | energy_pos = model_size.forward(X, weight_size, reuse=True, label=LABEL_SIZE) + \
359 | model_pos.forward(X, weight_pos, reuse=True, label=LABEL_POS)
360 |
361 | energy_neg = model_size.forward(tf.stop_gradient(x_mod), weight_size, reuse=True, label=LABEL_SIZE) + \
362 | model_pos.forward(tf.stop_gradient(x_mod), weight_pos, reuse=True, label=LABEL_POS)
363 |
364 | energy_neg_reduced = (energy_neg - tf.reduce_min(energy_neg))
365 | coeff = tf.stop_gradient(tf.exp(-energy_neg_reduced))
366 | norm_constant = tf.stop_gradient(tf.reduce_sum(coeff)) + 1e-4
367 | neg_loss = coeff * (-1*energy_neg) / norm_constant
368 |
369 | loss_ml = tf.reduce_mean(energy_pos) - tf.reduce_mean(energy_neg)
370 | loss_total = loss_ml + tf.reduce_mean(loss_kl) + 1 * (tf.reduce_mean(tf.square(energy_pos)) + tf.reduce_mean(tf.square(energy_neg)))
371 |
372 | optimizer = AdamOptimizer(1e-3, beta1=0.0, beta2=0.999)
373 | gvs = optimizer.compute_gradients(loss_total)
374 | gvs = [(k, v) for (k, v) in gvs if k is not None]
375 | train_op = optimizer.apply_gradients(gvs)
376 |
377 | vs = optimizer.variables()
378 | sess.run(tf.variables_initializer(vs))
379 |
380 | dataloader = DataLoader(DSpritesGen(data, latents), batch_size=FLAGS.batch_size, num_workers=6, drop_last=True, shuffle=True)
381 |
382 | x_off = tf.reduce_mean(tf.square(x_mod - X))
383 |
384 | itr = 0
385 | saver = tf.train.Saver()
386 | x_mod = None
387 |
388 |
389 | if FLAGS.train:
390 | replay_buffer = ReplayBuffer(10000)
391 | for _ in range(1):
392 |
393 |
394 | for data_corrupt, data, label_size, label_pos in tqdm(dataloader):
395 | data_corrupt = data_corrupt.numpy()[:, :, :]
396 | data = data.numpy()[:, :, :]
397 |
398 | if x_mod is not None:
399 | replay_buffer.add(x_mod)
400 | replay_batch = replay_buffer.sample(FLAGS.batch_size)
401 | replay_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.95)
402 | data_corrupt[replay_mask] = replay_batch[replay_mask]
403 |
404 | if FLAGS.joint_shape:
405 | feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SHAPE: label_size, LABEL_POS: label_pos}
406 | elif FLAGS.joint_rot:
407 | feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_ROT: label_size, LABEL_POS: label_pos}
408 | else:
409 | feed_dict = {X_NOISE: data_corrupt, X: data, LABEL_SIZE: label_size, LABEL_POS: label_pos}
410 |
411 | _, off_value, e_pos, e_neg, x_mod = sess.run([train_op, x_off, energy_pos, energy_neg, x_final], feed_dict=feed_dict)
412 | itr += 1
413 |
414 | if itr % 10 == 0:
415 | print("x_off of {}, e_pos of {}, e_neg of {} itr of {}".format(off_value, e_pos.mean(), e_neg.mean(), itr))
416 |
417 | if itr == FLAGS.break_steps:
418 | break
419 |
420 |
421 | saver.save(sess, osp.join(save_exp_dir, 'model_gentest'))
422 |
423 | saver.restore(sess, osp.join(save_exp_dir, 'model_gentest'))
424 |
425 | l = latents
426 |
427 | if FLAGS.joint_shape:
428 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 2] == 0.5)
429 | elif FLAGS.joint_rot:
430 | mask_gen = (l[:, 1] == 1) * (l[:, 2] == 0.5)
431 | else:
432 | mask_gen = (l[:, 3] == 30 * np.pi / 39) * (l[:, 1] == 1) & (~((l[:, 2] == 0.5) | ((l[:, 4] == 16/31) & (l[:, 5] == 16/31))))
433 |
434 | data_gen = datafull[mask_gen]
435 | latents_gen = latents[mask_gen]
436 |
437 | losses = []
438 |
439 | for dat, latent in zip(np.array_split(data_gen, 120), np.array_split(latents_gen, 120)):
440 | x = 0.5 + np.random.randn(*dat.shape)
441 |
442 | if FLAGS.joint_shape:
443 | feed_dict = {LABEL_SHAPE: np.eye(3)[latent[:, 1].astype(np.int32) - 1], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
444 | elif FLAGS.joint_rot:
445 | feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:, 3:4]), np.sin(latent[:, 3:4])], axis=1), LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
446 | else:
447 | feed_dict = {LABEL_SIZE: latent[:, 2:3], LABEL_POS: latent[:, 4:], X_NOISE: x, X: dat}
448 |
449 | for i in range(2):
450 | x = sess.run([x_final], feed_dict=feed_dict)[0]
451 | feed_dict[X_NOISE] = x
452 |
453 | loss = sess.run([x_off], feed_dict=feed_dict)[0]
454 | losses.append(loss)
455 |
456 | print("Mean MSE loss of {} ".format(np.mean(losses)))
457 |
458 | data_try = data_gen[:10]
459 | data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
460 | latent_scale = latents_gen[:10, 2:3]
461 | latent_pos = latents_gen[:10, 4:]
462 |
463 | if FLAGS.joint_shape:
464 | feed_dict = {X_NOISE: data_init, LABEL_SHAPE: np.eye(3)[latent[:10, 1].astype(np.int32)-1], LABEL_POS: latent_pos}
465 | elif FLAGS.joint_rot:
466 | feed_dict = {LABEL_ROT: np.concatenate([np.cos(latent[:10, 3:4]), np.sin(latent[:10, 3:4])], axis=1), LABEL_POS: latent[:10, 4:], X_NOISE: data_init}
467 | else:
468 | feed_dict = {X_NOISE: data_init, LABEL_SIZE: latent_scale, LABEL_POS: latent_pos}
469 |
470 | x_output = sess.run([x_final], feed_dict=feed_dict)[0]
471 |
472 | if FLAGS.joint_shape:
473 | im_name = "size_shape_combine_gentest.png"
474 | else:
475 | im_name = "size_scale_combine_gentest.png"
476 |
477 | x_output_wrap = np.ones((10, 66, 66))
478 | data_try_wrap = np.ones((10, 66, 66))
479 |
480 | x_output_wrap[:, 1:-1, 1:-1] = x_output
481 | data_try_wrap[:, 1:-1, 1:-1] = data_try
482 |
483 | im_output = np.concatenate([x_output_wrap, data_try_wrap], axis=2).reshape(-1, 66*2)
484 | impath = osp.join(save_exp_dir, im_name)
485 | imsave(impath, im_output)
486 | print("Successfully saved images at {}".format(impath))
487 |
488 |
489 |
490 | def conceptcombine(sess, kvs, data, latents, save_exp_dir):
491 | X_NOISE = kvs['X_NOISE']
492 | LABEL_SIZE = kvs['LABEL_SIZE']
493 | LABEL_SHAPE = kvs['LABEL_SHAPE']
494 | LABEL_POS = kvs['LABEL_POS']
495 | LABEL_ROT = kvs['LABEL_ROT']
496 | model_size = kvs['model_size']
497 | model_shape = kvs['model_shape']
498 | model_pos = kvs['model_pos']
499 | model_rot = kvs['model_rot']
500 | weight_size = kvs['weight_size']
501 | weight_shape = kvs['weight_shape']
502 | weight_pos = kvs['weight_pos']
503 | weight_rot = kvs['weight_rot']
504 |
505 | x_mod = X_NOISE
506 | for i in range(FLAGS.num_steps):
507 |
508 | if FLAGS.cond_scale:
509 | e_noise = model_size.forward(x_mod, weight_size, label=LABEL_SIZE)
510 | x_grad = tf.gradients(e_noise, [x_mod])[0]
511 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
512 | x_mod = x_mod - FLAGS.step_lr * x_grad
513 | x_mod = tf.clip_by_value(x_mod, 0, 1)
514 |
515 | if FLAGS.cond_shape:
516 | e_noise = model_shape.forward(x_mod, weight_shape, label=LABEL_SHAPE)
517 | x_grad = tf.gradients(e_noise, [x_mod])[0]
518 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
519 | x_mod = x_mod - FLAGS.step_lr * x_grad
520 | x_mod = tf.clip_by_value(x_mod, 0, 1)
521 |
522 | if FLAGS.cond_pos:
523 | e_noise = model_pos.forward(x_mod, weight_pos, label=LABEL_POS)
524 | x_grad = tf.gradients(e_noise, [x_mod])[0]
525 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
526 | x_mod = x_mod - FLAGS.step_lr * x_grad
527 | x_mod = tf.clip_by_value(x_mod, 0, 1)
528 |
529 | if FLAGS.cond_rot:
530 | e_noise = model_rot.forward(x_mod, weight_rot, label=LABEL_ROT)
531 | x_grad = tf.gradients(e_noise, [x_mod])[0]
532 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod), mean=0.0, stddev=0.005)
533 | x_mod = x_mod - FLAGS.step_lr * x_grad
534 | x_mod = tf.clip_by_value(x_mod, 0, 1)
535 |
536 | print("Finished constructing loop {}".format(i))
537 |
538 | x_final = x_mod
539 |
540 | data_try = data[:10]
541 | data_init = 0.5 + 0.5 * np.random.randn(10, 64, 64)
542 | label_scale = latents[:10, 2:3]
543 | label_shape = np.eye(3)[(latents[:10, 1]-1).astype(np.uint8)]
544 | label_rot = latents[:10, 3:4]
545 | label_rot = np.concatenate([np.cos(label_rot), np.sin(label_rot)], axis=1)
546 | label_pos = latents[:10, 4:]
547 |
548 | feed_dict = {X_NOISE: data_init, LABEL_SIZE: label_scale, LABEL_SHAPE: label_shape, LABEL_POS: label_pos,
549 | LABEL_ROT: label_rot}
550 | x_out = sess.run([x_final], feed_dict)[0]
551 |
552 | im_name = "im"
553 |
554 | if FLAGS.cond_scale:
555 | im_name += "_condscale"
556 |
557 | if FLAGS.cond_shape:
558 | im_name += "_condshape"
559 |
560 | if FLAGS.cond_pos:
561 | im_name += "_condpos"
562 |
563 | if FLAGS.cond_rot:
564 | im_name += "_condrot"
565 |
566 | im_name += ".png"
567 |
568 | x_out_pad, data_try_pad = np.ones((10, 66, 66)), np.ones((10, 66, 66))
569 | x_out_pad[:, 1:-1, 1:-1] = x_out
570 | data_try_pad[:, 1:-1, 1:-1] = data_try
571 |
572 | im_output = np.concatenate([x_out_pad, data_try_pad], axis=2).reshape(-1, 66*2)
573 | impath = osp.join(save_exp_dir, im_name)
574 | imsave(impath, im_output)
575 | print("Successfully saved images at {}".format(impath))
576 |
577 | def main():
578 | data = np.load(FLAGS.dsprites_path)['imgs']
579 | l = latents = np.load(FLAGS.dsprites_path)['latents_values']
580 |
581 | np.random.seed(1)
582 | idx = np.random.permutation(data.shape[0])
583 |
584 | data = data[idx]
585 | latents = latents[idx]
586 |
587 | config = tf.ConfigProto()
588 | sess = tf.Session(config=config)
589 |
590 | # Model 1 will be conditioned on size
591 | model_size = DspritesNet(num_filters=FLAGS.num_filters, cond_size=True)
592 | weight_size = model_size.construct_weights('context_0')
593 |
594 | # Model 2 will be conditioned on shape
595 | model_shape = DspritesNet(num_filters=FLAGS.num_filters, cond_shape=True)
596 | weight_shape = model_shape.construct_weights('context_1')
597 |
598 | # Model 3 will be conditioned on position
599 | model_pos = DspritesNet(num_filters=FLAGS.num_filters, cond_pos=True)
600 | weight_pos = model_pos.construct_weights('context_2')
601 |
602 | # Model 4 will be conditioned on rotation
603 | model_rot = DspritesNet(num_filters=FLAGS.num_filters, cond_rot=True)
604 | weight_rot = model_rot.construct_weights('context_3')
605 |
606 | sess.run(tf.global_variables_initializer())
607 | save_path_size = osp.join(FLAGS.logdir, FLAGS.exp_size, 'model_{}'.format(FLAGS.resume_size))
608 |
609 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(0))
610 | v_map = {(v.name.replace('context_{}'.format(0), 'context_0')[:-2]): v for v in v_list}
611 |
612 | if FLAGS.cond_scale:
613 | saver = tf.train.Saver(v_map)
614 | saver.restore(sess, save_path_size)
615 |
616 | save_path_shape = osp.join(FLAGS.logdir, FLAGS.exp_shape, 'model_{}'.format(FLAGS.resume_shape))
617 |
618 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(1))
619 | v_map = {(v.name.replace('context_{}'.format(1), 'context_0')[:-2]): v for v in v_list}
620 |
621 | if FLAGS.cond_shape:
622 | saver = tf.train.Saver(v_map)
623 | saver.restore(sess, save_path_shape)
624 |
625 |
626 | save_path_pos = osp.join(FLAGS.logdir, FLAGS.exp_pos, 'model_{}'.format(FLAGS.resume_pos))
627 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(2))
628 | v_map = {(v.name.replace('context_{}'.format(2), 'context_0')[:-2]): v for v in v_list}
629 | saver = tf.train.Saver(v_map)
630 |
631 | if FLAGS.cond_pos:
632 | saver.restore(sess, save_path_pos)
633 |
634 |
635 | save_path_rot = osp.join(FLAGS.logdir, FLAGS.exp_rot, 'model_{}'.format(FLAGS.resume_rot))
636 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(3))
637 | v_map = {(v.name.replace('context_{}'.format(3), 'context_0')[:-2]): v for v in v_list}
638 | saver = tf.train.Saver(v_map)
639 |
640 | if FLAGS.cond_rot:
641 | saver.restore(sess, save_path_rot)
642 |
643 | X_NOISE = tf.placeholder(shape=(None, 64, 64), dtype=tf.float32)
644 | LABEL_SIZE = tf.placeholder(shape=(None, 1), dtype=tf.float32)
645 | LABEL_SHAPE = tf.placeholder(shape=(None, 3), dtype=tf.float32)
646 | LABEL_POS = tf.placeholder(shape=(None, 2), dtype=tf.float32)
647 | LABEL_ROT = tf.placeholder(shape=(None, 2), dtype=tf.float32)
648 |
649 | x_mod = X_NOISE
650 |
651 | kvs = {}
652 | kvs['X_NOISE'] = X_NOISE
653 | kvs['LABEL_SIZE'] = LABEL_SIZE
654 | kvs['LABEL_SHAPE'] = LABEL_SHAPE
655 | kvs['LABEL_POS'] = LABEL_POS
656 | kvs['LABEL_ROT'] = LABEL_ROT
657 | kvs['model_size'] = model_size
658 | kvs['model_shape'] = model_shape
659 | kvs['model_pos'] = model_pos
660 | kvs['model_rot'] = model_rot
661 | kvs['weight_size'] = weight_size
662 | kvs['weight_shape'] = weight_shape
663 | kvs['weight_pos'] = weight_pos
664 | kvs['weight_rot'] = weight_rot
665 |
666 | save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_joint'.format(FLAGS.exp_size, FLAGS.exp_shape))
667 | if not osp.exists(save_exp_dir):
668 | os.makedirs(save_exp_dir)
669 |
670 |
671 | if FLAGS.task == 'conceptcombine':
672 | conceptcombine(sess, kvs, data, latents, save_exp_dir)
673 | elif FLAGS.task == 'labeldiscover':
674 | labeldiscover(sess, kvs, data, latents, save_exp_dir)
675 | elif FLAGS.task == 'gentest':
676 | save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen'.format(FLAGS.exp_size, FLAGS.exp_pos))
677 | if not osp.exists(save_exp_dir):
678 | os.makedirs(save_exp_dir)
679 |
680 | gentest(sess, kvs, data, latents, save_exp_dir)
681 | elif FLAGS.task == 'genbaseline':
682 | save_exp_dir = osp.join(FLAGS.savedir, '{}_{}_gen_baseline'.format(FLAGS.exp_size, FLAGS.exp_pos))
683 | if not osp.exists(save_exp_dir):
684 | os.makedirs(save_exp_dir)
685 |
686 | if FLAGS.plot_curve:
687 | mse_losses = []
688 | for frac in [i/10 for i in range(11)]:
689 | mse_loss = genbaseline(sess, kvs, data, latents, save_exp_dir, frac=frac)
690 | mse_losses.append(mse_loss)
691 | np.save("mse_baseline_comb.npy", mse_losses)
692 | else:
693 | genbaseline(sess, kvs, data, latents, save_exp_dir)
694 |
695 |
696 |
697 | if __name__ == "__main__":
698 | main()
699 |
--------------------------------------------------------------------------------
/EBMs/fid.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python3
2 | ''' Calculates the Frechet Inception Distance (FID) to evalulate GANs.
3 |
4 | The FID metric calculates the distance between two distributions of images.
5 | Typically, we have summary statistics (mean & covariance matrix) of one
6 | of these distributions, while the 2nd distribution is given by a GAN.
7 |
8 | When run as a stand-alone program, it compares the distribution of
9 | images that are stored as PNG/JPEG at a specified location with a
10 | distribution given by summary statistics (in pickle format).
11 |
12 | The FID is calculated by assuming that X_1 and X_2 are the activations of
13 | the pool_3 layer of the inception net for generated samples and real world
14 | samples respectivly.
15 |
16 | See --help to see further details.
17 | '''
18 |
19 | from __future__ import absolute_import, division, print_function
20 | import numpy as np
21 | import os
22 | import gzip, pickle
23 | import tensorflow as tf
24 | from scipy.misc import imread
25 | from scipy import linalg
26 | import pathlib
27 | import urllib
28 | import tarfile
29 | import warnings
30 |
31 | MODEL_DIR = '/tmp/imagenet'
32 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
33 | pool3 = None
34 |
35 | class InvalidFIDException(Exception):
36 | pass
37 |
38 | #-------------------------------------------------------------------------------
39 | def get_fid_score(images, images_gt):
40 | images = np.stack(images, 0)
41 | images_gt = np.stack(images_gt, 0)
42 |
43 | with tf.Session() as sess:
44 | m1, s1 = calculate_activation_statistics(images, sess)
45 | m2, s2 = calculate_activation_statistics(images_gt, sess)
46 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
47 |
48 | print("Obtained fid value of {}".format(fid_value))
49 | return fid_value
50 |
51 |
52 | def create_inception_graph(pth):
53 | """Creates a graph from saved GraphDef file."""
54 | # Creates graph from saved graph_def.pb.
55 | with tf.gfile.FastGFile( pth, 'rb') as f:
56 | graph_def = tf.GraphDef()
57 | graph_def.ParseFromString( f.read())
58 | _ = tf.import_graph_def( graph_def, name='FID_Inception_Net')
59 | #-------------------------------------------------------------------------------
60 |
61 |
62 | # code for handling inception net derived from
63 | # https://github.com/openai/improved-gan/blob/master/inception_score/model.py
64 | def _get_inception_layer(sess):
65 | """Prepares inception net for batched usage and returns pool_3 layer. """
66 | layername = 'FID_Inception_Net/pool_3:0'
67 | pool3 = sess.graph.get_tensor_by_name(layername)
68 | ops = pool3.graph.get_operations()
69 | for op_idx, op in enumerate(ops):
70 | for o in op.outputs:
71 | shape = o.get_shape()
72 | if shape._dims != []:
73 | shape = [s.value for s in shape]
74 | new_shape = []
75 | for j, s in enumerate(shape):
76 | if s == 1 and j == 0:
77 | new_shape.append(None)
78 | else:
79 | new_shape.append(s)
80 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
81 | return pool3
82 | #-------------------------------------------------------------------------------
83 |
84 |
85 | def get_activations(images, sess, batch_size=50, verbose=False):
86 | """Calculates the activations of the pool_3 layer for all images.
87 |
88 | Params:
89 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values
90 | must lie between 0 and 256.
91 | -- sess : current session
92 | -- batch_size : the images numpy array is split into batches with batch size
93 | batch_size. A reasonable batch size depends on the disposable hardware.
94 | -- verbose : If set to True and parameter out_step is given, the number of calculated
95 | batches is reported.
96 | Returns:
97 | -- A numpy array of dimension (num images, 2048) that contains the
98 | activations of the given tensor when feeding inception with the query tensor.
99 | """
100 | # inception_layer = _get_inception_layer(sess)
101 | d0 = images.shape[0]
102 | if batch_size > d0:
103 | print("warning: batch size is bigger than the data size. setting batch size to data size")
104 | batch_size = d0
105 | n_batches = d0//batch_size
106 | n_used_imgs = n_batches*batch_size
107 | pred_arr = np.empty((n_used_imgs,2048))
108 | for i in range(n_batches):
109 | if verbose:
110 | print("\rPropagating batch %d/%d" % (i+1, n_batches), end="", flush=True)
111 | start = i*batch_size
112 | end = start + batch_size
113 | batch = images[start:end]
114 | pred = sess.run(pool3, {'ExpandDims:0': batch})
115 | pred_arr[start:end] = pred.reshape(batch_size,-1)
116 | if verbose:
117 | print(" done")
118 | return pred_arr
119 | #-------------------------------------------------------------------------------
120 |
121 |
122 | def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
123 | """Numpy implementation of the Frechet Distance.
124 | The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
125 | and X_2 ~ N(mu_2, C_2) is
126 | d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
127 | Stable version by Dougal J. Sutherland.
128 |
129 | Params:
130 | -- mu1 : Numpy array containing the activations of the pool_3 layer of the
131 | inception net ( like returned by the function 'get_predictions')
132 | for generated samples.
133 | -- mu2 : The sample mean over activations of the pool_3 layer, precalcualted
134 | on an representive data set.
135 | -- sigma1: The covariance matrix over activations of the pool_3 layer for
136 | generated samples.
137 | -- sigma2: The covariance matrix over activations of the pool_3 layer,
138 | precalcualted on an representive data set.
139 |
140 | Returns:
141 | -- : The Frechet Distance.
142 | """
143 |
144 | mu1 = np.atleast_1d(mu1)
145 | mu2 = np.atleast_1d(mu2)
146 |
147 | sigma1 = np.atleast_2d(sigma1)
148 | sigma2 = np.atleast_2d(sigma2)
149 |
150 | assert mu1.shape == mu2.shape, "Training and test mean vectors have different lengths"
151 | assert sigma1.shape == sigma2.shape, "Training and test covariances have different dimensions"
152 |
153 | diff = mu1 - mu2
154 |
155 | # product might be almost singular
156 | covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
157 | if not np.isfinite(covmean).all():
158 | msg = "fid calculation produces singular product; adding %s to diagonal of cov estimates" % eps
159 | warnings.warn(msg)
160 | offset = np.eye(sigma1.shape[0]) * eps
161 | covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
162 |
163 | # numerical error might give slight imaginary component
164 | if np.iscomplexobj(covmean):
165 | if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
166 | m = np.max(np.abs(covmean.imag))
167 | raise ValueError("Imaginary component {}".format(m))
168 | covmean = covmean.real
169 |
170 | tr_covmean = np.trace(covmean)
171 |
172 | return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean
173 | #-------------------------------------------------------------------------------
174 |
175 |
176 | def calculate_activation_statistics(images, sess, batch_size=50, verbose=False):
177 | """Calculation of the statistics used by the FID.
178 | Params:
179 | -- images : Numpy array of dimension (n_images, hi, wi, 3). The values
180 | must lie between 0 and 255.
181 | -- sess : current session
182 | -- batch_size : the images numpy array is split into batches with batch size
183 | batch_size. A reasonable batch size depends on the available hardware.
184 | -- verbose : If set to True and parameter out_step is given, the number of calculated
185 | batches is reported.
186 | Returns:
187 | -- mu : The mean over samples of the activations of the pool_3 layer of
188 | the incption model.
189 | -- sigma : The covariance matrix of the activations of the pool_3 layer of
190 | the incption model.
191 | """
192 | act = get_activations(images, sess, batch_size, verbose)
193 | mu = np.mean(act, axis=0)
194 | sigma = np.cov(act, rowvar=False)
195 | return mu, sigma
196 | #-------------------------------------------------------------------------------
197 |
198 |
199 | #-------------------------------------------------------------------------------
200 | # The following functions aren't needed for calculating the FID
201 | # they're just here to make this module work as a stand-alone script
202 | # for calculating FID scores
203 | #-------------------------------------------------------------------------------
204 | def check_or_download_inception(inception_path):
205 | ''' Checks if the path to the inception file is valid, or downloads
206 | the file if it is not present. '''
207 | INCEPTION_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
208 | if inception_path is None:
209 | inception_path = '/tmp'
210 | inception_path = pathlib.Path(inception_path)
211 | model_file = inception_path / 'classify_image_graph_def.pb'
212 | if not model_file.exists():
213 | print("Downloading Inception model")
214 | from urllib import request
215 | import tarfile
216 | fn, _ = request.urlretrieve(INCEPTION_URL)
217 | with tarfile.open(fn, mode='r') as f:
218 | f.extract('classify_image_graph_def.pb', str(model_file.parent))
219 | return str(model_file)
220 |
221 |
222 | def _handle_path(path, sess):
223 | if path.endswith('.npz'):
224 | f = np.load(path)
225 | m, s = f['mu'][:], f['sigma'][:]
226 | f.close()
227 | else:
228 | path = pathlib.Path(path)
229 | files = list(path.glob('*.jpg')) + list(path.glob('*.png'))
230 | x = np.array([imread(str(fn)).astype(np.float32) for fn in files])
231 | m, s = calculate_activation_statistics(x, sess)
232 | return m, s
233 |
234 |
235 | def calculate_fid_given_paths(paths, inception_path):
236 | ''' Calculates the FID of two paths. '''
237 | inception_path = check_or_download_inception(inception_path)
238 |
239 | for p in paths:
240 | if not os.path.exists(p):
241 | raise RuntimeError("Invalid path: %s" % p)
242 |
243 | create_inception_graph(str(inception_path))
244 | with tf.Session() as sess:
245 | sess.run(tf.global_variables_initializer())
246 | m1, s1 = _handle_path(paths[0], sess)
247 | m2, s2 = _handle_path(paths[1], sess)
248 | fid_value = calculate_frechet_distance(m1, s1, m2, s2)
249 | return fid_value
250 |
251 |
252 | def _init_inception():
253 | global pool3
254 | if not os.path.exists(MODEL_DIR):
255 | os.makedirs(MODEL_DIR)
256 | filename = DATA_URL.split('/')[-1]
257 | filepath = os.path.join(MODEL_DIR, filename)
258 | if not os.path.exists(filepath):
259 | def _progress(count, block_size, total_size):
260 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
261 | filename, float(count * block_size) / float(total_size) * 100.0))
262 | sys.stdout.flush()
263 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
264 | print()
265 | statinfo = os.stat(filepath)
266 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
267 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
268 | with tf.gfile.FastGFile(os.path.join(
269 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
270 | graph_def = tf.GraphDef()
271 | graph_def.ParseFromString(f.read())
272 | _ = tf.import_graph_def(graph_def, name='')
273 | # Works with an arbitrary minibatch size.
274 | with tf.Session() as sess:
275 | pool3 = sess.graph.get_tensor_by_name('pool_3:0')
276 | ops = pool3.graph.get_operations()
277 | for op_idx, op in enumerate(ops):
278 | for o in op.outputs:
279 | shape = o.get_shape()
280 | if shape._dims != []:
281 | shape = [s.value for s in shape]
282 | new_shape = []
283 | for j, s in enumerate(shape):
284 | if s == 1 and j == 0:
285 | new_shape.append(None)
286 | else:
287 | new_shape.append(s)
288 | o.__dict__['_shape_val'] = tf.TensorShape(new_shape)
289 |
290 |
291 | if pool3 is None:
292 | _init_inception()
293 |
--------------------------------------------------------------------------------
/EBMs/hmc.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 |
4 | from tensorflow.python.platform import flags
5 | flags.DEFINE_bool('proposal_debug', False, 'Print hmc acceptance raes')
6 |
7 | FLAGS = flags.FLAGS
8 |
9 | def kinetic_energy(velocity):
10 | """Kinetic energy of the current velocity (assuming a standard Gaussian)
11 | (x dot x) / 2
12 |
13 | Parameters
14 | ----------
15 | velocity : tf.Variable
16 | Vector of current velocity
17 |
18 | Returns
19 | -------
20 | kinetic_energy : float
21 | """
22 | return 0.5 * tf.square(velocity)
23 |
24 | def hamiltonian(position, velocity, energy_function):
25 | """Computes the Hamiltonian of the current position, velocity pair
26 |
27 | H = U(x) + K(v)
28 |
29 | U is the potential energy and is = -log_posterior(x)
30 |
31 | Parameters
32 | ----------
33 | position : tf.Variable
34 | Position or state vector x (sample from the target distribution)
35 | velocity : tf.Variable
36 | Auxiliary velocity variable
37 | energy_function
38 | Function from state to position to 'energy'
39 | = -log_posterior
40 |
41 | Returns
42 | -------
43 | hamitonian : float
44 | """
45 | batch_size = tf.shape(velocity)[0]
46 | kinetic_energy_flat = tf.reshape(kinetic_energy(velocity), (batch_size, -1))
47 | return tf.squeeze(energy_function(position)) + tf.reduce_sum(kinetic_energy_flat, axis=[1])
48 |
49 | def leapfrog_step(x0,
50 | v0,
51 | neg_log_posterior,
52 | step_size,
53 | num_steps):
54 |
55 | # Start by updating the velocity a half-step
56 | v = v0 - 0.5 * step_size * tf.gradients(neg_log_posterior(x0), x0)[0]
57 |
58 | # Initalize x to be the first step
59 | x = x0 + step_size * v
60 |
61 | for i in range(num_steps):
62 | # Compute gradient of the log-posterior with respect to x
63 | gradient = tf.gradients(neg_log_posterior(x), x)[0]
64 |
65 | # Update velocity
66 | v = v - step_size * gradient
67 |
68 | # x_clip = tf.clip_by_value(x, 0.0, 1.0)
69 | # x = x_clip
70 | # v_mask = 1 - 2 * tf.abs(tf.sign(x - x_clip))
71 | # v = v * v_mask
72 |
73 | # Update x
74 | x = x + step_size * v
75 |
76 | # x = tf.clip_by_value(x, -0.01, 1.01)
77 |
78 | # x = tf.Print(x, [tf.reduce_min(x), tf.reduce_max(x), tf.reduce_mean(x)])
79 |
80 | # Do a final update of the velocity for a half step
81 | v = v - 0.5 * step_size * tf.gradients(neg_log_posterior(x), x)[0]
82 |
83 | # return new proposal state
84 | return x, v
85 |
86 | def hmc(initial_x,
87 | step_size,
88 | num_steps,
89 | neg_log_posterior):
90 | """Summary
91 |
92 | Parameters
93 | ----------
94 | initial_x : tf.Variable
95 | Initial sample x ~ p
96 | step_size : float
97 | Step-size in Hamiltonian simulation
98 | num_steps : int
99 | Number of steps to take in Hamiltonian simulation
100 | neg_log_posterior : str
101 | Negative log posterior (unnormalized) for the target distribution
102 |
103 | Returns
104 | -------
105 | sample :
106 | Sample ~ target distribution
107 | """
108 |
109 | v0 = tf.random_normal(tf.shape(initial_x))
110 | x, v = leapfrog_step(initial_x,
111 | v0,
112 | step_size=step_size,
113 | num_steps=num_steps,
114 | neg_log_posterior=neg_log_posterior)
115 |
116 | orig = hamiltonian(initial_x, v0, neg_log_posterior)
117 | current = hamiltonian(x, v, neg_log_posterior)
118 |
119 | prob_accept = tf.exp(orig - current)
120 |
121 | if FLAGS.proposal_debug:
122 | prob_accept = tf.Print(prob_accept, [tf.reduce_mean(tf.clip_by_value(prob_accept, 0, 1))])
123 |
124 | uniform = tf.random_uniform(tf.shape(prob_accept))
125 | keep_mask = (prob_accept > uniform)
126 | # print(keep_mask.get_shape())
127 |
128 | x_new = tf.where(keep_mask, x, initial_x)
129 | return x_new
130 |
--------------------------------------------------------------------------------
/EBMs/imagenet_demo.py:
--------------------------------------------------------------------------------
1 | from models import ResNet128
2 | import numpy as np
3 | import os.path as osp
4 | from tensorflow.python.platform import flags
5 | import tensorflow as tf
6 | import imageio
7 | from utils import optimistic_restore
8 |
9 |
10 | flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
11 | flags.DEFINE_integer('num_steps', 200, 'num of steps for conditional imagenet sampling')
12 | flags.DEFINE_float('step_lr', 180., 'step size for Langevin dynamics')
13 | flags.DEFINE_integer('batch_size', 16, 'number of steps to run')
14 | flags.DEFINE_string('exp', 'default', 'name of experiments')
15 | flags.DEFINE_integer('resume_iter', -1, 'iteration to resume training from')
16 | flags.DEFINE_bool('spec_norm', True, 'whether to use spectral normalization in weights in a model')
17 | flags.DEFINE_bool('cclass', True, 'conditional models')
18 | flags.DEFINE_bool('use_attention', False, 'using attention')
19 |
20 | FLAGS = flags.FLAGS
21 |
22 | def rescale_im(im):
23 | return np.clip(im * 256, 0, 255).astype(np.uint8)
24 |
25 |
26 | if __name__ == "__main__":
27 | model = ResNet128(num_filters=64)
28 | X_NOISE = tf.placeholder(shape=(None, 128, 128, 3), dtype=tf.float32)
29 | LABEL = tf.placeholder(shape=(None, 1000), dtype=tf.float32)
30 |
31 | sess = tf.InteractiveSession()
32 | weights = model.construct_weights("context_0")
33 |
34 | x_mod = X_NOISE
35 | x_mod = x_mod + tf.random_normal(tf.shape(x_mod),
36 | mean=0.0,
37 | stddev=0.005)
38 |
39 | energy_noise = energy_start = model.forward(x_mod, weights, label=LABEL,
40 | reuse=True, stop_at_grad=False, stop_batch=True)
41 |
42 | x_grad = tf.gradients(energy_noise, [x_mod])[0]
43 | energy_noise_old = energy_noise
44 |
45 | lr = FLAGS.step_lr
46 |
47 | x_last = x_mod - (lr) * x_grad
48 |
49 | x_mod = x_last
50 | x_mod = tf.clip_by_value(x_mod, 0, 1)
51 | x_output = x_mod
52 |
53 | sess.run(tf.global_variables_initializer())
54 | saver = loader = tf.train.Saver()
55 |
56 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
57 | model_file = osp.join(logdir, 'model_{}'.format(FLAGS.resume_iter))
58 | saver.restore(sess, model_file)
59 |
60 | lx = np.random.permutation(1000)[:16]
61 | ims = []
62 |
63 | # What to initialize sampling with.
64 | x_mod = np.random.uniform(0, 1, size=(FLAGS.batch_size, 128, 128, 3))
65 | labels = np.eye(1000)[lx]
66 |
67 | for i in range(FLAGS.num_steps):
68 | e, x_mod = sess.run([energy_noise, x_output], {X_NOISE:x_mod, LABEL:labels})
69 | ims.append(rescale_im(x_mod).reshape((4, 4, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((512, 512, 3)))
70 |
71 | imageio.mimwrite('sample.gif', ims)
72 |
73 |
74 |
--------------------------------------------------------------------------------
/EBMs/imagenet_preprocessing.py:
--------------------------------------------------------------------------------
1 | # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | # ==============================================================================
15 |
16 | """Image pre-processing utilities.
17 | """
18 | import tensorflow as tf
19 |
20 |
21 | IMAGE_DEPTH = 3 # color images
22 |
23 | import tensorflow as tf
24 |
25 | # _R_MEAN = 123.68
26 | # _G_MEAN = 116.78
27 | # _B_MEAN = 103.94
28 | # _CHANNEL_MEANS = [_R_MEAN, _G_MEAN, _B_MEAN]
29 | _CHANNEL_MEANS = [0.0, 0.0, 0.0]
30 |
31 | # The lower bound for the smallest side of the image for aspect-preserving
32 | # resizing. For example, if an image is 500 x 1000, it will be resized to
33 | # _RESIZE_MIN x (_RESIZE_MIN * 2).
34 | _RESIZE_MIN = 128
35 |
36 |
37 | def _decode_crop_and_flip(image_buffer, bbox, num_channels):
38 | """Crops the given image to a random part of the image, and randomly flips.
39 |
40 | We use the fused decode_and_crop op, which performs better than the two ops
41 | used separately in series, but note that this requires that the image be
42 | passed in as an un-decoded string Tensor.
43 |
44 | Args:
45 | image_buffer: scalar string Tensor representing the raw JPEG image buffer.
46 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
47 | where each coordinate is [0, 1) and the coordinates are arranged as
48 | [ymin, xmin, ymax, xmax].
49 | num_channels: Integer depth of the image buffer for decoding.
50 |
51 | Returns:
52 | 3-D tensor with cropped image.
53 |
54 | """
55 | # A large fraction of image datasets contain a human-annotated bounding box
56 | # delineating the region of the image containing the object of interest. We
57 | # choose to create a new bounding box for the object which is a randomly
58 | # distorted version of the human-annotated bounding box that obeys an
59 | # allowed range of aspect ratios, sizes and overlap with the human-annotated
60 | # bounding box. If no box is supplied, then we assume the bounding box is
61 | # the entire image.
62 | sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
63 | tf.image.extract_jpeg_shape(image_buffer),
64 | bounding_boxes=bbox,
65 | min_object_covered=0.1,
66 | aspect_ratio_range=[0.75, 1.33],
67 | area_range=[0.05, 1.0],
68 | max_attempts=100,
69 | use_image_if_no_bounding_boxes=True)
70 | bbox_begin, bbox_size, _ = sample_distorted_bounding_box
71 |
72 | # Reassemble the bounding box in the format the crop op requires.
73 | offset_y, offset_x, _ = tf.unstack(bbox_begin)
74 | target_height, target_width, _ = tf.unstack(bbox_size)
75 | crop_window = tf.stack([offset_y, offset_x, target_height, target_width])
76 |
77 | # Use the fused decode and crop op here, which is faster than each in series.
78 | cropped = tf.image.decode_and_crop_jpeg(
79 | image_buffer, crop_window, channels=num_channels)
80 |
81 | # Flip to add a little more random distortion in.
82 | cropped = tf.image.random_flip_left_right(cropped)
83 | return cropped
84 |
85 |
86 | def _central_crop(image, crop_height, crop_width):
87 | """Performs central crops of the given image list.
88 |
89 | Args:
90 | image: a 3-D image tensor
91 | crop_height: the height of the image following the crop.
92 | crop_width: the width of the image following the crop.
93 |
94 | Returns:
95 | 3-D tensor with cropped image.
96 | """
97 | shape = tf.shape(input=image)
98 | height, width = shape[0], shape[1]
99 |
100 | amount_to_be_cropped_h = (height - crop_height)
101 | crop_top = amount_to_be_cropped_h // 2
102 | amount_to_be_cropped_w = (width - crop_width)
103 | crop_left = amount_to_be_cropped_w // 2
104 | return tf.slice(
105 | image, [crop_top, crop_left, 0], [crop_height, crop_width, -1])
106 |
107 |
108 | def _mean_image_subtraction(image, means, num_channels):
109 | """Subtracts the given means from each image channel.
110 |
111 | For example:
112 | means = [123.68, 116.779, 103.939]
113 | image = _mean_image_subtraction(image, means)
114 |
115 | Note that the rank of `image` must be known.
116 |
117 | Args:
118 | image: a tensor of size [height, width, C].
119 | means: a C-vector of values to subtract from each channel.
120 | num_channels: number of color channels in the image that will be distorted.
121 |
122 | Returns:
123 | the centered image.
124 |
125 | Raises:
126 | ValueError: If the rank of `image` is unknown, if `image` has a rank other
127 | than three or if the number of channels in `image` doesn't match the
128 | number of values in `means`.
129 | """
130 | if image.get_shape().ndims != 3:
131 | raise ValueError('Input must be of size [height, width, C>0]')
132 |
133 | if len(means) != num_channels:
134 | raise ValueError('len(means) must match the number of channels')
135 |
136 | # We have a 1-D tensor of means; convert to 3-D.
137 | means = tf.expand_dims(tf.expand_dims(means, 0), 0)
138 |
139 | return image - means
140 |
141 |
142 | def _smallest_size_at_least(height, width, resize_min):
143 | """Computes new shape with the smallest side equal to `smallest_side`.
144 |
145 | Computes new shape with the smallest side equal to `smallest_side` while
146 | preserving the original aspect ratio.
147 |
148 | Args:
149 | height: an int32 scalar tensor indicating the current height.
150 | width: an int32 scalar tensor indicating the current width.
151 | resize_min: A python integer or scalar `Tensor` indicating the size of
152 | the smallest side after resize.
153 |
154 | Returns:
155 | new_height: an int32 scalar tensor indicating the new height.
156 | new_width: an int32 scalar tensor indicating the new width.
157 | """
158 | resize_min = tf.cast(resize_min, tf.float32)
159 |
160 | # Convert to floats to make subsequent calculations go smoothly.
161 | height, width = tf.cast(height, tf.float32), tf.cast(width, tf.float32)
162 |
163 | smaller_dim = tf.minimum(height, width)
164 | scale_ratio = resize_min / smaller_dim
165 |
166 | # Convert back to ints to make heights and widths that TF ops will accept.
167 | new_height = tf.cast(tf.ceil(height * scale_ratio), tf.int32)
168 | new_width = tf.cast(tf.ceil(width * scale_ratio), tf.int32)
169 |
170 | return new_height, new_width
171 |
172 |
173 | def _aspect_preserving_resize(image, resize_min):
174 | """Resize images preserving the original aspect ratio.
175 |
176 | Args:
177 | image: A 3-D image `Tensor`.
178 | resize_min: A python integer or scalar `Tensor` indicating the size of
179 | the smallest side after resize.
180 |
181 | Returns:
182 | resized_image: A 3-D tensor containing the resized image.
183 | """
184 | shape = tf.shape(input=image)
185 | height, width = shape[0], shape[1]
186 |
187 | new_height, new_width = _smallest_size_at_least(height, width, resize_min)
188 |
189 | return _resize_image(image, new_height, new_width)
190 |
191 |
192 | def _resize_image(image, height, width):
193 | """Simple wrapper around tf.resize_images.
194 |
195 | This is primarily to make sure we use the same `ResizeMethod` and other
196 | details each time.
197 |
198 | Args:
199 | image: A 3-D image `Tensor`.
200 | height: The target height for the resized image.
201 | width: The target width for the resized image.
202 |
203 | Returns:
204 | resized_image: A 3-D tensor containing the resized image. The first two
205 | dimensions have the shape [height, width].
206 | """
207 | return tf.image.resize_images(
208 | image, [height, width], method=tf.image.ResizeMethod.BILINEAR,
209 | align_corners=False)
210 |
211 |
212 | def preprocess_image(image_buffer, bbox, output_height, output_width,
213 | num_channels, is_training=False):
214 | """Preprocesses the given image.
215 |
216 | Preprocessing includes decoding, cropping, and resizing for both training
217 | and eval images. Training preprocessing, however, introduces some random
218 | distortion of the image to improve accuracy.
219 |
220 | Args:
221 | image_buffer: scalar string Tensor representing the raw JPEG image buffer.
222 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
223 | where each coordinate is [0, 1) and the coordinates are arranged as
224 | [ymin, xmin, ymax, xmax].
225 | output_height: The height of the image after preprocessing.
226 | output_width: The width of the image after preprocessing.
227 | num_channels: Integer depth of the image buffer for decoding.
228 | is_training: `True` if we're preprocessing the image for training and
229 | `False` otherwise.
230 |
231 | Returns:
232 | A preprocessed image.
233 | """
234 | if is_training:
235 | # For training, we want to randomize some of the distortions.
236 | image = _decode_crop_and_flip(image_buffer, bbox, num_channels)
237 | image = _resize_image(image, output_height, output_width)
238 | else:
239 | # For validation, we want to decode, resize, then just crop the middle.
240 | image = tf.image.decode_jpeg(image_buffer, channels=num_channels)
241 | image = _aspect_preserving_resize(image, _RESIZE_MIN)
242 | print(image)
243 | image = _central_crop(image, output_height, output_width)
244 |
245 | image.set_shape([output_height, output_width, num_channels])
246 |
247 | return _mean_image_subtraction(image, _CHANNEL_MEANS, num_channels)
248 |
249 |
250 | def parse_example_proto(example_serialized):
251 | """Parses an Example proto containing a training example of an image.
252 |
253 | The output of the build_image_data.py image preprocessing script is a dataset
254 | containing serialized Example protocol buffers. Each Example proto contains
255 | the following fields:
256 |
257 | image/height: 462
258 | image/width: 581
259 | image/colorspace: 'RGB'
260 | image/channels: 3
261 | image/class/label: 615
262 | image/class/synset: 'n03623198'
263 | image/class/text: 'knee pad'
264 | image/object/bbox/xmin: 0.1
265 | image/object/bbox/xmax: 0.9
266 | image/object/bbox/ymin: 0.2
267 | image/object/bbox/ymax: 0.6
268 | image/object/bbox/label: 615
269 | image/format: 'JPEG'
270 | image/filename: 'ILSVRC2012_val_00041207.JPEG'
271 | image/encoded:
272 |
273 | Args:
274 | example_serialized: scalar Tensor tf.string containing a serialized
275 | Example protocol buffer.
276 |
277 | Returns:
278 | image_buffer: Tensor tf.string containing the contents of a JPEG file.
279 | label: Tensor tf.int32 containing the label.
280 | bbox: 3-D float Tensor of bounding boxes arranged [1, num_boxes, coords]
281 | where each coordinate is [0, 1) and the coordinates are arranged as
282 | [ymin, xmin, ymax, xmax].
283 | text: Tensor tf.string containing the human-readable label.
284 | """
285 | # Dense features in Example proto.
286 | feature_map = {
287 | 'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
288 | default_value=''),
289 | 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
290 | default_value=-1),
291 | 'image/class/text': tf.FixedLenFeature([], dtype=tf.string,
292 | default_value=''),
293 | }
294 | sparse_float32 = tf.VarLenFeature(dtype=tf.float32)
295 | # Sparse features in Example proto.
296 | feature_map.update(
297 | {k: sparse_float32 for k in ['image/object/bbox/xmin',
298 | 'image/object/bbox/ymin',
299 | 'image/object/bbox/xmax',
300 | 'image/object/bbox/ymax']})
301 |
302 | features = tf.parse_single_example(example_serialized, feature_map)
303 | label = tf.cast(features['image/class/label'], dtype=tf.int32)
304 |
305 | xmin = tf.expand_dims(features['image/object/bbox/xmin'].values, 0)
306 | ymin = tf.expand_dims(features['image/object/bbox/ymin'].values, 0)
307 | xmax = tf.expand_dims(features['image/object/bbox/xmax'].values, 0)
308 | ymax = tf.expand_dims(features['image/object/bbox/ymax'].values, 0)
309 |
310 | # Note that we impose an ordering of (y, x) just to make life difficult.
311 | bbox = tf.concat([ymin, xmin, ymax, xmax], 0)
312 |
313 | # Force the variable number of bounding boxes into the shape
314 | # [1, num_boxes, coords].
315 | bbox = tf.expand_dims(bbox, 0)
316 | bbox = tf.transpose(bbox, [0, 2, 1])
317 |
318 | return features['image/encoded'], label, bbox, features['image/class/text']
319 |
320 |
321 | class ImagenetPreprocessor:
322 | def __init__(self, image_size, dtype, train):
323 | self.image_size = image_size
324 | self.dtype = dtype
325 | self.train = train
326 |
327 | def preprocess(self, image_buffer, bbox):
328 | # pylint: disable=g-import-not-at-top
329 | image = preprocess_image(image_buffer, bbox, self.image_size, self.image_size, IMAGE_DEPTH, is_training=self.train)
330 | return tf.cast(image, self.dtype)
331 |
332 | def parse_and_preprocess(self, value):
333 | image_buffer, label_index, bbox, _ = parse_example_proto(value)
334 | image = self.preprocess(image_buffer, bbox)
335 | image = tf.reshape(image, [self.image_size, self.image_size, IMAGE_DEPTH])
336 | return label_index, image
337 |
338 |
--------------------------------------------------------------------------------
/EBMs/inception.py:
--------------------------------------------------------------------------------
1 | # Code derived from tensorflow/tensorflow/models/image/imagenet/classify_image.py
2 | from __future__ import absolute_import
3 | from __future__ import division
4 | from __future__ import print_function
5 |
6 | import os.path
7 | import sys
8 | import tarfile
9 |
10 | import numpy as np
11 | from six.moves import urllib
12 | import tensorflow as tf
13 | import glob
14 | import scipy.misc
15 | import math
16 | import sys
17 |
18 | import horovod.tensorflow as hvd
19 |
20 | MODEL_DIR = '/tmp/imagenet'
21 | DATA_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
22 | softmax = None
23 |
24 | config = tf.ConfigProto()
25 | config.gpu_options.visible_device_list = str(hvd.local_rank())
26 | sess = tf.Session(config=config)
27 |
28 | # Call this function with list of images. Each of elements should be a
29 | # numpy array with values ranging from 0 to 255.
30 | def get_inception_score(images, splits=10):
31 | # For convenience
32 | if len(images[0].shape) != 3:
33 | return 0, 0
34 |
35 | # Bypassing all the assertions so that we don't end prematuraly'
36 | # assert(type(images) == list)
37 | # assert(type(images[0]) == np.ndarray)
38 | # assert(len(images[0].shape) == 3)
39 | # assert(np.max(images[0]) > 10)
40 | # assert(np.min(images[0]) >= 0.0)
41 | inps = []
42 | for img in images:
43 | img = img.astype(np.float32)
44 | inps.append(np.expand_dims(img, 0))
45 | bs = 1
46 | preds = []
47 | n_batches = int(math.ceil(float(len(inps)) / float(bs)))
48 | for i in range(n_batches):
49 | sys.stdout.write(".")
50 | sys.stdout.flush()
51 | inp = inps[(i * bs):min((i + 1) * bs, len(inps))]
52 | inp = np.concatenate(inp, 0)
53 | pred = sess.run(softmax, {'ExpandDims:0': inp})
54 | preds.append(pred)
55 | preds = np.concatenate(preds, 0)
56 | scores = []
57 | for i in range(splits):
58 | part = preds[(i * preds.shape[0] // splits):((i + 1) * preds.shape[0] // splits), :]
59 | kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
60 | kl = np.mean(np.sum(kl, 1))
61 | scores.append(np.exp(kl))
62 | return np.mean(scores), np.std(scores)
63 |
64 | # This function is called automatically.
65 | def _init_inception():
66 | global softmax
67 | if not os.path.exists(MODEL_DIR):
68 | os.makedirs(MODEL_DIR)
69 | filename = DATA_URL.split('/')[-1]
70 | filepath = os.path.join(MODEL_DIR, filename)
71 | if not os.path.exists(filepath):
72 | def _progress(count, block_size, total_size):
73 | sys.stdout.write('\r>> Downloading %s %.1f%%' % (
74 | filename, float(count * block_size) / float(total_size) * 100.0))
75 | sys.stdout.flush()
76 | filepath, _ = urllib.request.urlretrieve(DATA_URL, filepath, _progress)
77 | print()
78 | statinfo = os.stat(filepath)
79 | print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
80 | tarfile.open(filepath, 'r:gz').extractall(MODEL_DIR)
81 | with tf.gfile.FastGFile(os.path.join(
82 | MODEL_DIR, 'classify_image_graph_def.pb'), 'rb') as f:
83 | graph_def = tf.GraphDef()
84 | graph_def.ParseFromString(f.read())
85 | _ = tf.import_graph_def(graph_def, name='')
86 | # Works with an arbitrary minibatch size.
87 | pool3 = sess.graph.get_tensor_by_name('pool_3:0')
88 | ops = pool3.graph.get_operations()
89 | for op_idx, op in enumerate(ops):
90 | for o in op.outputs:
91 | shape = o.get_shape()
92 | shape = [s.value for s in shape]
93 | new_shape = []
94 | for j, s in enumerate(shape):
95 | if s == 1 and j == 0:
96 | new_shape.append(None)
97 | else:
98 | new_shape.append(s)
99 | o.set_shape(tf.TensorShape(new_shape))
100 | w = sess.graph.get_operation_by_name("softmax/logits/MatMul").inputs[1]
101 | logits = tf.matmul(tf.squeeze(pool3, [1, 2]), w)
102 | softmax = tf.nn.softmax(logits)
103 |
104 | if softmax is None:
105 | _init_inception()
106 |
--------------------------------------------------------------------------------
/EBMs/models.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | from tensorflow.python.platform import flags
3 | import numpy as np
4 | from utils import conv_block, get_weight, attention, conv_cond_concat, init_conv_weight, init_attention_weight, init_res_weight, smart_res_block, smart_res_block_optim, init_convt_weight
5 | from utils import init_fc_weight, smart_conv_block, smart_fc_block, smart_atten_block, groupsort, smart_convt_block, swish
6 |
7 | flags.DEFINE_bool('swish_act', False, 'use the swish activation for dsprites')
8 |
9 | FLAGS = flags.FLAGS
10 |
11 |
12 | class MnistNet(object):
13 | def __init__(self, num_channels=1, num_filters=64):
14 |
15 | self.channels = num_channels
16 | self.dim_hidden = num_filters
17 | self.datasource = FLAGS.datasource
18 |
19 | if FLAGS.cclass:
20 | self.label_size = 10
21 | else:
22 | self.label_size = 0
23 |
24 | def construct_weights(self, scope=''):
25 | weights = {}
26 |
27 | dtype = tf.float32
28 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
29 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
30 |
31 | classes = 1
32 |
33 | with tf.variable_scope(scope):
34 | init_conv_weight(weights, 'c1_pre', 3, 1, 64)
35 | init_conv_weight(weights, 'c1', 4, 64, self.dim_hidden, classes=classes)
36 | init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
37 | init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
38 | init_fc_weight(weights, 'fc_dense', 4*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
39 | init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
40 |
41 | if FLAGS.cclass:
42 | self.label_size = 10
43 | else:
44 | self.label_size = 0
45 | return weights
46 |
47 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, **kwargs):
48 | channels = self.channels
49 | weights = weights.copy()
50 | inp = tf.reshape(inp, (tf.shape(inp)[0], 28, 28, 1))
51 |
52 | if FLAGS.swish_act:
53 | act = swish
54 | else:
55 | act = tf.nn.leaky_relu
56 |
57 | if stop_grad:
58 | for k, v in weights.items():
59 | if type(v) == dict:
60 | v = v.copy()
61 | weights[k] = v
62 | for k_sub, v_sub in v.items():
63 | v[k_sub] = tf.stop_gradient(v_sub)
64 | else:
65 | weights[k] = tf.stop_gradient(v)
66 |
67 | if FLAGS.cclass:
68 | label_d = tf.reshape(label, shape=(tf.shape(label)[0], 1, 1, self.label_size))
69 | inp = conv_cond_concat(inp, label_d)
70 |
71 | h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
72 | h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
73 | h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=False, activation=act)
74 | h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=False, extra_bias=False, activation=act)
75 |
76 | h5 = tf.reshape(h4, [-1, np.prod([int(dim) for dim in h4.get_shape()[1:]])])
77 | h6 = act(smart_fc_block(h5, weights, reuse, 'fc_dense'))
78 | hidden6 = smart_fc_block(h6, weights, reuse, 'fc5')
79 |
80 | return hidden6
81 |
82 |
83 | class DspritesNet(object):
84 | def __init__(self, num_channels=1, num_filters=64, cond_size=False, cond_shape=False, cond_pos=False,
85 | cond_rot=False, label_size=1):
86 |
87 | self.channels = num_channels
88 | self.dim_hidden = num_filters
89 | self.img_size = 64
90 | self.label_size = label_size
91 |
92 | if FLAGS.cclass:
93 | self.label_size = 3
94 |
95 | try:
96 | if FLAGS.dshape_only:
97 | self.label_size = 3
98 |
99 | if FLAGS.dpos_only:
100 | self.label_size = 2
101 |
102 | if FLAGS.dsize_only:
103 | self.label_size = 1
104 |
105 | if FLAGS.drot_only:
106 | self.label_size = 2
107 | except:
108 | pass
109 |
110 | if cond_size:
111 | self.label_size = 1
112 |
113 | if cond_shape:
114 | self.label_size = 3
115 |
116 | if cond_pos:
117 | self.label_size = 2
118 |
119 | if cond_rot:
120 | self.label_size = 2
121 |
122 | self.cond_size = cond_size
123 | self.cond_shape = cond_shape
124 | self.cond_pos = cond_pos
125 |
126 | def construct_weights(self, scope=''):
127 | weights = {}
128 |
129 | dtype = tf.float32
130 | conv_initializer = tf.contrib.layers.xavier_initializer_conv2d(dtype=dtype)
131 | fc_initializer = tf.contrib.layers.xavier_initializer(dtype=dtype)
132 | k = 5
133 | classes = self.label_size
134 |
135 | with tf.variable_scope(scope):
136 | init_conv_weight(weights, 'c1_pre', 3, 1, 32)
137 | init_conv_weight(weights, 'c1', 4, 32, self.dim_hidden, classes=classes)
138 | init_conv_weight(weights, 'c2', 4, self.dim_hidden, 2*self.dim_hidden, classes=classes)
139 | init_conv_weight(weights, 'c3', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
140 | init_conv_weight(weights, 'c4', 4, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
141 | init_fc_weight(weights, 'fc_dense', 2*4*4*self.dim_hidden, 2*self.dim_hidden, spec_norm=True)
142 | init_fc_weight(weights, 'fc5', 2*self.dim_hidden, 1, spec_norm=False)
143 |
144 | return weights
145 |
146 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False, return_logit=False):
147 | channels = self.channels
148 | batch_size = tf.shape(inp)[0]
149 |
150 | inp = tf.reshape(inp, (batch_size, 64, 64, 1))
151 |
152 | if FLAGS.swish_act:
153 | act = swish
154 | else:
155 | act = tf.nn.leaky_relu
156 |
157 | if not FLAGS.cclass:
158 | label = None
159 |
160 | weights = weights.copy()
161 |
162 | if stop_grad:
163 | for k, v in weights.items():
164 | if type(v) == dict:
165 | v = v.copy()
166 | weights[k] = v
167 | for k_sub, v_sub in v.items():
168 | v[k_sub] = tf.stop_gradient(v_sub)
169 | else:
170 | weights[k] = tf.stop_gradient(v)
171 |
172 | h1 = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
173 | h2 = smart_conv_block(h1, weights, reuse, 'c1', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
174 | h3 = smart_conv_block(h2, weights, reuse, 'c2', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
175 | h4 = smart_conv_block(h3, weights, reuse, 'c3', use_stride=True, downsample=True, label=label, use_scale=True, extra_bias=True, activation=act)
176 | h5 = smart_conv_block(h4, weights, reuse, 'c4', use_stride=True, downsample=True, label=label, extra_bias=True, activation=act)
177 |
178 | hidden6 = tf.reshape(h5, (tf.shape(h5)[0], -1))
179 | hidden7 = act(smart_fc_block(hidden6, weights, reuse, 'fc_dense'))
180 | energy = smart_fc_block(hidden7, weights, reuse, 'fc5')
181 |
182 | if return_logit:
183 | return hidden7
184 | else:
185 | return energy
186 |
187 |
188 |
189 | class ResNet32(object):
190 | def __init__(self, num_channels=3, num_filters=128):
191 |
192 | self.channels = num_channels
193 | self.dim_hidden = num_filters
194 | self.groupsort = groupsort()
195 |
196 | def construct_weights(self, scope=''):
197 | weights = {}
198 | dtype = tf.float32
199 |
200 | if FLAGS.cclass:
201 | classes = 10
202 | else:
203 | classes = 1
204 |
205 | with tf.variable_scope(scope):
206 | # First block
207 | init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
208 | init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
209 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
210 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
211 | init_res_weight(weights, 'res_3', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
212 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
213 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
214 | init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
215 | init_fc_weight(weights, 'fc5', 2*self.dim_hidden , 1, spec_norm=False)
216 |
217 | init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
218 |
219 | return weights
220 |
221 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
222 | weights = weights.copy()
223 | batch = tf.shape(inp)[0]
224 |
225 | act = tf.nn.leaky_relu
226 |
227 | if not FLAGS.cclass:
228 | label = None
229 |
230 | if stop_grad:
231 | for k, v in weights.items():
232 | if type(v) == dict:
233 | v = v.copy()
234 | weights[k] = v
235 | for k_sub, v_sub in v.items():
236 | v[k_sub] = tf.stop_gradient(v_sub)
237 | else:
238 | weights[k] = tf.stop_gradient(v)
239 |
240 | # Make sure gradients are modified a bit
241 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
242 |
243 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, act=act)
244 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, act=act)
245 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, label=label, act=act)
246 |
247 | if FLAGS.use_attention:
248 | hidden4 = smart_atten_block(hidden3, weights, reuse, 'atten', stop_at_grad=stop_at_grad, label=label)
249 | else:
250 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, act=act)
251 |
252 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', stop_batch=stop_batch, adaptive=False, label=label, act=act)
253 | compact = hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
254 | hidden6 = tf.nn.relu(hidden6)
255 | hidden5 = tf.reduce_sum(hidden6, [1, 2])
256 |
257 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
258 |
259 | energy = hidden6
260 |
261 | return energy
262 |
263 |
264 | class ResNet32Large(object):
265 | def __init__(self, num_channels=3, num_filters=128, train=False):
266 |
267 | self.channels = num_channels
268 | self.dim_hidden = num_filters
269 | self.dropout = train
270 | self.train = train
271 |
272 | def construct_weights(self, scope=''):
273 | weights = {}
274 | dtype = tf.float32
275 |
276 | if FLAGS.cclass:
277 | classes = 10
278 | else:
279 | classes = 1
280 |
281 | with tf.variable_scope(scope):
282 | # First block
283 | init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
284 | init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
285 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
286 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
287 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
288 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
289 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
290 | init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
291 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
292 | init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
293 | init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
294 |
295 | init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden, trainable_gamma=True)
296 |
297 | return weights
298 |
299 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
300 | weights = weights.copy()
301 | batch = tf.shape(inp)[0]
302 |
303 | if not FLAGS.cclass:
304 | label = None
305 |
306 | if stop_grad:
307 | for k, v in weights.items():
308 | if type(v) == dict:
309 | v = v.copy()
310 | weights[k] = v
311 | for k_sub, v_sub in v.items():
312 | v[k_sub] = tf.stop_gradient(v_sub)
313 | else:
314 | weights[k] = tf.stop_gradient(v)
315 |
316 | # Make sure gradients are modified a bit
317 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
318 |
319 | dropout = self.dropout
320 | train = self.train
321 |
322 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label, dropout=dropout, train=train)
323 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
324 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train)
325 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
326 |
327 | if FLAGS.use_attention:
328 | hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
329 | else:
330 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
331 |
332 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
333 |
334 | hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train)
335 | hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
336 |
337 | compact = hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train)
338 |
339 | if FLAGS.cclass:
340 | hidden6 = tf.nn.leaky_relu(hidden9)
341 | else:
342 | hidden6 = tf.nn.relu(hidden9)
343 | hidden5 = tf.reduce_sum(hidden6, [1, 2])
344 |
345 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
346 |
347 | energy = hidden6
348 |
349 | return energy
350 |
351 |
352 | class ResNet32Wider(object):
353 | def __init__(self, num_channels=3, num_filters=128, train=False):
354 |
355 | self.channels = num_channels
356 | self.dim_hidden = num_filters
357 | self.dropout = train
358 | self.train = train
359 |
360 | def construct_weights(self, scope=''):
361 | weights = {}
362 | dtype = tf.float32
363 |
364 | if FLAGS.cclass and FLAGS.dataset == "cifar10":
365 | classes = 10
366 | elif FLAGS.cclass and FLAGS.dataset == "imagenet":
367 | classes = 1000
368 | else:
369 | classes = 1
370 |
371 | with tf.variable_scope(scope):
372 | # First block
373 | init_conv_weight(weights, 'c1_pre', 3, self.channels, 128)
374 | init_res_weight(weights, 'res_optim', 3, 128, self.dim_hidden, classes=classes)
375 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
376 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
377 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
378 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
379 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
380 | init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
381 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
382 | init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
383 | init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
384 |
385 | init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
386 |
387 | return weights
388 |
389 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
390 | weights = weights.copy()
391 | batch = tf.shape(inp)[0]
392 |
393 | if not FLAGS.cclass:
394 | label = None
395 |
396 | if stop_grad:
397 | for k, v in weights.items():
398 | if type(v) == dict:
399 | v = v.copy()
400 | weights[k] = v
401 | for k_sub, v_sub in v.items():
402 | v[k_sub] = tf.stop_gradient(v_sub)
403 | else:
404 | weights[k] = tf.stop_gradient(v)
405 |
406 | if FLAGS.swish_act:
407 | act = swish
408 | else:
409 | act = tf.nn.leaky_relu
410 |
411 | # Make sure gradients are modified a bit
412 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
413 | dropout = self.dropout
414 | train = self.train
415 |
416 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=True, label=label, dropout=dropout, train=train)
417 |
418 | if FLAGS.use_attention:
419 | hidden2 = smart_atten_block(hidden1, weights, reuse, 'atten', train=train, dropout=dropout, stop_at_grad=stop_at_grad)
420 | else:
421 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
422 |
423 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label, dropout=dropout, train=train, act=act)
424 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
425 |
426 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
427 |
428 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
429 |
430 | hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
431 | hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
432 |
433 | hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act)
434 |
435 | if FLAGS.swish_act:
436 | hidden6 = act(hidden9)
437 | else:
438 | hidden6 = tf.nn.relu(hidden9)
439 |
440 | hidden5 = tf.reduce_sum(hidden6, [1, 2])
441 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
442 | energy = hidden6
443 |
444 | return energy
445 |
446 |
447 | class ResNet32Larger(object):
448 | def __init__(self, num_channels=3, num_filters=128):
449 |
450 | self.channels = num_channels
451 | self.dim_hidden = num_filters
452 |
453 | def construct_weights(self, scope=''):
454 | weights = {}
455 | dtype = tf.float32
456 |
457 | if FLAGS.cclass:
458 | classes = 10
459 | else:
460 | classes = 1
461 |
462 | with tf.variable_scope(scope):
463 | # First block
464 | init_conv_weight(weights, 'c1_pre', 3, self.channels, self.dim_hidden)
465 | init_res_weight(weights, 'res_optim', 3, self.dim_hidden, self.dim_hidden, classes=classes)
466 | init_res_weight(weights, 'res_1', 3, self.dim_hidden, self.dim_hidden, classes=classes)
467 | init_res_weight(weights, 'res_2', 3, self.dim_hidden, self.dim_hidden, classes=classes)
468 | init_res_weight(weights, 'res_2a', 3, self.dim_hidden, self.dim_hidden, classes=classes)
469 | init_res_weight(weights, 'res_2b', 3, self.dim_hidden, self.dim_hidden, classes=classes)
470 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
471 | init_res_weight(weights, 'res_4', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
472 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
473 | init_res_weight(weights, 'res_5a', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
474 | init_res_weight(weights, 'res_5b', 3, 2*self.dim_hidden, 2*self.dim_hidden, classes=classes)
475 | init_res_weight(weights, 'res_6', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
476 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
477 | init_res_weight(weights, 'res_8', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
478 | init_res_weight(weights, 'res_8a', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
479 | init_res_weight(weights, 'res_8b', 3, 4*self.dim_hidden, 4*self.dim_hidden, classes=classes)
480 | init_fc_weight(weights, 'fc_dense', 4*4*2*self.dim_hidden, 4*self.dim_hidden)
481 | init_fc_weight(weights, 'fc5', 4*self.dim_hidden , 1, spec_norm=False)
482 |
483 | init_attention_weight(weights, 'atten', 2*self.dim_hidden, self.dim_hidden / 2, trainable_gamma=True)
484 |
485 | return weights
486 |
487 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
488 | weights = weights.copy()
489 | batch = tf.shape(inp)[0]
490 |
491 | if not FLAGS.cclass:
492 | label = None
493 |
494 | if stop_grad:
495 | for k, v in weights.items():
496 | if type(v) == dict:
497 | v = v.copy()
498 | weights[k] = v
499 | for k_sub, v_sub in v.items():
500 | v[k_sub] = tf.stop_gradient(v_sub)
501 | else:
502 | weights[k] = tf.stop_gradient(v)
503 |
504 | # Make sure gradients are modified a bit
505 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False)
506 |
507 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', adaptive=False, label=label)
508 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_1', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
509 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_2', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
510 | hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2a', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
511 | hidden3 = smart_res_block(hidden3, weights, reuse, 'res_2b', stop_batch=stop_batch, downsample=False, adaptive=False, label=label)
512 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_3', stop_batch=stop_batch, label=label)
513 |
514 | if FLAGS.use_attention:
515 | hidden5 = smart_atten_block(hidden4, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
516 | else:
517 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_4', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
518 |
519 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_5', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
520 |
521 | hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
522 | hidden6 = smart_res_block(hidden6, weights, reuse, 'res_5b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
523 | hidden7 = smart_res_block(hidden6, weights, reuse, 'res_6', stop_batch=stop_batch, label=label)
524 | hidden8 = smart_res_block(hidden7, weights, reuse, 'res_7', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
525 | hidden9 = smart_res_block(hidden8, weights, reuse, 'res_8', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
526 | hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8a', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
527 | compact = hidden9 = smart_res_block(hidden9, weights, reuse, 'res_8b', adaptive=False, downsample=False, stop_batch=stop_batch, label=label)
528 |
529 | if FLAGS.cclass:
530 | hidden6 = tf.nn.leaky_relu(hidden9)
531 | else:
532 | hidden6 = tf.nn.relu(hidden9)
533 | hidden5 = tf.reduce_sum(hidden6, [1, 2])
534 |
535 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
536 |
537 | energy = hidden6
538 |
539 | return energy
540 |
541 |
542 | class ResNet128(object):
543 | """Construct the convolutional network specified in MAML"""
544 |
545 | def __init__(self, num_channels=3, num_filters=64, train=False):
546 |
547 | self.channels = num_channels
548 | self.dim_hidden = num_filters
549 | self.dropout = train
550 | self.train = train
551 |
552 | def construct_weights(self, scope=''):
553 | weights = {}
554 | dtype = tf.float32
555 |
556 | classes = 1000
557 |
558 | with tf.variable_scope(scope):
559 | # First block
560 | init_conv_weight(weights, 'c1_pre', 3, self.channels, 64)
561 | init_res_weight(weights, 'res_optim', 3, 64, self.dim_hidden, classes=classes)
562 | init_res_weight(weights, 'res_3', 3, self.dim_hidden, 2*self.dim_hidden, classes=classes)
563 | init_res_weight(weights, 'res_5', 3, 2*self.dim_hidden, 4*self.dim_hidden, classes=classes)
564 | init_res_weight(weights, 'res_7', 3, 4*self.dim_hidden, 8*self.dim_hidden, classes=classes)
565 | init_res_weight(weights, 'res_9', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
566 | init_res_weight(weights, 'res_10', 3, 8*self.dim_hidden, 8*self.dim_hidden, classes=classes)
567 | init_fc_weight(weights, 'fc5', 8*self.dim_hidden , 1, spec_norm=False)
568 |
569 |
570 | init_attention_weight(weights, 'atten', self.dim_hidden, self.dim_hidden / 2., trainable_gamma=True)
571 |
572 | return weights
573 |
574 | def forward(self, inp, weights, reuse=False, scope='', stop_grad=False, label=None, stop_at_grad=False, stop_batch=False):
575 | weights = weights.copy()
576 | batch = tf.shape(inp)[0]
577 |
578 | if not FLAGS.cclass:
579 | label = None
580 |
581 |
582 | if stop_grad:
583 | for k, v in weights.items():
584 | if type(v) == dict:
585 | v = v.copy()
586 | weights[k] = v
587 | for k_sub, v_sub in v.items():
588 | v[k_sub] = tf.stop_gradient(v_sub)
589 | else:
590 | weights[k] = tf.stop_gradient(v)
591 |
592 | if FLAGS.swish_act:
593 | act = swish
594 | else:
595 | act = tf.nn.leaky_relu
596 |
597 | dropout = self.dropout
598 | train = self.train
599 |
600 | # Make sure gradients are modified a bit
601 | inp = smart_conv_block(inp, weights, reuse, 'c1_pre', use_stride=False, activation=act)
602 | hidden1 = smart_res_block(inp, weights, reuse, 'res_optim', label=label, dropout=dropout, train=train, downsample=True, adaptive=False)
603 |
604 | if FLAGS.use_attention:
605 | hidden1 = smart_atten_block(hidden1, weights, reuse, 'atten', stop_at_grad=stop_at_grad)
606 |
607 | hidden2 = smart_res_block(hidden1, weights, reuse, 'res_3', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
608 | hidden3 = smart_res_block(hidden2, weights, reuse, 'res_5', stop_batch=stop_batch, downsample=True, adaptive=True, label=label, dropout=dropout, train=train, act=act)
609 | hidden4 = smart_res_block(hidden3, weights, reuse, 'res_7', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=True)
610 | hidden5 = smart_res_block(hidden4, weights, reuse, 'res_9', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=True, adaptive=False)
611 | hidden6 = smart_res_block(hidden5, weights, reuse, 'res_10', stop_batch=stop_batch, label=label, dropout=dropout, train=train, act=act, downsample=False, adaptive=False)
612 |
613 | if FLAGS.swish_act:
614 | hidden6 = act(hidden6)
615 | else:
616 | hidden6 = tf.nn.relu(hidden6)
617 |
618 | hidden5 = tf.reduce_sum(hidden6, [1, 2])
619 | hidden6 = smart_fc_block(hidden5, weights, reuse, 'fc5')
620 | energy = hidden6
621 |
622 | return energy
623 |
--------------------------------------------------------------------------------
/EBMs/requirements.txt:
--------------------------------------------------------------------------------
1 | scipy==1.10.0
2 | horovod==0.24.0
3 | torch==1.13.1
4 | torchvision==0.6.0
5 | six==1.11.0
6 | imageio==2.8.0
7 | tqdm==4.46.0
8 | matplotlib==3.2.1
9 | mpi4py==3.0.3
10 | numpy==1.22.0
11 | Pillow==10.0.1
12 | baselines==0.1.5
13 | scikit-image==0.14.2
14 | scikit_learn
15 | tensorflow==2.11.1
16 | cloudpickle==1.3.0
17 | Cython==0.29.17
18 | mujoco-py==1.50.1.68
19 |
--------------------------------------------------------------------------------
/EBMs/test_inception.py:
--------------------------------------------------------------------------------
1 | import tensorflow as tf
2 | import numpy as np
3 | from tensorflow.python.platform import flags
4 | from models import ResNet32, ResNet32Large, ResNet32Larger, ResNet32Wider, ResNet128
5 | import os.path as osp
6 | import os
7 | from utils import optimistic_restore, remap_restore, optimistic_remap_restore
8 | from tqdm import tqdm
9 | import random
10 | from scipy.misc import imsave
11 | from data import Cifar10, Svhn, Cifar100, Textures, Imagenet, TFImagenetLoader
12 | from torch.utils.data import DataLoader
13 | from baselines.common.tf_util import initialize
14 |
15 | import horovod.tensorflow as hvd
16 | hvd.init()
17 |
18 | from inception import get_inception_score
19 | from fid import get_fid_score
20 |
21 | flags.DEFINE_string('logdir', 'cachedir', 'location where log of experiments will be stored')
22 | flags.DEFINE_string('exp', 'default', 'name of experiments')
23 | flags.DEFINE_bool('cclass', False, 'whether to condition on class')
24 |
25 | # Architecture settings
26 | flags.DEFINE_bool('bn', False, 'Whether to use batch normalization or not')
27 | flags.DEFINE_bool('spec_norm', True, 'Whether to use spectral normalization on weights')
28 | flags.DEFINE_bool('use_bias', True, 'Whether to use bias in convolution')
29 | flags.DEFINE_bool('use_attention', False, 'Whether to use self attention in network')
30 | flags.DEFINE_float('step_lr', 10.0, 'Size of steps for gradient descent')
31 | flags.DEFINE_integer('num_steps', 20, 'number of steps to optimize the label')
32 | flags.DEFINE_float('proj_norm', 0.05, 'Maximum change of input images')
33 | flags.DEFINE_integer('batch_size', 512, 'batch size')
34 | flags.DEFINE_integer('resume_iter', -1, 'resume iteration')
35 | flags.DEFINE_integer('ensemble', 10, 'number of ensembles')
36 | flags.DEFINE_integer('im_number', 50000, 'number of ensembles')
37 | flags.DEFINE_integer('repeat_scale', 100, 'number of repeat iterations')
38 | flags.DEFINE_float('noise_scale', 0.005, 'amount of noise to output')
39 | flags.DEFINE_integer('idx', 0, 'save index')
40 | flags.DEFINE_integer('nomix', 10, 'number of intervals to stop mixing')
41 | flags.DEFINE_bool('scaled', True, 'whether to scale noise added')
42 | flags.DEFINE_bool('large_model', False, 'whether to use a small or large model')
43 | flags.DEFINE_bool('larger_model', False, 'Whether to use a large model')
44 | flags.DEFINE_bool('wider_model', False, 'Whether to use a large model')
45 | flags.DEFINE_bool('single', False, 'single ')
46 | flags.DEFINE_string('datasource', 'random', 'default or noise or negative or single')
47 | flags.DEFINE_string('dataset', 'cifar10', 'cifar10 or imagenet or imagenetfull')
48 |
49 | FLAGS = flags.FLAGS
50 |
51 | class InceptionReplayBuffer(object):
52 | def __init__(self, size):
53 | """Create Replay buffer.
54 | Parameters
55 | ----------
56 | size: int
57 | Max number of transitions to store in the buffer. When the buffer
58 | overflows the old memories are dropped.
59 | """
60 | self._storage = []
61 | self._label_storage = []
62 | self._maxsize = size
63 | self._next_idx = 0
64 |
65 | def __len__(self):
66 | return len(self._storage)
67 |
68 | def add(self, ims, labels):
69 | batch_size = ims.shape[0]
70 | if self._next_idx >= len(self._storage):
71 | self._storage.extend(list(ims))
72 | self._label_storage.extend(list(labels))
73 | else:
74 | if batch_size + self._next_idx < self._maxsize:
75 | self._storage[self._next_idx:self._next_idx+batch_size] = list(ims)
76 | self._label_storage[self._next_idx:self._next_idx+batch_size] = list(labels)
77 | else:
78 | split_idx = self._maxsize - self._next_idx
79 | self._storage[self._next_idx:] = list(ims)[:split_idx]
80 | self._storage[:batch_size-split_idx] = list(ims)[split_idx:]
81 | self._label_storage[self._next_idx:] = list(labels)[:split_idx]
82 | self._label_storage[:batch_size-split_idx] = list(labels)[split_idx:]
83 |
84 | self._next_idx = (self._next_idx + ims.shape[0]) % self._maxsize
85 |
86 | def _encode_sample(self, idxes):
87 | ims = []
88 | labels = []
89 | for i in idxes:
90 | ims.append(self._storage[i])
91 | labels.append(self._label_storage[i])
92 | return np.array(ims), np.array(labels)
93 |
94 | def sample(self, batch_size):
95 | """Sample a batch of experiences.
96 | Parameters
97 | ----------
98 | batch_size: int
99 | How many transitions to sample.
100 | Returns
101 | -------
102 | obs_batch: np.array
103 | batch of observations
104 | act_batch: np.array
105 | batch of actions executed given obs_batch
106 | rew_batch: np.array
107 | rewards received as results of executing act_batch
108 | next_obs_batch: np.array
109 | next set of observations seen after executing act_batch
110 | done_mask: np.array
111 | done_mask[i] = 1 if executing act_batch[i] resulted in
112 | the end of an episode and 0 otherwise.
113 | """
114 | idxes = [random.randint(0, len(self._storage) - 1) for _ in range(batch_size)]
115 | return self._encode_sample(idxes), idxes
116 |
117 | def set_elms(self, idxes, data, labels):
118 | for i, ix in enumerate(idxes):
119 | self._storage[ix] = data[i]
120 | self._label_storage[ix] = labels[i]
121 |
122 |
123 | def rescale_im(im):
124 | return np.clip(im * 256, 0, 255).astype(np.uint8)
125 |
126 | def compute_inception(sess, target_vars):
127 | X_START = target_vars['X_START']
128 | Y_GT = target_vars['Y_GT']
129 | X_finals = target_vars['X_finals']
130 | NOISE_SCALE = target_vars['NOISE_SCALE']
131 | energy_noise = target_vars['energy_noise']
132 |
133 | size = FLAGS.im_number
134 | num_steps = size // 1000
135 |
136 | images = []
137 | test_ims = []
138 |
139 |
140 | if FLAGS.dataset == "cifar10":
141 | test_dataset = Cifar10(full=True, noise=False)
142 | elif FLAGS.dataset == "imagenet" or FLAGS.dataset == "imagenetfull":
143 | test_dataset = Imagenet(train=False)
144 |
145 | if FLAGS.dataset != "imagenetfull":
146 | test_dataloader = DataLoader(test_dataset, batch_size=FLAGS.batch_size, num_workers=4, shuffle=True, drop_last=False)
147 | else:
148 | test_dataloader = TFImagenetLoader('test', FLAGS.batch_size, 0, 1)
149 |
150 | for data_corrupt, data, label_gt in tqdm(test_dataloader):
151 | data = data.numpy()
152 | test_ims.extend(list(rescale_im(data)))
153 |
154 | if FLAGS.dataset == "imagenetfull" and len(test_ims) > 60000:
155 | test_ims = test_ims[:60000]
156 | break
157 |
158 |
159 | # n = min(len(images), len(test_ims))
160 | print(len(test_ims))
161 | # fid = get_fid_score(test_ims[:30000], test_ims[-30000:])
162 | # print("Base FID of score {}".format(fid))
163 |
164 | if FLAGS.dataset == "cifar10":
165 | classes = 10
166 | else:
167 | classes = 1000
168 |
169 | if FLAGS.dataset == "imagenetfull":
170 | n = 128
171 | else:
172 | n = 32
173 |
174 | for j in range(num_steps):
175 | itr = int(1000 / 500 * FLAGS.repeat_scale)
176 | data_buffer = InceptionReplayBuffer(1000)
177 | curr_index = 0
178 |
179 | identity = np.eye(classes)
180 |
181 | for i in tqdm(range(itr)):
182 | model_index = curr_index % len(X_finals)
183 | x_final = X_finals[model_index]
184 |
185 | noise_scale = [1]
186 | if len(data_buffer) < 1000:
187 | x_init = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
188 | label = np.random.randint(0, classes, (FLAGS.batch_size))
189 | label = identity[label]
190 | x_new = sess.run([x_final], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})[0]
191 | data_buffer.add(x_new, label)
192 | else:
193 | (x_init, label), idx = data_buffer.sample(FLAGS.batch_size)
194 | keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.99)
195 | label_keep_mask = (np.random.uniform(0, 1, (FLAGS.batch_size)) > 0.9)
196 | label_corrupt = np.random.randint(0, classes, (FLAGS.batch_size))
197 | label_corrupt = identity[label_corrupt]
198 | x_init_corrupt = np.random.uniform(0, 1, (FLAGS.batch_size, n, n, 3))
199 |
200 | if i < itr - FLAGS.nomix:
201 | x_init[keep_mask] = x_init_corrupt[keep_mask]
202 | label[label_keep_mask] = label_corrupt[label_keep_mask]
203 | # else:
204 | # noise_scale = [0.7]
205 |
206 | x_new, e_noise = sess.run([x_final, energy_noise], {X_START:x_init, Y_GT:label, NOISE_SCALE: noise_scale})
207 | data_buffer.set_elms(idx, x_new, label)
208 |
209 | if FLAGS.im_number != 50000:
210 | print(np.mean(e_noise), np.std(e_noise))
211 |
212 | curr_index += 1
213 |
214 | ims = np.array(data_buffer._storage[:1000])
215 | ims = rescale_im(ims)
216 |
217 | images.extend(list(ims))
218 |
219 | saveim = osp.join('sandbox_cachedir', FLAGS.exp, "test{}.png".format(FLAGS.idx))
220 |
221 | ims = ims[:100]
222 |
223 | if FLAGS.dataset != "imagenetfull":
224 | im_panel = ims.reshape((10, 10, 32, 32, 3)).transpose((0, 2, 1, 3, 4)).reshape((320, 320, 3))
225 | else:
226 | im_panel = ims.reshape((10, 10, 128, 128, 3)).transpose((0, 2, 1, 3, 4)).reshape((1280, 1280, 3))
227 | imsave(saveim, im_panel)
228 |
229 | print("Saved image!!!!")
230 | splits = max(1, len(images) // 5000)
231 | score, std = get_inception_score(images, splits=splits)
232 | print("Inception score of {} with std of {}".format(score, std))
233 |
234 | # FID score
235 | # n = min(len(images), len(test_ims))
236 | fid = get_fid_score(images, test_ims)
237 | print("FID of score {}".format(fid))
238 |
239 |
240 |
241 |
242 | def main(model_list):
243 |
244 | if FLAGS.dataset == "imagenetfull":
245 | model = ResNet128(num_filters=64)
246 | elif FLAGS.large_model:
247 | model = ResNet32Large(num_filters=128)
248 | elif FLAGS.larger_model:
249 | model = ResNet32Larger(num_filters=hidden_dim)
250 | elif FLAGS.wider_model:
251 | model = ResNet32Wider(num_filters=256, train=False)
252 | else:
253 | model = ResNet32(num_filters=128)
254 |
255 | # config = tf.ConfigProto()
256 | sess = tf.InteractiveSession()
257 |
258 | logdir = osp.join(FLAGS.logdir, FLAGS.exp)
259 | weights = []
260 |
261 | for i, model_num in enumerate(model_list):
262 | weight = model.construct_weights('context_{}'.format(i))
263 | initialize()
264 | save_file = osp.join(logdir, 'model_{}'.format(model_num))
265 |
266 | v_list = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='context_{}'.format(i))
267 | v_map = {(v.name.replace('context_{}'.format(i), 'context_0')[:-2]): v for v in v_list}
268 | saver = tf.train.Saver(v_map)
269 | try:
270 | saver.restore(sess, save_file)
271 | except:
272 | optimistic_remap_restore(sess, save_file, i)
273 | weights.append(weight)
274 |
275 |
276 | if FLAGS.dataset == "imagenetfull":
277 | X_START = tf.placeholder(shape=(None, 128, 128, 3), dtype = tf.float32)
278 | else:
279 | X_START = tf.placeholder(shape=(None, 32, 32, 3), dtype = tf.float32)
280 |
281 | if FLAGS.dataset == "cifar10":
282 | Y_GT = tf.placeholder(shape=(None, 10), dtype = tf.float32)
283 | else:
284 | Y_GT = tf.placeholder(shape=(None, 1000), dtype = tf.float32)
285 |
286 | NOISE_SCALE = tf.placeholder(shape=1, dtype=tf.float32)
287 |
288 | X_finals = []
289 |
290 |
291 | # Seperate loops
292 | for weight in weights:
293 | X = X_START
294 |
295 | steps = tf.constant(0)
296 | c = lambda i, x: tf.less(i, FLAGS.num_steps)
297 | def langevin_step(counter, X):
298 | scale_rate = 1
299 |
300 | X = X + tf.random_normal(tf.shape(X), mean=0.0, stddev=scale_rate * FLAGS.noise_scale * NOISE_SCALE)
301 |
302 | energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
303 | x_grad = tf.gradients(energy_noise, [X])[0]
304 |
305 | if FLAGS.proj_norm != 0.0:
306 | x_grad = tf.clip_by_value(x_grad, -FLAGS.proj_norm, FLAGS.proj_norm)
307 |
308 | X = X - FLAGS.step_lr * x_grad * scale_rate
309 | X = tf.clip_by_value(X, 0, 1)
310 |
311 | counter = counter + 1
312 |
313 | return counter, X
314 |
315 | steps, X = tf.while_loop(c, langevin_step, (steps, X))
316 | energy_noise = model.forward(X, weight, label=Y_GT, reuse=True)
317 | X_final = X
318 | X_finals.append(X_final)
319 |
320 | target_vars = {}
321 | target_vars['X_START'] = X_START
322 | target_vars['Y_GT'] = Y_GT
323 | target_vars['X_finals'] = X_finals
324 | target_vars['NOISE_SCALE'] = NOISE_SCALE
325 | target_vars['energy_noise'] = energy_noise
326 |
327 | compute_inception(sess, target_vars)
328 |
329 |
330 | if __name__ == "__main__":
331 | # model_list = [117000, 116700]
332 | model_list = [FLAGS.resume_iter - 300*i for i in range(FLAGS.ensemble)]
333 | main(model_list)
334 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | ## resources and experiments on autonomous agents
2 |
3 |
4 |
5 | * **[⬛ ai && ml tl; dr](deep_learning)**
6 | * **[⬛ large language models](llms)**
7 | * **[⬛ agents on blockchains](crypto_agents)**
8 | * **[⬛ on quantum computing](EBMs)** (my adaptation of openai's implicit generation and generalization in energy based models)
9 |
10 |
11 |
12 | ### cool resources
13 |
14 |
15 |
16 | * **[mr. vp jd vance at the ai action summit in paris (2025)](https://www.youtube.com/watch?v=MnKsxnP2IVk)**
17 |
--------------------------------------------------------------------------------
/books/advances_in_financial_machine_learning.pdf:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/autistic-symposium/ml-ai-agents-py/fdba2218a052e2a23dd7c0d7d2d029168377f11a/books/advances_in_financial_machine_learning.pdf
--------------------------------------------------------------------------------
/crypto_agents/README.md:
--------------------------------------------------------------------------------
1 | ## crypto agents
2 |
3 |
4 |
5 | * **[basic strategy workflow](strategy_workflow)**
6 |
7 |
8 |
9 | ---
10 |
11 | ### cool resources
12 |
13 |
14 |
15 |
16 | ##### projects
17 |
18 | * **[ritual.net](https://ritual.net/)**
19 | * **[sahara labs](https://saharalabs.ai/)**
20 | * **[multi-agent orchestrator](https://github.com/awslabs/multi-agent-orchestrator)**
21 | * **[swarms](https://github.com/kyegomez/swarms)**
22 | * **[hugging faces](https://huggingface.co/)**
23 | * **[io.net](https://io.net/)**
24 | * **[exo](https://github.com/exo-explore/exo)**
25 |
26 |
27 |
28 | ##### readings
29 |
30 | * **[the internet's notary public: why verifiability matters, by axal](https://axal.substack.com/p/the-internets-notary-public-why-verifiability)**
31 | * **[cryptos role in the ai revolution, by pantera](https://panteracapital.com/blockchain-letter/cryptos-role-in-the-ai-revolution/)**
32 | * **[the promise and challenges of crypto + ai applications, by vub](https://vitalik.eth.limo/general/2024/01/30/cryptoai.html)**
33 | * **[on training defi agents with markov chains, by bt3gl](https://mirror.xyz/go-outside.eth/DKaWYobU7q3EvZw8x01J7uEmF_E8PfNN27j0VgxQhNQ)**
34 |
35 |
36 |
37 | ##### metas
38 |
39 | * **eth denver 2025 ai + agentic metas**:
40 | * **[a new dawn for decentralized and the convergence of ai x blockchain, by n. ameline](https://www.youtube.com/watch?v=HQuGtN9zidQ)**
41 | * **[upgrading agent infra with onchain perpetualaAgents, by kd conway](https://www.youtube.com/watch?v=hDayCeDA5fI)**
42 | * **[ai meets blockchain: payment, multi-Agent & distributed inference, by o. jaros](https://www.youtube.com/watch?v=aPTosw4hrY0)**
43 | * **[from ai agents to agentic economies, by r. bodkin](https://www.youtube.com/watch?v=Q7eaYJ9aPpI)**
44 | * **[building ai agents and agent economies, by d. minarsch](https://www.youtube.com/watch?v=tDVK2Q5RY0c)**
45 | * **[building ai agents on top of hedera, by j. hall](https://www.youtube.com/watch?v=h8D6vi2m8LQ)**
46 | * **[identity, privacy, and security in the new age of ai agents, by m. csernai](https://www.youtube.com/watch?v=vMcaot04RQo)**
47 | * **[verifiable ai agents and world superintelligence, by k. wong](https://www.youtube.com/watch?v=ngkp7HTj_4A)**
48 | * **[how web3 can compete in ai, by g. narula](https://www.youtube.com/watch?v=oLLM1I-3fDU)**
49 | * **[shade agents, by m. lockyer](https://www.youtube.com/watch?v=PEfJnCtrbMU)**
50 | * **[how ai will enable the next generatin of intent, by i. yang](https://www.youtube.com/watch?v=fbc3DpI6jiA)**
51 | * **[2025 is the year of agents, by i. polosukhin](https://www.youtube.com/watch?v=jPyzVNcQMKw)**
52 |
53 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/README.md:
--------------------------------------------------------------------------------
1 | ## strategy workflow
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | 1. **[data analysis](data_analysis.md)**
12 | 2. **[supervised model training](supervised_learning.md)**
13 | 3. **[policy development](policy.md)**
14 | 4. **[backtesting](backtesting.md)**
15 | 5. **[parameter optimization](optimization.md)**
16 | 6. **[simulation and paper trading](paper_trading.md)**
17 | 7. **[live trading](live_trading.md)**
18 | 8. **[strategy metrics](strategy_metrics.md)**
19 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/backtesting.md:
--------------------------------------------------------------------------------
1 | ## strategy backtesting
2 |
3 |
4 |
5 | * use a simulator to test an inital version of the strategy against a set of historical data.
6 | * the simulator can take things such as order book liquidity, network latencies, fees, etc.
7 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/data_analysis.md:
--------------------------------------------------------------------------------
1 | ## data analysis
2 |
3 |
4 |
5 | * perform exploratory data analysis to find trading opportunities, such as look at charts, calculate statistics, etc.
6 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/defi_glossary.md:
--------------------------------------------------------------------------------
1 | ## DeFi and MEV Glossary
2 |
3 |
4 |
5 |
6 | ### A
7 |
8 | - Arbitrage: the simultaneous buying and selling of assets (e.g., cryptocurrencies) in several markets to take advantage of their price discrepancies.
9 | - Assets under management (AUM): the total market value of the investments that a person or entity manages on behalf of clients.
10 |
11 |
12 |
13 |
14 | ### B
15 |
16 | - Backrunning: when an attacker attempts to have a transaction ordered immediately after a certain unconfirmed target transaction.
17 | - Blocks: a block contains transaction data and the hash of the previous block ensuring immutability in the blockchain network. Each block in a blockchain contains a list of transactions in a particular order. These transactions encode the updates to the blockchain state.
18 | - Block time: the time interval between blocks being added to the blockchain.
19 | - Broadcasting: whenever a user interacts with the blockchain, they broadcast a request to include the transaction to the network. This request is public (anyone can listen to it).
20 | - Builders: actors that take bundles (of pendent transactions from the mempool) and create a final block to send to (multiple) relays (setting themselves afeeRecipient to receive the block’s MEV).
21 | - Bundles: one or more transactions that are grouped together and executed in the order they are provided. In addition to the searcher's transaction(s), a bundle can also contain other users' pending transactions from the mempool. Bundles can target specific blocks for inclusion as well.
22 |
23 |
24 |
25 | ### C
26 |
27 | - Central limit order book (CLOB): patient buyers and sellers post limit orders with the price and size that they are willing to buy or sell a given asset. Impatient buyers and sellers place market orders that run through the CLOB until the desired size is reached.
28 | - Contract address: the address hosting some source code deployed on the Ethereum blockchain, which is executed by a triggering transaction.
29 | - Crypto copy trading strategy: a trading strategy that uses automation to buy and sell crypto, letting you copy another trader's method.
30 |
31 |
32 |
33 | ### D
34 |
35 | - Derivatives: financial contracts that derive their values from underlying assets.
36 | - Dollar-cost-averaging (DCA) strategy: a one-stop automated trading, based on time intervals, and reducing the influence of market volatility. Parameters for DCA can be: currency, fixed/maximum investment, and amount, investment frequency.
37 |
38 |
39 |
40 | ### E
41 |
42 | - Epoch: in the context of Ethereum's block production, in each slot (every 12 seconds), a validator is randomly chosen to propose the block in that slot. An epoch contains 32 slots.
43 | - Externally owned account (EOA): an account that is a combination of public address and private key, and that can be used to send and receive Ether to/from another account. An Ethereum address is a 42-character hexadecimal address derived from the last 20 bytes of the public key of the account (with 0x appended in front).
44 |
45 |
46 |
47 | ### F
48 |
49 | - Frontrunning: the process by which an adversary observes transactions on the network layer and acts on this information to obtain profit.
50 | - Fully diluted valuations (FDV): the total number of tokens multiplied by the current price of a single token.
51 | - Futures: contracts used as proxy tools to speculate on the future prices of crypto assets or to hedge against their price changes.
52 | - Future grid trading bots: bots that automate futures trading activities based on grid trading strategies (a set of orders is placed both above and below a specific reference market price for the asset).
53 |
54 |
55 |
56 | ### G
57 |
58 | - Gas price: used somewhat like a bid, indicating an amount the user is willing to pay (per unit of execution) to have their transaction processed.
59 | - Gwei: a small unit of the Ethereum network's Ether (ETH) cryptocurrency. A gwei or gigawei is defined as 1,000,000,000 wei, the smallest base unit of Ether. Conversely, 1 ETH represents 1 billion gwei.
60 | - Grid trading strategy: a strategy that involves placing orders above and below a set price, using a price grid of orders (which shows orders at incrementally increasing and decreasing prices). Grid trading is based on the overarching goal of buying low and selling high.
61 |
62 |
63 |
64 | ### H
65 |
66 | - Hedging: taking short positions.
67 |
68 |
69 |
70 | ### K
71 |
72 | - Keys: blockchain account keys can be either private keys (for digital signatures), or public keys (for addresses).
73 |
74 |
75 |
76 | ### L
77 |
78 | - Limit orders: when one longs or shorts a contract, several execution options can be placed (usually with a fee difference). Limit orders that are set at a specific price to be traded, and there is no guarantee that the trade will be executed (see market orders and stop-loss orders).
79 | - Liquidity pools: a collection of crypto assets that can be used for decentralized trading. They are essential for automated market makers (AMM), borrow-lend protocols, yield farming, synthetic assets, on-chain insurance, blockchain gaming, etc.
80 | - Liquidation threshold: the percentage at which a collateral value is counted towards the borrowing capacity.
81 | - Liquidation: when the value of a borrowed asset exceeds the collateral. Anyone can liquidate the collateral and collect the liquidation fee for themselves.
82 | - Long: traders maintain long positions, which means that they expect the price of a coin to rise in the future.
83 |
84 |
85 |
86 | ### M
87 |
88 | - Fully diluted market capitalization: the total token supply, multiplied by the price of a single token.
89 | - Circulating supply market capitalization: the number of tokens that are available in the market, multiplied by the price of a single token.
90 | - Margin trading: buying or sell assets with leverage.
91 | - Marginal seller: a type of seller who is willing first to leave the market if the prices are lower.
92 | - Market orders: Market orders are executed immediately at the asset's market price (see limit orders).
93 | - Mean reversion strategy: a trading range (or mean reversion) strategy is based on the concept that an asset's high and low prices are a temporary effect that reverts to their mean value (average value).
94 | - Mempool: a cryptocurrency node’s mechanism for storing information on unconfirmed transactions.
95 | - Merkle tree: a type of binary tree, composed of: 1) a set of notes with a large number of leaf nodes at the bottom, containing the underlying data, 2) a set of intermediate nodes where each node is the hash of its two children, and 3) a single root node, also formed from the hash of its two children, representing the top of the tree.
96 | - Minting: the process of validating information, creating a new block, and recording that information into the blockchain.
97 |
98 |
99 |
100 | ### P
101 |
102 | - Perpetual contract: a contract without an expiration date, where interest rates can be calculated by methods such as Time-Weighted-Average-Price (TWAP).
103 | - Priority gas auctions: bots compete against each other by binding up transaction fees (gas) to extract revenue from arbitrage opportunities, driving up user fees.
104 | - Private key: a secret number enabling a blockchain user to prove ownership on an account or contract, via a digital signature.
105 | - Publick key: a number generated by a one-way (hash) function from the private key, used to verify a digital signature made with the matching private key.
106 | - Provider: an entity that provides an abstraction for a connection to the blockchain network.
107 | - POFPs: private order flow protocols.
108 |
109 |
110 |
111 | ### O
112 |
113 | - Order flow: in the context of Ethereum and EVM-based blockchains, an order is anything that allows changing the state of the blockchain.
114 | - Open interest: total number of futures contracts held by market participants at the end of the trading day. Used as an indicator to determine market sentiment and the strength behind price trends.
115 |
116 |
117 |
118 | ### R
119 |
120 | - RPC endpoints: blockchain odes with RPC endpoints.
121 |
122 |
123 |
124 | ### S
125 |
126 | - Slots: in the context of Ethereum's block production, a slot is a time period of 12 seconds in which a randomly chosen validator has time to propose a block.
127 | - Smart contracts: a computer protocol intended to enforce a contract on the blockchain without third parties. They are reliant upon code (the functions) and data (the state), and they can trigger specific actions, such as transferring tokens from A to B.
128 | - Sandwich attack: when slippage value is not set, this attack can happen by an actor bumping the price of an asset to an unfavorable level, executing the trade, and then returning the asset to the original price.
129 | - Slippage: delta in pricing between the time of order and when the order is executed.
130 | - Short: traders maintain short positions, which means they expect the price of a coin to drop in the future.
131 | - Short squeeze: occurs when a heavily shorted stock experiences an increase in price for some unexpected reason. This situation prompts short sellers to scramble to buy the stock to cover their positions and cap their mounting losses.
132 | - Spot trading: buy or selling assets for immediate delivery.
133 | - Statistical trading: is the class of strategies that aim to generate profitable situations, stemming from pricing inefficiencies among financial markets. Statistical arbitrage is a strategy to obtain profit by applying past statistics.
134 | - Stop-loss orders: this type of order execution places a market/limit order to close a position to restrict an investor's loss on a crypto asset.
135 |
136 |
137 |
138 | ### T
139 |
140 | - otal value locked (TVL): the value of all tokens locked in various DeFi protocols such as lending platforms, DEXes, or derivatives protocols.
141 | - Тrading volume: the total amount of traded cryptocurrency (equivalent to US dollars) during a given timeframe.
142 | - Transaction: on EVM-based blockchains, there the two types of transactions are normal transactions and contract interactions.
143 | - Transaction hash: a unique 66-character identifier generated with each new transaction.
144 | - Transaction ordering: blockchains usually have loose requirements for how transactions are ordered within a block, allowing attacks that benefit from certain ordering.
145 | - Time-weighted average price strategy: TWAP strategy breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using evenly divided time slots between a start and end time.
146 |
147 |
148 |
149 | ### V
150 |
151 | - Validation: a mathematical proof that the state change in the blockchain is consistent. To be included into a block in the blockchain, a list of transactions needs to be validated.
152 | - VTRPs: validator transaction Reordering protocols.
153 | - Volume-weighted average price strategy: VWAP breaks up a large order and releases dynamically determined smaller chunks of the order to the market, using historical volume profiles.
154 |
155 |
156 |
157 | ### W
158 |
159 | - Whales: individuals or institutions who hold large amounts of coins of a certain cryptocurrency, and can become powerful enough to manipulate the valuation.
160 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/live_trading.md:
--------------------------------------------------------------------------------
1 | ## live trading
2 |
3 |
4 |
5 | * the strategy is now running live on an exchange.
6 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/optimization.md:
--------------------------------------------------------------------------------
1 | ## parameter optimization
2 |
3 |
4 |
5 | * perform a search, for example grid search, over possible values of strategy parameters like thresholds or coefficients (using the simulator and a set of historical data)
6 | * overfitting to historical data is a big risk (be careful with validation and test sets).
7 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/paper_trading.md:
--------------------------------------------------------------------------------
1 | ## paper trading
2 |
3 |
4 |
5 | * before the strategy goes live, simulation is done on new market data, in real-time (paper trading), which prevents overfitting
6 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/policy.md:
--------------------------------------------------------------------------------
1 | ## policy development
2 |
3 |
4 |
5 | * come with a rule-based policy that determines what actions to take based on the current state of the market and the outpus of supervised models.
6 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/strategy_metrics.md:
--------------------------------------------------------------------------------
1 | ## trading strategy metrics
2 |
3 |
4 |
5 | * **net pnl (net profit and loss):** how much money an algorithm makes (positive) or loses (negative) over some period, minus trading fees
6 | * **alpha nad beta**
7 | * **shape ratio:** the excess return per unit of risk you are taking (return on capital over the standard deviation adjusted for risk; the higher the better).
8 | * **maximum drawdown:** maximum difference between a local maximum and a subsequent local minimum as an another measure of risk.
9 | * **value at risk (var):** how much capital you may lose over a given time frame with some probability, assumong normal market conditions.
10 |
--------------------------------------------------------------------------------
/crypto_agents/strategy_workflow/supervised_learning.md:
--------------------------------------------------------------------------------
1 | ## supervised learning
2 |
3 |
4 |
5 | * train one or more supervised learning models to predict quantities of interest that are necessary for the strategy work, for example, price prediction, quantity prediction, etc.
6 |
--------------------------------------------------------------------------------
/crypto_agents/trading_on_gmx.md:
--------------------------------------------------------------------------------
1 | ## basics on trading, illustrated on gmx
2 |
3 |
4 |
5 | #### price chart
6 |
7 |
8 | * the current price is the price of the most recent trade.
9 | * it varies on whether that trade was a buy or a sell.
10 | * high volume means the price movement is more reliable (consensus of a large number of market participants).
11 | * candlesitck chart showing open/start (O), high (h), low (low), anc close/end (c) prices for a given time window.
12 |
13 |
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 | ----
22 |
23 | #### order book
24 |
25 |
26 |
27 | * the order book is made of two sides, asks (sell, offers) and bids (buy).
28 | * the best ask (the lowest price someone is willing to sell ) > the best bid (the highest price someone is willing to buy).
29 | * the difference between the best ask and the best bid is called spread.
30 | * **market order**: best price possible, right now. it takes liquidity from the market and usually has higher fees.
31 | * **limit order (passive order)**: specify the price and qty you are willing to buy or sell at, and then wait for the match.
32 | * **stop orders**: allow you to set a maximum price for your market orders.
33 |
34 |
35 |
--------------------------------------------------------------------------------
/deep_learning/README.md:
--------------------------------------------------------------------------------
1 | ## ai agents
2 |
3 |
4 |
5 | * **[deep learning](deep_learning.md)**
6 | * **[reinforcement learning](reinforcement_learning.md)**
7 |
8 |
9 |
10 | ----
11 |
12 | ### cool resources
13 |
14 |
15 |
16 | * **[cursor ai editor](https://www.cursor.com/)**
17 | * **[microsoft notes on ai agents](https://github.com/microsoft/generative-ai-for-beginners/tree/main/17-ai-agents)**
18 | * **[google's jax (composable transformations of numpy programs)](https://github.com/google/jax)**
19 | * **[machine learning engineering open book](https://github.com/stas00/ml-engineering)**
20 | * **[advances in financial machine learning](books/advances_in_financial_machine_learning.pdf)**
21 |
--------------------------------------------------------------------------------
/deep_learning/deep_learning.md:
--------------------------------------------------------------------------------
1 | ## deep learning
2 |
3 |
4 |
5 | ### timeline tl; dr
6 |
7 |
8 |
9 | * **[2012: imagenet and alexnet](https://github.com/tensorflow/models/blob/master/research/slim/nets/alexnet.py)**
10 |
11 | * **[2013: atari with deep reinforcement learning](https://www.tensorflow.org/agents/tutorials/1_dqn_tutorial)**
12 | * **[2014: seq2seq](https://www.tensorflow.org/addons/tutorials/networks_seq2seq_nmt)**
13 | * **[2014: adam optmizer](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/optimizer_v2/adam.py#L32-L281)**
14 | * **[2015: gans](https://www.tensorflow.org/tutorials/generative/dcgan)**
15 | * **[2015: resnets](https://github.com/tensorflow/tensorflow/blob/v2.2.0/tensorflow/python/keras/applications/resnet.py)**
16 | * **[2017: transformers](https://github.com/huggingface/transformers)**
17 | * **[2018: bert](https://arxiv.org/abs/1810.04805)**
18 |
19 |
20 |
21 | ---
22 |
23 | ### deep reinforcement learning for trading
24 |
25 |
26 |
27 | * a map consists of a set of states, a set of actions, a transition function that describes the probability of moving rom one state to another after taking an action, and a reward function that assigns a numerical reward to each state-action pair
28 |
29 | * the goal of a map is to maximize its expected cumulative reward over a sequence of actions, called a policy.
30 |
31 | * a policy is a function that maps each state to a probability distribution over actions. The optimal policy is the one that maximizes the expected cumulative rewards.
32 |
33 | * the problem of reinforcement learning can be formalized using ideas from dynamical systems theory, specifically, as the optimal control of incompletely-known Markov decision processes.
34 |
35 | * as opposed to supervised learning, an agent must be able to learn from its own experience. and as oppose to unsupervised learning because, reinforcement learning is trying to maximize a reward signal instead of trying to find hidden structure.
36 |
37 | * the agent has to exploit what it has already experienced in order to obtain reward, but it also has to explore in order to make better action selections in the future. on a stochastic task, each action must be tried many times to gain a reliable estimate of its expected reward.
38 |
39 | * beyond the agent and the environment, one can identify four main subelements of a reinforcement learning system: a policy, a reward signal, a value function, and, optionally, a model of the environment.
40 |
41 | * traditional reinforcement learning problems can be formulated as a markov decision process (MDP):
42 | * we have an agent acting in an environment
43 | * each step *t* the agent receives as the input the current state S_t, takes an action A_t, and receives a reward R_{t+1} and the next state S_{t+1}
44 | * the agent choose the action based on some policy pi: A_t = pi(S_t)
45 | * it's our goal to find a policy that maximizes the cumulative reward Sum R_t over some finite or infinite time horizon
46 |
47 |
48 |
49 |
50 |
51 |
52 |
53 |
54 | #### agent
55 |
56 |
57 |
58 | * agent is the trading agent (e.g. the human trader who opens the gui of an exchange and makes trading decision based on the current state of the exchange and their account)
59 |
60 |
61 |
62 | #### environment
63 |
64 |
65 |
66 | * the exchange and other agents are the environment, and they are not something we can control
67 | * by putting other agents together into some big complex environment, we lose the ability to explicitly model them
68 | * if we try to reverse-engineer the algorithms and strategies that other traders are running, put us into a multi-agent reinforcement learning (MARL) problem setting
69 |
70 |
71 |
72 | #### state
73 |
74 |
75 |
76 | * in the case of trading on an exchange, we don't observe the complete state of the environment (e.g. other agents), so we are dealing with a partially observable markov decision process (pomdp).
77 | * what the agents observe is not the actual state S_t of the environment, but some derivation of that.
78 | * we can call the observation X_t, which is calculated using some function of the full state X_t ~ O(S_t)
79 | * the observation at each timestep t is simply the history of all exchange events received up to time t.
80 | * this event history can be used to build up the current exchange state, however, in order for our agent to make decisions, extra info such as account balance and open limit orders need to be included.
81 |
82 |
83 |
84 | #### time scale
85 |
86 |
87 |
88 | * hft techniques: decisions are based almost entirely on market microstructure signals. decisions are made on nanoseconds timescales and trading strategies use dedicated connections to exchanges and extremly fast but simple algorithms running fpga hardware.
89 | * neural networks are slow, they can't make predictions on nanoseconds time scales, so they can't compete with the speed of hft algorithms.
90 | * guess: the optimal time scale is between a few milliseconds and a few minutes.
91 | * can deep rl algorithms pick up hidden patterns?
92 |
93 |
94 |
95 | #### action space
96 |
97 |
98 |
99 | * the simplest approach has 3 actions: buy, hold, and sell. this works but limits us to placing market orders and to invest a deterministic amount of money at each step.
100 | * in the next level we would let our agents learn how much money to invest, based on the uncertainty of our model, putting us into a continuous action space.
101 | * in the next level, we would introduce limit orders, and the agent needs to decide the level (price) and wuantity of the order, and be able to cancel orders that have not been yet matched.
102 |
103 |
104 |
105 | #### reward function
106 |
107 |
108 |
109 | * there are several possible reward functions, an obvious would realized PnL (profit and loss). the agent receives a reward whenever it closes a position.
110 | * the net profit is either negative or positive, and this is the reward signal.
111 | * as the agent maximize the total cumulative reward, it learns to trade profitably. the reward function leads to the optimal policy in the limit.
112 | * however, buy and sell actions are rare compared to doing nothing; the agent needs to learn without receiving frequent feedback.
113 | * an alternative is unrealized pnl, which the net profit the agent would get if it were to close all of its positions immediately.
114 | * because the unrealized pnl may change at each time step, it gives the agent more frequent feedback signals. however the direct feedback may bias the agent towards short-term actions.
115 | * both naively optimize for profit, but a trader may want to minimize risk (lower volatility)
116 | * using the sharpe ration is one simple way to take risk into account. other way is maximum drawdown.
117 |
118 |
119 |
120 |
121 |
122 |
123 |
124 | #### learned policies
125 |
126 |
127 |
128 | * instead of needing to hand-code a rule-based policy, rl directly learns a policy
129 |
130 |
131 |
132 |
133 | #### trained directly in simulation environments
134 |
135 |
136 |
137 | * we need separate backtesting and parameter optimization steps because it was difficult for our strategies to take into account environmental factors: order book liquidity, fee structures, latencies.
138 | * getting around environmental limitations is part of the opimization process. if we simulate the latency in the reinforcement learning environment, and this results in the agent making a mistake, the agent will get a negative rewards, forcing it to learn to work around the latencies.
139 | * by learning a model of the environment and performing rollouts using techniques like a monte carlo tree search (mcts), we could take into account potential reactions of the market (other agents)
140 | * by being smart about the data we collect from the live environment, we can continously improve our model
141 | * do we act optimally in the live environment to generate profits, or do we act suboptimally to gather interesting information that we can use to improve the model of our environment and other agents?
142 |
143 |
144 |
145 | #### learning to adapt to market conditions
146 |
147 |
148 |
149 | * some strategy may work better in a bearish environment but lose money in a bullish environment.
150 | * because rl agents are learning powerful policies parameterized by NN, they can alos learn to adapt to market conditions by seeing them in historical data, given that they are trained over long time horizon and have sufficient memory.
151 |
152 |
153 |
154 | #### trading as research
155 |
156 |
157 |
158 | * the trading environment is a multiplayer game with thousands of agents acting simultaneously
159 | * understanding how to build models of other agents is only one possible we can, we can choose perfom actions in a live environment with the goal of maximizing the information grain with respect to kind policies the other agents may be following
160 | * trading agents receive sparse rewards from the market. naively applying reward-hungry rl algorithms will fail.
161 | * this opens up the possibility for new algorithms and techniques, that can efficiently deal with sparse rewards.
162 | * many of today's standard algorithms, such as dqn or a3c, use a very naive approach exploration - basically adding random noise to the policy. however, in the trading case, most states in the environment are bad, and there are only a few good ones. a naive random approach to exploration will almost never stumble upon good state-actions paris.
163 | * the trading environment is inherently nonstationary. market conditions change and other agent join, leave, and constantly change their strategies.
164 | * can we train an agent that can transit from bear to bull and then back to bear, without needing to be re-trained?
165 |
--------------------------------------------------------------------------------
/deep_learning/reinforcement_learning.md:
--------------------------------------------------------------------------------
1 | ## reinforcement learning
2 |
3 |
4 |
5 | ### tl; dr
6 |
7 |
8 |
9 | * reinforcement learning is learning what to do (how to map situations to actions) so as to maximize a numerical reward signal
10 | * an autonomous agent is a software program or system that can operate independently and make decisions on its own, without direct intervention from a human
11 |
12 |
13 |
14 | ---
15 |
16 | ### overview
17 |
18 |
19 |
20 | * we formalize the problem of reinforcement using ideas from dynamical system theory, as the optimal control of incompletely-known Markov decision processes.
21 | * a learning agent must be able to sense the state of its environment to some extent and must be able to take actions that affect the state.
22 | * markov decision processes are intented to include just these three aspects, sensation, action, and goal.
23 | * the agent has to exploit what it has already experienced in order to obtain reward, but it has also to explore in order to make better action selections in the future.
24 | * on a stochastic tasks, each action must be tried many times to gain a reliable estimate of its expected reward.
25 |
26 |
27 |
28 | ---
29 |
30 | ### elements of reinforcement learning
31 |
32 |
33 |
34 | * beyond the agent and the environment, 4 more elements belong to a reinforcement learning system: a policy, a reward signal, a value funtion, and a model of the environmnet.
35 | * a policy defines the learning agent's way of behacing at a given time. It's a mapping from perceiv ed states of the environment to actions to be taken when in those states. in general, policies may be stochastics (specifying probabilities for each action).
36 | * a reward signal defines the goal of a reinforcement learning problem: on each time step, the environment sends to the reinforcement learning agent a single number called the reward. the agent's sole objective is to maximize the total reward over the run.
37 | * a value function specifies what is good in the long run, the valye of a state in the total amount of reward an agent can expect to accumulate over the future, starting from that state
38 | * a model of the environment.
39 | * the most important feature distinguishing reinforcement learning from other types of learning is that it uses training information that evaluates the actions taken rather than instructs by giving correct actions.
40 |
41 |
42 |
43 | ---
44 |
45 | ### finite markov decision processes (mdps)
46 |
47 |
48 |
49 | * the problem involves evaluating feedbacks and choosing different actions in different situations.
50 | * mdps are a classical formalization of sequential decision making, where actions influence not just immediate rewards, but also subsequent situations.
51 | * mdps involve delayed reward and the need to trade off immediate and delayed reward.
52 |
53 |
54 |
55 | ##### the agent-environment interface
56 |
57 | * mdps are meant to be a straightfoward framing of the problem of learning from interaction to achieve a goal.
58 | * the learner and the decision makers is called the agent.
59 | * the thing it interacts with, comprimising everything outside the agent, is called the environment.
60 | * the environment gives rise to rewards, numerical values that the agent seeks to maximize over time through its choice of actions.
61 |
62 |
63 |
64 |
65 |
66 |
67 |
68 | * the agent and the environment interact at each of a sequence of discrete steps, t = 0, 1, 2, 3...
69 | * at each time step t, the agent receives some representation of the environments state St
70 | * on that basis, the agent selects an action At
71 | * one step later, in part of a consequence of its action, the agent receives a numerical rewards and finds itself in a new state.
72 | * the mdp and the agent together give rise to a sequence (trajectory)
73 | * in a finite mdp, the set of states, actions, and rewards all have a finite number of elements. in this case, the random variables R and S have well defined discrete probability distributions dependent only on the proceding state and action.
74 | * in a markov decision process, the probabilities given by p completely characterize the environment's dynamics.
75 | * the state must include information about all aspects of the past agent-environment interaction that make a differnce for the future.
76 | * anything that cannot be changed arbitrarily by the agent is considered to be outside of it and thus part of its environment.
77 |
78 |
79 |
80 | ##### goals and rewards
81 |
82 |
83 | * each episode ends in a special state called the terminal state, followed by a reset to a standard starting state or to a sample from a standard distribution of starting states.
84 | * almost all reinforcement learning algorithms involve estimating value functions—functions of states (or of state–action pairs) that estimate how good it is for the agent to be in a given state (or how good it is to perform a given action in a given state).
85 | * the Bellman equation averages over all the possibilities, weighting each by its probability of occurring. tt states that the value of the start state must equal the
86 | (discounted) value of the expected next state, plus the reward expected along the way.
87 | * solving a reinforcement learning task means finding a policy that achieves a lot of reward over the long run.
88 |
89 |
90 |
91 | ---
92 |
93 | ### dynamic programming
94 |
95 | * collection of algorithms that can be used to compute optimal policies given a perfect model of the environment as a mdp.
96 | * a common way of obtaining approximate solutions for tasks with continuous states and actions is to quantize the state and action spaces and then apply finite-state DP methods.
97 | * the reason for computing the value function for a policy is to help find better policies.
98 | * asynchronous DP algorithms are in-place iterative DP algorithms that are not organized in terms of systematic sweeps of the state set. these algorithms update the values of states in any order whatsoever, using whatever values of other states happen to be available. the values of some states may be updated several times before the values of others ar
99 | * policy evaluation refers to the (typi- cally) iterative computation of the value function for a given policy.
100 | * policy improvement refers to the computation of an improved policy given the value function for that policy.
101 |
102 |
103 |
104 | ##### generalized policy interaction
105 |
106 | * policy iteration consists of two simultaneous, interacting processes, one making the value function consistent with the current policy (policy evaluation), and the other making the policy greedy with respect to the current value function (policy improvement).
107 | * generalized policy iteration (GPI) refers to the general idea of letting policy-evaluation and policy-improvement processes interact, independent of the granularity and other details of the two processes.
108 | * DP is sometimes thought to be of limited applicability because of the curse of dimen- sionality, the fact that the number of states often grows exponentially with the number of state variables
109 |
110 |
111 |
112 | ---
113 |
114 | ### cool resources
115 |
116 |
117 |
118 | * **[gymnasium api](https://gymnasium.farama.org/)**
119 | * **[reinforcement learning with unsupervised auxiliary tasks, by jaderberg et al.](https://arxiv.org/abs/1611.05397)**
--------------------------------------------------------------------------------
/llms/README.md:
--------------------------------------------------------------------------------
1 | ## large language models
2 |
3 |
4 |
5 | * **[gpt-x](gpt)**
6 | * **[claude](claude)**
7 | * **[eliza](eliza)**
8 | * **[deepseek](deeepsek)**
9 |
10 |
11 |
12 | ---
13 |
14 | ### cool resources
15 |
16 |
17 |
18 | #### articles
19 |
20 | * **[people cannot distinguish gpt-4 from a human in a turing test, by c. jones et al (2024)](https://arxiv.org/pdf/2405.08007)**
21 |
22 |
23 |
24 | #### code
25 |
26 | * **[awesome chatgpt prompts](https://github.com/f/awesome-chatgpt-prompts)**
27 |
--------------------------------------------------------------------------------
/llms/claude/README.md:
--------------------------------------------------------------------------------
1 | ## claude
2 |
3 |
4 |
5 | ### cool resources
6 |
7 |
8 |
9 | * **[dario amodei on agi & the future of humanity, by lex fridman](https://www.youtube.com/watch?v=ugvHCXCOmm4)**
10 |
--------------------------------------------------------------------------------
/llms/deepseek/README.md:
--------------------------------------------------------------------------------
1 | ## deepseek stuff
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
--------------------------------------------------------------------------------
/llms/eliza/README.md:
--------------------------------------------------------------------------------
1 | ## eliza
--------------------------------------------------------------------------------
/llms/gpt/README.md:
--------------------------------------------------------------------------------
1 | ## gpt
2 |
3 |
4 |
5 | ### cool resources
6 |
7 |
8 |
9 | * **[vscode chatgpt plugin](https://github.com/mpociot/chatgpt-vscode) (and [here](https://marketplace.visualstudio.com/items?itemName=timkmecl.chatgpt))**
10 | * **[scispace extension (paper explainer)](https://chrome.google.com/webstore/detail/scispace-copilot/cipccbpjpemcnijhjcdjmkjhmhniiick/related)**
11 | * **[fix python bugs](https://platform.openai.com/playground/p/default-fix-python-bugs?model=code-davinci-002)**
12 | * **[explain code](https://platform.openai.com/playground/p/default-explain-code?model=code-davinci-002)**
13 | * **[translate code](https://platform.openai.com/playground/p/default-translate-code?model=code-davinci-002)**
14 | * **[translate sql](https://platform.openai.com/playground/p/default-sql-translate?model=code-davinci-002)**
15 | * **[calculate time complexity](https://platform.openai.com/playground/p/default-time-complexity?model=text-davinci-003)**
16 | * **[text to programmatic command](https://platform.openai.com/playground/p/default-text-to-command?model=text-davinci-003)**
17 |
--------------------------------------------------------------------------------